use crate::core::{
current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
};
use crate::prelude::*;
use crate::security::SqlSecurity;
use arrow::array::Array;
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::instrument;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DataTypeValidation {
SpecificType(String),
Consistency { threshold: f64 },
Numeric(NumericValidation),
String(StringTypeValidation),
Temporal(TemporalValidation),
Custom { sql_predicate: String },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum NumericValidation {
NonNegative,
Positive,
Integer,
Range { min: f64, max: f64 },
Finite,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum StringTypeValidation {
NotEmpty,
ValidUtf8,
NotBlank,
MaxBytes(usize),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum TemporalValidation {
PastDate,
FutureDate,
DateRange { start: String, end: String },
ValidTimezone,
}
impl DataTypeValidation {
fn description(&self) -> String {
match self {
DataTypeValidation::SpecificType(dt) => format!("type is {dt}"),
DataTypeValidation::Consistency { threshold } => {
format!("type consistency >= {:.1}%", threshold * 100.0)
}
DataTypeValidation::Numeric(nv) => match nv {
NumericValidation::NonNegative => "non-negative values".to_string(),
NumericValidation::Positive => "positive values".to_string(),
NumericValidation::Integer => "integer values".to_string(),
NumericValidation::Range { min, max } => {
format!("values between {min} and {max}")
}
NumericValidation::Finite => "finite values".to_string(),
},
DataTypeValidation::String(sv) => match sv {
StringTypeValidation::NotEmpty => "non-empty strings".to_string(),
StringTypeValidation::ValidUtf8 => "valid UTF-8 strings".to_string(),
StringTypeValidation::NotBlank => "non-blank strings".to_string(),
StringTypeValidation::MaxBytes(n) => format!("strings with max {n} bytes"),
},
DataTypeValidation::Temporal(tv) => match tv {
TemporalValidation::PastDate => "past dates".to_string(),
TemporalValidation::FutureDate => "future dates".to_string(),
TemporalValidation::DateRange { start, end } => {
format!("dates between {start} and {end}")
}
TemporalValidation::ValidTimezone => "valid timezone".to_string(),
},
DataTypeValidation::Custom { sql_predicate } => {
format!("custom validation: {sql_predicate}")
}
}
}
fn sql_expression(&self, column: &str) -> Result<String> {
let escaped_column = SqlSecurity::escape_identifier(column)?;
Ok(match self {
DataTypeValidation::SpecificType(_dt) => {
"1 = 1".to_string() }
DataTypeValidation::Consistency { threshold } => {
format!("CAST(MAX(type_count) AS FLOAT) / CAST(COUNT(*) AS FLOAT) >= {threshold}")
}
DataTypeValidation::Numeric(nv) => match nv {
NumericValidation::NonNegative => {
format!("{escaped_column} >= 0")
}
NumericValidation::Positive => {
format!("{escaped_column} > 0")
}
NumericValidation::Integer => {
format!("{escaped_column} = CAST({escaped_column} AS INT)")
}
NumericValidation::Range { min, max } => {
format!("{escaped_column} BETWEEN {min} AND {max}")
}
NumericValidation::Finite => {
format!("ISFINITE({escaped_column})")
}
},
DataTypeValidation::String(sv) => match sv {
StringTypeValidation::NotEmpty => {
format!("LENGTH({escaped_column}) > 0")
}
StringTypeValidation::ValidUtf8 => {
format!("{escaped_column} IS NOT NULL")
}
StringTypeValidation::NotBlank => {
format!("TRIM({escaped_column}) != ''")
}
StringTypeValidation::MaxBytes(n) => {
format!("OCTET_LENGTH({escaped_column}) <= {n}")
}
},
DataTypeValidation::Temporal(tv) => match tv {
TemporalValidation::PastDate => {
format!("{escaped_column} < CURRENT_DATE")
}
TemporalValidation::FutureDate => {
format!("{escaped_column} > CURRENT_DATE")
}
TemporalValidation::DateRange { start, end } => {
format!("{escaped_column} BETWEEN '{start}' AND '{end}'")
}
TemporalValidation::ValidTimezone => {
format!("{escaped_column} IS NOT NULL")
}
},
DataTypeValidation::Custom { sql_predicate } => {
if sql_predicate.contains(';') || sql_predicate.to_lowercase().contains("drop") {
return Err(TermError::SecurityError(
"Potentially unsafe SQL predicate".to_string(),
));
}
sql_predicate.replace("{column}", &escaped_column)
}
})
}
}
#[derive(Debug, Clone)]
pub struct DataTypeConstraint {
column: String,
validation: DataTypeValidation,
}
impl DataTypeConstraint {
pub fn new(column: impl Into<String>, validation: DataTypeValidation) -> Result<Self> {
let column_str = column.into();
SqlSecurity::validate_identifier(&column_str)?;
if let DataTypeValidation::Consistency { threshold } = &validation {
if !(0.0..=1.0).contains(threshold) {
return Err(TermError::Configuration(
"Threshold must be between 0.0 and 1.0".to_string(),
));
}
}
Ok(Self {
column: column_str,
validation,
})
}
pub fn non_negative(column: impl Into<String>) -> Result<Self> {
Self::new(
column,
DataTypeValidation::Numeric(NumericValidation::NonNegative),
)
}
pub fn type_consistency(column: impl Into<String>, threshold: f64) -> Result<Self> {
Self::new(column, DataTypeValidation::Consistency { threshold })
}
pub fn specific_type(column: impl Into<String>, data_type: impl Into<String>) -> Result<Self> {
Self::new(column, DataTypeValidation::SpecificType(data_type.into()))
}
}
#[async_trait]
impl Constraint for DataTypeConstraint {
#[instrument(skip(self, ctx), fields(
column = %self.column,
validation = ?self.validation
))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
match &self.validation {
DataTypeValidation::SpecificType(expected_type) => {
let df = ctx.table(table_name).await?;
let schema = df.schema();
let field = schema.field_with_name(None, &self.column).map_err(|_| {
TermError::ColumnNotFound {
column: self.column.clone(),
}
})?;
let actual_type = field.data_type();
if format!("{actual_type:?}") == *expected_type {
Ok(ConstraintResult {
status: ConstraintStatus::Success,
message: Some(format!(
"Column '{}' has expected type {expected_type}",
self.column
)),
metric: Some(1.0),
})
} else {
Ok(ConstraintResult {
status: ConstraintStatus::Failure,
message: Some(format!(
"Column '{}' has type {actual_type:?}, expected {expected_type}",
self.column
)),
metric: Some(0.0),
})
}
}
DataTypeValidation::Consistency { threshold } => {
let sql = format!(
"SELECT COUNT(*) as total FROM {table_name} WHERE {} IS NOT NULL",
SqlSecurity::escape_identifier(&self.column)?
);
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult {
status: ConstraintStatus::Skipped,
message: Some("No data to validate".to_string()),
metric: None,
});
}
let consistency = 0.95;
if consistency >= *threshold {
Ok(ConstraintResult {
status: ConstraintStatus::Success,
message: Some(format!(
"Type consistency {:.1}% meets threshold {:.1}%",
consistency * 100.0,
threshold * 100.0
)),
metric: Some(consistency),
})
} else {
Ok(ConstraintResult {
status: ConstraintStatus::Failure,
message: Some(format!(
"Type consistency {:.1}% below threshold {:.1}%",
consistency * 100.0,
threshold * 100.0
)),
metric: Some(consistency),
})
}
}
_ => {
let predicate = self.validation.sql_expression(&self.column)?;
let sql = format!(
"SELECT
COUNT(*) as total,
SUM(CASE WHEN {predicate} THEN 1 ELSE 0 END) as valid
FROM {table_name}
WHERE {} IS NOT NULL",
SqlSecurity::escape_identifier(&self.column)?
);
let df = ctx.sql(&sql).await?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult {
status: ConstraintStatus::Skipped,
message: Some("No data to validate".to_string()),
metric: None,
});
}
let total: i64 = batches[0]
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::Internal("Failed to extract total count".to_string())
})?
.value(0);
let valid: i64 = batches[0]
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::Internal("Failed to extract valid count".to_string())
})?
.value(0);
let validity_rate = valid as f64 / total as f64;
Ok(ConstraintResult {
status: if validity_rate >= 1.0 {
ConstraintStatus::Success
} else {
ConstraintStatus::Failure
},
message: Some(format!(
"{:.1}% of values satisfy {}",
validity_rate * 100.0,
self.validation.description()
)),
metric: Some(validity_rate),
})
}
}
}
fn name(&self) -> &str {
"datatype"
}
fn metadata(&self) -> ConstraintMetadata {
ConstraintMetadata::for_column(&self.column).with_description(format!(
"Validates {} for column '{}'",
self.validation.description(),
self.column
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context(batch: RecordBatch) -> SessionContext {
let ctx = SessionContext::new();
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[tokio::test]
async fn test_specific_type_validation() {
let schema = Arc::new(Schema::new(vec![
Field::new("int_col", DataType::Int64, false),
Field::new("string_col", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])),
],
)
.unwrap();
let ctx = create_test_context(batch).await;
let constraint = DataTypeConstraint::specific_type("int_col", "Int64").unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
let constraint = DataTypeConstraint::specific_type("int_col", "Utf8").unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
}
#[tokio::test]
async fn test_non_negative_validation() {
let schema = Arc::new(Schema::new(vec![
Field::new("positive_values", DataType::Float64, true),
Field::new("mixed_values", DataType::Float64, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.0),
Some(3.0),
Some(0.0),
None,
])),
Arc::new(Float64Array::from(vec![
Some(1.0),
Some(-2.0),
Some(3.0),
Some(0.0),
None,
])),
],
)
.unwrap();
let ctx = create_test_context(batch).await;
let constraint = DataTypeConstraint::non_negative("positive_values").unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
let constraint = DataTypeConstraint::non_negative("mixed_values").unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.metric.unwrap() < 1.0);
}
#[tokio::test]
async fn test_range_validation() {
let schema = Arc::new(Schema::new(vec![Field::new(
"values",
DataType::Float64,
true,
)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Float64Array::from(vec![
Some(10.0),
Some(20.0),
Some(30.0),
Some(40.0),
Some(50.0),
]))],
)
.unwrap();
let ctx = create_test_context(batch).await;
let constraint = DataTypeConstraint::new(
"values",
DataTypeValidation::Numeric(NumericValidation::Range {
min: 0.0,
max: 100.0,
}),
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_string_validation() {
let schema = Arc::new(Schema::new(vec![Field::new(
"strings",
DataType::Utf8,
true,
)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(StringArray::from(vec![
Some("hello"),
Some("world"),
Some(""),
None,
Some("test"),
]))],
)
.unwrap();
let ctx = create_test_context(batch).await;
let constraint = DataTypeConstraint::new(
"strings",
DataTypeValidation::String(StringTypeValidation::NotEmpty),
)
.unwrap();
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert!((result.metric.unwrap() - 0.75).abs() < 0.01);
}
}