use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
use crate::error::{Result, TermError};
use crate::security::SqlSecurity;
use arrow::array::{Array, Int64Array, StringArray};
use async_trait::async_trait;
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForeignKeyConstraint {
child_column: String,
parent_column: String,
allow_nulls: bool,
use_left_join: bool,
max_violations_reported: usize,
}
impl ForeignKeyConstraint {
pub fn new(child_column: impl Into<String>, parent_column: impl Into<String>) -> Self {
Self {
child_column: child_column.into(),
parent_column: parent_column.into(),
allow_nulls: false,
use_left_join: true,
max_violations_reported: 100,
}
}
pub fn allow_nulls(mut self, allow: bool) -> Self {
self.allow_nulls = allow;
self
}
pub fn use_left_join(mut self, use_left_join: bool) -> Self {
self.use_left_join = use_left_join;
self
}
pub fn max_violations_reported(mut self, max_violations: usize) -> Self {
self.max_violations_reported = max_violations;
self
}
pub fn child_column(&self) -> &str {
&self.child_column
}
pub fn parent_column(&self) -> &str {
&self.parent_column
}
fn parse_qualified_column(&self, qualified_column: &str) -> Result<(String, String)> {
let parts: Vec<&str> = qualified_column.split('.').collect();
if parts.len() != 2 {
return Err(TermError::constraint_evaluation(
"foreign_key",
format!(
"Foreign key column must be qualified (table.column): '{qualified_column}'"
),
));
}
let table = parts[0].to_string();
let column = parts[1].to_string();
SqlSecurity::validate_identifier(&table)?;
SqlSecurity::validate_identifier(&column)?;
Ok((table, column))
}
fn generate_left_join_query(
&self,
child_table: &str,
child_col: &str,
parent_table: &str,
parent_col: &str,
) -> Result<String> {
let null_condition = if self.allow_nulls {
format!("AND {child_table}.{child_col} IS NOT NULL")
} else {
String::new()
};
let sql = format!(
"SELECT
COUNT(*) as total_violations,
COUNT(DISTINCT {child_table}.{child_col}) as unique_violations
FROM {child_table}
LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
WHERE {parent_table}.{parent_col} IS NULL {null_condition}"
);
debug!("Generated foreign key validation query: {}", sql);
Ok(sql)
}
fn generate_violations_query(
&self,
child_table: &str,
child_col: &str,
parent_table: &str,
parent_col: &str,
) -> Result<String> {
if self.max_violations_reported == 0 {
return Ok(String::new());
}
let null_condition = if self.allow_nulls {
format!("AND {child_table}.{child_col} IS NOT NULL")
} else {
String::new()
};
let limit = self.max_violations_reported;
let sql = format!(
"SELECT DISTINCT {child_table}.{child_col} as violating_value
FROM {child_table}
LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
WHERE {parent_table}.{parent_col} IS NULL {null_condition}
LIMIT {limit}"
);
debug!("Generated violations query: {}", sql);
Ok(sql)
}
async fn collect_violation_examples_efficiently(
&self,
ctx: &SessionContext,
child_table: &str,
child_col: &str,
parent_table: &str,
parent_col: &str,
) -> Result<Vec<String>> {
if self.max_violations_reported == 0 {
return Ok(Vec::new());
}
let violations_sql =
self.generate_violations_query(child_table, child_col, parent_table, parent_col)?;
if violations_sql.is_empty() {
return Ok(Vec::new());
}
debug!("Executing foreign key violations query with memory-efficient collection");
let violations_df = ctx.sql(&violations_sql).await.map_err(|e| {
TermError::constraint_evaluation(
"foreign_key",
format!("Failed to execute violations query: {e}"),
)
})?;
let batches = violations_df.collect().await.map_err(|e| {
TermError::constraint_evaluation(
"foreign_key",
format!("Failed to collect violation examples: {e}"),
)
})?;
let mut violation_examples = Vec::with_capacity(self.max_violations_reported);
for batch in batches {
for i in 0..batch.num_rows() {
if violation_examples.len() >= self.max_violations_reported {
debug!(
"Reached max violations limit ({}), stopping collection",
self.max_violations_reported
);
return Ok(violation_examples);
}
if let Some(string_array) = batch.column(0).as_any().downcast_ref::<StringArray>() {
if !string_array.is_null(i) {
violation_examples.push(string_array.value(i).to_string());
}
} else if let Some(int64_array) =
batch.column(0).as_any().downcast_ref::<Int64Array>()
{
if !int64_array.is_null(i) {
violation_examples.push(int64_array.value(i).to_string());
}
} else if let Some(float64_array) = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Float64Array>()
{
if !float64_array.is_null(i) {
violation_examples.push(float64_array.value(i).to_string());
}
} else if let Some(int32_array) = batch
.column(0)
.as_any()
.downcast_ref::<arrow::array::Int32Array>()
{
if !int32_array.is_null(i) {
violation_examples.push(int32_array.value(i).to_string());
}
}
}
}
debug!(
"Collected {} foreign key violation examples",
violation_examples.len()
);
Ok(violation_examples)
}
}
#[async_trait]
impl Constraint for ForeignKeyConstraint {
#[instrument(skip(self, ctx), fields(constraint = "foreign_key"))]
async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
debug!(
"Evaluating foreign key constraint: {} -> {}",
self.child_column, self.parent_column
);
let (child_table, child_col) = self.parse_qualified_column(&self.child_column)?;
let (parent_table, parent_col) = self.parse_qualified_column(&self.parent_column)?;
let sql =
self.generate_left_join_query(&child_table, &child_col, &parent_table, &parent_col)?;
let df = ctx.sql(&sql).await.map_err(|e| {
TermError::constraint_evaluation(
"foreign_key",
format!("Foreign key validation query failed: {e}"),
)
})?;
let batches = df.collect().await.map_err(|e| {
TermError::constraint_evaluation(
"foreign_key",
format!("Failed to collect foreign key results: {e}"),
)
})?;
if batches.is_empty() || batches[0].num_rows() == 0 {
return Ok(ConstraintResult::success());
}
let batch = &batches[0];
let total_violations = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"foreign_key",
"Invalid total violations column type",
)
})?
.value(0);
let unique_violations = batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| {
TermError::constraint_evaluation(
"foreign_key",
"Invalid unique violations column type",
)
})?
.value(0);
if total_violations == 0 {
debug!("Foreign key constraint passed: no violations found");
return Ok(ConstraintResult::success());
}
let violation_examples = self
.collect_violation_examples_efficiently(
ctx,
&child_table,
&child_col,
&parent_table,
&parent_col,
)
.await?;
let message = if violation_examples.is_empty() {
format!(
"Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations})",
self.child_column, self.parent_column
)
} else {
let examples_str = if violation_examples.len() <= 5 {
violation_examples.join(", ")
} else {
format!(
"{}, ... ({} more)",
violation_examples[..5].join(", "),
violation_examples.len() - 5
)
};
format!(
"Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations}). Examples: [{examples_str}]",
self.child_column, self.parent_column
)
};
warn!("{}", message);
Ok(ConstraintResult {
status: ConstraintStatus::Failure,
metric: Some(total_violations as f64),
message: Some(message),
})
}
fn name(&self) -> &str {
"foreign_key"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::create_test_context;
#[tokio::test]
async fn test_foreign_key_constraint_success() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE customers_success (id BIGINT, name STRING)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO customers_success VALUES (1, 'Alice'), (2, 'Bob')")
.await?
.collect()
.await?;
ctx.sql("CREATE TABLE orders_success (id BIGINT, customer_id BIGINT, amount DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_success VALUES (1, 1, 100.0), (2, 2, 200.0)")
.await?
.collect()
.await?;
let constraint =
ForeignKeyConstraint::new("orders_success.customer_id", "customers_success.id");
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Success);
assert!(result.message.is_none());
Ok(())
}
#[tokio::test]
async fn test_foreign_key_constraint_violation() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE customers_violation (id BIGINT, name STRING)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO customers_violation VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')")
.await?
.collect()
.await?;
ctx.sql("CREATE TABLE orders_violation (id BIGINT, customer_id BIGINT, amount DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_violation VALUES (1, 1, 100.0), (2, 2, 200.0), (3, 999, 300.0), (4, 998, 400.0)")
.await?
.collect()
.await?;
let constraint =
ForeignKeyConstraint::new("orders_violation.customer_id", "customers_violation.id");
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Failure);
assert!(result.message.is_some());
assert_eq!(result.metric, Some(2.0));
let message = result.message.unwrap();
assert!(message.contains("Foreign key constraint violation"));
assert!(message.contains("2 values"));
assert!(message.contains("orders_violation.customer_id"));
assert!(message.contains("customers_violation.id"));
Ok(())
}
#[tokio::test]
async fn test_foreign_key_with_nulls_disallowed() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE customers_nulls_disallowed (id BIGINT, name STRING)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO customers_nulls_disallowed VALUES (1, 'Alice')")
.await?
.collect()
.await?;
ctx.sql(
"CREATE TABLE orders_nulls_disallowed (id BIGINT, customer_id BIGINT, amount DOUBLE)",
)
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_nulls_disallowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
.await?
.collect()
.await?;
let constraint = ForeignKeyConstraint::new(
"orders_nulls_disallowed.customer_id",
"customers_nulls_disallowed.id",
)
.allow_nulls(false);
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Failure);
Ok(())
}
#[tokio::test]
async fn test_foreign_key_with_nulls_allowed() -> Result<()> {
let ctx = create_test_context().await?;
ctx.sql("CREATE TABLE customers_nulls_allowed (id BIGINT, name STRING)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO customers_nulls_allowed VALUES (1, 'Alice')")
.await?
.collect()
.await?;
ctx.sql("CREATE TABLE orders_nulls_allowed (id BIGINT, customer_id BIGINT, amount DOUBLE)")
.await?
.collect()
.await?;
ctx.sql("INSERT INTO orders_nulls_allowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
.await?
.collect()
.await?;
let constraint = ForeignKeyConstraint::new(
"orders_nulls_allowed.customer_id",
"customers_nulls_allowed.id",
)
.allow_nulls(true);
let result = constraint.evaluate(&ctx).await?;
assert_eq!(result.status, ConstraintStatus::Success);
Ok(())
}
#[test]
fn test_parse_qualified_column() {
let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
let (table, column) = constraint
.parse_qualified_column("orders.customer_id")
.unwrap();
assert_eq!(table, "orders");
assert_eq!(column, "customer_id");
assert!(constraint.parse_qualified_column("invalid_column").is_err());
assert!(constraint.parse_qualified_column("too.many.parts").is_err());
}
#[test]
fn test_constraint_configuration() {
let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id")
.allow_nulls(true)
.use_left_join(false)
.max_violations_reported(50);
assert_eq!(constraint.child_column(), "orders.customer_id");
assert_eq!(constraint.parent_column(), "customers.id");
assert!(constraint.allow_nulls);
assert!(!constraint.use_left_join);
assert_eq!(constraint.max_violations_reported, 50);
}
#[test]
fn test_constraint_name() {
let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
assert_eq!(constraint.name(), "foreign_key");
}
#[test]
fn test_sql_generation() -> Result<()> {
let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
let sql =
constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
assert!(sql.contains("LEFT JOIN"));
assert!(sql.contains("orders.customer_id = customers.id"));
assert!(sql.contains("customers.id IS NULL"));
assert!(sql.contains("COUNT(*) as total_violations"));
Ok(())
}
#[test]
fn test_sql_generation_with_nulls_allowed() -> Result<()> {
let constraint =
ForeignKeyConstraint::new("orders.customer_id", "customers.id").allow_nulls(true);
let sql =
constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
assert!(sql.contains("AND orders.customer_id IS NOT NULL"));
Ok(())
}
}