term_guard/analyzers/
runner.rs1use datafusion::prelude::*;
4use std::sync::Arc;
5use tracing::{debug, error, info, instrument};
6
7use super::{AnalyzerContext, AnalyzerError, AnalyzerResult, MetricValue};
8
9pub type ProgressCallback = Arc<dyn Fn(f64) + Send + Sync>;
11
12pub type AnalyzerExecution = Box<
14 dyn Fn(&SessionContext) -> futures::future::BoxFuture<'_, AnalyzerResult<(String, MetricValue)>>
15 + Send
16 + Sync,
17>;
18
19pub struct AnalysisRunner {
48 executions: Vec<AnalyzerExecution>,
50 analyzer_names: Vec<String>,
52 on_progress: Option<ProgressCallback>,
54 continue_on_error: bool,
56}
57
58impl Default for AnalysisRunner {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl AnalysisRunner {
65 pub fn new() -> Self {
67 Self {
68 executions: Vec::new(),
69 analyzer_names: Vec::new(),
70 on_progress: None,
71 continue_on_error: true,
72 }
73 }
74
75 #[allow(clippy::should_implement_trait)]
85 pub fn add<A>(mut self, analyzer: A) -> Self
86 where
87 A: crate::analyzers::Analyzer + 'static,
88 A::Metric: Into<MetricValue> + 'static,
89 {
90 use futures::FutureExt;
91
92 let name = analyzer.name().to_string();
93 self.analyzer_names.push(name.clone());
94
95 let analyzer = Arc::new(analyzer);
97
98 let execution: AnalyzerExecution = Box::new(move |ctx| {
100 let analyzer = analyzer.clone();
101 async move {
102 let state = analyzer.compute_state_from_data(ctx).await?;
104
105 let metric = analyzer.compute_metric_from_state(&state)?;
107
108 Ok((analyzer.metric_key(), metric.into()))
109 }
110 .boxed()
111 });
112
113 self.executions.push(execution);
114 self
115 }
116
117 pub fn on_progress<F>(mut self, callback: F) -> Self
121 where
122 F: Fn(f64) + Send + Sync + 'static,
123 {
124 self.on_progress = Some(Arc::new(callback));
125 self
126 }
127
128 pub fn continue_on_error(mut self, continue_on_error: bool) -> Self {
132 self.continue_on_error = continue_on_error;
133 self
134 }
135
136 #[instrument(skip(self, ctx), fields(analyzer_count = self.executions.len()))]
149 pub async fn run(&self, ctx: &SessionContext) -> AnalyzerResult<AnalyzerContext> {
150 info!("Starting analysis with {} analyzers", self.executions.len());
151
152 let mut context = AnalyzerContext::new();
153 context.metadata_mut().record_start();
154
155 let total_analyzers = self.executions.len() as f64;
156 let mut completed = 0.0;
157
158 for (idx, execution) in self.executions.iter().enumerate() {
161 let analyzer_name = &self.analyzer_names[idx];
162 debug!("Executing analyzer: {}", analyzer_name);
163
164 let result = execution(ctx).await;
166
167 match result {
168 Ok((name, metric)) => {
169 context.store_metric(&name, metric);
171 debug!("Stored metric for analyzer: {}", name);
172 }
173 Err(e) => {
174 error!("Analyzer {} failed: {}", analyzer_name, e);
175 context.record_error(analyzer_name, e);
176
177 if !self.continue_on_error {
178 return Err(AnalyzerError::execution(format!(
179 "Analyzer {analyzer_name} failed"
180 )));
181 }
182 }
183 }
184
185 completed += 1.0;
187 if let Some(ref callback) = self.on_progress {
188 callback(completed / total_analyzers);
189 }
190 }
191
192 context.metadata_mut().record_end();
193
194 if let Some(duration) = context.metadata().duration() {
195 info!(
196 "Analysis completed in {:.2}s",
197 duration.num_milliseconds() as f64 / 1000.0
198 );
199 }
200
201 Ok(context)
202 }
203
204 pub fn analyzer_count(&self) -> usize {
206 self.executions.len()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::analyzers::basic::{CompletenessAnalyzer, SizeAnalyzer};
214 use crate::analyzers::MetricValue;
215 use datafusion::arrow::array::{Float64Array, Int64Array};
216 use datafusion::arrow::datatypes::{DataType, Field, Schema};
217 use datafusion::arrow::record_batch::RecordBatch;
218 use std::sync::Arc;
219
220 async fn create_test_context() -> SessionContext {
221 let ctx = SessionContext::new();
222
223 let schema = Arc::new(Schema::new(vec![
225 Field::new("id", DataType::Int64, false),
226 Field::new("value", DataType::Float64, true),
227 ]));
228
229 let batch = RecordBatch::try_new(
230 schema,
231 vec![
232 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
233 Arc::new(Float64Array::from(vec![
234 Some(10.0),
235 None,
236 Some(30.0),
237 Some(40.0),
238 Some(50.0),
239 ])),
240 ],
241 )
242 .unwrap();
243
244 ctx.register_batch("data", batch).unwrap();
245 ctx
246 }
247
248 #[tokio::test]
249 async fn test_analysis_runner_basic() {
250 let ctx = create_test_context().await;
251
252 let runner = AnalysisRunner::new().add(SizeAnalyzer::new());
253
254 let context = runner.run(&ctx).await.unwrap();
255
256 let size_metric = context.get_metric("size").expect("Size metric not found");
258 if let MetricValue::Long(size) = size_metric {
259 assert_eq!(*size, 5);
260 } else {
261 panic!("Expected Long metric for size");
262 }
263 }
264
265 #[tokio::test]
266 async fn test_analysis_runner_multiple_analyzers() {
267 let ctx = create_test_context().await;
268
269 let runner = AnalysisRunner::new()
270 .add(SizeAnalyzer::new())
271 .add(CompletenessAnalyzer::new("value"));
272
273 let context = runner.run(&ctx).await.unwrap();
274
275 let size_metric = context.get_metric("size").expect("Size metric not found");
277 if let MetricValue::Long(size) = size_metric {
278 assert_eq!(*size, 5);
279 }
280
281 let completeness_metric = context
283 .get_metric("completeness.value")
284 .expect("Completeness metric not found");
285 if let MetricValue::Double(completeness) = completeness_metric {
286 assert!((completeness - 0.8).abs() < 0.001); }
288 }
289
290 #[tokio::test]
291 async fn test_progress_callback() {
292 let ctx = create_test_context().await;
293
294 let progress_values = Arc::new(std::sync::Mutex::new(Vec::new()));
295 let progress_clone = progress_values.clone();
296
297 let runner = AnalysisRunner::new()
298 .add(SizeAnalyzer::new())
299 .add(CompletenessAnalyzer::new("value"))
300 .on_progress(move |progress| {
301 progress_clone.lock().unwrap().push(progress);
302 });
303
304 let _context = runner.run(&ctx).await.unwrap();
305
306 let progress = progress_values.lock().unwrap();
307 assert!(!progress.is_empty());
308 assert_eq!(*progress.last().unwrap(), 1.0);
309 }
310
311 #[tokio::test]
312 async fn test_error_handling() {
313 let ctx = SessionContext::new(); let runner = AnalysisRunner::new()
316 .add(SizeAnalyzer::new())
317 .continue_on_error(true);
318
319 let context = runner.run(&ctx).await.unwrap();
320
321 assert!(context.has_errors());
323 assert_eq!(context.errors().len(), 1);
324 }
325
326 #[tokio::test]
327 async fn test_fail_fast() {
328 let ctx = SessionContext::new(); let runner = AnalysisRunner::new()
331 .add(SizeAnalyzer::new())
332 .continue_on_error(false);
333
334 let result = runner.run(&ctx).await;
335
336 assert!(result.is_err());
338 }
339
340 #[tokio::test]
341 async fn test_many_analyzers() {
342 use crate::analyzers::basic::*;
343
344 let ctx = create_test_context().await;
345
346 let runner = AnalysisRunner::new()
348 .add(SizeAnalyzer::new())
349 .add(CompletenessAnalyzer::new("id"))
350 .add(CompletenessAnalyzer::new("value"))
351 .add(DistinctnessAnalyzer::new("id"))
352 .add(DistinctnessAnalyzer::new("value"))
353 .add(MeanAnalyzer::new("value"))
354 .add(MinAnalyzer::new("value"))
355 .add(MaxAnalyzer::new("value"))
356 .add(SumAnalyzer::new("value"))
357 .add(MinAnalyzer::new("id"))
358 .add(MaxAnalyzer::new("id"))
359 .add(SumAnalyzer::new("id"));
360
361 assert_eq!(runner.analyzer_count(), 12);
362
363 let context = runner.run(&ctx).await.unwrap();
364
365 assert!(context.get_metric("size").is_some());
367 assert!(context.get_metric("completeness.id").is_some());
368 assert!(context.get_metric("completeness.value").is_some());
369 assert!(context.get_metric("distinctness.id").is_some());
370 assert!(context.get_metric("distinctness.value").is_some());
371 assert!(context.get_metric("mean.value").is_some());
372 assert!(context.get_metric("min.value").is_some());
373 assert!(context.get_metric("max.value").is_some());
374 assert!(context.get_metric("sum.value").is_some());
375 assert!(context.get_metric("min.id").is_some());
376 assert!(context.get_metric("max.id").is_some());
377 assert!(context.get_metric("sum.id").is_some());
378
379 if let MetricValue::Long(size) = context.get_metric("size").unwrap() {
381 assert_eq!(*size, 5);
382 }
383
384 if let MetricValue::Double(completeness) = context.get_metric("completeness.id").unwrap() {
385 assert_eq!(*completeness, 1.0); }
387
388 if let MetricValue::Double(completeness) = context.get_metric("completeness.value").unwrap()
389 {
390 assert!((completeness - 0.8).abs() < 0.001); }
392
393 assert!(!context.has_errors());
395 }
396}