use crate::core::{
current_validation_context, Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus,
};
use crate::prelude::*;
use arrow::array::{Array, LargeStringArray, StringViewArray};
use async_trait::async_trait;
use datafusion::prelude::*;
use std::fmt;
use std::sync::Arc;
use tracing::instrument;
#[derive(Debug, Clone, PartialEq)]
pub struct HistogramBucket {
pub value: String,
pub count: i64,
pub ratio: f64,
}
#[derive(Debug, Clone)]
pub struct Histogram {
pub buckets: Vec<HistogramBucket>,
pub total_count: i64,
pub distinct_count: usize,
pub null_count: i64,
}
impl Histogram {
pub fn new(buckets: Vec<HistogramBucket>, total_count: i64, null_count: i64) -> Self {
let distinct_count = buckets.len();
Self {
buckets,
total_count,
distinct_count,
null_count,
}
}
pub fn most_common_ratio(&self) -> f64 {
self.buckets.first().map(|b| b.ratio).unwrap_or(0.0)
}
pub fn least_common_ratio(&self) -> f64 {
self.buckets.last().map(|b| b.ratio).unwrap_or(0.0)
}
pub fn bucket_count(&self) -> usize {
self.buckets.len()
}
pub fn top_n(&self, n: usize) -> Vec<(&str, f64)> {
self.buckets
.iter()
.take(n)
.map(|b| (b.value.as_str(), b.ratio))
.collect()
}
pub fn is_roughly_uniform(&self, threshold: f64) -> bool {
if self.buckets.is_empty() {
return true;
}
let max_ratio = self.most_common_ratio();
let min_ratio = self.least_common_ratio();
if min_ratio == 0.0 {
return false;
}
max_ratio / min_ratio <= threshold
}
pub fn get_value_ratio(&self, value: &str) -> Option<f64> {
self.buckets
.iter()
.find(|b| b.value == value)
.map(|b| b.ratio)
}
pub fn entropy(&self) -> f64 {
self.buckets
.iter()
.filter(|b| b.ratio > 0.0)
.map(|b| -b.ratio * b.ratio.ln())
.sum()
}
pub fn follows_power_law(&self, top_n: usize, threshold: f64) -> bool {
let top_sum: f64 = self.buckets.iter().take(top_n).map(|b| b.ratio).sum();
top_sum >= threshold
}
pub fn null_ratio(&self) -> f64 {
if self.total_count == 0 {
0.0
} else {
self.null_count as f64 / self.total_count as f64
}
}
}
pub type HistogramAssertion = Arc<dyn Fn(&Histogram) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct HistogramConstraint {
column: String,
assertion: HistogramAssertion,
assertion_description: String,
}
impl fmt::Debug for HistogramConstraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HistogramConstraint")
.field("column", &self.column)
.field("assertion_description", &self.assertion_description)
.finish()
}
}
impl HistogramConstraint {
pub fn new(column: impl Into<String>, assertion: HistogramAssertion) -> Self {
Self {
column: column.into(),
assertion,
assertion_description: "custom assertion".to_string(),
}
}
pub fn new_with_description(
column: impl Into<String>,
assertion: HistogramAssertion,
description: impl Into<String>,
) -> Self {
Self {
column: column.into(),
assertion,
assertion_description: description.into(),
}
}
}
#[async_trait]
impl Constraint for HistogramConstraint {
#[instrument(skip(self, ctx), fields(column = %self.column))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
let validation_ctx = current_validation_context();
let table_name = validation_ctx.table_name();
let sql = format!(
r#"
WITH value_counts AS (
SELECT
CAST({} AS VARCHAR) as value,
COUNT(*) as count
FROM {table_name}
WHERE {} IS NOT NULL
GROUP BY {}
),
totals AS (
SELECT
COUNT(*) as total_cnt,
SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) as null_cnt
FROM {table_name}
)
SELECT
vc.value,
vc.count,
vc.count * 1.0 / (t.total_cnt - t.null_cnt) as ratio,
t.total_cnt as total_count,
t.null_cnt as null_count
FROM value_counts vc
CROSS JOIN totals t
ORDER BY vc.count DESC, vc.value
"#,
self.column, self.column, self.column, self.column
);
let df = ctx.sql(&sql).await.map_err(|e| {
TermError::constraint_evaluation(
self.name(),
format!("Failed to execute histogram query: {e}"),
)
})?;
let batches = df.collect().await?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::skipped("No data to analyze"));
}
let mut buckets = Vec::new();
let mut total_count = 0i64;
let mut null_count = 0i64;
for batch in &batches {
let values_col = batch.column(0);
let value_strings: Vec<String> = match values_col.data_type() {
arrow::datatypes::DataType::Utf8 => {
let arr = values_col
.as_any()
.downcast_ref::<arrow::array::StringArray>()
.ok_or_else(|| {
TermError::constraint_evaluation(
self.name(),
"Failed to extract string values",
)
})?;
(0..arr.len()).map(|i| arr.value(i).to_string()).collect()
}
arrow::datatypes::DataType::Utf8View => {
let arr = values_col
.as_any()
.downcast_ref::<StringViewArray>()
.ok_or_else(|| {
TermError::constraint_evaluation(
self.name(),
"Failed to extract string view values",
)
})?;
(0..arr.len()).map(|i| arr.value(i).to_string()).collect()
}
arrow::datatypes::DataType::LargeUtf8 => {
let arr = values_col
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| {
TermError::constraint_evaluation(
self.name(),
"Failed to extract large string values",
)
})?;
(0..arr.len()).map(|i| arr.value(i).to_string()).collect()
}
_ => {
return Err(TermError::constraint_evaluation(
self.name(),
format!("Unexpected value column type: {:?}", values_col.data_type()),
));
}
};
let count_array = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(self.name(), "Failed to extract counts")
})?;
let ratio_array = batch
.column(2)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(self.name(), "Failed to extract ratios")
})?;
let total_array = batch
.column(3)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(self.name(), "Failed to extract total count")
})?;
let null_array = batch
.column(4)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(self.name(), "Failed to extract null count")
})?;
if batch.num_rows() > 0 {
total_count = total_array.value(0);
null_count = null_array.value(0);
}
for (i, value) in value_strings.into_iter().enumerate() {
let count = count_array.value(i);
let ratio = ratio_array.value(i);
buckets.push(HistogramBucket {
value,
count,
ratio,
});
}
}
let histogram = Histogram::new(buckets, total_count, null_count);
let assertion_result = (self.assertion)(&histogram);
let status = if assertion_result {
ConstraintStatus::Success
} else {
ConstraintStatus::Failure
};
let message = if status == ConstraintStatus::Failure {
let most_common_pct = histogram.most_common_ratio() * 100.0;
let null_pct = histogram.null_ratio() * 100.0;
Some(format!(
"Histogram assertion '{}' failed for column '{}'. Distribution: {} distinct values, most common ratio: {most_common_pct:.2}%, null ratio: {null_pct:.2}%",
self.assertion_description,
self.column,
histogram.distinct_count
))
} else {
None
};
Ok(ConstraintResult {
status,
metric: Some(histogram.entropy()),
message,
})
}
fn name(&self) -> &str {
"histogram"
}
fn column(&self) -> Option<&str> {
Some(&self.column)
}
fn metadata(&self) -> ConstraintMetadata {
ConstraintMetadata::for_column(&self.column)
.with_description(format!(
"Analyzes value distribution in column '{}' and applies assertion: {}",
self.column, self.assertion_description
))
.with_custom("assertion", &self.assertion_description)
.with_custom("constraint_type", "histogram")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ConstraintStatus;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::MemTable;
use std::sync::Arc;
use crate::test_helpers::evaluate_constraint_with_context;
async fn create_test_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new(
"test_col",
DataType::Utf8,
true,
)]));
let array = StringArray::from(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
ctx
}
#[test]
fn test_histogram_basic() {
let buckets = vec![
HistogramBucket {
value: "A".to_string(),
count: 50,
ratio: 0.5,
},
HistogramBucket {
value: "B".to_string(),
count: 30,
ratio: 0.3,
},
HistogramBucket {
value: "C".to_string(),
count: 20,
ratio: 0.2,
},
];
let histogram = Histogram::new(buckets, 100, 0);
assert_eq!(histogram.most_common_ratio(), 0.5);
assert_eq!(histogram.least_common_ratio(), 0.2);
assert_eq!(histogram.bucket_count(), 3);
assert_eq!(histogram.null_ratio(), 0.0);
}
#[test]
fn test_histogram_entropy() {
let uniform_buckets = vec![
HistogramBucket {
value: "A".to_string(),
count: 25,
ratio: 0.25,
},
HistogramBucket {
value: "B".to_string(),
count: 25,
ratio: 0.25,
},
HistogramBucket {
value: "C".to_string(),
count: 25,
ratio: 0.25,
},
HistogramBucket {
value: "D".to_string(),
count: 25,
ratio: 0.25,
},
];
let uniform_hist = Histogram::new(uniform_buckets, 100, 0);
let skewed_buckets = vec![
HistogramBucket {
value: "A".to_string(),
count: 90,
ratio: 0.9,
},
HistogramBucket {
value: "B".to_string(),
count: 10,
ratio: 0.1,
},
];
let skewed_hist = Histogram::new(skewed_buckets, 100, 0);
assert!(uniform_hist.entropy() > skewed_hist.entropy());
}
#[tokio::test]
async fn test_most_common_ratio_constraint() {
let values = vec![
Some("A"),
Some("A"),
Some("A"),
Some("A"),
Some("A"),
Some("A"),
Some("B"),
Some("B"),
Some("C"),
Some("C"),
];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new_with_description(
"test_col",
Arc::new(|hist| hist.most_common_ratio() < 0.5),
"most common value appears less than 50%",
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.message.is_some());
let constraint =
HistogramConstraint::new("test_col", Arc::new(|hist| hist.most_common_ratio() < 0.7));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_bucket_count_constraint() {
let values = vec![
Some("RED"),
Some("BLUE"),
Some("GREEN"),
Some("YELLOW"),
Some("RED"),
Some("BLUE"),
];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new_with_description(
"test_col",
Arc::new(|hist| hist.bucket_count() >= 3 && hist.bucket_count() <= 5),
"has between 3 and 5 distinct values",
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_uniform_distribution_check() {
let values = vec![
Some("A"),
Some("A"),
Some("B"),
Some("B"),
Some("C"),
Some("C"),
Some("D"),
Some("D"),
];
let ctx = create_test_context_with_data(values).await;
let constraint =
HistogramConstraint::new("test_col", Arc::new(|hist| hist.is_roughly_uniform(1.5)));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_power_law_distribution() {
let values = vec![
Some("Popular1"),
Some("Popular1"),
Some("Popular1"),
Some("Popular1"),
Some("Popular2"),
Some("Popular2"),
Some("Popular2"),
Some("Rare1"),
Some("Rare2"),
Some("Rare3"),
];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new_with_description(
"test_col",
Arc::new(|hist| hist.follows_power_law(2, 0.7)),
"top 2 values account for 70% of distribution",
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_with_nulls() {
let values = vec![
Some("A"),
Some("A"),
None,
None,
None,
Some("B"),
Some("B"),
Some("C"),
];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new(
"test_col",
Arc::new(|hist| hist.null_ratio() > 0.3 && hist.null_ratio() < 0.4),
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_empty_data() {
let ctx = create_test_context_with_data(vec![]).await;
let constraint = HistogramConstraint::new("test_col", Arc::new(|_| true));
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Skipped);
}
#[tokio::test]
async fn test_specific_value_check() {
let values = vec![
Some("PENDING"),
Some("PENDING"),
Some("APPROVED"),
Some("APPROVED"),
Some("APPROVED"),
Some("REJECTED"),
];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new_with_description(
"test_col",
Arc::new(|hist| {
hist.get_value_ratio("APPROVED").unwrap_or(0.0) > 0.4
}),
"APPROVED status is most common",
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_top_n_values() {
let values = vec![
Some("A"),
Some("A"),
Some("A"),
Some("A"), Some("B"),
Some("B"),
Some("B"), Some("C"),
Some("C"), Some("D"), ];
let ctx = create_test_context_with_data(values).await;
let constraint = HistogramConstraint::new(
"test_col",
Arc::new(|hist| {
let top_2 = hist.top_n(2);
top_2.len() == 2 && top_2[0].1 == 0.4 && top_2[1].1 == 0.3
}),
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
#[tokio::test]
async fn test_numeric_data_histogram() {
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
let values = vec![
Some(25),
Some(25),
Some(30),
Some(30),
Some(30),
Some(35),
Some(35),
Some(40),
Some(45),
Some(50),
];
let array = Int64Array::from(values);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("data", Arc::new(provider)).unwrap();
let constraint = HistogramConstraint::new_with_description(
"age",
Arc::new(|hist| {
hist.bucket_count() >= 5 && hist.most_common_ratio() < 0.4
}),
"age distribution is reasonable",
);
let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
.await
.unwrap();
assert_eq!(result.status, ConstraintStatus::Success);
}
}