use crate::analyzers::suggestions::{ConstraintParameter, SuggestionPriority};
use crate::core::Check;
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{info, instrument};
pub struct SchemaAnalyzer<'a> {
ctx: &'a SessionContext,
naming_patterns: NamingPatterns,
}
#[derive(Debug, Clone)]
struct NamingPatterns {
foreign_key_suffixes: Vec<String>,
temporal_patterns: Vec<String>,
amount_patterns: Vec<String>,
quantity_patterns: Vec<String>,
}
impl Default for NamingPatterns {
fn default() -> Self {
Self {
foreign_key_suffixes: vec![
"_id".to_string(),
"_key".to_string(),
"_fk".to_string(),
"_ref".to_string(),
],
temporal_patterns: vec![
"_at".to_string(),
"_date".to_string(),
"_time".to_string(),
"_timestamp".to_string(),
"created".to_string(),
"updated".to_string(),
"modified".to_string(),
"processed".to_string(),
"completed".to_string(),
],
amount_patterns: vec![
"amount".to_string(),
"total".to_string(),
"price".to_string(),
"cost".to_string(),
"payment".to_string(),
"revenue".to_string(),
"balance".to_string(),
],
quantity_patterns: vec![
"quantity".to_string(),
"qty".to_string(),
"count".to_string(),
"units".to_string(),
"items".to_string(),
],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossTableSuggestion {
pub constraint_type: String,
pub tables: Vec<String>,
pub columns: HashMap<String, Vec<String>>,
pub confidence: f64,
pub rationale: String,
pub priority: SuggestionPriority,
pub parameters: HashMap<String, ConstraintParameter>,
}
impl<'a> SchemaAnalyzer<'a> {
pub fn new(ctx: &'a SessionContext) -> Self {
Self {
ctx,
naming_patterns: NamingPatterns::default(),
}
}
#[instrument(skip(self))]
pub async fn analyze_all_tables(&self) -> crate::error::Result<Vec<CrossTableSuggestion>> {
let mut suggestions = Vec::new();
let catalog = self.ctx.catalog("datafusion").unwrap();
let schema = catalog.schema("public").unwrap();
let table_names: Vec<String> = schema.table_names();
info!(
"Analyzing {} tables for constraint suggestions",
table_names.len()
);
let mut table_schemas = HashMap::new();
for table_name in &table_names {
if let Ok(Some(table)) = schema.table(table_name).await {
let schema = table.schema();
table_schemas.insert(table_name.clone(), schema);
}
}
suggestions.extend(self.analyze_foreign_keys(&table_schemas));
suggestions.extend(self.analyze_temporal_constraints(&table_schemas));
suggestions.extend(self.analyze_financial_consistency(&table_schemas));
suggestions.extend(self.analyze_join_coverage(&table_schemas));
suggestions.sort_by(|a, b| match (&a.priority, &b.priority) {
(SuggestionPriority::Critical, SuggestionPriority::Critical) => {
b.confidence.partial_cmp(&a.confidence).unwrap()
}
(SuggestionPriority::Critical, _) => std::cmp::Ordering::Less,
(_, SuggestionPriority::Critical) => std::cmp::Ordering::Greater,
_ => b.confidence.partial_cmp(&a.confidence).unwrap(),
});
Ok(suggestions)
}
fn analyze_foreign_keys(
&self,
schemas: &HashMap<String, Arc<Schema>>,
) -> Vec<CrossTableSuggestion> {
let mut suggestions = Vec::new();
for (table_name, schema) in schemas {
for field in schema.fields() {
if let Some(referenced_table) = self.detect_foreign_key(field.name(), schemas) {
if let Some(ref_schema) = schemas.get(&referenced_table) {
let ref_column =
self.infer_primary_key_column(&referenced_table, ref_schema);
let mut columns = HashMap::new();
columns.insert(table_name.clone(), vec![field.name().to_string()]);
columns.insert(referenced_table.clone(), vec![ref_column.clone()]);
suggestions.push(CrossTableSuggestion {
constraint_type: "foreign_key".to_string(),
tables: vec![table_name.clone(), referenced_table.clone()],
columns,
confidence: self.calculate_fk_confidence(field.name(), &referenced_table),
rationale: format!(
"Column '{}' in '{table_name}' appears to reference '{referenced_table}' based on naming convention",
field.name()
),
priority: SuggestionPriority::High,
parameters: HashMap::new(),
});
}
}
}
}
suggestions
}
fn detect_foreign_key(
&self,
column_name: &str,
schemas: &HashMap<String, Arc<Schema>>,
) -> Option<String> {
for suffix in &self.naming_patterns.foreign_key_suffixes {
if column_name.ends_with(suffix) {
let base_name = &column_name[..column_name.len() - suffix.len()];
for table_name in schemas.keys() {
if self.matches_table_name(base_name, table_name) {
return Some(table_name.clone());
}
}
}
}
None
}
fn matches_table_name(&self, base_name: &str, table_name: &str) -> bool {
if base_name == table_name {
return true;
}
if format!("{base_name}s") == table_name {
return true;
}
if base_name == format!("{table_name}s") {
return true;
}
if base_name.ends_with('y')
&& table_name == format!("{}ies", &base_name[..base_name.len() - 1])
{
return true;
}
false
}
fn infer_primary_key_column(&self, table_name: &str, schema: &Arc<Schema>) -> String {
let table_id = format!("{table_name}_id");
let table_key = format!("{table_name}_key");
let common_pk_names = vec!["id", table_id.as_str(), "key", table_key.as_str()];
for field in schema.fields() {
for pk_name in &common_pk_names {
if field.name().to_lowercase() == pk_name.to_lowercase() {
return field.name().to_string();
}
}
}
"id".to_string()
}
fn calculate_fk_confidence(&self, column_name: &str, referenced_table: &str) -> f64 {
let mut confidence: f64 = 0.5;
if column_name.contains(referenced_table)
|| column_name.contains(&referenced_table[..referenced_table.len().saturating_sub(1)])
{
confidence += 0.3;
}
if column_name.ends_with("_id") {
confidence += 0.2;
}
confidence.min(1.0)
}
fn analyze_temporal_constraints(
&self,
schemas: &HashMap<String, Arc<Schema>>,
) -> Vec<CrossTableSuggestion> {
let mut suggestions = Vec::new();
for (table_name, schema) in schemas {
let temporal_columns = self.find_temporal_columns(schema);
if temporal_columns.len() >= 2 {
for i in 0..temporal_columns.len() {
for j in i + 1..temporal_columns.len() {
let col1 = &temporal_columns[i];
let col2 = &temporal_columns[j];
let (before, after) = self.infer_temporal_order(col1, col2);
let mut columns = HashMap::new();
columns.insert(table_name.clone(), vec![before.clone(), after.clone()]);
let mut parameters = HashMap::new();
parameters.insert(
"validation_type".to_string(),
ConstraintParameter::String("before_after".to_string()),
);
suggestions.push(CrossTableSuggestion {
constraint_type: "temporal_ordering".to_string(),
tables: vec![table_name.clone()],
columns,
confidence: 0.8,
rationale: format!(
"Columns '{before}' and '{after}' appear to have a temporal relationship"
),
priority: SuggestionPriority::Medium,
parameters,
});
}
}
}
for col in &temporal_columns {
if col.contains("transaction") || col.contains("order") || col.contains("payment") {
let mut columns = HashMap::new();
columns.insert(table_name.clone(), vec![col.clone()]);
let mut parameters = HashMap::new();
parameters.insert(
"start_time".to_string(),
ConstraintParameter::String("09:00".to_string()),
);
parameters.insert(
"end_time".to_string(),
ConstraintParameter::String("17:00".to_string()),
);
suggestions.push(CrossTableSuggestion {
constraint_type: "business_hours".to_string(),
tables: vec![table_name.clone()],
columns,
confidence: 0.6,
rationale: format!(
"Column '{col}' may benefit from business hours validation"
),
priority: SuggestionPriority::Low,
parameters,
});
}
}
}
suggestions
}
fn find_temporal_columns(&self, schema: &Arc<Schema>) -> Vec<String> {
let mut temporal_columns = Vec::new();
for field in schema.fields() {
let is_temporal_type = matches!(
field.data_type(),
DataType::Date32
| DataType::Date64
| DataType::Timestamp(_, _)
| DataType::Time32(_)
| DataType::Time64(_)
);
let matches_pattern = self
.naming_patterns
.temporal_patterns
.iter()
.any(|pattern| field.name().to_lowercase().contains(pattern));
if is_temporal_type || matches_pattern {
temporal_columns.push(field.name().to_string());
}
}
temporal_columns
}
fn infer_temporal_order(&self, col1: &str, col2: &str) -> (String, String) {
let order_keywords = vec![
("created", 0),
("started", 1),
("updated", 2),
("modified", 2),
("processed", 3),
("completed", 4),
("finished", 4),
("ended", 5),
];
let get_order = |col: &str| -> i32 {
for (keyword, order) in &order_keywords {
if col.to_lowercase().contains(keyword) {
return *order;
}
}
100 };
let order1 = get_order(col1);
let order2 = get_order(col2);
if order1 <= order2 {
(col1.to_string(), col2.to_string())
} else {
(col2.to_string(), col1.to_string())
}
}
fn analyze_financial_consistency(
&self,
schemas: &HashMap<String, Arc<Schema>>,
) -> Vec<CrossTableSuggestion> {
let mut suggestions = Vec::new();
let mut amount_columns: HashMap<String, Vec<String>> = HashMap::new();
let mut quantity_columns: HashMap<String, Vec<String>> = HashMap::new();
for (table_name, schema) in schemas {
for field in schema.fields() {
if self.is_amount_column(field.name(), field.data_type()) {
amount_columns
.entry(table_name.clone())
.or_default()
.push(field.name().to_string());
}
if self.is_quantity_column(field.name(), field.data_type()) {
quantity_columns
.entry(table_name.clone())
.or_default()
.push(field.name().to_string());
}
}
}
for (table1, cols1) in &amount_columns {
for (table2, cols2) in &amount_columns {
if table1 < table2 && self.are_tables_related(table1, table2, schemas) {
for col1 in cols1 {
for col2 in cols2 {
if self.are_columns_likely_related(col1, col2) {
let mut columns = HashMap::new();
columns.insert(table1.clone(), vec![col1.clone()]);
columns.insert(table2.clone(), vec![col2.clone()]);
let mut parameters = HashMap::new();
parameters.insert(
"tolerance".to_string(),
ConstraintParameter::Float(0.01),
);
suggestions.push(CrossTableSuggestion {
constraint_type: "cross_table_sum".to_string(),
tables: vec![table1.clone(), table2.clone()],
columns,
confidence: 0.7,
rationale: format!(
"Financial columns '{table1}.{col1}' and '{table2}.{col2}' may need sum consistency validation"
),
priority: SuggestionPriority::High,
parameters,
});
}
}
}
}
}
}
suggestions
}
fn is_amount_column(&self, name: &str, data_type: &DataType) -> bool {
let is_numeric = matches!(
data_type,
DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
);
if !is_numeric {
return false;
}
self.naming_patterns
.amount_patterns
.iter()
.any(|pattern| name.to_lowercase().contains(pattern))
}
fn is_quantity_column(&self, name: &str, data_type: &DataType) -> bool {
let is_numeric = matches!(
data_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64
);
if !is_numeric {
return false;
}
self.naming_patterns
.quantity_patterns
.iter()
.any(|pattern| name.to_lowercase().contains(pattern))
}
fn are_tables_related(
&self,
table1: &str,
table2: &str,
schemas: &HashMap<String, Arc<Schema>>,
) -> bool {
if let Some(schema1) = schemas.get(table1) {
for field in schema1.fields() {
if let Some(ref_table) = self.detect_foreign_key(field.name(), schemas) {
if ref_table == table2 {
return true;
}
}
}
}
if let Some(schema2) = schemas.get(table2) {
for field in schema2.fields() {
if let Some(ref_table) = self.detect_foreign_key(field.name(), schemas) {
if ref_table == table1 {
return true;
}
}
}
}
table1.contains(table2) || table2.contains(table1)
}
fn are_columns_likely_related(&self, col1: &str, col2: &str) -> bool {
if col1 == col2 {
return true;
}
let keywords = vec!["total", "amount", "sum", "payment", "cost", "price"];
for keyword in keywords {
if col1.contains(keyword) && col2.contains(keyword) {
return true;
}
}
false
}
fn analyze_join_coverage(
&self,
schemas: &HashMap<String, Arc<Schema>>,
) -> Vec<CrossTableSuggestion> {
let mut suggestions = Vec::new();
for (table_name, schema) in schemas {
for field in schema.fields() {
if let Some(referenced_table) = self.detect_foreign_key(field.name(), schemas) {
let mut columns = HashMap::new();
columns.insert(table_name.clone(), vec![field.name().to_string()]);
columns.insert(referenced_table.clone(), vec!["id".to_string()]);
let mut parameters = HashMap::new();
parameters.insert(
"expected_coverage".to_string(),
ConstraintParameter::Float(0.95),
);
suggestions.push(CrossTableSuggestion {
constraint_type: "join_coverage".to_string(),
tables: vec![table_name.clone(), referenced_table.clone()],
columns,
confidence: 0.75,
rationale: format!(
"Join between '{table_name}' and '{referenced_table}' should have high coverage for data quality"
),
priority: SuggestionPriority::Medium,
parameters,
});
}
}
}
suggestions
}
pub fn suggestions_to_check(
&self,
suggestions: &[CrossTableSuggestion],
check_name: &str,
) -> Check {
let mut builder = Check::builder(check_name);
for suggestion in suggestions {
match suggestion.constraint_type.as_str() {
"foreign_key" => {
if suggestion.tables.len() == 2 {
let child_col = format!(
"{}.{}",
suggestion.tables[0], suggestion.columns[&suggestion.tables[0]][0]
);
let parent_col = format!(
"{}.{}",
suggestion.tables[1], suggestion.columns[&suggestion.tables[1]][0]
);
builder = builder.foreign_key(child_col, parent_col);
}
}
"cross_table_sum" => {
if suggestion.tables.len() == 2 {
let left_col = format!(
"{}.{}",
suggestion.tables[0], suggestion.columns[&suggestion.tables[0]][0]
);
let right_col = format!(
"{}.{}",
suggestion.tables[1], suggestion.columns[&suggestion.tables[1]][0]
);
builder = builder.cross_table_sum(left_col, right_col);
}
}
"join_coverage" => {
if suggestion.tables.len() == 2 {
builder =
builder.join_coverage(&suggestion.tables[0], &suggestion.tables[1]);
}
}
"temporal_ordering" => {
if !suggestion.tables.is_empty() {
builder = builder.temporal_ordering(&suggestion.tables[0]);
}
}
_ => {}
}
}
builder.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::datatypes::{Field, Schema as ArrowSchema};
#[test]
fn test_foreign_key_detection() {
let ctx = SessionContext::new();
let analyzer = SchemaAnalyzer::new(&ctx);
let mut schemas = HashMap::new();
let orders_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("customer_id", DataType::Int64, false),
Field::new("total", DataType::Float64, false),
]));
schemas.insert("orders".to_string(), orders_schema);
let customers_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
schemas.insert("customers".to_string(), customers_schema);
let suggestions = analyzer.analyze_foreign_keys(&schemas);
assert!(!suggestions.is_empty());
assert_eq!(suggestions[0].constraint_type, "foreign_key");
assert!(suggestions[0].tables.contains(&"orders".to_string()));
assert!(suggestions[0].tables.contains(&"customers".to_string()));
}
#[test]
fn test_temporal_column_detection() {
let ctx = SessionContext::new();
let analyzer = SchemaAnalyzer::new(&ctx);
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new(
"created_at",
DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None),
false,
),
Field::new(
"updated_at",
DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None),
false,
),
Field::new("name", DataType::Utf8, false),
]));
let temporal_cols = analyzer.find_temporal_columns(&schema);
assert_eq!(temporal_cols.len(), 2);
assert!(temporal_cols.contains(&"created_at".to_string()));
assert!(temporal_cols.contains(&"updated_at".to_string()));
}
#[test]
fn test_temporal_ordering() {
let ctx = SessionContext::new();
let analyzer = SchemaAnalyzer::new(&ctx);
let (before, after) = analyzer.infer_temporal_order("created_at", "updated_at");
assert_eq!(before, "created_at");
assert_eq!(after, "updated_at");
let (before, after) = analyzer.infer_temporal_order("processed_at", "created_at");
assert_eq!(before, "created_at");
assert_eq!(after, "processed_at");
}
#[test]
fn test_amount_column_detection() {
let ctx = SessionContext::new();
let analyzer = SchemaAnalyzer::new(&ctx);
assert!(analyzer.is_amount_column("total_amount", &DataType::Float64));
assert!(analyzer.is_amount_column("price", &DataType::Decimal128(10, 2)));
assert!(!analyzer.is_amount_column("customer_id", &DataType::Int64));
assert!(!analyzer.is_amount_column("total", &DataType::Utf8));
}
}