use crate::constraints::{
CrossTableSumConstraint, ForeignKeyConstraint, JoinCoverageConstraint,
TemporalOrderingConstraint,
};
use crate::core::{Check, Level};
use std::sync::Arc;
#[derive(Debug)]
pub struct MultiTableCheck {
name: String,
level: Level,
description: Option<String>,
current_context: Option<TableContext>,
constraints: Vec<Arc<dyn crate::core::Constraint>>,
}
#[derive(Debug, Clone)]
struct TableContext {
left_table: String,
right_table: Option<String>,
join_columns: Vec<(String, String)>,
group_by_columns: Vec<String>,
}
impl MultiTableCheck {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
level: Level::Error,
description: None,
current_context: None,
constraints: Vec::new(),
}
}
pub fn level(mut self, level: Level) -> Self {
self.level = level;
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn validate_tables(
mut self,
left_table: impl Into<String>,
right_table: impl Into<String>,
) -> Self {
self.current_context = Some(TableContext {
left_table: left_table.into(),
right_table: Some(right_table.into()),
join_columns: Vec::new(),
group_by_columns: Vec::new(),
});
self
}
pub fn and_validate_tables(
self,
left_table: impl Into<String>,
right_table: impl Into<String>,
) -> Self {
self.validate_tables(left_table, right_table)
}
pub fn validate_temporal(mut self, table: impl Into<String>) -> Self {
self.current_context = Some(TableContext {
left_table: table.into(),
right_table: None,
join_columns: Vec::new(),
group_by_columns: Vec::new(),
});
self
}
pub fn and_validate_temporal(self, table: impl Into<String>) -> Self {
self.validate_temporal(table)
}
pub fn join_on(
mut self,
left_column: impl Into<String>,
right_column: impl Into<String>,
) -> Self {
if let Some(ref mut ctx) = self.current_context {
ctx.join_columns
.push((left_column.into(), right_column.into()));
}
self
}
pub fn join_on_multiple(mut self, columns: Vec<(&str, &str)>) -> Self {
if let Some(ref mut ctx) = self.current_context {
for (left, right) in columns {
ctx.join_columns.push((left.to_string(), right.to_string()));
}
}
self
}
pub fn group_by(mut self, column: impl Into<String>) -> Self {
if let Some(ref mut ctx) = self.current_context {
ctx.group_by_columns.push(column.into());
}
self
}
pub fn group_by_multiple(mut self, columns: Vec<impl Into<String>>) -> Self {
if let Some(ref mut ctx) = self.current_context {
ctx.group_by_columns
.extend(columns.into_iter().map(Into::into));
}
self
}
pub fn ensure_referential_integrity(mut self) -> Self {
if let Some(ref ctx) = self.current_context {
if let (Some(right_table), Some((left_col, right_col))) =
(&ctx.right_table, ctx.join_columns.first())
{
let child_column = format!("{}.{left_col}", ctx.left_table);
let parent_column = format!("{right_table}.{right_col}");
self.constraints.push(Arc::new(ForeignKeyConstraint::new(
child_column,
parent_column,
)));
}
}
self
}
pub fn expect_join_coverage(mut self, min_coverage: f64) -> Self {
if let Some(ref ctx) = self.current_context {
if let Some(ref right_table) = ctx.right_table {
let mut constraint = JoinCoverageConstraint::new(&ctx.left_table, right_table);
if ctx.join_columns.len() == 1 {
let (left_col, right_col) = &ctx.join_columns[0];
constraint = constraint.on(left_col, right_col);
} else if ctx.join_columns.len() > 1 {
let cols: Vec<_> = ctx
.join_columns
.iter()
.map(|(l, r)| (l.as_str(), r.as_str()))
.collect();
constraint = constraint.on_multiple(cols);
}
constraint = constraint.expect_match_rate(min_coverage);
self.constraints.push(Arc::new(constraint));
}
}
self
}
pub fn ensure_sum_consistency(
mut self,
left_column: impl Into<String>,
right_column: impl Into<String>,
) -> Self {
if let Some(ref ctx) = self.current_context {
if let Some(ref right_table) = ctx.right_table {
let left_col = format!("{}.{}", ctx.left_table, left_column.into());
let right_col = format!("{right_table}.{}", right_column.into());
let mut constraint = CrossTableSumConstraint::new(left_col, right_col);
if !ctx.group_by_columns.is_empty() {
constraint = constraint.group_by(ctx.group_by_columns.clone());
}
self.constraints.push(Arc::new(constraint));
}
}
self
}
pub fn with_tolerance(self, _tolerance: f64) -> Self {
if let Some(_last) = self.constraints.last() {
}
self
}
pub fn ensure_ordering(
mut self,
before_column: impl Into<String>,
after_column: impl Into<String>,
) -> Self {
if let Some(ref ctx) = self.current_context {
let constraint = TemporalOrderingConstraint::new(&ctx.left_table)
.before_after(before_column, after_column);
self.constraints.push(Arc::new(constraint));
}
self
}
pub fn within_business_hours(
mut self,
start_time: impl Into<String>,
end_time: impl Into<String>,
) -> Self {
if let Some(ref ctx) = self.current_context {
let constraint = TemporalOrderingConstraint::new(&ctx.left_table).business_hours(
"timestamp",
start_time,
end_time,
);
self.constraints.push(Arc::new(constraint));
}
self
}
pub fn build(self) -> Check {
let mut builder = Check::builder(self.name).level(self.level);
if let Some(desc) = self.description {
builder = builder.description(desc);
}
for constraint in self.constraints {
builder = builder.arc_constraint(constraint);
}
builder.build()
}
}
pub trait CheckMultiTableExt {
fn multi_table(name: impl Into<String>) -> MultiTableCheck;
}
impl CheckMultiTableExt for Check {
fn multi_table(name: impl Into<String>) -> MultiTableCheck {
MultiTableCheck::new(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fluent_builder_basic() {
let check = MultiTableCheck::new("test_validation")
.level(Level::Warning)
.description("Test multi-table validation")
.validate_tables("orders", "customers")
.join_on("customer_id", "id")
.ensure_referential_integrity()
.build();
assert_eq!(check.name(), "test_validation");
assert_eq!(check.level(), Level::Warning);
assert_eq!(check.description(), Some("Test multi-table validation"));
assert_eq!(check.constraints().len(), 1);
}
#[test]
fn test_fluent_builder_complex() {
let check = MultiTableCheck::new("complex_validation")
.validate_tables("orders", "customers")
.join_on("customer_id", "id")
.ensure_referential_integrity()
.expect_join_coverage(0.95)
.and_validate_tables("orders", "payments")
.join_on("order_id", "order_id")
.ensure_sum_consistency("total", "amount")
.group_by("customer_id")
.and_validate_temporal("events")
.ensure_ordering("created_at", "processed_at")
.build();
assert_eq!(check.name(), "complex_validation");
assert_eq!(check.constraints().len(), 4);
}
#[test]
fn test_composite_keys() {
let check = MultiTableCheck::new("composite_key_validation")
.validate_tables("order_items", "products")
.join_on_multiple(vec![("product_id", "id"), ("variant", "variant_code")])
.ensure_referential_integrity()
.build();
assert_eq!(check.constraints().len(), 1);
}
}