use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
use crate::error::{Result, TermError};
use crate::security::SqlSecurity;
use arrow::array::{Array, Float64Array, StringArray};
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossTableSumConstraint {
left_column: String,
right_column: String,
group_by_columns: Vec<String>,
tolerance: f64,
max_violations_reported: usize,
}
impl CrossTableSumConstraint {
pub fn new(left_column: impl Into<String>, right_column: impl Into<String>) -> Self {
Self {
left_column: left_column.into(),
right_column: right_column.into(),
group_by_columns: Vec::new(),
tolerance: 0.0,
max_violations_reported: 100,
}
}
pub fn group_by(mut self, columns: Vec<impl Into<String>>) -> Self {
self.group_by_columns = columns.into_iter().map(Into::into).collect();
self
}
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance.abs(); self
}
pub fn max_violations_reported(mut self, max_violations: usize) -> Self {
self.max_violations_reported = max_violations;
self
}
pub fn left_column(&self) -> &str {
&self.left_column
}
pub fn right_column(&self) -> &str {
&self.right_column
}
pub fn group_by_columns(&self) -> &[String] {
&self.group_by_columns
}
fn parse_qualified_column(&self, qualified_column: &str) -> Result<(String, String)> {
let parts: Vec<&str> = qualified_column.split('.').collect();
if parts.len() != 2 {
return Err(TermError::constraint_evaluation(
"cross_table_sum",
format!("Column must be qualified (table.column): '{qualified_column}'"),
));
}
let table = parts[0].to_string();
let column = parts[1].to_string();
SqlSecurity::validate_identifier(&table)?;
SqlSecurity::validate_identifier(&column)?;
Ok((table, column))
}
fn validate_group_by_columns(&self) -> Result<()> {
for column in &self.group_by_columns {
SqlSecurity::validate_identifier(column)?;
}
Ok(())
}
fn generate_validation_query(
&self,
left_table: &str,
left_col: &str,
right_table: &str,
right_col: &str,
) -> Result<String> {
if self.group_by_columns.is_empty() {
let tolerance = self.tolerance;
let sql = format!(
"SELECT
1 as total_groups,
CASE WHEN ABS(left_total - right_total) > {tolerance}
THEN 1 ELSE 0 END as violating_groups,
left_total as total_left_sum,
right_total as total_right_sum,
ABS(left_total - right_total) as max_difference
FROM (
SELECT
COALESCE((SELECT SUM({left_col}) FROM {left_table}), 0.0) as left_total,
COALESCE((SELECT SUM({right_col}) FROM {right_table}), 0.0) as right_total
) totals"
);
debug!("Generated optimized non-grouped cross-table sum query: {sql}");
Ok(sql)
} else {
let group_columns = self
.group_by_columns
.iter()
.map(|col| col.to_string())
.collect::<Vec<_>>();
let left_group_select = group_columns
.iter()
.map(|col| format!("{left_table}.{col}"))
.collect::<Vec<_>>()
.join(", ");
let right_group_select = group_columns
.iter()
.map(|col| format!("{right_table}.{col}"))
.collect::<Vec<_>>()
.join(", ");
let _group_by_clause = group_columns
.iter()
.map(|col| col.to_string())
.collect::<Vec<_>>()
.join(", ");
let tolerance = self.tolerance;
let join_condition = group_columns
.iter()
.map(|col| format!("l.{col} = r.{col}"))
.collect::<Vec<_>>()
.join(" AND ");
let sql = format!(
"WITH left_sums AS (
SELECT {left_group_select},
COALESCE(SUM({left_table}.{left_col}), 0.0) as left_sum
FROM {left_table}
GROUP BY {left_group_select}
),
right_sums AS (
SELECT {right_group_select},
COALESCE(SUM({right_table}.{right_col}), 0.0) as right_sum
FROM {right_table}
GROUP BY {right_group_select}
),
combined_data AS (
SELECT
COALESCE(l.left_sum, 0.0) as total_left_sum,
COALESCE(r.right_sum, 0.0) as total_right_sum,
ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) as difference,
CASE WHEN ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) > {tolerance}
THEN 1 ELSE 0 END as is_violation
FROM left_sums l
FULL OUTER JOIN right_sums r ON {join_condition}
)
SELECT
COUNT(*) as total_groups,
SUM(is_violation) as violating_groups,
SUM(total_left_sum) as total_left_sum,
SUM(total_right_sum) as total_right_sum,
MAX(difference) as max_difference
FROM combined_data"
);
debug!("Generated optimized grouped cross-table sum query: {sql}");
Ok(sql)
}
}
fn generate_violations_query(
&self,
left_table: &str,
left_col: &str,
right_table: &str,
right_col: &str,
) -> Result<String> {
if self.max_violations_reported == 0 {
return Ok(String::new());
}
if self.group_by_columns.is_empty() {
let tolerance = self.tolerance;
let limit = self.max_violations_reported;
let sql = format!(
"SELECT
'ALL' as group_key,
left_total as left_sum,
right_total as right_sum,
ABS(left_total - right_total) as difference
FROM (
SELECT
COALESCE((SELECT SUM({left_col}) FROM {left_table}), 0.0) as left_total,
COALESCE((SELECT SUM({right_col}) FROM {right_table}), 0.0) as right_total
) totals
WHERE ABS(left_total - right_total) > {tolerance}
LIMIT {limit}"
);
debug!("Generated optimized non-grouped violations query: {sql}");
Ok(sql)
} else {
let group_columns = self
.group_by_columns
.iter()
.map(|col| col.to_string())
.collect::<Vec<_>>();
let left_group_select = group_columns
.iter()
.map(|col| format!("{left_table}.{col}"))
.collect::<Vec<_>>()
.join(", ");
let right_group_select = group_columns
.iter()
.map(|col| format!("{right_table}.{col}"))
.collect::<Vec<_>>()
.join(", ");
let group_key_concat = if group_columns.len() == 1 {
format!(
"CAST(COALESCE(l.{}, r.{}) AS STRING)",
group_columns[0], group_columns[0]
)
} else {
format!(
"CONCAT({})",
group_columns
.iter()
.map(|col| format!("CAST(COALESCE(l.{col}, r.{col}) AS STRING)"))
.collect::<Vec<_>>()
.join(", '|', ")
)
};
let _group_by_clause = group_columns.join(", ");
let tolerance = self.tolerance;
let limit = self.max_violations_reported;
let join_condition = group_columns
.iter()
.map(|col| format!("l.{col} = r.{col}"))
.collect::<Vec<_>>()
.join(" AND ");
let sql = format!(
"WITH left_sums AS (
SELECT {left_group_select},
COALESCE(SUM({left_table}.{left_col}), 0.0) as left_sum
FROM {left_table}
GROUP BY {left_group_select}
),
right_sums AS (
SELECT {right_group_select},
COALESCE(SUM({right_table}.{right_col}), 0.0) as right_sum
FROM {right_table}
GROUP BY {right_group_select}
)
SELECT
{group_key_concat} as group_key,
COALESCE(l.left_sum, 0.0) as left_sum,
COALESCE(r.right_sum, 0.0) as right_sum,
ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) as difference
FROM left_sums l
FULL OUTER JOIN right_sums r ON {join_condition}
WHERE ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) > {tolerance}
ORDER BY ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) DESC
LIMIT {limit}"
);
debug!("Generated optimized grouped violations query: {sql}");
Ok(sql)
}
}
async fn collect_violation_examples_simple(
&self,
ctx: &SessionContext,
left_table: &str,
left_col: &str,
right_table: &str,
right_col: &str,
) -> Result<Vec<String>> {
if !self.group_by_columns.is_empty() {
debug!("Skipping violation example collection for grouped constraint due to DataFusion limitations");
return Ok(Vec::new());
}
let violations_sql =
self.generate_violations_query(left_table, left_col, right_table, right_col)?;
if violations_sql.is_empty() {
return Ok(Vec::new());
}
debug!("Executing simple violations query");
let violations_df = ctx.sql(&violations_sql).await.map_err(|e| {
TermError::constraint_evaluation(
"cross_table_sum",
format!("Failed to execute violations query: {e}"),
)
})?;
let batches = violations_df.collect().await.map_err(|e| {
TermError::constraint_evaluation(
"cross_table_sum",
format!("Failed to collect violation examples: {e}"),
)
})?;
let mut violation_examples = Vec::with_capacity(self.max_violations_reported);
for batch in batches {
for i in 0..batch.num_rows() {
if violation_examples.len() >= self.max_violations_reported {
break;
}
if let (Some(group_key), Some(left_sum), Some(right_sum), Some(diff)) = (
batch.column(0).as_any().downcast_ref::<StringArray>(),
batch.column(1).as_any().downcast_ref::<Float64Array>(),
batch.column(2).as_any().downcast_ref::<Float64Array>(),
batch.column(3).as_any().downcast_ref::<Float64Array>(),
) {
if !group_key.is_null(i) {
violation_examples.push(format!(
"Group '{}': {} = {:.4}, {} = {:.4} (diff: {:.4})",
group_key.value(i),
self.left_column,
left_sum.value(i),
self.right_column,
right_sum.value(i),
diff.value(i)
));
}
}
}
}
debug!("Collected {} violation examples", violation_examples.len());
Ok(violation_examples)
}
}
#[async_trait]
impl Constraint for CrossTableSumConstraint {
#[instrument(skip(self, ctx), fields(constraint = "cross_table_sum"))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
debug!(
"Evaluating cross-table sum constraint: {} vs {}",
self.left_column, self.right_column
);
let (left_table, left_col) = self.parse_qualified_column(&self.left_column)?;
let (right_table, right_col) = self.parse_qualified_column(&self.right_column)?;
self.validate_group_by_columns()?;
let sql =
self.generate_validation_query(&left_table, &left_col, &right_table, &right_col)?;
let df = ctx.sql(&sql).await.map_err(|e| {
TermError::constraint_evaluation(
"cross_table_sum",
format!("Cross-table sum validation query failed: {e}"),
)
})?;
let batches = df.collect().await.map_err(|e| {
TermError::constraint_evaluation(
"cross_table_sum",
format!("Failed to collect cross-table sum results: {e}"),
)
})?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::skipped(
"No data found for cross-table sum comparison",
));
}
let batch = &batches[0];
let total_groups = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"cross_table_sum",
"Invalid total_groups column type",
)
})?
.value(0);
let violating_groups = batch
.column(1)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"cross_table_sum",
"Invalid violating_groups column type",
)
})?
.value(0);
let total_left_sum = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"cross_table_sum",
"Invalid total_left_sum column type",
)
})?
.value(0);
let total_right_sum = batch
.column(3)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"cross_table_sum",
"Invalid total_right_sum column type",
)
})?
.value(0);
let max_difference = batch
.column(4)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"cross_table_sum",
"Invalid max_difference column type",
)
})?
.value(0);
if violating_groups == 0 {
debug!("Cross-table sum constraint passed: all groups match within tolerance");
return Ok(ConstraintResult::success_with_metric(max_difference));
}
let mut violation_examples = Vec::new();
if self.max_violations_reported > 0 {
violation_examples = self
.collect_violation_examples_simple(
ctx,
&left_table,
&left_col,
&right_table,
&right_col,
)
.await?;
}
let tolerance_text = if self.tolerance > 0.0 {
format!(" (tolerance: {:.4})", self.tolerance)
} else {
" (exact match required)".to_string()
};
let grouping_text = if self.group_by_columns.is_empty() {
"overall totals".to_string()
} else {
format!("groups by [{}]", self.group_by_columns.join(", "))
};
let message = if violation_examples.is_empty() {
format!(
"Cross-table sum mismatch: {violating_groups}/{total_groups} {grouping_text} failed validation{tolerance_text}, total sums: {total_left_sum} vs {total_right_sum} (max diff: {max_difference:.4})"
)
} else {
let examples_str = if violation_examples.len() <= 3 {
violation_examples.join("; ")
} else {
format!(
"{}; ... ({} more)",
violation_examples[..3].join("; "),
violation_examples.len() - 3
)
};
format!(
"Cross-table sum mismatch: {violating_groups}/{total_groups} {grouping_text} failed validation{tolerance_text}. Examples: [{examples_str}]"
)
};
warn!("{}", message);
Ok(ConstraintResult {
status: ConstraintStatus::Failure,
metric: Some(max_difference),
message: Some(message),
})
}
fn name(&self) -> &str {
"cross_table_sum"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::create_test_context;
async fn create_test_tables(ctx: &SessionContext, table_suffix: &str) -> Result<()> {
let orders_table = format!("orders_{table_suffix}");
let payments_table = format!("payments_{table_suffix}");
ctx.sql(&format!(
"CREATE TABLE {orders_table} (id BIGINT, customer_id BIGINT, total DOUBLE)"
))
.await?
.collect()
.await?;
ctx.sql(&format!(
"INSERT INTO {orders_table} VALUES (1, 1, 100.0), (2, 1, 200.0), (3, 2, 150.0), (4, 2, 300.0)"
))
.await?
.collect()
.await?;
ctx.sql(&format!(
"CREATE TABLE {payments_table} (id BIGINT, customer_id BIGINT, amount DOUBLE)"
))
.await?
.collect()
.await?;
ctx.sql(&format!(
"INSERT INTO {payments_table} VALUES (1, 1, 300.0), (2, 2, 450.0)"
))
.await?
.collect()
.await?;
Ok(())
}
#[tokio::test]
async fn test_cross_table_sum_success() -> Result<()> {
let ctx = create_test_context().await?;
create_test_tables(&ctx, "success").await?;
let constraint =
CrossTableSumConstraint::new("orders_success.total", "payments_success.amount")
.group_by(vec!["customer_id"]);
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Success);
assert!(result.metric.is_some());
Ok(())
}
#[tokio::test]
async fn test_cross_table_sum_violation() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE orders_violation (id BIGINT, customer_id BIGINT, total DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_violation VALUES (1, 1, 100.0), (2, 1, 200.0)")
.await?
.collect()
.await?;
ctx.sql("CREATE TABLE payments_violation (id BIGINT, customer_id BIGINT, amount DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO payments_violation VALUES (1, 1, 250.0)")
.await?
.collect()
.await?;
let constraint =
CrossTableSumConstraint::new("orders_violation.total", "payments_violation.amount")
.group_by(vec!["customer_id"]);
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.message.is_some());
assert!(result.metric.is_some());
let message = result.message.unwrap();
assert!(message.contains("Cross-table sum mismatch"));
assert!(message.contains("customer_id"));
Ok(())
}
#[tokio::test]
async fn test_cross_table_sum_with_tolerance() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE orders_tolerance (id BIGINT, total DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_tolerance VALUES (1, 100.005)")
.await?
.collect()
.await?;
ctx.sql("CREATE TABLE payments_tolerance (id BIGINT, amount DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO payments_tolerance VALUES (1, 100.001)")
.await?
.collect()
.await?;
let constraint_no_tolerance =
CrossTableSumConstraint::new("orders_tolerance.total", "payments_tolerance.amount");
let result = constraint_no_tolerance.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Failure);
let constraint_with_tolerance =
CrossTableSumConstraint::new("orders_tolerance.total", "payments_tolerance.amount")
.tolerance(0.01);
let result = constraint_with_tolerance.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Success);
Ok(())
}
#[tokio::test]
async fn test_cross_table_sum_no_grouping() -> Result<()> {
let ctx = create_test_context().await?;
create_test_tables(&ctx, "no_grouping").await?;
let constraint =
CrossTableSumConstraint::new("orders_no_grouping.total", "payments_no_grouping.amount");
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Success);
Ok(())
}
#[test]
fn test_parse_qualified_column() {
let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount");
let (table, column) = constraint.parse_qualified_column("orders.total").unwrap();
assert_eq!(table, "orders");
assert_eq!(column, "total");
assert!(constraint.parse_qualified_column("invalid_column").is_err());
assert!(constraint.parse_qualified_column("too.many.parts").is_err());
}
#[test]
fn test_constraint_configuration() {
let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount")
.group_by(vec!["customer_id", "order_date"])
.tolerance(0.01)
.max_violations_reported(50);
assert_eq!(constraint.left_column(), "orders.total");
assert_eq!(constraint.right_column(), "payments.amount");
assert_eq!(
constraint.group_by_columns(),
&["customer_id", "order_date"]
);
assert_eq!(constraint.tolerance, 0.01);
assert_eq!(constraint.max_violations_reported, 50);
}
#[test]
fn test_constraint_name() {
let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount");
assert_eq!(constraint.name(), "cross_table_sum");
}
}