use std::collections::HashMap;
use std::sync::Arc;
use uni_common::Value;
use uni_cypher::ast::{Clause, Expr, MatchClause, ReturnClause, ReturnItem, Statement};
use uni_cypher::locy_ast::CalibrationMethod;
use uni_locy::{
BetaFitter, CalibrationMethodKind, CalibrationResult, CalibratorFitter, ClassifierRegistry,
ClassifyInput, CompiledCalibrate, CompiledModel, FactRow, FeatureValue, IdentityCalibrator,
IsotonicFitter, NeuralClassifier, PlattFitter, TemperatureFitter, brier_score,
expected_calibration_error,
};
const ECE_BINS: usize = 10;
#[derive(Debug)]
pub enum CalibrateRuntimeError {
ClassifierMissing {
model_name: String,
},
UnknownModelInCatalog {
model_name: String,
},
EmptyDataset {
model_name: String,
},
InsufficientData {
model_name: String,
train: usize,
holdout: usize,
},
FitFailure {
model_name: String,
message: String,
},
}
impl std::fmt::Display for CalibrateRuntimeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ClassifierMissing { model_name } => write!(
f,
"CALIBRATE: classifier '{}' not registered; \
add it to LocyConfig::classifier_registry before evaluating",
model_name
),
Self::UnknownModelInCatalog { model_name } => write!(
f,
"CALIBRATE: model '{}' not in CompiledProgram.model_catalog \
(compiler should have rejected this earlier)",
model_name
),
Self::EmptyDataset { model_name } => write!(
f,
"CALIBRATE: model '{}' MATCH pattern produced zero rows",
model_name
),
Self::InsufficientData {
model_name,
train,
holdout,
} => write!(
f,
"CALIBRATE: model '{model_name}' needs at least 1 sample in each \
split (got train={train}, holdout={holdout}); increase the data \
set or pick a different HOLDOUT fraction"
),
Self::FitFailure {
model_name,
message,
} => {
write!(f, "CALIBRATE: model '{model_name}' fitter error: {message}")
}
}
}
}
impl std::error::Error for CalibrateRuntimeError {}
fn build_collection_query(
cmd: &CompiledCalibrate,
model: &CompiledModel,
) -> uni_cypher::ast::Query {
let mut items: Vec<ReturnItem> = Vec::with_capacity(model.inputs.len() + 1);
for binding in &model.inputs {
items.push(ReturnItem::Expr {
expr: Expr::Variable(binding.variable.clone()),
alias: Some(binding.variable.clone()),
source_text: None,
});
}
items.push(ReturnItem::Expr {
expr: cmd.target_expr.clone(),
alias: Some("__calibrate_target".to_string()),
source_text: None,
});
let stmt = Statement {
clauses: vec![
Clause::Match(MatchClause {
optional: false,
pattern: cmd.pattern.clone(),
where_clause: cmd.where_expr.clone(),
for_update: false,
}),
Clause::Return(ReturnClause {
distinct: false,
items,
order_by: None,
skip: None,
limit: None,
}),
],
};
uni_cypher::ast::Query::Single(stmt)
}
fn row_to_feature(row: &FactRow, name: &str) -> FeatureValue {
match row.get(name) {
Some(Value::Float(f)) => FeatureValue::Float(*f),
Some(Value::Int(i)) => FeatureValue::Int(*i),
Some(Value::String(s)) => FeatureValue::String(s.clone()),
Some(Value::Bool(b)) => FeatureValue::Bool(*b),
Some(Value::Null) | None => FeatureValue::Null,
Some(_) => FeatureValue::Null,
}
}
fn target_to_label(v: Option<&Value>) -> bool {
match v {
Some(Value::Bool(b)) => *b,
Some(Value::Int(i)) => *i != 0,
Some(Value::Float(f)) => *f != 0.0,
Some(Value::String(s)) => !s.is_empty(),
Some(Value::Null) | None => false,
Some(_) => false,
}
}
fn fit_method(
method: CalibrationMethod,
preds: &[f64],
labels: &[bool],
model_name: &str,
) -> Result<Arc<dyn uni_locy::Calibrator>, CalibrateRuntimeError> {
let result = match method {
CalibrationMethod::PlattScaling => PlattFitter.fit(preds, labels),
CalibrationMethod::IsotonicRegression => IsotonicFitter.fit(preds, labels),
CalibrationMethod::TemperatureScaling => TemperatureFitter.fit(preds, labels),
CalibrationMethod::BetaCalibration => BetaFitter.fit(preds, labels),
CalibrationMethod::Conformal { alpha } => {
uni_locy::calibration::ConformalFitter { alpha }.fit(preds, labels)
}
CalibrationMethod::None => {
Ok(Arc::new(IdentityCalibrator) as Arc<dyn uni_locy::Calibrator>)
}
CalibrationMethod::Dirichlet => {
Err(uni_locy::calibration::CalibrationError::NumericIssue(
"Dirichlet is multi-class; the binary CALIBRATE statement \
cannot fit it. Use `uni_locy::calibration::DirichletFitter` \
directly until the multi-class CALIBRATE surface form ships.",
))
}
};
result.map_err(|e| CalibrateRuntimeError::FitFailure {
model_name: model_name.to_string(),
message: e.to_string(),
})
}
fn method_kind(method: CalibrationMethod) -> CalibrationMethodKind {
match method {
CalibrationMethod::PlattScaling => CalibrationMethodKind::Platt,
CalibrationMethod::IsotonicRegression => CalibrationMethodKind::Isotonic,
CalibrationMethod::TemperatureScaling => CalibrationMethodKind::Temperature,
CalibrationMethod::BetaCalibration => CalibrationMethodKind::Beta,
CalibrationMethod::Conformal { .. } => CalibrationMethodKind::Conformal,
CalibrationMethod::Dirichlet => CalibrationMethodKind::Dirichlet,
CalibrationMethod::None => CalibrationMethodKind::Identity,
}
}
pub async fn run_calibrate(
cmd: &CompiledCalibrate,
model_catalog: &HashMap<String, CompiledModel>,
classifier_registry: &Arc<ClassifierRegistry>,
rows: Vec<FactRow>,
) -> Result<CalibrationResult, CalibrateRuntimeError> {
let model = model_catalog.get(&cmd.model_name).ok_or_else(|| {
CalibrateRuntimeError::UnknownModelInCatalog {
model_name: cmd.model_name.clone(),
}
})?;
let classifier: Arc<dyn NeuralClassifier> =
classifier_registry
.get(&cmd.model_name)
.cloned()
.ok_or_else(|| CalibrateRuntimeError::ClassifierMissing {
model_name: cmd.model_name.clone(),
})?;
if rows.is_empty() {
return Err(CalibrateRuntimeError::EmptyDataset {
model_name: cmd.model_name.clone(),
});
}
let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(rows.len());
let mut labels: Vec<bool> = Vec::with_capacity(rows.len());
for row in &rows {
let mut features = HashMap::with_capacity(model.inputs.len());
for binding in &model.inputs {
features.insert(
binding.variable.clone(),
row_to_feature(row, &binding.variable),
);
}
inputs.push(ClassifyInput { features });
labels.push(target_to_label(row.get("__calibrate_target")));
}
let predictions =
classifier
.classify(&inputs)
.await
.map_err(|e| CalibrateRuntimeError::FitFailure {
model_name: cmd.model_name.clone(),
message: e.to_string(),
})?;
if predictions.len() != labels.len() {
return Err(CalibrateRuntimeError::FitFailure {
model_name: cmd.model_name.clone(),
message: format!(
"classifier returned {} predictions for {} inputs",
predictions.len(),
labels.len()
),
});
}
let n = predictions.len();
let holdout_size = ((n as f64) * cmd.holdout).ceil().max(1.0) as usize;
let holdout_size = holdout_size.min(n);
let mut train_preds: Vec<f64> = Vec::new();
let mut train_labels: Vec<bool> = Vec::new();
let mut holdout_preds: Vec<f64> = Vec::new();
let mut holdout_labels: Vec<bool> = Vec::new();
for (i, (p, y)) in predictions.iter().zip(labels.iter()).enumerate() {
if i < holdout_size {
holdout_preds.push(*p);
holdout_labels.push(*y);
} else {
train_preds.push(*p);
train_labels.push(*y);
}
}
if train_preds.is_empty() || holdout_preds.is_empty() {
return Err(CalibrateRuntimeError::InsufficientData {
model_name: cmd.model_name.clone(),
train: train_preds.len(),
holdout: holdout_preds.len(),
});
}
let calibrator = fit_method(cmd.method, &train_preds, &train_labels, &cmd.model_name)?;
let raw_brier = brier_score(&holdout_preds, &holdout_labels);
let raw_ece = expected_calibration_error(&holdout_preds, &holdout_labels, ECE_BINS);
let calibrated: Vec<f64> = calibrator.apply_batch(&holdout_preds);
let calibrated_brier = brier_score(&calibrated, &holdout_labels);
let calibrated_ece = expected_calibration_error(&calibrated, &holdout_labels, ECE_BINS);
let confidence_band_quantile = calibrator
.confidence_band(0.5)
.map(|band| (band.upper - band.lower) / 2.0);
Ok(CalibrationResult {
model_name: cmd.model_name.clone(),
method: method_kind(cmd.method),
n_samples: predictions.len(),
holdout_size: holdout_preds.len(),
calibrator,
raw_brier,
raw_ece,
calibrated_brier,
calibrated_ece,
confidence_band_quantile,
})
}
pub fn calibrate_collection_query(
cmd: &CompiledCalibrate,
model: &CompiledModel,
) -> uni_cypher::ast::Query {
build_collection_query(cmd, model)
}
#[cfg(test)]
mod tests {
use super::*;
use uni_cypher::locy_ast::{CalibrationMethod as AstCalibration, OutputType};
use uni_locy::{CompiledInputBinding, MockClassifier};
fn fact_row(pairs: &[(&str, Value)]) -> FactRow {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect()
}
fn model_with_one_input() -> CompiledModel {
CompiledModel {
name: "scorer".into(),
inputs: vec![CompiledInputBinding {
variable: "s".into(),
label: Some("Supplier".into()),
}],
embedder_alias: None,
features: vec![],
path_context: None,
output_type: OutputType::Prob,
output_name: "risk".into(),
xervo_alias: "classify/test".into(),
calibration: None,
version: None,
annotations: Default::default(),
}
}
fn dummy_pattern() -> uni_cypher::ast::Pattern {
uni_cypher::ast::Pattern { paths: vec![] }
}
fn cmd(method: AstCalibration) -> CompiledCalibrate {
CompiledCalibrate {
model_name: "scorer".into(),
pattern: dummy_pattern(),
where_expr: None,
target_expr: Expr::Variable("label".into()),
method,
holdout: 0.25,
}
}
#[tokio::test]
async fn calibrate_constant_classifier_improves_ece() {
let mut catalog = HashMap::new();
catalog.insert("scorer".to_string(), model_with_one_input());
let mut registry = ClassifierRegistry::new();
let c: Arc<dyn NeuralClassifier> =
Arc::new(MockClassifier::constant("classify/test", 0.95));
registry.insert("scorer".into(), c);
let registry = Arc::new(registry);
let rows: Vec<FactRow> = (0..100)
.map(|i| {
fact_row(&[
("s", Value::Int(i as i64)),
("__calibrate_target", Value::Bool(i % 2 == 0)),
])
})
.collect();
let result = run_calibrate(
&cmd(AstCalibration::PlattScaling),
&catalog,
®istry,
rows,
)
.await
.unwrap();
assert_eq!(result.model_name, "scorer");
assert_eq!(result.method, CalibrationMethodKind::Platt);
assert!(
result.calibrated_ece < result.raw_ece * 0.5,
"Platt should reduce ECE by ≥50%: raw={} cal={}",
result.raw_ece,
result.calibrated_ece
);
}
#[tokio::test]
async fn calibrate_missing_classifier_errors() {
let mut catalog = HashMap::new();
catalog.insert("scorer".to_string(), model_with_one_input());
let registry = Arc::new(ClassifierRegistry::new());
let rows = vec![fact_row(&[
("s", Value::Int(1)),
("__calibrate_target", Value::Bool(true)),
])];
let err = run_calibrate(
&cmd(AstCalibration::PlattScaling),
&catalog,
®istry,
rows,
)
.await
.unwrap_err();
assert!(matches!(
err,
CalibrateRuntimeError::ClassifierMissing { .. }
));
}
#[tokio::test]
async fn calibrate_empty_dataset_errors() {
let mut catalog = HashMap::new();
catalog.insert("scorer".to_string(), model_with_one_input());
let mut registry = ClassifierRegistry::new();
let c: Arc<dyn NeuralClassifier> = Arc::new(MockClassifier::constant("classify/test", 0.5));
registry.insert("scorer".into(), c);
let registry = Arc::new(registry);
let err = run_calibrate(
&cmd(AstCalibration::PlattScaling),
&catalog,
®istry,
vec![],
)
.await
.unwrap_err();
assert!(matches!(err, CalibrateRuntimeError::EmptyDataset { .. }));
}
}