use crate::constraints::Assertion;
use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
use crate::prelude::*;
use crate::security::SqlSecurity;
use arrow::array::Array;
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use std::fmt;
use tracing::instrument;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum StatisticType {
Min,
Max,
Mean,
Sum,
StandardDeviation,
Variance,
Median,
Percentile(f64),
}
impl StatisticType {
fn sql_function(&self) -> String {
match self {
StatisticType::Min => "MIN".to_string(),
StatisticType::Max => "MAX".to_string(),
StatisticType::Mean => "AVG".to_string(),
StatisticType::Sum => "SUM".to_string(),
StatisticType::StandardDeviation => "STDDEV".to_string(),
StatisticType::Variance => "VARIANCE".to_string(),
StatisticType::Median => "APPROX_PERCENTILE_CONT".to_string(),
StatisticType::Percentile(_) => "APPROX_PERCENTILE_CONT".to_string(),
}
}
fn sql_expression(&self, column: &str) -> String {
match self {
StatisticType::Median => {
let func = self.sql_function();
format!("{func}({column}, 0.5)")
}
StatisticType::Percentile(p) => {
let func = self.sql_function();
format!("{func}({column}, {p})")
}
_ => {
let func = self.sql_function();
format!("{func}({column})")
}
}
}
fn name(&self) -> &str {
match self {
StatisticType::Min => "minimum",
StatisticType::Max => "maximum",
StatisticType::Mean => "mean",
StatisticType::Sum => "sum",
StatisticType::StandardDeviation => "standard deviation",
StatisticType::Variance => "variance",
StatisticType::Median => "median",
StatisticType::Percentile(p) => {
if (*p - 0.5).abs() < f64::EPSILON {
"median"
} else {
"percentile"
}
}
}
}
fn constraint_name(&self) -> &str {
match self {
StatisticType::Min => "min",
StatisticType::Max => "max",
StatisticType::Mean => "mean",
StatisticType::Sum => "sum",
StatisticType::StandardDeviation => "standard_deviation",
StatisticType::Variance => "variance",
StatisticType::Median => "median",
StatisticType::Percentile(_) => "percentile",
}
}
}
impl fmt::Display for StatisticType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StatisticType::Percentile(p) => write!(f, "{}({p})", self.name()),
_ => write!(f, "{}", self.name()),
}
}
}
#[derive(Debug, Clone)]
pub struct StatisticalConstraint {
column: String,
statistic: StatisticType,
assertion: Assertion,
}
impl StatisticalConstraint {
pub fn new(
column: impl Into<String>,
statistic: StatisticType,
assertion: Assertion,
) -> Result<Self> {
let column_str = column.into();
SqlSecurity::validate_identifier(&column_str)?;
if let StatisticType::Percentile(p) = &statistic {
if !(0.0..=1.0).contains(p) {
return Err(TermError::SecurityError(
"Percentile must be between 0.0 and 1.0".to_string(),
));
}
}
Ok(Self {
column: column_str,
statistic,
assertion,
})
}
pub fn min(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Min, assertion)
}
pub fn max(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Max, assertion)
}
pub fn mean(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Mean, assertion)
}
pub fn sum(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Sum, assertion)
}
pub fn standard_deviation(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::StandardDeviation, assertion)
}
pub fn variance(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Variance, assertion)
}
pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(column, StatisticType::Median, assertion)
}
pub fn percentile(
column: impl Into<String>,
percentile: f64,
assertion: Assertion,
) -> Result<Self> {
Self::new(column, StatisticType::Percentile(percentile), assertion)
}
}
#[async_trait]
impl Constraint for StatisticalConstraint {
#[instrument(skip(self, ctx), fields(
column = %self.column,
statistic = %self.statistic,
assertion = %self.assertion
))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
let stat_expr = self.statistic.sql_expression(&column_identifier);
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!("SELECT {stat_expr} as stat_value FROM {table_name}");
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let batch = &batches[0];
if batch.num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let value = if let Ok(array) = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
{
if array.is_null(0) {
let stat_name = self.statistic.name();
return Ok(ConstraintResult::failure(format!(
"{stat_name} is null (no non-null values)"
)));
}
array.value(0) as f64
} else if let Ok(array) = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.ok_or_else(|| TermError::Internal("Failed to extract statistic value".to_string()))
{
if array.is_null(0) {
let stat_name = self.statistic.name();
return Ok(ConstraintResult::failure(format!(
"{stat_name} is null (no non-null values)"
)));
}
array.value(0)
} else {
return Err(TermError::Internal(
"Failed to extract statistic value".to_string(),
));
};
if self.assertion.evaluate(value) {
Ok(ConstraintResult::success_with_metric(value))
} else {
Ok(ConstraintResult::failure_with_metric(
value,
format!(
"{} {value} does not {}",
self.statistic.name(),
self.assertion
),
))
}
}
fn name(&self) -> &str {
self.statistic.constraint_name()
}
fn column(&self) -> Option<&str> {
Some(&self.column)
}
fn metadata(&self) -> ConstraintMetadata {
let mut metadata = ConstraintMetadata::for_column(&self.column)
.with_description(format!(
"Checks that {} of {} {}",
self.statistic.name(),
self.column,
self.assertion.description()
))
.with_custom("assertion", self.assertion.to_string())
.with_custom("statistic_type", self.statistic.to_string())
.with_custom("constraint_type", "statistical");
if let StatisticType::Percentile(p) = self.statistic {
metadata = metadata.with_custom("percentile", p.to_string());
}
metadata
}
}
#[derive(Debug, Clone)]
pub struct MultiStatisticalConstraint {
column: String,
statistics: Vec<(StatisticType, Assertion)>,
}
impl MultiStatisticalConstraint {
pub fn new(
column: impl Into<String>,
statistics: Vec<(StatisticType, Assertion)>,
) -> Result<Self> {
let column_str = column.into();
SqlSecurity::validate_identifier(&column_str)?;
for (stat, _) in &statistics {
if let StatisticType::Percentile(p) = stat {
if !(0.0..=1.0).contains(p) {
return Err(TermError::SecurityError(
"Percentile must be between 0.0 and 1.0".to_string(),
));
}
}
}
Ok(Self {
column: column_str,
statistics,
})
}
}
#[async_trait]
impl Constraint for MultiStatisticalConstraint {
#[instrument(skip(self, ctx), fields(
column = %self.column,
num_statistics = %self.statistics.len()
))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let column_identifier = SqlSecurity::escape_identifier(&self.column)?;
let sql_parts: Vec<String> = self
.statistics
.iter()
.enumerate()
.map(|(i, (stat, _))| {
let expr = stat.sql_expression(&column_identifier);
format!("{expr} as stat_{i}")
})
.collect();
let parts = sql_parts.join(", ");
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!("SELECT {parts} FROM {table_name}");
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let batch = &batches[0];
if batch.num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let mut failures = Vec::new();
let mut all_metrics = Vec::new();
for (i, (stat_type, assertion)) in self.statistics.iter().enumerate() {
let column = batch.column(i);
let value = if let Some(array) =
column.as_any().downcast_ref::<arrow::array::Float64Array>()
{
if array.is_null(0) {
let name = stat_type.name();
failures.push(format!("{name} is null"));
continue;
}
array.value(0)
} else if let Some(array) = column.as_any().downcast_ref::<arrow::array::Int64Array>() {
if array.is_null(0) {
let name = stat_type.name();
failures.push(format!("{name} is null"));
continue;
}
array.value(0) as f64
} else {
let name = stat_type.name();
failures.push(format!("Failed to compute {name}"));
continue;
};
all_metrics.push((stat_type.name().to_string(), value));
if !assertion.evaluate(value) {
failures.push(format!(
"{} is {value} which does not {assertion}",
stat_type.name()
));
}
}
if failures.is_empty() {
let first_metric = all_metrics.first().map(|(_, v)| *v).unwrap_or(0.0);
Ok(ConstraintResult::success_with_metric(first_metric))
} else {
Ok(ConstraintResult::failure(failures.join("; ")))
}
}
fn name(&self) -> &str {
"multi_statistical"
}
fn column(&self) -> Option<&str> {
Some(&self.column)
}
fn metadata(&self) -> ConstraintMetadata {
let stat_names: Vec<String> = self
.statistics
.iter()
.map(|(s, _)| s.name().to_string())
.collect();
ConstraintMetadata::for_column(&self.column)
.with_description({
let stats = stat_names.join(", ");
format!(
"Checks multiple statistics ({stats}) for column {}",
self.column
)
})
.with_custom("statistics_count", self.statistics.len().to_string())
.with_custom("constraint_type", "multi_statistical")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ConstraintStatus;
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Float64,
true,
)]));
let array = Float64Array::from(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[tokio::test]
async fn test_mean_constraint() {
let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
let constraint = StatisticalConstraint::mean("value", Assertion::Equals(20.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(20.0));
}
#[tokio::test]
async fn test_min_max_constraints() {
let ctx = create_test_context(vec![Some(5.0), Some(10.0), Some(15.0)]).await;
let min_constraint = StatisticalConstraint::min("value", Assertion::Equals(5.0)).unwrap();
let result = evaluate_constraint_with_context(&min_constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(5.0));
let max_constraint = StatisticalConstraint::max("value", Assertion::Equals(15.0)).unwrap();
let result = evaluate_constraint_with_context(&max_constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(15.0));
}
#[tokio::test]
async fn test_sum_constraint() {
let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
let constraint = StatisticalConstraint::sum("value", Assertion::Equals(60.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(60.0));
}
#[tokio::test]
async fn test_with_nulls() {
let ctx = create_test_context(vec![Some(10.0), None, Some(20.0)]).await;
let constraint = StatisticalConstraint::mean("value", Assertion::Equals(15.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(15.0));
}
#[tokio::test]
async fn test_all_nulls() {
let ctx = create_test_context(vec![None, None, None]).await;
let constraint = StatisticalConstraint::mean("value", Assertion::Equals(0.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.message.unwrap().contains("null"));
}
#[test]
fn test_statistic_type_display() {
assert_eq!(StatisticType::Min.to_string(), "minimum");
assert_eq!(StatisticType::Mean.to_string(), "mean");
assert_eq!(
StatisticType::Percentile(0.95).to_string(),
"percentile(0.95)"
);
assert_eq!(StatisticType::Median.to_string(), "median");
}
#[tokio::test]
async fn test_multi_statistical_constraint() {
let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0), Some(40.0)]).await;
let constraint = MultiStatisticalConstraint::new(
"value",
vec![
(StatisticType::Min, Assertion::GreaterThanOrEqual(10.0)),
(StatisticType::Max, Assertion::LessThanOrEqual(40.0)),
(StatisticType::Mean, Assertion::Equals(25.0)),
(StatisticType::Sum, Assertion::Equals(100.0)),
],
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_multi_statistical_constraint_failure() {
let ctx = create_test_context(vec![Some(10.0), Some(20.0), Some(30.0)]).await;
let constraint = MultiStatisticalConstraint::new(
"value",
vec![
(StatisticType::Min, Assertion::Equals(5.0)), (StatisticType::Max, Assertion::Equals(30.0)),
],
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.message.unwrap().contains("minimum is 10"));
}
#[test]
fn test_invalid_percentile() {
let result = StatisticalConstraint::new(
"value",
StatisticType::Percentile(1.5),
Assertion::LessThan(100.0),
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Percentile must be between 0.0 and 1.0"));
}
}