use crate::core::{current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
use crate::prelude::*;
use arrow::array::Array;
use async_trait::async_trait;
use datafusion::prelude::*;
use tracing::instrument;
#[derive(Debug, Clone)]
pub struct DataTypeConsistencyConstraint {
column: String,
threshold: f64,
}
impl DataTypeConsistencyConstraint {
pub fn new(column: impl Into<String>, threshold: f64) -> Self {
assert!(
(0.0..=1.0).contains(&threshold),
"Threshold must be between 0.0 and 1.0"
);
Self {
column: column.into(),
threshold,
}
}
}
#[async_trait]
impl Constraint for DataTypeConsistencyConstraint {
#[instrument(skip(self, ctx), fields(column = %self.column, threshold = %self.threshold))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let type_dist_sql = format!(
"WITH type_analysis AS (
SELECT
CASE
WHEN {} IS NULL THEN 'null'
WHEN {} ~ '^-?\\d+$' THEN 'integer'
WHEN {} ~ '^-?\\d*\\.?\\d+([eE][+-]?\\d+)?$' THEN 'float'
WHEN {} ~ '^(true|false|TRUE|FALSE|True|False|0|1)$' THEN 'boolean'
WHEN {} ~ '^\\d{{4}}-\\d{{2}}-\\d{{2}}$' THEN 'date'
WHEN {} ~ '^\\d{{4}}-\\d{{2}}-\\d{{2}}[ T]\\d{{2}}:\\d{{2}}:\\d{{2}}' THEN 'timestamp'
ELSE 'string'
END as detected_type
FROM {table_name}
)
SELECT
detected_type,
COUNT(*) as type_count
FROM type_analysis
GROUP BY detected_type
ORDER BY type_count DESC",
self.column, self.column, self.column, self.column, self.column, self.column
);
let df = ctx.sql(&type_dist_sql).await?;
let batches = df.collect().await?;
if batches.is_empty() {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let batch = &batches[0];
if batch.num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to validate"));
}
let types = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| TermError::Internal("Failed to extract type names".to_string()))?;
let counts = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| TermError::Internal("Failed to extract type counts".to_string()))?;
let mut total_count = 0i64;
let mut _null_count = 0i64;
let mut non_null_count = 0i64;
let mut max_type_count = 0i64;
let mut dominant_type = "";
let mut distinct_type_count = 0i64;
let mut type_distribution = Vec::new();
for i in 0..batch.num_rows() {
if !types.is_null(i) {
let type_name = types.value(i);
let count = counts.value(i);
total_count += count;
if type_name == "null" {
_null_count = count;
} else {
non_null_count += count;
distinct_type_count += 1;
if count > max_type_count {
max_type_count = count;
dominant_type = type_name;
}
}
}
}
for i in 0..batch.num_rows() {
if !types.is_null(i) {
let type_name = types.value(i);
let count = counts.value(i);
let percentage = (count as f64 / total_count as f64) * 100.0;
type_distribution.push(format!("{type_name}: {percentage:.1}%"));
}
}
if non_null_count == 0 {
return Ok(ConstraintResult::skipped("No non-null data to validate"));
}
let consistency_ratio = max_type_count as f64 / non_null_count as f64;
let type_report = type_distribution.join(", ");
if consistency_ratio >= self.threshold {
Ok(ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(consistency_ratio),
message: Some(format!(
"Column has consistent type '{}' ({:.1}% of non-null values). Type distribution: [{}]",
dominant_type,
consistency_ratio * 100.0,
type_report
)),
})
} else {
Ok(ConstraintResult::failure_with_metric(
consistency_ratio,
format!(
"Type consistency {:.1}% is below threshold {:.1}%. Found {} distinct types. Type distribution: [{}]",
consistency_ratio * 100.0,
self.threshold * 100.0,
distinct_type_count,
type_report
),
))
}
}
fn name(&self) -> &str {
"data_type_consistency"
}
fn column(&self) -> Option<&str> {
Some(&self.column)
}
fn metadata(&self) -> ConstraintMetadata {
ConstraintMetadata::for_column(&self.column)
.with_description(format!(
"Checks that at least {:.1}% of values in '{}' have the same data type",
self.threshold * 100.0,
self.column
))
.with_custom("threshold", self.threshold.to_string())
.with_custom("constraint_type", "consistency")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ConstraintStatus;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context(values: Vec<Option<&str>>) -> SessionContext {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new(
"mixed_col",
DataType::Utf8,
true,
)]));
let array = StringArray::from(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[tokio::test]
async fn test_consistent_integers() {
let values = vec![Some("123"), Some("456"), Some("789"), Some("0")];
let ctx = create_test_context(values).await;
let constraint = DataTypeConsistencyConstraint::new("mixed_col", 0.95);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(1.0)); assert!(result.message.as_ref().unwrap().contains("integer"));
}
#[tokio::test]
async fn test_mixed_types() {
let values = vec![
Some("123"), Some("456.78"), Some("hello"), Some("true"), ];
let ctx = create_test_context(values).await;
let constraint = DataTypeConsistencyConstraint::new("mixed_col", 0.8);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert_eq!(result.metric, Some(0.25)); assert!(result
.message
.as_ref()
.unwrap()
.contains("4 distinct types"));
}
#[tokio::test]
async fn test_mostly_consistent() {
let values = vec![
Some("123"),
Some("456"),
Some("789"),
Some("hello"), ];
let ctx = create_test_context(values).await;
let constraint = DataTypeConsistencyConstraint::new("mixed_col", 0.7);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(0.75)); }
#[tokio::test]
async fn test_with_nulls() {
let values = vec![Some("123"), None, Some("456"), None, Some("789")];
let ctx = create_test_context(values).await;
let constraint = DataTypeConsistencyConstraint::new("mixed_col", 0.95);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(1.0)); assert!(result.message.as_ref().unwrap().contains("null: 40.0%"));
}
#[tokio::test]
async fn test_date_consistency() {
let values = vec![
Some("2024-01-01"),
Some("2024-01-02"),
Some("2024-01-03"),
Some("not-a-date"),
];
let ctx = create_test_context(values).await;
let constraint = DataTypeConsistencyConstraint::new("mixed_col", 0.7);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
assert_eq!(result.metric, Some(0.75)); assert!(result.message.as_ref().unwrap().contains("date"));
}
#[test]
#[should_panic(expected = "Threshold must be between 0.0 and 1.0")]
fn test_invalid_threshold() {
DataTypeConsistencyConstraint::new("col", 1.5);
}
}