1use std::{collections::HashMap, sync::Arc, time::Instant};
2
3use aqueducts_schemas::{Aqueduct, ProgressEvent, Stage};
4use datafusion::execution::context::SessionContext;
5use regex::Regex;
6use tokio::task::JoinHandle;
7use tracing::{debug, instrument, warn};
8
9#[cfg(feature = "custom_udfs")]
10pub mod custom_udfs;
11pub mod error;
12pub mod progress_tracker;
13pub mod templating;
14
15mod destinations;
16mod schema_transform;
17mod sources;
18mod stages;
19mod store;
20
21use destinations::{register_destination, write_to_destination};
22use progress_tracker::*;
23use sources::register_source;
24use stages::process_stage;
25
26#[instrument(skip_all, err)]
68pub async fn run_pipeline(
69 ctx: Arc<SessionContext>,
70 aqueduct: Aqueduct,
71 progress_tracker: Option<Arc<dyn ProgressTracker>>,
72) -> error::Result<Arc<SessionContext>> {
73 let mut stage_ttls: HashMap<String, usize> = HashMap::new();
74 let start_time = Instant::now();
75
76 debug!("Running Aqueduct ...");
77
78 if let Some(tracker) = &progress_tracker {
79 tracker.on_progress(ProgressEvent::Started);
80 }
81
82 if let Some(destination) = &aqueduct.destination {
83 let time = Instant::now();
84
85 register_destination(ctx.clone(), destination).await?;
86
87 debug!(
88 "Created destination ... Elapsed time: {:.2?}",
89 time.elapsed()
90 );
91 }
92
93 let handles = aqueduct
94 .sources
95 .iter()
96 .map(|source| {
97 let time = Instant::now();
98 let source_ = source.clone();
99 let ctx_ = ctx.clone();
100 let source_name = source.name();
101
102 let handle = tokio::spawn(async move {
103 register_source(ctx_, source_).await?;
104
105 Ok(())
106 });
107
108 (source_name, time, handle)
109 })
110 .collect::<Vec<(String, Instant, JoinHandle<error::Result<()>>)>>();
111
112 for (source_name, time, handle) in handles {
113 handle.await.expect("failed to join task")?;
114
115 debug!(
116 "Registered source {source_name} ... Elapsed time: {:.2?}",
117 time.elapsed()
118 );
119
120 if let Some(tracker) = &progress_tracker {
121 tracker.on_progress(ProgressEvent::SourceRegistered { name: source_name });
122 }
123 }
124
125 for (pos, parallel) in aqueduct.stages.iter().enumerate() {
126 let mut handles: Vec<JoinHandle<error::Result<()>>> = Vec::new();
127
128 for (sub, stage) in parallel.iter().enumerate() {
129 let stage_ = stage.clone();
130 let ctx_ = ctx.clone();
131 let name = stage.name.clone();
132 let tracker = progress_tracker.clone();
133
134 let handle = tokio::spawn(async move {
135 let time = Instant::now();
136 debug!("Running stage {} #{pos}:{sub}", name);
137
138 if let Some(tracker_ref) = &tracker {
139 tracker_ref.on_progress(ProgressEvent::StageStarted {
140 name: name.clone(),
141 position: pos,
142 sub_position: sub,
143 });
144 }
145
146 process_stage(ctx_, stage_, tracker.clone()).await?;
147
148 let elapsed = time.elapsed();
149 debug!(
150 "Finished processing stage {name} #{pos}:{sub} ... Elapsed time: {:.2?}",
151 elapsed
152 );
153
154 if let Some(tracker) = &tracker {
155 tracker.on_progress(ProgressEvent::StageCompleted {
156 name: name.clone(),
157 position: pos,
158 sub_position: sub,
159 duration_ms: elapsed.as_millis() as u64,
160 });
161 }
162
163 Ok(())
164 });
165
166 calculate_ttl(&mut stage_ttls, stage.name.as_str(), pos, &aqueduct.stages)?;
167 handles.push(handle);
168 }
169
170 for handle in handles {
171 handle.await.expect("failed to join task")?;
172 }
173
174 deregister_stages(ctx.clone(), &stage_ttls, pos)?;
175 }
176
177 if let (Some(last_stage), Some(destination)) = (
178 aqueduct.stages.last().and_then(|s| s.last()),
179 &aqueduct.destination,
180 ) {
181 let time = Instant::now();
182
183 let df = ctx.table(last_stage.name.as_str()).await?;
184 write_to_destination(ctx.clone(), destination, df).await?;
185
186 ctx.deregister_table(last_stage.name.as_str())?;
187
188 let elapsed = time.elapsed();
189 debug!(
190 "Finished writing to destination ... Elapsed time: {:.2?}",
191 elapsed
192 );
193
194 if let Some(tracker) = &progress_tracker {
196 tracker.on_progress(ProgressEvent::DestinationCompleted);
197 }
198 } else {
199 warn!("No destination defined ... skipping write");
200 }
201
202 let total_duration = start_time.elapsed();
203 debug!(
204 "Finished processing pipeline ... Total time: {:.2?}",
205 total_duration
206 );
207
208 if let Some(tracker) = &progress_tracker {
210 tracker.on_progress(ProgressEvent::Completed {
211 duration_ms: total_duration.as_millis() as u64,
212 });
213 }
214
215 Ok(ctx)
216}
217
218fn calculate_ttl<'a>(
220 stage_ttls: &'a mut HashMap<String, usize>,
221 stage_name: &'a str,
222 stage_pos: usize,
223 stages: &[Vec<Stage>],
224) -> error::Result<()> {
225 let stage_name_r = format!("\\s{stage_name}(\\s|\\;|\\n|\\)|\\.|$)");
226 let regex = Regex::new(stage_name_r.as_str())?;
227
228 let ttl = stages
229 .iter()
230 .enumerate()
231 .skip(stage_pos + 1)
232 .flat_map(|(forward_pos, parallel)| parallel.iter().map(move |stage| (forward_pos, stage)))
233 .filter_map(|(forward_pos, stage)| {
234 if regex.is_match(stage.query.as_str()) {
235 debug!("Registering TTL for {stage_name}. STAGE_POS={stage_pos} TTL={forward_pos}");
236 Some(forward_pos)
237 } else {
238 None
239 }
240 })
241 .next_back()
242 .unwrap_or(stage_pos + 1);
243
244 stage_ttls
245 .entry(stage_name.to_string())
246 .and_modify(|e| *e = ttl)
247 .or_insert_with(|| ttl);
248
249 Ok(())
250}
251
252fn deregister_stages(
254 ctx: Arc<SessionContext>,
255 ttls: &HashMap<String, usize>,
256 current_pos: usize,
257) -> error::Result<()> {
258 ttls.iter().try_for_each(|(table, ttl)| {
259 if *ttl == current_pos {
260 debug!("Deregistering table {table}, current_pos {current_pos}, ttl {ttl}");
261 ctx.deregister_table(table).map(|_| ())
262 } else {
263 Ok(())
264 }
265 })?;
266
267 Ok(())
268}