use super::{evaluate_rule, RefEvalError, RefRelation, RefRelationStore, RefValue, VariableOrder};
use crate::ast::Rule;
use xlog_core::ScalarType;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FixpointConfig {
pub max_iterations: usize,
}
impl Default for FixpointConfig {
fn default() -> Self {
Self { max_iterations: 32 }
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FixpointError {
RuleNotForTarget {
rule_index: usize,
observed: String,
expected: String,
},
HeadArityMismatch {
rule_index: usize,
observed_arity: usize,
expected_arity: usize,
},
TargetPredicateInBaseRelations {
name: String,
},
RuleEval {
rule_index: usize,
source: RefEvalError,
},
MaxIterationsExceeded {
limit: usize,
observed_size: usize,
},
TargetSchemaIndeterminable,
InvalidMaxIterations,
}
pub fn evaluate_fixpoint(
rules: &[Rule],
base_relations: &RefRelationStore,
target_predicate: &str,
order: &dyn VariableOrder,
config: &FixpointConfig,
) -> Result<RefRelation, FixpointError> {
if config.max_iterations == 0 {
return Err(FixpointError::InvalidMaxIterations);
}
if base_relations.contains_key(target_predicate) {
return Err(FixpointError::TargetPredicateInBaseRelations {
name: target_predicate.to_string(),
});
}
for (i, rule) in rules.iter().enumerate() {
if rule.head.predicate != target_predicate {
return Err(FixpointError::RuleNotForTarget {
rule_index: i,
observed: rule.head.predicate.clone(),
expected: target_predicate.to_string(),
});
}
}
let target_arity = rules
.iter()
.find(|r| !r.head.terms.is_empty())
.map(|r| r.head.terms.len())
.ok_or(FixpointError::TargetSchemaIndeterminable)?;
for (i, rule) in rules.iter().enumerate() {
if rule.head.terms.is_empty() {
continue;
}
if rule.head.terms.len() != target_arity {
return Err(FixpointError::HeadArityMismatch {
rule_index: i,
observed_arity: rule.head.terms.len(),
expected_arity: target_arity,
});
}
}
let mut derived_schema: Option<Vec<ScalarType>> = None;
let mut derived_rows: Vec<Vec<RefValue>> = Vec::new();
for _iter in 0..config.max_iterations {
let mut store = base_relations.clone();
let placeholder_schema = derived_schema
.clone()
.unwrap_or_else(|| vec![ScalarType::U32; target_arity]);
store.insert(
target_predicate.to_string(),
RefRelation {
schema: placeholder_schema,
rows: derived_rows.clone(),
},
);
let mut new_rows: Vec<Vec<RefValue>> = derived_rows.clone();
for (rule_index, rule) in rules.iter().enumerate() {
let rows = evaluate_rule(rule, &store, order).map_err(|e| FixpointError::RuleEval {
rule_index,
source: e,
})?;
new_rows.extend(rows);
}
new_rows.sort();
new_rows.dedup();
if derived_schema.is_none() {
if let Some(first) = new_rows.first() {
derived_schema = Some(infer_schema(first));
}
}
if new_rows == derived_rows {
let schema = derived_schema.unwrap_or_else(|| vec![ScalarType::U32; target_arity]);
return Ok(RefRelation {
schema,
rows: derived_rows,
});
}
derived_rows = new_rows;
}
Err(FixpointError::MaxIterationsExceeded {
limit: config.max_iterations,
observed_size: derived_rows.len(),
})
}
fn infer_schema(row: &[RefValue]) -> Vec<ScalarType> {
row.iter()
.map(|v| match v {
RefValue::U32(_) => ScalarType::U32,
RefValue::U64(_) => ScalarType::U64,
RefValue::I32(_) => ScalarType::I32,
RefValue::I64(_) => ScalarType::I64,
RefValue::Bool(_) => ScalarType::Bool,
RefValue::Symbol(_) => ScalarType::Symbol,
})
.collect()
}