use std::collections::HashMap;
use std::time::Duration;
use uni_common::{Properties, Value};
use crate::types::{RuntimeWarning, RuntimeWarningCode};
pub type FactRow = HashMap<String, Value>;
#[derive(Debug, Clone)]
pub struct LocyResult {
pub derived: HashMap<String, Vec<FactRow>>,
pub stats: LocyStats,
pub command_results: Vec<CommandResult>,
pub warnings: Vec<RuntimeWarning>,
pub compile_warnings: Vec<crate::types::CompilerWarning>,
pub approximate_groups: HashMap<String, Vec<String>>,
pub derived_fact_set: Option<DerivedFactSet>,
pub incomplete: Option<uni_common::LocyIncomplete>,
}
#[derive(Debug, Clone)]
pub enum CommandResult {
Query(Vec<FactRow>),
Assume(Vec<FactRow>),
Explain(DerivationNode),
Abduce(AbductionResult),
Derive {
affected: usize,
},
Cypher(Vec<FactRow>),
Calibrate(CalibrationResult),
Validate(ValidationResult),
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub rule_name: String,
pub prob_column: String,
pub n_samples: usize,
pub metrics: Vec<(uni_cypher::locy_ast::ValidationMetric, f64)>,
}
impl ValidationResult {
pub fn metric(&self, m: uni_cypher::locy_ast::ValidationMetric) -> Option<f64> {
self.metrics
.iter()
.find(|(name, _)| *name == m)
.map(|(_, v)| *v)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ConfidenceBand {
pub lower: f64,
pub upper: f64,
pub source: ConfidenceSource,
}
#[derive(Debug, Clone, Copy)]
pub enum ConfidenceSource {
Conformal { alpha: f64 },
EnsembleVariance { n_estimators: usize },
Credal { lower_prior: f64, upper_prior: f64 },
}
#[cfg(test)]
mod confidence_source_tests {
use super::ConfidenceSource;
#[test]
fn conformal_debug_format() {
let s = ConfidenceSource::Conformal { alpha: 0.1 };
let dbg = format!("{:?}", s);
assert!(dbg.contains("Conformal"));
assert!(dbg.contains("0.1"));
}
#[test]
fn ensemble_variance_debug_format() {
let s = ConfidenceSource::EnsembleVariance { n_estimators: 50 };
let dbg = format!("{:?}", s);
assert!(dbg.contains("EnsembleVariance"));
assert!(dbg.contains("50"));
}
#[test]
fn credal_debug_format() {
let s = ConfidenceSource::Credal {
lower_prior: 0.1,
upper_prior: 0.9,
};
let dbg = format!("{:?}", s);
assert!(dbg.contains("Credal"));
assert!(dbg.contains("0.1"));
assert!(dbg.contains("0.9"));
}
}
#[derive(Debug, Clone)]
pub struct CalibrationResult {
pub model_name: String,
pub method: crate::calibration::CalibrationMethodKind,
pub n_samples: usize,
pub holdout_size: usize,
pub calibrator: std::sync::Arc<dyn crate::calibration::Calibrator>,
pub raw_brier: f64,
pub raw_ece: f64,
pub calibrated_brier: f64,
pub calibrated_ece: f64,
pub confidence_band_quantile: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct NeuralProvenance {
pub model_name: String,
pub raw_probability: f64,
pub calibrated_probability: Option<f64>,
pub confidence_band: Option<ConfidenceBand>,
}
#[derive(Debug, Clone)]
pub struct DerivationNode {
pub rule: String,
pub clause_index: usize,
pub priority: Option<i64>,
pub bindings: HashMap<String, Value>,
pub along_values: HashMap<String, Value>,
pub children: Vec<DerivationNode>,
pub graph_fact: Option<String>,
pub approximate: bool,
pub proof_probability: Option<f64>,
pub neural_calls: Vec<NeuralProvenance>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct AbductionResult {
pub modifications: Vec<ValidatedModification>,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ValidatedModification {
pub modification: Modification,
pub validated: bool,
pub cost: f64,
}
#[derive(Debug, Clone, serde::Serialize)]
pub enum Modification {
RemoveEdge {
source_var: String,
target_var: String,
edge_var: String,
edge_type: String,
match_properties: HashMap<String, Value>,
},
ChangeProperty {
element_var: String,
property: String,
old_value: Box<Value>,
new_value: Box<Value>,
},
AddEdge {
source_var: String,
target_var: String,
edge_type: String,
properties: HashMap<String, Value>,
},
}
#[derive(Debug, Clone)]
pub struct DerivedEdge {
pub edge_type: String,
pub source_label: String,
pub source_properties: Properties,
pub target_label: String,
pub target_properties: Properties,
pub edge_properties: Properties,
}
#[derive(Debug, Clone)]
pub struct DerivedFactSet {
pub vertices: HashMap<String, Vec<Properties>>,
pub edges: Vec<DerivedEdge>,
pub stats: LocyStats,
pub evaluated_at_version: u64,
#[doc(hidden)]
pub mutation_queries: Vec<uni_cypher::ast::Query>,
}
impl DerivedFactSet {
pub fn fact_count(&self) -> usize {
self.vertices.values().map(|v| v.len()).sum::<usize>() + self.edges.len()
}
pub fn is_empty(&self) -> bool {
self.vertices.is_empty() && self.edges.is_empty()
}
}
#[derive(Debug, Clone, Default)]
pub struct LocyStats {
pub strata_evaluated: usize,
pub total_iterations: usize,
pub derived_nodes: usize,
pub derived_edges: usize,
pub evaluation_time: Duration,
pub queries_executed: usize,
pub mutations_executed: usize,
pub peak_memory_bytes: usize,
}
impl LocyResult {
pub fn derived_facts(&self, rule: &str) -> Option<&Vec<FactRow>> {
self.derived.get(rule)
}
pub fn rows(&self) -> Option<&Vec<FactRow>> {
self.command_results.iter().find_map(|cr| cr.as_query())
}
pub fn columns(&self) -> Option<Vec<String>> {
self.rows().and_then(|rows| {
rows.first().map(|row| {
let mut cols: Vec<String> = row.keys().cloned().collect();
cols.sort();
cols
})
})
}
pub fn stats(&self) -> &LocyStats {
&self.stats
}
pub fn iterations(&self) -> usize {
self.stats.total_iterations
}
pub fn compile_warnings(&self) -> &[crate::types::CompilerWarning] {
&self.compile_warnings
}
pub fn command_results(&self) -> &[CommandResult] {
&self.command_results
}
pub fn warnings(&self) -> &[RuntimeWarning] {
&self.warnings
}
pub fn has_warning(&self, code: &RuntimeWarningCode) -> bool {
self.warnings.iter().any(|w| w.code == *code)
}
pub fn timed_out(&self) -> bool {
self.incomplete.is_some()
}
}
impl CommandResult {
pub fn as_explain(&self) -> Option<&DerivationNode> {
match self {
CommandResult::Explain(node) => Some(node),
_ => None,
}
}
pub fn as_query(&self) -> Option<&Vec<FactRow>> {
match self {
CommandResult::Query(rows) => Some(rows),
_ => None,
}
}
pub fn as_abduce(&self) -> Option<&AbductionResult> {
match self {
CommandResult::Abduce(result) => Some(result),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn columns_returned_in_sorted_order() {
let mut row = FactRow::new();
row.insert("zeta".into(), Value::Int(1));
row.insert("alpha".into(), Value::Int(2));
row.insert("mu".into(), Value::Int(3));
let result = LocyResult {
derived: HashMap::new(),
stats: LocyStats::default(),
command_results: vec![CommandResult::Query(vec![row])],
warnings: Vec::new(),
compile_warnings: Vec::new(),
approximate_groups: HashMap::new(),
derived_fact_set: None,
incomplete: None,
};
let cols = result
.columns()
.expect("expected columns for non-empty result");
assert_eq!(
cols,
vec!["alpha".to_owned(), "mu".to_owned(), "zeta".to_owned()]
);
}
#[test]
fn abduce_result_serializes_to_json() {
let result = AbductionResult {
modifications: vec![
ValidatedModification {
modification: Modification::ChangeProperty {
element_var: "a".into(),
property: "flagged".into(),
old_value: Box::new(Value::String("false".into())),
new_value: Box::new(Value::String("true".into())),
},
validated: true,
cost: 0.5,
},
ValidatedModification {
modification: Modification::RemoveEdge {
source_var: "a".into(),
target_var: "b".into(),
edge_var: "e".into(),
edge_type: "TRANSFERS_TO".into(),
match_properties: HashMap::from([("amount".into(), Value::Float(1000.0))]),
},
validated: false,
cost: 1.0,
},
ValidatedModification {
modification: Modification::AddEdge {
source_var: "a".into(),
target_var: "b".into(),
edge_type: "FLAGGED_BY".into(),
properties: HashMap::new(),
},
validated: true,
cost: 1.5,
},
],
};
let json = serde_json::to_value(&result).expect("serialization failed");
let mods = json["modifications"].as_array().unwrap();
assert_eq!(mods.len(), 3);
assert_eq!(mods[0]["validated"], true);
assert_eq!(mods[0]["cost"], 0.5);
assert!(mods[0]["modification"]["ChangeProperty"].is_object());
assert!(mods[1]["modification"]["RemoveEdge"].is_object());
assert!(mods[2]["modification"]["AddEdge"].is_object());
}
}