use std::collections::HashMap;
use uni_common::Value;
use uni_cypher::ast::Expr;
use uni_cypher::locy_ast::{IsReference, RuleCondition};
use uni_locy::types::{CompiledClause, CompiledRule};
use uni_locy::{FactRow, LocyError};
use super::locy_ast_builder::expr_references_var;
use super::locy_eval::{
eval_expr, eval_locy_expr, record_batches_to_locy_rows, values_equal_for_join,
};
use super::locy_traits::DerivedFactSource;
pub type KeyTuple = Vec<Value>;
#[derive(Debug, Clone)]
pub struct RowRelation {
pub columns: Vec<String>,
pub rows: Vec<FactRow>,
}
impl RowRelation {
pub fn new(columns: Vec<String>, rows: Vec<FactRow>) -> Self {
Self { columns, rows }
}
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
pub fn len(&self) -> usize {
self.rows.len()
}
}
pub type RowStore = HashMap<String, RowRelation>;
pub fn extract_key(row: &FactRow, key_cols: &[String]) -> KeyTuple {
key_cols
.iter()
.map(|k| row.get(k).cloned().unwrap_or(Value::Null))
.collect()
}
pub fn extract_cypher_conditions(conditions: &[RuleCondition]) -> Vec<Expr> {
conditions
.iter()
.filter_map(|c| match c {
RuleCondition::Expression(expr) => Some(expr.clone()),
_ => None,
})
.collect()
}
fn is_ref_matches(
base_row: &FactRow,
derived_fact: &FactRow,
is_ref: &IsReference,
schema: &[String],
) -> bool {
for (i, subject) in is_ref.subjects.iter().enumerate() {
if i >= schema.len() {
break;
}
let base_val = base_row.get(subject).cloned().unwrap_or(Value::Null);
let fact_val = derived_fact.get(&schema[i]).cloned().unwrap_or(Value::Null);
if !values_equal_for_join(&base_val, &fact_val) {
return false;
}
}
if let Some(target) = &is_ref.target {
let target_idx = is_ref.subjects.len();
if target_idx < schema.len() {
let base_val = base_row.get(target).cloned().unwrap_or(Value::Null);
let fact_val = derived_fact
.get(&schema[target_idx])
.cloned()
.unwrap_or(Value::Null);
if base_val != Value::Null
&& fact_val != Value::Null
&& !values_equal_for_join(&base_val, &fact_val)
{
return false;
}
}
}
true
}
fn semi_join_is_ref(
base_rows: &[FactRow],
derived_facts: &[FactRow],
is_ref: &IsReference,
schema: &[String],
) -> Vec<FactRow> {
let mut result = Vec::new();
for base_row in base_rows {
for derived_fact in derived_facts {
if is_ref_matches(base_row, derived_fact, is_ref, schema) {
let mut row = base_row.clone();
if let Some(target) = &is_ref.target {
let target_col_idx = is_ref.subjects.len();
if target_col_idx < schema.len() {
let val = derived_fact
.get(&schema[target_col_idx])
.cloned()
.unwrap_or(Value::Null);
row.insert(target.clone(), val);
}
}
let bound_count =
is_ref.subjects.len() + if is_ref.target.is_some() { 1 } else { 0 };
for col in schema.iter().skip(bound_count) {
if let Some(val) = derived_fact.get(col) {
row.entry(col.clone()).or_insert_with(|| val.clone());
}
}
for (col, val) in derived_fact {
row.insert(format!("__prev_{}", col), val.clone());
}
result.push(row);
}
}
}
result
}
fn anti_join_is_ref(
base_rows: &[FactRow],
derived_facts: &[FactRow],
is_ref: &IsReference,
schema: &[String],
) -> Vec<FactRow> {
base_rows
.iter()
.filter(|base_row| {
!derived_facts
.iter()
.any(|df| is_ref_matches(base_row, df, is_ref, schema))
})
.cloned()
.collect()
}
fn prob_complement_is_ref(
base_rows: &[FactRow],
derived_facts: &[FactRow],
is_ref: &IsReference,
schema: &[String],
target_prob_col: &str,
complement_col_name: &str,
) -> Vec<FactRow> {
base_rows
.iter()
.map(|base_row| {
let mut combined_p = 0.0_f64;
for df in derived_facts {
if is_ref_matches(base_row, df, is_ref, schema) {
let p = df
.get(target_prob_col)
.and_then(|v| v.as_f64())
.unwrap_or(0.0)
.clamp(0.0, 1.0);
combined_p = 1.0 - (1.0 - combined_p) * (1.0 - p);
}
}
let complement = 1.0 - combined_p;
let mut row = base_row.clone();
row.insert(complement_col_name.to_string(), Value::Float(complement));
row
})
.collect()
}
pub fn multiply_prob_factors_rows(
rows: &mut [FactRow],
prob_col: &str,
complement_cols: &[String],
) {
for row in rows.iter_mut() {
let mut factor = 1.0_f64;
for col in complement_cols {
if let Some(v) = row.get(col).and_then(|v| v.as_f64()) {
factor *= v;
}
row.remove(col);
}
let current = row.get(prob_col).and_then(|v| v.as_f64()).unwrap_or(1.0);
row.insert(prob_col.to_string(), Value::Float(current * factor));
}
}
pub async fn resolve_clause_with_is_refs(
clause: &CompiledClause,
fact_source: &dyn DerivedFactSource,
derived_store: &RowStore,
rule_catalog: &HashMap<String, CompiledRule>,
calling_rule_prob_col: Option<&str>,
) -> Result<Vec<FactRow>, LocyError> {
let mut is_ref_vars: Vec<String> = Vec::new();
for cond in &clause.where_conditions {
if let RuleCondition::IsReference(is_ref) = cond
&& !is_ref.negated
{
if let Some(target) = &is_ref.target {
is_ref_vars.push(target.clone());
}
let rule_name = is_ref.rule_name.to_string();
if let Some(target_rule) = rule_catalog.get(&rule_name) {
for col in &target_rule.yield_schema {
if !col.is_key {
is_ref_vars.push(col.name.clone());
}
}
}
}
}
let all_cypher = extract_cypher_conditions(&clause.where_conditions);
let mut match_safe = Vec::new();
let mut target_dependent = Vec::new();
for expr in all_cypher {
if is_ref_vars.iter().any(|v| expr_references_var(&expr, v)) {
target_dependent.push(expr);
} else {
match_safe.push(expr);
}
}
let raw_batches = fact_source
.execute_pattern(&clause.match_pattern, &match_safe)
.await?;
let mut rows = record_batches_to_locy_rows(&raw_batches);
for cond in &clause.where_conditions {
if let RuleCondition::IsReference(is_ref) = cond {
let rule_name = is_ref.rule_name.to_string();
let (schema, derived_facts) = if let Some(relation) = derived_store.get(&rule_name) {
(relation.columns.clone(), relation.rows.clone())
} else {
(Vec::new(), Vec::new())
};
if is_ref.negated {
let target_prob_col = rule_catalog.get(&rule_name).and_then(|r| {
r.yield_schema
.iter()
.find(|c| c.is_prob)
.map(|c| c.name.clone())
});
if let (Some(tpc), Some(_)) = (&target_prob_col, &calling_rule_prob_col) {
let complement_col = format!("__prob_complement_{}", rule_name);
rows = prob_complement_is_ref(
&rows,
&derived_facts,
is_ref,
&schema,
tpc,
&complement_col,
);
} else {
rows = anti_join_is_ref(&rows, &derived_facts, is_ref, &schema);
}
} else {
rows = semi_join_is_ref(&rows, &derived_facts, is_ref, &schema);
}
}
}
if !target_dependent.is_empty() {
rows.retain(|row| {
target_dependent.iter().all(|expr| {
eval_expr(expr, row)
.map(|v| v.as_bool().unwrap_or(false))
.unwrap_or(false)
})
});
}
if !clause.along.is_empty() {
let mut new_rows = Vec::with_capacity(rows.len());
for mut row in rows {
let prev_values: FactRow = row
.iter()
.filter_map(|(k, v)| {
k.strip_prefix("__prev_")
.map(|col| (col.to_string(), v.clone()))
})
.collect();
for along in &clause.along {
let result = eval_locy_expr(&along.expr, &row, Some(&prev_values))?;
row.insert(along.name.clone(), result);
}
new_rows.push(row);
}
rows = new_rows;
}
for row in &mut rows {
row.retain(|k, _| !k.starts_with("__prev_"));
}
Ok(rows)
}
#[cfg(test)]
mod tests {
use super::*;
use uni_common::Value;
use uni_cypher::locy_ast::{IsReference, QualifiedName};
fn make_is_ref(subject: &str, rule: &str) -> IsReference {
IsReference {
subjects: vec![subject.to_string()],
rule_name: QualifiedName {
parts: vec![rule.to_string()],
},
target: None,
negated: true,
}
}
fn row_int(pairs: &[(&str, i64)]) -> FactRow {
pairs
.iter()
.map(|(k, v)| (k.to_string(), Value::Int(*v)))
.collect()
}
#[test]
fn test_prob_complement_basic() {
let base = vec![
row_int(&[("n", 1)]),
row_int(&[("n", 2)]),
row_int(&[("n", 3)]),
];
let derived = vec![
{
let mut r = row_int(&[("n", 1)]);
r.insert("risk_score".into(), Value::Float(0.7));
r
},
{
let mut r = row_int(&[("n", 2)]);
r.insert("risk_score".into(), Value::Float(0.3));
r
},
];
let is_ref = make_is_ref("n", "risky");
let schema = vec!["n".to_string(), "risk_score".to_string()];
let result = prob_complement_is_ref(&base, &derived, &is_ref, &schema, "risk_score", "__c");
assert_eq!(result.len(), 3);
let c0 = result[0].get("__c").unwrap().as_f64().unwrap();
let c1 = result[1].get("__c").unwrap().as_f64().unwrap();
let c2 = result[2].get("__c").unwrap().as_f64().unwrap();
assert!((c0 - 0.3).abs() < 1e-10, "Alice: expected 0.3, got {c0}");
assert!((c1 - 0.7).abs() < 1e-10, "Bob: expected 0.7, got {c1}");
assert!((c2 - 1.0).abs() < 1e-10, "Charlie: expected 1.0, got {c2}");
}
#[test]
fn test_prob_complement_noisy_or_duplicates() {
let base = vec![row_int(&[("n", 1)])];
let derived = vec![
{
let mut r = row_int(&[("n", 1)]);
r.insert("prob".into(), Value::Float(0.3));
r
},
{
let mut r = row_int(&[("n", 1)]);
r.insert("prob".into(), Value::Float(0.5));
r
},
];
let is_ref = make_is_ref("n", "rule");
let schema = vec!["n".to_string(), "prob".to_string()];
let result = prob_complement_is_ref(&base, &derived, &is_ref, &schema, "prob", "__c");
let c = result[0].get("__c").unwrap().as_f64().unwrap();
assert!((c - 0.35).abs() < 1e-10, "expected 0.35, got {c}");
}
#[test]
fn test_multiply_prob_factors_rows_basic() {
let mut rows = vec![{
let mut r = FactRow::new();
r.insert("safety".into(), Value::Float(1.0));
r.insert("__prob_complement_risky".into(), Value::Float(0.3));
r
}];
multiply_prob_factors_rows(&mut rows, "safety", &["__prob_complement_risky".into()]);
assert!(
!rows[0].contains_key("__prob_complement_risky"),
"complement col should be removed"
);
let safety = rows[0].get("safety").unwrap().as_f64().unwrap();
assert!((safety - 0.3).abs() < 1e-10, "expected 0.3, got {safety}");
}
#[test]
fn test_anti_join_unchanged_without_prob() {
let base = vec![row_int(&[("n", 1)]), row_int(&[("n", 2)])];
let derived = vec![row_int(&[("n", 1)])];
let is_ref = make_is_ref("n", "flagged");
let schema = vec!["n".to_string()];
let result = anti_join_is_ref(&base, &derived, &is_ref, &schema);
assert_eq!(result.len(), 1);
assert_eq!(result[0].get("n").unwrap().as_i64(), Some(2));
}
}