use crate::constraints::Assertion;
use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult};
use crate::prelude::*;
use crate::security::SqlSecurity;
use arrow::array::Array;
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, instrument};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum QuantileMethod {
Approximate,
Exact,
Auto { threshold: usize },
}
impl Default for QuantileMethod {
fn default() -> Self {
Self::Auto { threshold: 10000 }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct QuantileCheck {
pub quantile: f64,
pub assertion: Assertion,
}
impl QuantileCheck {
pub fn new(quantile: f64, assertion: Assertion) -> Result<Self> {
if !(0.0..=1.0).contains(&quantile) {
return Err(TermError::Configuration(
"Quantile must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self {
quantile,
assertion,
})
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DistributionConfig {
pub quantiles: Vec<f64>,
pub include_bounds: bool,
pub compute_iqr: bool,
}
impl Default for DistributionConfig {
fn default() -> Self {
Self {
quantiles: vec![0.25, 0.5, 0.75],
include_bounds: true,
compute_iqr: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum QuantileValidation {
Single(QuantileCheck),
Multiple(Vec<QuantileCheck>),
Distribution {
config: DistributionConfig,
iqr_assertion: Option<Assertion>,
quantile_assertions: HashMap<String, Assertion>,
},
Monotonic {
quantiles: Vec<f64>,
strict: bool,
},
Custom {
sql_expression: String,
assertion: Assertion,
},
}
#[derive(Debug, Clone)]
pub struct QuantileConstraint {
column: String,
validation: QuantileValidation,
method: QuantileMethod,
}
impl QuantileConstraint {
pub fn new(column: impl Into<String>, validation: QuantileValidation) -> Result<Self> {
Self::with_method(column, validation, QuantileMethod::default())
}
pub fn with_method(
column: impl Into<String>,
validation: QuantileValidation,
method: QuantileMethod,
) -> Result<Self> {
let column_str = column.into();
SqlSecurity::validate_identifier(&column_str)?;
Ok(Self {
column: column_str,
validation,
method,
})
}
pub fn median(column: impl Into<String>, assertion: Assertion) -> Result<Self> {
Self::new(
column,
QuantileValidation::Single(QuantileCheck::new(0.5, assertion)?),
)
}
pub fn percentile(
column: impl Into<String>,
quantile: f64,
assertion: Assertion,
) -> Result<Self> {
Self::new(
column,
QuantileValidation::Single(QuantileCheck::new(quantile, assertion)?),
)
}
pub fn multiple(column: impl Into<String>, checks: Vec<QuantileCheck>) -> Result<Self> {
if checks.is_empty() {
return Err(TermError::Configuration(
"At least one quantile check is required".to_string(),
));
}
Self::new(column, QuantileValidation::Multiple(checks))
}
pub fn distribution(column: impl Into<String>, config: DistributionConfig) -> Result<Self> {
Self::new(
column,
QuantileValidation::Distribution {
config,
iqr_assertion: None,
quantile_assertions: HashMap::new(),
},
)
}
fn approx_quantile_sql(&self, quantile: f64) -> Result<String> {
let escaped_column = SqlSecurity::escape_identifier(&self.column)?;
Ok(format!(
"APPROX_PERCENTILE_CONT({quantile}) WITHIN GROUP (ORDER BY {escaped_column})"
))
}
#[allow(dead_code)]
fn exact_quantile_sql(&self, quantile: f64) -> Result<String> {
self.approx_quantile_sql(quantile)
}
async fn should_use_exact(&self, ctx: &SessionContext) -> Result<bool> {
match &self.method {
QuantileMethod::Exact => Ok(true),
QuantileMethod::Approximate => Ok(false),
QuantileMethod::Auto { threshold } => {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let count_sql = format!("SELECT COUNT(*) as cnt FROM {table_name}");
let df = ctx.sql(&count_sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(true);
}
let count = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::Internal("Failed to downcast to Int64Array".to_string())
})?
.value(0) as usize;
Ok(count <= *threshold)
}
}
}
}
#[async_trait]
impl Constraint for QuantileConstraint {
#[instrument(skip(self, ctx), fields(
column = %self.column,
validation = ?self.validation
))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let _use_exact = self.should_use_exact(ctx).await?;
match &self.validation {
QuantileValidation::Single(check) => {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!(
"SELECT {} as q_value FROM {table_name}",
self.approx_quantile_sql(check.quantile)?
);
debug!("Quantile SQL: {}", sql);
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let column = batches[0].column(0);
let value = if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Float64Array>()
{
arr.value(0)
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int64Array>()
{
arr.value(0) as f64
} else if let Some(arr) = column.as_any().downcast_ref::<arrow::array::Int32Array>()
{
arr.value(0) as f64
} else {
return Err(TermError::TypeMismatch {
expected: "Float64, Int64, or Int32".to_string(),
found: format!("{:?}", column.data_type()),
});
};
if check.assertion.evaluate(value) {
Ok(ConstraintResult::success_with_metric(value))
} else {
Ok(ConstraintResult::failure_with_metric(
value,
format!(
"Quantile {} is {value} which does not {}",
check.quantile, check.assertion
),
))
}
}
QuantileValidation::Multiple(checks) => {
let sql_parts: Vec<String> = checks
.iter()
.enumerate()
.map(|(i, check)| {
self.approx_quantile_sql(check.quantile)
.map(|q_sql| format!("{q_sql} as q_{i}"))
})
.collect::<Result<Vec<_>>>()?;
let parts = sql_parts.join(", ");
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!("SELECT {parts} FROM {table_name}");
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let mut failures = Vec::new();
let batch = &batches[0];
for (i, check) in checks.iter().enumerate() {
let column = batch.column(i);
let value = if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Float64Array>()
{
arr.value(0)
} else if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Int64Array>()
{
arr.value(0) as f64
} else if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Int32Array>()
{
arr.value(0) as f64
} else {
return Err(TermError::TypeMismatch {
expected: "Float64, Int64, or Int32".to_string(),
found: format!("{:?}", column.data_type()),
});
};
if !check.assertion.evaluate(value) {
let q_pct = (check.quantile * 100.0) as i32;
failures.push(format!(
"Q{q_pct} is {value} which does not {}",
check.assertion
));
}
}
if failures.is_empty() {
Ok(ConstraintResult::success())
} else {
Ok(ConstraintResult::failure(failures.join("; ")))
}
}
QuantileValidation::Monotonic { quantiles, strict } => {
let sql_parts: Vec<String> = quantiles
.iter()
.enumerate()
.map(|(i, q)| {
self.approx_quantile_sql(*q)
.map(|q_sql| format!("{q_sql} as q_{i}"))
})
.collect::<Result<Vec<_>>>()?;
let parts = sql_parts.join(", ");
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!("SELECT {parts} FROM {table_name}");
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let batch = &batches[0];
let mut values = Vec::new();
for i in 0..quantiles.len() {
let column = batch.column(i);
let value = if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Float64Array>()
{
arr.value(0)
} else if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Int64Array>()
{
arr.value(0) as f64
} else if let Some(arr) =
column.as_any().downcast_ref::<arrow::array::Int32Array>()
{
arr.value(0) as f64
} else {
return Err(TermError::TypeMismatch {
expected: "Float64, Int64, or Int32".to_string(),
found: format!("{:?}", column.data_type()),
});
};
values.push(value);
}
let mut is_monotonic = true;
for i in 1..values.len() {
if *strict {
if values[i] <= values[i - 1] {
is_monotonic = false;
break;
}
} else if values[i] < values[i - 1] {
is_monotonic = false;
break;
}
}
if is_monotonic {
Ok(ConstraintResult::success())
} else {
let monotonic_type = if *strict { "strictly" } else { "" };
Ok(ConstraintResult::failure(format!(
"Quantiles are not {monotonic_type} monotonic: {values:?}"
)))
}
}
_ => {
Ok(ConstraintResult::skipped(
"Validation type not yet implemented",
))
}
}
}
fn name(&self) -> &str {
"quantile"
}
fn metadata(&self) -> ConstraintMetadata {
ConstraintMetadata::for_column(&self.column).with_description(format!(
"Validates quantile properties for column '{}'",
self.column
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ConstraintStatus;
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context(values: Vec<Option<f64>>) -> SessionContext {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Float64,
true,
)]));
let array = Float64Array::from(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[tokio::test]
async fn test_median_check() {
let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
let ctx = create_test_context(values).await;
let constraint =
QuantileConstraint::median("value", Assertion::Between(45.0, 55.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_percentile_check() {
let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
let ctx = create_test_context(values).await;
let constraint =
QuantileConstraint::percentile("value", 0.95, Assertion::Between(94.0, 96.0)).unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_multiple_quantiles() {
let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
let ctx = create_test_context(values).await;
let constraint = QuantileConstraint::multiple(
"value",
vec![
QuantileCheck::new(0.25, Assertion::Between(24.0, 26.0)).unwrap(),
QuantileCheck::new(0.75, Assertion::Between(74.0, 76.0)).unwrap(),
],
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_monotonic_check() {
let values: Vec<Option<f64>> = (1..=100).map(|i| Some(i as f64)).collect();
let ctx = create_test_context(values).await;
let constraint = QuantileConstraint::new(
"value",
QuantileValidation::Monotonic {
quantiles: vec![0.1, 0.5, 0.9],
strict: true,
},
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[test]
fn test_invalid_quantile() {
let result = QuantileCheck::new(1.5, Assertion::LessThan(100.0));
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Quantile must be between 0.0 and 1.0"));
}
}