aqueducts_core/
lib.rs

1use std::{collections::HashMap, sync::Arc, time::Instant};
2
3use aqueducts_schemas::{Aqueduct, ProgressEvent, Stage};
4use datafusion::execution::context::SessionContext;
5use regex::Regex;
6use tokio::task::JoinHandle;
7use tracing::{debug, instrument, warn};
8
9#[cfg(feature = "custom_udfs")]
10pub mod custom_udfs;
11pub mod error;
12pub mod progress_tracker;
13pub mod templating;
14
15mod destinations;
16mod schema_transform;
17mod sources;
18mod stages;
19mod store;
20
21use destinations::{register_destination, write_to_destination};
22use progress_tracker::*;
23use sources::register_source;
24use stages::process_stage;
25
26/// Execute an Aqueducts data pipeline.
27///
28/// This is the main entry point for running data pipelines defined in aqueduct files.
29/// The pipeline will execute all sources, stages, and destinations in sequential order
30///
31/// # Arguments
32///
33/// * `ctx` - A DataFusion SessionContext for SQL execution
34/// * `aqueduct` - The pipeline configuration loaded from a file
35/// * `progress_tracker` - Optional tracker for monitoring execution progress
36///
37/// # Returns
38///
39/// Returns the SessionContext after successful execution, which can be used
40/// for further operations or inspection of registered tables.
41///
42/// # Example
43///
44/// ```rust,no_run
45/// use aqueducts_core::{run_pipeline, progress_tracker::LoggingProgressTracker, templating::TemplateLoader};
46/// use aqueducts_schemas::Aqueduct;
47/// use datafusion::prelude::SessionContext;
48/// use std::sync::Arc;
49///
50/// #[tokio::main]
51/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
52///     // Load pipeline configuration
53///     let pipeline = Aqueduct::from_file("pipeline.yml", Default::default())?;
54///     
55///     // Create DataFusion context
56///     let ctx = Arc::new(SessionContext::new());
57///     
58///     // Create progress tracker
59///     let tracker = Arc::new(LoggingProgressTracker);
60///     
61///     // Execute pipeline
62///     let result_ctx = run_pipeline(ctx, pipeline, Some(tracker)).await?;
63///     
64///     Ok(())
65/// }
66/// ```
67#[instrument(skip_all, err)]
68pub async fn run_pipeline(
69    ctx: Arc<SessionContext>,
70    aqueduct: Aqueduct,
71    progress_tracker: Option<Arc<dyn ProgressTracker>>,
72) -> error::Result<Arc<SessionContext>> {
73    let mut stage_ttls: HashMap<String, usize> = HashMap::new();
74    let start_time = Instant::now();
75
76    debug!("Running Aqueduct ...");
77
78    if let Some(tracker) = &progress_tracker {
79        tracker.on_progress(ProgressEvent::Started);
80    }
81
82    if let Some(destination) = &aqueduct.destination {
83        let time = Instant::now();
84
85        register_destination(ctx.clone(), destination).await?;
86
87        debug!(
88            "Created destination ... Elapsed time: {:.2?}",
89            time.elapsed()
90        );
91    }
92
93    let handles = aqueduct
94        .sources
95        .iter()
96        .map(|source| {
97            let time = Instant::now();
98            let source_ = source.clone();
99            let ctx_ = ctx.clone();
100            let source_name = source.name();
101
102            let handle = tokio::spawn(async move {
103                register_source(ctx_, source_).await?;
104
105                Ok(())
106            });
107
108            (source_name, time, handle)
109        })
110        .collect::<Vec<(String, Instant, JoinHandle<error::Result<()>>)>>();
111
112    for (source_name, time, handle) in handles {
113        handle.await.expect("failed to join task")?;
114
115        debug!(
116            "Registered source {source_name} ... Elapsed time: {:.2?}",
117            time.elapsed()
118        );
119
120        if let Some(tracker) = &progress_tracker {
121            tracker.on_progress(ProgressEvent::SourceRegistered { name: source_name });
122        }
123    }
124
125    for (pos, parallel) in aqueduct.stages.iter().enumerate() {
126        let mut handles: Vec<JoinHandle<error::Result<()>>> = Vec::new();
127
128        for (sub, stage) in parallel.iter().enumerate() {
129            let stage_ = stage.clone();
130            let ctx_ = ctx.clone();
131            let name = stage.name.clone();
132            let tracker = progress_tracker.clone();
133
134            let handle = tokio::spawn(async move {
135                let time = Instant::now();
136                debug!("Running stage {} #{pos}:{sub}", name);
137
138                if let Some(tracker_ref) = &tracker {
139                    tracker_ref.on_progress(ProgressEvent::StageStarted {
140                        name: name.clone(),
141                        position: pos,
142                        sub_position: sub,
143                    });
144                }
145
146                process_stage(ctx_, stage_, tracker.clone()).await?;
147
148                let elapsed = time.elapsed();
149                debug!(
150                    "Finished processing stage {name} #{pos}:{sub} ... Elapsed time: {:.2?}",
151                    elapsed
152                );
153
154                if let Some(tracker) = &tracker {
155                    tracker.on_progress(ProgressEvent::StageCompleted {
156                        name: name.clone(),
157                        position: pos,
158                        sub_position: sub,
159                        duration_ms: elapsed.as_millis() as u64,
160                    });
161                }
162
163                Ok(())
164            });
165
166            calculate_ttl(&mut stage_ttls, stage.name.as_str(), pos, &aqueduct.stages)?;
167            handles.push(handle);
168        }
169
170        for handle in handles {
171            handle.await.expect("failed to join task")?;
172        }
173
174        deregister_stages(ctx.clone(), &stage_ttls, pos)?;
175    }
176
177    if let (Some(last_stage), Some(destination)) = (
178        aqueduct.stages.last().and_then(|s| s.last()),
179        &aqueduct.destination,
180    ) {
181        let time = Instant::now();
182
183        let df = ctx.table(last_stage.name.as_str()).await?;
184        write_to_destination(ctx.clone(), destination, df).await?;
185
186        ctx.deregister_table(last_stage.name.as_str())?;
187
188        let elapsed = time.elapsed();
189        debug!(
190            "Finished writing to destination ... Elapsed time: {:.2?}",
191            elapsed
192        );
193
194        // Emit destination completed event
195        if let Some(tracker) = &progress_tracker {
196            tracker.on_progress(ProgressEvent::DestinationCompleted);
197        }
198    } else {
199        warn!("No destination defined ... skipping write");
200    }
201
202    let total_duration = start_time.elapsed();
203    debug!(
204        "Finished processing pipeline ... Total time: {:.2?}",
205        total_duration
206    );
207
208    // Emit completed event
209    if let Some(tracker) = &progress_tracker {
210        tracker.on_progress(ProgressEvent::Completed {
211            duration_ms: total_duration.as_millis() as u64,
212        });
213    }
214
215    Ok(ctx)
216}
217
218// calculate time to live for a stage based on the position of the stage
219fn calculate_ttl<'a>(
220    stage_ttls: &'a mut HashMap<String, usize>,
221    stage_name: &'a str,
222    stage_pos: usize,
223    stages: &[Vec<Stage>],
224) -> error::Result<()> {
225    let stage_name_r = format!("\\s{stage_name}(\\s|\\;|\\n|\\)|\\.|$)");
226    let regex = Regex::new(stage_name_r.as_str())?;
227
228    let ttl = stages
229        .iter()
230        .enumerate()
231        .skip(stage_pos + 1)
232        .flat_map(|(forward_pos, parallel)| parallel.iter().map(move |stage| (forward_pos, stage)))
233        .filter_map(|(forward_pos, stage)| {
234            if regex.is_match(stage.query.as_str()) {
235                debug!("Registering TTL for {stage_name}. STAGE_POS={stage_pos} TTL={forward_pos}");
236                Some(forward_pos)
237            } else {
238                None
239            }
240        })
241        .next_back()
242        .unwrap_or(stage_pos + 1);
243
244    stage_ttls
245        .entry(stage_name.to_string())
246        .and_modify(|e| *e = ttl)
247        .or_insert_with(|| ttl);
248
249    Ok(())
250}
251
252// deregister stages from context if the current position matches the ttl of the stages
253fn deregister_stages(
254    ctx: Arc<SessionContext>,
255    ttls: &HashMap<String, usize>,
256    current_pos: usize,
257) -> error::Result<()> {
258    ttls.iter().try_for_each(|(table, ttl)| {
259        if *ttl == current_pos {
260            debug!("Deregistering table {table}, current_pos {current_pos}, ttl {ttl}");
261            ctx.deregister_table(table).map(|_| ())
262        } else {
263            Ok(())
264        }
265    })?;
266
267    Ok(())
268}