use std::collections::HashMap;
use std::sync::Arc;
use uni_common::Value;
use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
use uni_cypher::locy_ast::ValidationMetric;
use uni_locy::{
CompiledValidate, FactRow, ValidationResult, accuracy, auc, brier_score, debiased_ece,
expected_calibration_error, log_loss,
};
const ECE_BINS: usize = 10;
#[derive(Debug)]
pub enum ValidateRuntimeError {
RuleNotDerived { rule_name: String },
EmptyDataset { rule_name: String },
JoinKeysMissing { rule_name: String, key: String },
}
impl std::fmt::Display for ValidateRuntimeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RuleNotDerived { rule_name } => write!(
f,
"VALIDATE: rule '{rule_name}' has no derived facts; \
ensure it appears in a stratum before VALIDATE"
),
Self::EmptyDataset { rule_name } => write!(
f,
"VALIDATE: rule '{rule_name}' produced no \
(prediction, label) pairs (empty join)"
),
Self::JoinKeysMissing { rule_name, key } => write!(
f,
"VALIDATE: rule '{rule_name}' KEY column '{key}' missing \
from either the rule's derived facts or the TARGET query rows"
),
}
}
}
impl std::error::Error for ValidateRuntimeError {}
pub fn validate_collection_query(
cmd: &CompiledValidate,
key_columns: &[String],
) -> uni_cypher::ast::Query {
let mut items: Vec<ReturnItem> = Vec::with_capacity(key_columns.len() + 1);
for col in key_columns {
items.push(ReturnItem::Expr {
expr: Expr::Variable(col.clone()),
alias: Some(col.clone()),
source_text: None,
});
}
items.push(ReturnItem::Expr {
expr: cmd.target_expr.clone(),
alias: Some("__validate_target".to_string()),
source_text: None,
});
let stmt = Statement {
clauses: vec![
Clause::Match(MatchClause {
optional: false,
pattern: cmd.pattern.clone(),
where_clause: cmd.where_expr.clone(),
for_update: false,
}),
Clause::Return(ReturnClause {
distinct: false,
items,
order_by: None,
skip: None,
limit: None,
}),
],
};
uni_cypher::ast::Query::Single(stmt)
}
fn target_to_label(v: Option<&Value>) -> bool {
match v {
Some(Value::Bool(b)) => *b,
Some(Value::Int(i)) => *i != 0,
Some(Value::Float(f)) => *f != 0.0,
Some(Value::String(s)) => !s.is_empty(),
Some(Value::Null) | None => false,
Some(_) => false,
}
}
fn canonical_key(v: &Value) -> String {
match v {
Value::Node(n) => format!("v:{}", n.vid),
Value::Edge(e) => format!("e:{}", e.eid),
Value::Int(i) => format!("v:{i}"),
Value::Float(f) => format!("f:{f}"),
Value::Bool(b) => format!("b:{b}"),
Value::String(s) => format!("s:{s}"),
Value::Null => "null".into(),
other => format!("{other:?}"),
}
}
fn join_key(row: &FactRow, key_columns: &[String]) -> Option<String> {
let mut parts = Vec::with_capacity(key_columns.len());
for col in key_columns {
let v = row.get(col)?;
parts.push(canonical_key(v));
}
Some(parts.join("|"))
}
pub fn run_validate(
cmd: &CompiledValidate,
rule_key_columns: &[String],
rule_facts: &[FactRow],
target_rows: Vec<FactRow>,
) -> Result<ValidationResult, ValidateRuntimeError> {
if rule_facts.is_empty() {
return Err(ValidateRuntimeError::RuleNotDerived {
rule_name: cmd.rule_name.clone(),
});
}
let mut by_key: HashMap<String, f64> = HashMap::with_capacity(rule_facts.len());
for row in rule_facts {
let key = join_key(row, rule_key_columns).ok_or_else(|| {
ValidateRuntimeError::JoinKeysMissing {
rule_name: cmd.rule_name.clone(),
key: rule_key_columns.join(","),
}
})?;
let prob = match row.get(&cmd.prob_column) {
Some(Value::Float(f)) => *f,
Some(Value::Int(i)) => *i as f64,
_ => continue,
};
by_key.insert(key, prob.clamp(0.0, 1.0));
}
let mut preds: Vec<f64> = Vec::new();
let mut labels: Vec<bool> = Vec::new();
for row in &target_rows {
let key = join_key(row, rule_key_columns).ok_or_else(|| {
ValidateRuntimeError::JoinKeysMissing {
rule_name: cmd.rule_name.clone(),
key: rule_key_columns.join(","),
}
})?;
if let Some(&pred) = by_key.get(&key) {
preds.push(pred);
labels.push(target_to_label(row.get("__validate_target")));
}
}
if preds.is_empty() {
let rule_sample = rule_facts
.first()
.map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
let target_sample = target_rows
.first()
.map(|r| r.keys().cloned().collect::<Vec<_>>().join(","));
tracing::warn!(
"VALIDATE empty join for rule '{}'. rule_facts={}, target_rows={}, \
rule_cols={:?}, target_cols={:?}, key_columns={:?}, \
rule_key_sample={:?}, target_key_sample={:?}",
cmd.rule_name,
rule_facts.len(),
target_rows.len(),
rule_sample,
target_sample,
rule_key_columns,
rule_facts
.first()
.and_then(|r| r.get(&rule_key_columns[0]).cloned()),
target_rows
.first()
.and_then(|r| r.get(&rule_key_columns[0]).cloned()),
);
return Err(ValidateRuntimeError::EmptyDataset {
rule_name: cmd.rule_name.clone(),
});
}
let mut metrics_out: Vec<(ValidationMetric, f64)> = Vec::with_capacity(cmd.metrics.len());
for m in &cmd.metrics {
let v = match m {
ValidationMetric::BrierScore => brier_score(&preds, &labels),
ValidationMetric::LogLoss => log_loss(&preds, &labels),
ValidationMetric::Ece => expected_calibration_error(&preds, &labels, ECE_BINS),
ValidationMetric::DebiasedEce => debiased_ece(&preds, &labels, ECE_BINS),
ValidationMetric::Accuracy => accuracy(&preds, &labels),
ValidationMetric::Auc => auc(&preds, &labels),
};
metrics_out.push((*m, v));
}
Ok(ValidationResult {
rule_name: cmd.rule_name.clone(),
prob_column: cmd.prob_column.clone(),
n_samples: preds.len(),
metrics: metrics_out,
})
}
pub fn into_arc_error(e: ValidateRuntimeError) -> Arc<dyn std::error::Error + Send + Sync> {
Arc::new(e)
}
#[cfg(test)]
mod tests {
use super::*;
use uni_cypher::ast::Pattern;
fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect()
}
fn dummy_cmd() -> CompiledValidate {
CompiledValidate {
rule_name: "risky".into(),
pattern: Pattern { paths: vec![] },
where_expr: None,
target_expr: Expr::Variable("label".into()),
metrics: vec![
ValidationMetric::BrierScore,
ValidationMetric::Accuracy,
ValidationMetric::Auc,
],
prob_column: "risk".into(),
}
}
#[test]
fn validate_joins_facts_with_target_rows() {
let cmd = dummy_cmd();
let rule_facts = vec![
fact_row(&[("s", Value::Int(1)), ("risk", Value::Float(0.9))]),
fact_row(&[("s", Value::Int(2)), ("risk", Value::Float(0.1))]),
fact_row(&[("s", Value::Int(3)), ("risk", Value::Float(0.8))]),
fact_row(&[("s", Value::Int(4)), ("risk", Value::Float(0.2))]),
];
let target_rows = vec![
fact_row(&[
("s", Value::Int(1)),
("__validate_target", Value::Bool(true)),
]),
fact_row(&[
("s", Value::Int(2)),
("__validate_target", Value::Bool(false)),
]),
fact_row(&[
("s", Value::Int(3)),
("__validate_target", Value::Bool(true)),
]),
fact_row(&[
("s", Value::Int(4)),
("__validate_target", Value::Bool(false)),
]),
];
let res = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap();
assert_eq!(res.n_samples, 4);
let brier = res.metric(ValidationMetric::BrierScore).unwrap();
assert!(brier < 0.05, "expected small Brier, got {brier}");
let acc = res.metric(ValidationMetric::Accuracy).unwrap();
assert_eq!(acc, 1.0);
let a = res.metric(ValidationMetric::Auc).unwrap();
assert!((a - 1.0).abs() < 1e-12);
}
#[test]
fn validate_drops_unjoinable_rows() {
let cmd = dummy_cmd();
let rule_facts = vec![fact_row(&[
("s", Value::Int(1)),
("risk", Value::Float(0.9)),
])];
let target_rows = vec![fact_row(&[
("s", Value::Int(99)),
("__validate_target", Value::Bool(true)),
])];
let err = run_validate(&cmd, &["s".to_string()], &rule_facts, target_rows).unwrap_err();
assert!(matches!(err, ValidateRuntimeError::EmptyDataset { .. }));
}
#[test]
fn validate_errors_on_no_rule_facts() {
let cmd = dummy_cmd();
let err = run_validate(&cmd, &["s".to_string()], &[], vec![]).unwrap_err();
assert!(matches!(err, ValidateRuntimeError::RuleNotDerived { .. }));
}
}