use datafusion::prelude::*;
use std::sync::Arc;
use tracing::{debug, error, info, instrument};
use super::{AnalyzerContext, AnalyzerError, AnalyzerResult, MetricValue};
pub type ProgressCallback = Arc<dyn Fn(f64) + Send + Sync>;
pub type AnalyzerExecution = Box<
dyn Fn(&SessionContext) -> futures::future::BoxFuture<'_, AnalyzerResult<(String, MetricValue)>>
+ Send
+ Sync,
>;
pub struct AnalysisRunner {
executions: Vec<AnalyzerExecution>,
analyzer_names: Vec<String>,
on_progress: Option<ProgressCallback>,
continue_on_error: bool,
}
impl Default for AnalysisRunner {
fn default() -> Self {
Self::new()
}
}
impl AnalysisRunner {
pub fn new() -> Self {
Self {
executions: Vec::new(),
analyzer_names: Vec::new(),
on_progress: None,
continue_on_error: true,
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<A>(mut self, analyzer: A) -> Self
where
A: crate::analyzers::Analyzer + 'static,
A::Metric: Into<MetricValue> + 'static,
{
use futures::FutureExt;
let name = analyzer.name().to_string();
self.analyzer_names.push(name.clone());
let analyzer = Arc::new(analyzer);
let execution: AnalyzerExecution = Box::new(move |ctx| {
let analyzer = analyzer.clone();
async move {
let state = analyzer.compute_state_from_data(ctx).await?;
let metric = analyzer.compute_metric_from_state(&state)?;
Ok((analyzer.metric_key(), metric.into()))
}
.boxed()
});
self.executions.push(execution);
self
}
pub fn on_progress<F>(mut self, callback: F) -> Self
where
F: Fn(f64) + Send + Sync + 'static,
{
self.on_progress = Some(Arc::new(callback));
self
}
pub fn continue_on_error(mut self, continue_on_error: bool) -> Self {
self.continue_on_error = continue_on_error;
self
}
#[instrument(skip(self, ctx), fields(analyzer_count = self.executions.len()))]
pub async fn run(&self, ctx: &SessionContext) -> AnalyzerResult<AnalyzerContext> {
info!("Starting analysis with {} analyzers", self.executions.len());
let mut context = AnalyzerContext::new();
context.metadata_mut().record_start();
let total_analyzers = self.executions.len() as f64;
let mut completed = 0.0;
for (idx, execution) in self.executions.iter().enumerate() {
let analyzer_name = &self.analyzer_names[idx];
debug!("Executing analyzer: {}", analyzer_name);
let result = execution(ctx).await;
match result {
Ok((name, metric)) => {
context.store_metric(&name, metric);
debug!("Stored metric for analyzer: {}", name);
}
Err(e) => {
error!("Analyzer {} failed: {}", analyzer_name, e);
context.record_error(analyzer_name, e);
if !self.continue_on_error {
return Err(AnalyzerError::execution(format!(
"Analyzer {analyzer_name} failed"
)));
}
}
}
completed += 1.0;
if let Some(ref callback) = self.on_progress {
callback(completed / total_analyzers);
}
}
context.metadata_mut().record_end();
if let Some(duration) = context.metadata().duration() {
info!(
"Analysis completed in {:.2}s",
duration.num_milliseconds() as f64 / 1000.0
);
}
Ok(context)
}
pub fn analyzer_count(&self) -> usize {
self.executions.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyzers::basic::{CompletenessAnalyzer, SizeAnalyzer};
use crate::analyzers::MetricValue;
use datafusion::arrow::array::{Float64Array, Int64Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use std::sync::Arc;
async fn create_test_context() -> SessionContext {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("value", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(Float64Array::from(vec![
Some(10.0),
None,
Some(30.0),
Some(40.0),
Some(50.0),
])),
],
)
.unwrap();
ctx.register_batch("data", batch).unwrap();
ctx
}
#[tokio::test]
async fn test_analysis_runner_basic() {
let ctx = create_test_context().await;
let runner = AnalysisRunner::new().add(SizeAnalyzer::new());
let context = runner.run(&ctx).await.unwrap();
let size_metric = context.get_metric("size").expect("Size metric not found");
if let MetricValue::Long(size) = size_metric {
assert_eq!(*size, 5);
} else {
panic!("Expected Long metric for size");
}
}
#[tokio::test]
async fn test_analysis_runner_multiple_analyzers() {
let ctx = create_test_context().await;
let runner = AnalysisRunner::new()
.add(SizeAnalyzer::new())
.add(CompletenessAnalyzer::new("value"));
let context = runner.run(&ctx).await.unwrap();
let size_metric = context.get_metric("size").expect("Size metric not found");
if let MetricValue::Long(size) = size_metric {
assert_eq!(*size, 5);
}
let completeness_metric = context
.get_metric("completeness.value")
.expect("Completeness metric not found");
if let MetricValue::Double(completeness) = completeness_metric {
assert!((completeness - 0.8).abs() < 0.001); }
}
#[tokio::test]
async fn test_progress_callback() {
let ctx = create_test_context().await;
let progress_values = Arc::new(std::sync::Mutex::new(Vec::new()));
let progress_clone = progress_values.clone();
let runner = AnalysisRunner::new()
.add(SizeAnalyzer::new())
.add(CompletenessAnalyzer::new("value"))
.on_progress(move |progress| {
progress_clone.lock().unwrap().push(progress);
});
let _context = runner.run(&ctx).await.unwrap();
let progress = progress_values.lock().unwrap();
assert!(!progress.is_empty());
assert_eq!(*progress.last().unwrap(), 1.0);
}
#[tokio::test]
async fn test_error_handling() {
let ctx = SessionContext::new();
let runner = AnalysisRunner::new()
.add(SizeAnalyzer::new())
.continue_on_error(true);
let context = runner.run(&ctx).await.unwrap();
assert!(context.has_errors());
assert_eq!(context.errors().len(), 1);
}
#[tokio::test]
async fn test_fail_fast() {
let ctx = SessionContext::new();
let runner = AnalysisRunner::new()
.add(SizeAnalyzer::new())
.continue_on_error(false);
let result = runner.run(&ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_many_analyzers() {
use crate::analyzers::basic::*;
let ctx = create_test_context().await;
let runner = AnalysisRunner::new()
.add(SizeAnalyzer::new())
.add(CompletenessAnalyzer::new("id"))
.add(CompletenessAnalyzer::new("value"))
.add(DistinctnessAnalyzer::new("id"))
.add(DistinctnessAnalyzer::new("value"))
.add(MeanAnalyzer::new("value"))
.add(MinAnalyzer::new("value"))
.add(MaxAnalyzer::new("value"))
.add(SumAnalyzer::new("value"))
.add(MinAnalyzer::new("id"))
.add(MaxAnalyzer::new("id"))
.add(SumAnalyzer::new("id"));
assert_eq!(runner.analyzer_count(), 12);
let context = runner.run(&ctx).await.unwrap();
assert!(context.get_metric("size").is_some());
assert!(context.get_metric("completeness.id").is_some());
assert!(context.get_metric("completeness.value").is_some());
assert!(context.get_metric("distinctness.id").is_some());
assert!(context.get_metric("distinctness.value").is_some());
assert!(context.get_metric("mean.value").is_some());
assert!(context.get_metric("min.value").is_some());
assert!(context.get_metric("max.value").is_some());
assert!(context.get_metric("sum.value").is_some());
assert!(context.get_metric("min.id").is_some());
assert!(context.get_metric("max.id").is_some());
assert!(context.get_metric("sum.id").is_some());
if let MetricValue::Long(size) = context.get_metric("size").unwrap() {
assert_eq!(*size, 5);
}
if let MetricValue::Double(completeness) = context.get_metric("completeness.id").unwrap() {
assert_eq!(*completeness, 1.0); }
if let MetricValue::Double(completeness) = context.get_metric("completeness.value").unwrap()
{
assert!((completeness - 0.8).abs() < 0.001); }
assert!(!context.has_errors());
}
}