use arrow::array::Array;
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::analyzers::{Analyzer, AnalyzerError, AnalyzerResult, AnalyzerState, MetricValue};
use crate::core::current_validation_context;
#[derive(Debug, Clone)]
pub struct StandardDeviationAnalyzer {
column: String,
}
impl StandardDeviationAnalyzer {
pub fn new(column: impl Into<String>) -> Self {
Self {
column: column.into(),
}
}
pub fn column(&self) -> &str {
&self.column
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardDeviationState {
pub count: u64,
pub sum: f64,
pub sum_squared: f64,
pub mean: f64,
}
impl StandardDeviationState {
pub fn population_std_dev(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
let variance = self.population_variance()?;
Some(variance.sqrt())
}
}
pub fn sample_std_dev(&self) -> Option<f64> {
if self.count <= 1 {
None
} else {
let variance = self.sample_variance()?;
Some(variance.sqrt())
}
}
pub fn population_variance(&self) -> Option<f64> {
if self.count == 0 {
None
} else {
let mean_of_squares = self.sum_squared / self.count as f64;
let variance = mean_of_squares - (self.mean * self.mean);
Some(variance.max(0.0))
}
}
pub fn sample_variance(&self) -> Option<f64> {
if self.count <= 1 {
None
} else {
let sum_of_squared_deviations =
self.sum_squared - (self.sum * self.sum / self.count as f64);
let variance = sum_of_squared_deviations / (self.count - 1) as f64;
Some(variance.max(0.0))
}
}
pub fn coefficient_of_variation(&self) -> Option<f64> {
let std_dev = self.population_std_dev()?;
if self.mean.abs() < f64::EPSILON {
None
} else {
Some(std_dev / self.mean.abs())
}
}
}
impl AnalyzerState for StandardDeviationState {
fn merge(states: Vec<Self>) -> AnalyzerResult<Self> {
if states.is_empty() {
return Err(AnalyzerError::state_merge("No states to merge"));
}
let count: u64 = states.iter().map(|s| s.count).sum();
let sum: f64 = states.iter().map(|s| s.sum).sum();
let sum_squared: f64 = states.iter().map(|s| s.sum_squared).sum();
let mean = if count > 0 { sum / count as f64 } else { 0.0 };
Ok(StandardDeviationState {
count,
sum,
sum_squared,
mean,
})
}
fn is_empty(&self) -> bool {
self.count == 0
}
}
#[async_trait]
impl Analyzer for StandardDeviationAnalyzer {
type State = StandardDeviationState;
type Metric = MetricValue;
#[instrument(skip(ctx), fields(analyzer = "standard_deviation", column = %self.column))]
async fn compute_state_from_data(&self, ctx: &SessionContext) -> AnalyzerResult<Self::State> {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!(
"SELECT
COUNT({0}) as count,
AVG({0}) as mean,
SUM({0}) as sum,
SUM({0} * {0}) as sum_squared
FROM {table_name}
WHERE {0} IS NOT NULL",
self.column
);
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
let (count, mean, sum, sum_squared) = if let Some(batch) = batches.first() {
if batch.num_rows() > 0 && !batch.column(0).is_null(0) {
let count_array = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| AnalyzerError::invalid_data("Expected Int64 for count"))?;
let count = count_array.value(0) as u64;
if count == 0 {
(0, 0.0, 0.0, 0.0)
} else {
let mean_array = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for mean"))?;
let mean = mean_array.value(0);
let sum_array = batch
.column(2)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.ok_or_else(|| AnalyzerError::invalid_data("Expected Float64 for sum"))?;
let sum = sum_array.value(0);
let sum_squared_array = batch
.column(3)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.ok_or_else(|| {
AnalyzerError::invalid_data("Expected Float64 for sum_squared")
})?;
let sum_squared = sum_squared_array.value(0);
(count, mean, sum, sum_squared)
}
} else {
(0, 0.0, 0.0, 0.0)
}
} else {
return Err(AnalyzerError::NoData);
};
Ok(StandardDeviationState {
count,
sum,
sum_squared,
mean,
})
}
fn compute_metric_from_state(&self, state: &Self::State) -> AnalyzerResult<Self::Metric> {
use std::collections::HashMap;
let mut stats = HashMap::new();
stats.insert("count".to_string(), MetricValue::Long(state.count as i64));
stats.insert("mean".to_string(), MetricValue::Double(state.mean));
if let Some(pop_std_dev) = state.population_std_dev() {
stats.insert("std_dev".to_string(), MetricValue::Double(pop_std_dev));
}
if let Some(sample_std_dev) = state.sample_std_dev() {
stats.insert(
"sample_std_dev".to_string(),
MetricValue::Double(sample_std_dev),
);
}
if let Some(pop_variance) = state.population_variance() {
stats.insert("variance".to_string(), MetricValue::Double(pop_variance));
}
if let Some(sample_variance) = state.sample_variance() {
stats.insert(
"sample_variance".to_string(),
MetricValue::Double(sample_variance),
);
}
if let Some(cv) = state.coefficient_of_variation() {
stats.insert(
"coefficient_of_variation".to_string(),
MetricValue::Double(cv),
);
}
Ok(MetricValue::Map(stats))
}
fn name(&self) -> &str {
"standard_deviation"
}
fn description(&self) -> &str {
"Computes standard deviation and variance metrics"
}
fn columns(&self) -> Vec<&str> {
vec![&self.column]
}
}