use super::{
evaluate_rule, FixpointConfig, RefEvalError, RefRelation, RefRelationStore, RefValue,
VariableOrder,
};
use crate::ast::Rule;
use std::collections::BTreeMap;
use xlog_core::ScalarType;
#[derive(Debug, Clone, PartialEq)]
pub enum SccFixpointError {
RuleHeadPredicateMismatch {
group_key: String,
rule_index: usize,
observed: String,
},
HeadArityMismatch {
predicate: String,
rule_index: usize,
observed_arity: usize,
expected_arity: usize,
},
InconsistentHeadValueTypes {
predicate: String,
column: usize,
expected: ScalarType,
got: String,
},
PredicateInBaseRelations {
name: String,
},
RuleEval {
predicate: String,
rule_index: usize,
source: RefEvalError,
},
MaxIterationsExceeded {
limit: usize,
predicate_count: usize,
total_observed_rows: usize,
},
SchemaIndeterminable {
predicate: String,
},
InvalidMaxIterations,
}
#[allow(clippy::result_large_err)]
pub fn evaluate_scc_fixpoint(
rules: &BTreeMap<String, Vec<Rule>>,
base_relations: &RefRelationStore,
order: &dyn VariableOrder,
config: &FixpointConfig,
) -> Result<RefRelationStore, SccFixpointError> {
if config.max_iterations == 0 {
return Err(SccFixpointError::InvalidMaxIterations);
}
let mut arities: BTreeMap<String, usize> = BTreeMap::new();
for (predicate, group) in rules.iter() {
if base_relations.contains_key(predicate) {
return Err(SccFixpointError::PredicateInBaseRelations {
name: predicate.clone(),
});
}
for (idx, rule) in group.iter().enumerate() {
if rule.head.predicate != *predicate {
return Err(SccFixpointError::RuleHeadPredicateMismatch {
group_key: predicate.clone(),
rule_index: idx,
observed: rule.head.predicate.clone(),
});
}
}
let arity = group
.iter()
.find(|r| !r.head.terms.is_empty())
.map(|r| r.head.terms.len())
.ok_or_else(|| SccFixpointError::SchemaIndeterminable {
predicate: predicate.clone(),
})?;
for (idx, rule) in group.iter().enumerate() {
if rule.head.terms.is_empty() {
continue;
}
if rule.head.terms.len() != arity {
return Err(SccFixpointError::HeadArityMismatch {
predicate: predicate.clone(),
rule_index: idx,
observed_arity: rule.head.terms.len(),
expected_arity: arity,
});
}
}
arities.insert(predicate.clone(), arity);
}
let mut frozen_schemas: BTreeMap<String, Option<Vec<ScalarType>>> = BTreeMap::new();
let mut derived: BTreeMap<String, Vec<Vec<RefValue>>> = BTreeMap::new();
for predicate in rules.keys() {
frozen_schemas.insert(predicate.clone(), None);
derived.insert(predicate.clone(), Vec::new());
}
for _iter in 0..config.max_iterations {
let mut store = base_relations.clone();
for (predicate, rows) in derived.iter() {
let schema = frozen_schemas
.get(predicate)
.and_then(|s| s.clone())
.unwrap_or_else(|| vec![ScalarType::U32; arities[predicate]]);
store.insert(
predicate.clone(),
RefRelation {
schema,
rows: rows.clone(),
},
);
}
let mut next: BTreeMap<String, Vec<Vec<RefValue>>> = derived.clone();
for (predicate, group) in rules.iter() {
let mut produced: Vec<Vec<RefValue>> = Vec::new();
for (rule_index, rule) in group.iter().enumerate() {
let rows =
evaluate_rule(rule, &store, order).map_err(|e| SccFixpointError::RuleEval {
predicate: predicate.clone(),
rule_index,
source: e,
})?;
produced.extend(rows);
}
let frozen_entry = frozen_schemas.get_mut(predicate).expect("inserted above");
if frozen_entry.is_none() {
if let Some(first) = produced.first() {
*frozen_entry = Some(infer_schema(first));
}
}
if let Some(schema) = frozen_entry.as_ref() {
for row in &produced {
if let Some((column, expected, got)) = first_type_mismatch(row, schema) {
return Err(SccFixpointError::InconsistentHeadValueTypes {
predicate: predicate.clone(),
column,
expected,
got,
});
}
}
}
let target = next.get_mut(predicate).expect("predicate present");
target.extend(produced);
target.sort();
target.dedup();
}
if next == derived {
let mut out: RefRelationStore = BTreeMap::new();
for (predicate, rows) in derived.into_iter() {
let schema = frozen_schemas
.get(&predicate)
.and_then(|s| s.clone())
.unwrap_or_else(|| vec![ScalarType::U32; arities[&predicate]]);
out.insert(predicate, RefRelation { schema, rows });
}
return Ok(out);
}
derived = next;
}
let total: usize = derived.values().map(|v| v.len()).sum();
Err(SccFixpointError::MaxIterationsExceeded {
limit: config.max_iterations,
predicate_count: rules.len(),
total_observed_rows: total,
})
}
fn infer_schema(row: &[RefValue]) -> Vec<ScalarType> {
row.iter()
.map(|v| match v {
RefValue::U32(_) => ScalarType::U32,
RefValue::U64(_) => ScalarType::U64,
RefValue::I32(_) => ScalarType::I32,
RefValue::I64(_) => ScalarType::I64,
RefValue::Bool(_) => ScalarType::Bool,
RefValue::Symbol(_) => ScalarType::Symbol,
})
.collect()
}
fn first_type_mismatch(
row: &[RefValue],
schema: &[ScalarType],
) -> Option<(usize, ScalarType, String)> {
for (i, (val, ty)) in row.iter().zip(schema.iter()).enumerate() {
let ok = matches!(
(val, ty),
(RefValue::U32(_), ScalarType::U32)
| (RefValue::U64(_), ScalarType::U64)
| (RefValue::I32(_), ScalarType::I32)
| (RefValue::I64(_), ScalarType::I64)
| (RefValue::Bool(_), ScalarType::Bool)
| (RefValue::Symbol(_), ScalarType::Symbol)
);
if !ok {
return Some((i, *ty, format!("{val:?}")));
}
}
None
}