term_guard/analyzers/
runner.rs

1//! Orchestration layer for efficient analyzer execution.
2
3use datafusion::prelude::*;
4use std::sync::Arc;
5use tracing::{debug, error, info, instrument};
6
7use super::{AnalyzerContext, AnalyzerError, AnalyzerResult, MetricValue};
8
9/// Type alias for progress callback function.
10pub type ProgressCallback = Arc<dyn Fn(f64) + Send + Sync>;
11
12/// Type alias for a boxed analyzer execution function.
13pub type AnalyzerExecution = Box<
14    dyn Fn(&SessionContext) -> futures::future::BoxFuture<'_, AnalyzerResult<(String, MetricValue)>>
15        + Send
16        + Sync,
17>;
18
19/// Orchestrates the execution of multiple analyzers on a dataset.
20///
21/// The AnalysisRunner optimizes execution by grouping compatible analyzers
22/// that can share computation, minimizing the number of DataFrame scans required.
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use term_guard::analyzers::{AnalysisRunner, basic::*};
28/// use datafusion::prelude::*;
29///
30/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
31/// let ctx = SessionContext::new();
32/// // Register your data table
33///
34/// let runner = AnalysisRunner::new()
35///     .add(SizeAnalyzer::new())
36///     .add(CompletenessAnalyzer::new("user_id"))
37///     .add(DistinctnessAnalyzer::new("user_id"))
38///     .on_progress(|progress| {
39///         println!("Analysis progress: {:.1}%", progress * 100.0);
40///     });
41///
42/// let context = runner.run(&ctx).await?;
43/// println!("Computed {} metrics", context.all_metrics().len());
44/// # Ok(())
45/// # }
46/// ```
47pub struct AnalysisRunner {
48    /// Analyzer executions to run.
49    executions: Vec<AnalyzerExecution>,
50    /// Names of the analyzers for debugging.
51    analyzer_names: Vec<String>,
52    /// Optional progress callback.
53    on_progress: Option<ProgressCallback>,
54    /// Whether to continue on analyzer failures.
55    continue_on_error: bool,
56}
57
58impl Default for AnalysisRunner {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl AnalysisRunner {
65    /// Creates a new empty AnalysisRunner.
66    pub fn new() -> Self {
67        Self {
68            executions: Vec::new(),
69            analyzer_names: Vec::new(),
70            on_progress: None,
71            continue_on_error: true,
72        }
73    }
74
75    /// Adds an analyzer to the runner.
76    ///
77    /// # Arguments
78    ///
79    /// * `analyzer` - The analyzer to add
80    ///
81    /// # Type Parameters
82    ///
83    /// * `A` - The analyzer type that implements the Analyzer trait
84    #[allow(clippy::should_implement_trait)]
85    pub fn add<A>(mut self, analyzer: A) -> Self
86    where
87        A: crate::analyzers::Analyzer + 'static,
88        A::Metric: Into<MetricValue> + 'static,
89    {
90        use futures::FutureExt;
91
92        let name = analyzer.name().to_string();
93        self.analyzer_names.push(name.clone());
94
95        // Wrap analyzer in Arc to allow sharing
96        let analyzer = Arc::new(analyzer);
97
98        // Create an execution closure that captures the analyzer
99        let execution: AnalyzerExecution = Box::new(move |ctx| {
100            let analyzer = analyzer.clone();
101            async move {
102                // Compute state from data
103                let state = analyzer.compute_state_from_data(ctx).await?;
104
105                // Compute metric from state
106                let metric = analyzer.compute_metric_from_state(&state)?;
107
108                Ok((analyzer.metric_key(), metric.into()))
109            }
110            .boxed()
111        });
112
113        self.executions.push(execution);
114        self
115    }
116
117    /// Sets a progress callback that will be called during execution.
118    ///
119    /// The callback receives a float between 0.0 and 1.0 indicating progress.
120    pub fn on_progress<F>(mut self, callback: F) -> Self
121    where
122        F: Fn(f64) + Send + Sync + 'static,
123    {
124        self.on_progress = Some(Arc::new(callback));
125        self
126    }
127
128    /// Sets whether to continue execution when individual analyzers fail.
129    ///
130    /// Default is true (continue on error).
131    pub fn continue_on_error(mut self, continue_on_error: bool) -> Self {
132        self.continue_on_error = continue_on_error;
133        self
134    }
135
136    /// Executes all analyzers on the given data context.
137    ///
138    /// This method optimizes execution by grouping compatible analyzers
139    /// and executing them together when possible.
140    ///
141    /// # Arguments
142    ///
143    /// * `ctx` - The DataFusion session context with registered data
144    ///
145    /// # Returns
146    ///
147    /// An AnalyzerContext containing all computed metrics and any errors
148    #[instrument(skip(self, ctx), fields(analyzer_count = self.executions.len()))]
149    pub async fn run(&self, ctx: &SessionContext) -> AnalyzerResult<AnalyzerContext> {
150        info!("Starting analysis with {} analyzers", self.executions.len());
151
152        let mut context = AnalyzerContext::new();
153        context.metadata_mut().record_start();
154
155        let total_analyzers = self.executions.len() as f64;
156        let mut completed = 0.0;
157
158        // Execute each analyzer
159        // TODO: In the future, group compatible analyzers for shared execution
160        for (idx, execution) in self.executions.iter().enumerate() {
161            let analyzer_name = &self.analyzer_names[idx];
162            debug!("Executing analyzer: {}", analyzer_name);
163
164            // Execute the analyzer
165            let result = execution(ctx).await;
166
167            match result {
168                Ok((name, metric)) => {
169                    // Store the metric in the context
170                    context.store_metric(&name, metric);
171                    debug!("Stored metric for analyzer: {}", name);
172                }
173                Err(e) => {
174                    error!("Analyzer {} failed: {}", analyzer_name, e);
175                    context.record_error(analyzer_name, e);
176
177                    if !self.continue_on_error {
178                        return Err(AnalyzerError::execution(format!(
179                            "Analyzer {analyzer_name} failed"
180                        )));
181                    }
182                }
183            }
184
185            // Update progress
186            completed += 1.0;
187            if let Some(ref callback) = self.on_progress {
188                callback(completed / total_analyzers);
189            }
190        }
191
192        context.metadata_mut().record_end();
193
194        if let Some(duration) = context.metadata().duration() {
195            info!(
196                "Analysis completed in {:.2}s",
197                duration.num_milliseconds() as f64 / 1000.0
198            );
199        }
200
201        Ok(context)
202    }
203
204    /// Returns the number of analyzers configured.
205    pub fn analyzer_count(&self) -> usize {
206        self.executions.len()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::analyzers::basic::{CompletenessAnalyzer, SizeAnalyzer};
214    use crate::analyzers::MetricValue;
215    use datafusion::arrow::array::{Float64Array, Int64Array};
216    use datafusion::arrow::datatypes::{DataType, Field, Schema};
217    use datafusion::arrow::record_batch::RecordBatch;
218    use std::sync::Arc;
219
220    async fn create_test_context() -> SessionContext {
221        let ctx = SessionContext::new();
222
223        // Create test data
224        let schema = Arc::new(Schema::new(vec![
225            Field::new("id", DataType::Int64, false),
226            Field::new("value", DataType::Float64, true),
227        ]));
228
229        let batch = RecordBatch::try_new(
230            schema,
231            vec![
232                Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
233                Arc::new(Float64Array::from(vec![
234                    Some(10.0),
235                    None,
236                    Some(30.0),
237                    Some(40.0),
238                    Some(50.0),
239                ])),
240            ],
241        )
242        .unwrap();
243
244        ctx.register_batch("data", batch).unwrap();
245        ctx
246    }
247
248    #[tokio::test]
249    async fn test_analysis_runner_basic() {
250        let ctx = create_test_context().await;
251
252        let runner = AnalysisRunner::new().add(SizeAnalyzer::new());
253
254        let context = runner.run(&ctx).await.unwrap();
255
256        // Check that we got the size metric
257        let size_metric = context.get_metric("size").expect("Size metric not found");
258        if let MetricValue::Long(size) = size_metric {
259            assert_eq!(*size, 5);
260        } else {
261            panic!("Expected Long metric for size");
262        }
263    }
264
265    #[tokio::test]
266    async fn test_analysis_runner_multiple_analyzers() {
267        let ctx = create_test_context().await;
268
269        let runner = AnalysisRunner::new()
270            .add(SizeAnalyzer::new())
271            .add(CompletenessAnalyzer::new("value"));
272
273        let context = runner.run(&ctx).await.unwrap();
274
275        // Check size metric
276        let size_metric = context.get_metric("size").expect("Size metric not found");
277        if let MetricValue::Long(size) = size_metric {
278            assert_eq!(*size, 5);
279        }
280
281        // Check completeness metric
282        let completeness_metric = context
283            .get_metric("completeness.value")
284            .expect("Completeness metric not found");
285        if let MetricValue::Double(completeness) = completeness_metric {
286            assert!((completeness - 0.8).abs() < 0.001); // 4/5 = 0.8
287        }
288    }
289
290    #[tokio::test]
291    async fn test_progress_callback() {
292        let ctx = create_test_context().await;
293
294        let progress_values = Arc::new(std::sync::Mutex::new(Vec::new()));
295        let progress_clone = progress_values.clone();
296
297        let runner = AnalysisRunner::new()
298            .add(SizeAnalyzer::new())
299            .add(CompletenessAnalyzer::new("value"))
300            .on_progress(move |progress| {
301                progress_clone.lock().unwrap().push(progress);
302            });
303
304        let _context = runner.run(&ctx).await.unwrap();
305
306        let progress = progress_values.lock().unwrap();
307        assert!(!progress.is_empty());
308        assert_eq!(*progress.last().unwrap(), 1.0);
309    }
310
311    #[tokio::test]
312    async fn test_error_handling() {
313        let ctx = SessionContext::new(); // No data registered
314
315        let runner = AnalysisRunner::new()
316            .add(SizeAnalyzer::new())
317            .continue_on_error(true);
318
319        let context = runner.run(&ctx).await.unwrap();
320
321        // Should have recorded an error
322        assert!(context.has_errors());
323        assert_eq!(context.errors().len(), 1);
324    }
325
326    #[tokio::test]
327    async fn test_fail_fast() {
328        let ctx = SessionContext::new(); // No data registered
329
330        let runner = AnalysisRunner::new()
331            .add(SizeAnalyzer::new())
332            .continue_on_error(false);
333
334        let result = runner.run(&ctx).await;
335
336        // Should fail immediately
337        assert!(result.is_err());
338    }
339
340    #[tokio::test]
341    async fn test_many_analyzers() {
342        use crate::analyzers::basic::*;
343
344        let ctx = create_test_context().await;
345
346        // Add 10+ analyzers
347        let runner = AnalysisRunner::new()
348            .add(SizeAnalyzer::new())
349            .add(CompletenessAnalyzer::new("id"))
350            .add(CompletenessAnalyzer::new("value"))
351            .add(DistinctnessAnalyzer::new("id"))
352            .add(DistinctnessAnalyzer::new("value"))
353            .add(MeanAnalyzer::new("value"))
354            .add(MinAnalyzer::new("value"))
355            .add(MaxAnalyzer::new("value"))
356            .add(SumAnalyzer::new("value"))
357            .add(MinAnalyzer::new("id"))
358            .add(MaxAnalyzer::new("id"))
359            .add(SumAnalyzer::new("id"));
360
361        assert_eq!(runner.analyzer_count(), 12);
362
363        let context = runner.run(&ctx).await.unwrap();
364
365        // Verify we got metrics from all analyzers
366        assert!(context.get_metric("size").is_some());
367        assert!(context.get_metric("completeness.id").is_some());
368        assert!(context.get_metric("completeness.value").is_some());
369        assert!(context.get_metric("distinctness.id").is_some());
370        assert!(context.get_metric("distinctness.value").is_some());
371        assert!(context.get_metric("mean.value").is_some());
372        assert!(context.get_metric("min.value").is_some());
373        assert!(context.get_metric("max.value").is_some());
374        assert!(context.get_metric("sum.value").is_some());
375        assert!(context.get_metric("min.id").is_some());
376        assert!(context.get_metric("max.id").is_some());
377        assert!(context.get_metric("sum.id").is_some());
378
379        // Verify specific metric values
380        if let MetricValue::Long(size) = context.get_metric("size").unwrap() {
381            assert_eq!(*size, 5);
382        }
383
384        if let MetricValue::Double(completeness) = context.get_metric("completeness.id").unwrap() {
385            assert_eq!(*completeness, 1.0); // All IDs are non-null
386        }
387
388        if let MetricValue::Double(completeness) = context.get_metric("completeness.value").unwrap()
389        {
390            assert!((completeness - 0.8).abs() < 0.001); // 4/5 values are non-null
391        }
392
393        // Check that analysis completed without errors
394        assert!(!context.has_errors());
395    }
396}