use super::ir::{Hyperedge, HypergraphRule, VertexId};
use super::var_order::VariableOrder;
use super::{analyze, Eligibility, ExecutorContext};
use crate::ast::{Atom, BodyLiteral, CompOp, Comparison, Rule, Term};
use std::collections::BTreeMap;
use xlog_core::ScalarType;
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum RefValue {
U32(u32),
U64(u64),
I32(i32),
I64(i64),
Bool(bool),
Symbol(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct RefRelation {
pub schema: Vec<ScalarType>,
pub rows: Vec<Vec<RefValue>>,
}
pub type RefRelationStore = BTreeMap<String, RefRelation>;
#[derive(Debug, Clone, PartialEq)]
pub enum RefEvalError {
Ineligible(Vec<super::Boundary>),
MissingRelation(String),
RelationArityMismatch {
predicate: String,
atom_arity: usize,
relation_arity: usize,
},
RelationRowArityMismatch {
predicate: String,
row_index: usize,
row_len: usize,
schema_len: usize,
},
RelationValueTypeMismatch {
predicate: String,
row_index: usize,
column: usize,
expected: ScalarType,
got: String,
},
ConstantTypeMismatch {
predicate: String,
position: usize,
expected: ScalarType,
got: String,
},
ComparisonTypeMismatch {
left: String,
right: String,
op: CompOp,
},
UnboundHeadVariable(String),
IsExprNotSupported,
ConflictingVariableType {
var: String,
first_predicate: String,
first_position: usize,
first_type: ScalarType,
second_predicate: String,
second_position: usize,
second_type: ScalarType,
},
InferenceConflict {
predicate: String,
column: usize,
first_rule_index: usize,
first_type: ScalarType,
second_rule_index: usize,
second_type: ScalarType,
},
}
pub fn evaluate_rule(
rule: &Rule,
relations: &RefRelationStore,
order: &dyn VariableOrder,
) -> Result<Vec<Vec<RefValue>>, RefEvalError> {
let hg = HypergraphRule::from_rule(rule);
match analyze(&hg, ExecutorContext::HashFallback) {
Eligibility::Eligible => {}
Eligibility::Ineligible(bs) => return Err(RefEvalError::Ineligible(bs)),
}
if rule
.body
.iter()
.any(|l| matches!(l, BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_)))
{
return Err(RefEvalError::IsExprNotSupported);
}
let positive_atoms: Vec<&Atom> = rule
.body
.iter()
.filter_map(|l| match l {
BodyLiteral::Positive(a) => Some(a),
_ => None,
})
.collect();
let comparisons: Vec<&Comparison> = rule
.body
.iter()
.filter_map(|l| match l {
BodyLiteral::Comparison(c) => Some(c),
_ => None,
})
.collect();
let mut atom_specs: Vec<AtomSpec> = Vec::with_capacity(positive_atoms.len());
for (atom, edge) in positive_atoms.iter().zip(hg.hyperedges.iter()) {
let rel = relations
.get(&atom.predicate)
.ok_or_else(|| RefEvalError::MissingRelation(atom.predicate.clone()))?;
validate_relation_rows(&atom.predicate, rel)?;
let spec = AtomSpec::build(atom, edge, rel)?;
atom_specs.push(spec);
}
let domains: Vec<Vec<RefValue>> = build_domains(&hg, &atom_specs);
let var_order: Vec<VertexId> = order.order(&hg);
debug_assert_eq!(var_order.len(), hg.vertex_count());
let mut binding: BTreeMap<VertexId, RefValue> = BTreeMap::new();
let mut results: Vec<Vec<RefValue>> = Vec::new();
enumerate(
&hg,
&atom_specs,
&comparisons,
&domains,
&var_order,
0,
&mut binding,
&rule.head,
&mut results,
)?;
results.sort();
results.dedup();
Ok(results)
}
struct AtomSpec<'a> {
relation: &'a RefRelation,
positions: Vec<PositionSpec>,
}
enum PositionSpec {
Vertex(VertexId),
Constant(RefValue),
Wildcard,
}
impl<'a> AtomSpec<'a> {
fn build(
atom: &Atom,
edge: &Hyperedge,
relation: &'a RefRelation,
) -> Result<Self, RefEvalError> {
if atom.terms.len() != relation.schema.len() {
return Err(RefEvalError::RelationArityMismatch {
predicate: atom.predicate.clone(),
atom_arity: atom.terms.len(),
relation_arity: relation.schema.len(),
});
}
let mut positions = Vec::with_capacity(atom.terms.len());
for (i, term) in atom.terms.iter().enumerate() {
let p = match term {
Term::Variable(_) => {
match edge.vertex_positions.get(i).and_then(|p| *p) {
Some(vid) => PositionSpec::Vertex(vid),
None => {
PositionSpec::Wildcard
}
}
}
Term::Anonymous => PositionSpec::Wildcard,
Term::Integer(n) => {
let coerced =
coerce_integer_constant(*n, relation.schema[i], &atom.predicate, i)?;
PositionSpec::Constant(coerced)
}
Term::String(s) => match relation.schema[i] {
ScalarType::Symbol => PositionSpec::Constant(RefValue::Symbol(s.clone())),
other => {
return Err(RefEvalError::ConstantTypeMismatch {
predicate: atom.predicate.clone(),
position: i,
expected: other,
got: format!("string {s:?}"),
});
}
},
Term::Symbol(_id) => {
return Err(RefEvalError::ConstantTypeMismatch {
predicate: atom.predicate.clone(),
position: i,
expected: relation.schema[i],
got: "interned Symbol(u32) — use Term::String against a Symbol column"
.to_string(),
});
}
Term::Float(f) => {
return Err(RefEvalError::ConstantTypeMismatch {
predicate: atom.predicate.clone(),
position: i,
expected: relation.schema[i],
got: format!("float {f}"),
});
}
Term::Aggregate(_) => {
return Err(RefEvalError::ConstantTypeMismatch {
predicate: atom.predicate.clone(),
position: i,
expected: relation.schema[i],
got: "aggregate in body atom".to_string(),
});
}
Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
return Err(RefEvalError::ConstantTypeMismatch {
predicate: atom.predicate.clone(),
position: i,
expected: relation.schema[i],
got: format!("unsupported v0.8.5 term in body atom: {term:?}"),
});
}
};
positions.push(p);
}
Ok(AtomSpec {
relation,
positions,
})
}
}
fn coerce_integer_constant(
n: i64,
target: ScalarType,
predicate: &str,
position: usize,
) -> Result<RefValue, RefEvalError> {
let mismatch = |got: String| RefEvalError::ConstantTypeMismatch {
predicate: predicate.to_string(),
position,
expected: target,
got,
};
match target {
ScalarType::U32 => {
if n < 0 || n > u32::MAX as i64 {
Err(mismatch(format!("integer {n} out of range for U32")))
} else {
Ok(RefValue::U32(n as u32))
}
}
ScalarType::U64 => {
if n < 0 {
Err(mismatch(format!("negative integer {n} for U64")))
} else {
Ok(RefValue::U64(n as u64))
}
}
ScalarType::I32 => {
if !(i32::MIN as i64..=i32::MAX as i64).contains(&n) {
Err(mismatch(format!("integer {n} out of range for I32")))
} else {
Ok(RefValue::I32(n as i32))
}
}
ScalarType::I64 => Ok(RefValue::I64(n)),
ScalarType::Bool | ScalarType::F32 | ScalarType::F64 | ScalarType::Symbol => {
Err(mismatch(format!("integer {n} for {target:?}")))
}
}
}
fn validate_relation_rows(predicate: &str, relation: &RefRelation) -> Result<(), RefEvalError> {
let schema_len = relation.schema.len();
for (row_index, row) in relation.rows.iter().enumerate() {
if row.len() != schema_len {
return Err(RefEvalError::RelationRowArityMismatch {
predicate: predicate.to_string(),
row_index,
row_len: row.len(),
schema_len,
});
}
for (column, value) in row.iter().enumerate() {
let expected = relation.schema[column];
if !ref_value_matches_scalar_type(value, expected) {
return Err(RefEvalError::RelationValueTypeMismatch {
predicate: predicate.to_string(),
row_index,
column,
expected,
got: format!("{value:?}"),
});
}
}
}
Ok(())
}
fn ref_value_matches_scalar_type(value: &RefValue, expected: ScalarType) -> bool {
matches!(
(value, expected),
(RefValue::U32(_), ScalarType::U32)
| (RefValue::U64(_), ScalarType::U64)
| (RefValue::I32(_), ScalarType::I32)
| (RefValue::I64(_), ScalarType::I64)
| (RefValue::Bool(_), ScalarType::Bool)
| (RefValue::Symbol(_), ScalarType::Symbol)
)
}
fn build_domains(hg: &HypergraphRule, atom_specs: &[AtomSpec<'_>]) -> Vec<Vec<RefValue>> {
let mut domains: Vec<Vec<RefValue>> = vec![Vec::new(); hg.vertex_count()];
for spec in atom_specs {
for (i, pos) in spec.positions.iter().enumerate() {
if let PositionSpec::Vertex(vid) = pos {
let VertexId(idx) = *vid;
for row in &spec.relation.rows {
domains[idx].push(row[i].clone());
}
}
}
}
for d in domains.iter_mut() {
d.sort();
d.dedup();
}
domains
}
#[allow(clippy::too_many_arguments)]
fn enumerate(
hg: &HypergraphRule,
atom_specs: &[AtomSpec<'_>],
comparisons: &[&Comparison],
domains: &[Vec<RefValue>],
var_order: &[VertexId],
depth: usize,
binding: &mut BTreeMap<VertexId, RefValue>,
head: &Atom,
results: &mut Vec<Vec<RefValue>>,
) -> Result<(), RefEvalError> {
if depth == var_order.len() {
for spec in atom_specs {
if !atom_has_matching_row(spec, binding) {
return Ok(());
}
}
for cmp in comparisons {
if !evaluate_comparison(cmp, binding, hg)? {
return Ok(());
}
}
let mut row = Vec::with_capacity(head.terms.len());
for term in &head.terms {
match term {
Term::Variable(name) => {
let vid = match hg.vertices.iter().position(|v| &v.name == name) {
Some(i) => VertexId(i),
None => return Err(RefEvalError::UnboundHeadVariable(name.clone())),
};
let v = binding
.get(&vid)
.ok_or_else(|| RefEvalError::UnboundHeadVariable(name.clone()))?
.clone();
row.push(v);
}
Term::Integer(n) => row.push(RefValue::I64(*n)),
Term::String(s) => row.push(RefValue::Symbol(s.clone())),
Term::Anonymous
| Term::Symbol(_)
| Term::Float(_)
| Term::Aggregate(_)
| Term::List(_)
| Term::Cons { .. }
| Term::Compound { .. }
| Term::PredRef(_) => {
return Err(RefEvalError::UnboundHeadVariable(format!(
"unsupported head term: {term:?}"
)));
}
}
}
results.push(row);
return Ok(());
}
let next_vid = var_order[depth];
let VertexId(idx) = next_vid;
for value in &domains[idx] {
binding.insert(next_vid, value.clone());
enumerate(
hg,
atom_specs,
comparisons,
domains,
var_order,
depth + 1,
binding,
head,
results,
)?;
binding.remove(&next_vid);
}
Ok(())
}
fn atom_has_matching_row(spec: &AtomSpec<'_>, binding: &BTreeMap<VertexId, RefValue>) -> bool {
'rows: for row in &spec.relation.rows {
for (i, pos) in spec.positions.iter().enumerate() {
let row_v = &row[i];
match pos {
PositionSpec::Vertex(vid) => {
let bound = match binding.get(vid) {
Some(b) => b,
None => return false, };
if bound != row_v {
continue 'rows;
}
}
PositionSpec::Constant(c) => {
if c != row_v {
continue 'rows;
}
}
PositionSpec::Wildcard => {
}
}
}
return true;
}
false
}
fn evaluate_comparison(
cmp: &Comparison,
binding: &BTreeMap<VertexId, RefValue>,
hg: &HypergraphRule,
) -> Result<bool, RefEvalError> {
let lhs = resolve_term_for_comparison(&cmp.left, binding, hg)?;
let rhs = resolve_term_for_comparison(&cmp.right, binding, hg)?;
apply_comparison_op(&lhs, &rhs, cmp.op)
}
fn resolve_term_for_comparison(
term: &Term,
binding: &BTreeMap<VertexId, RefValue>,
hg: &HypergraphRule,
) -> Result<RefValue, RefEvalError> {
match term {
Term::Variable(name) => {
let vid = hg
.vertices
.iter()
.position(|v| &v.name == name)
.map(VertexId)
.ok_or_else(|| RefEvalError::UnboundHeadVariable(name.clone()))?;
binding
.get(&vid)
.cloned()
.ok_or_else(|| RefEvalError::UnboundHeadVariable(name.clone()))
}
Term::Integer(n) => Ok(RefValue::I64(*n)),
Term::String(s) => Ok(RefValue::Symbol(s.clone())),
Term::Anonymous
| Term::Symbol(_)
| Term::Float(_)
| Term::Aggregate(_)
| Term::List(_)
| Term::Cons { .. }
| Term::Compound { .. }
| Term::PredRef(_) => Err(RefEvalError::ComparisonTypeMismatch {
left: format!("{term:?}"),
right: "<n/a>".to_string(),
op: CompOp::Eq,
}),
}
}
fn apply_comparison_op(lhs: &RefValue, rhs: &RefValue, op: CompOp) -> Result<bool, RefEvalError> {
let to_i128 = |v: &RefValue| -> Option<i128> {
match v {
RefValue::U32(n) => Some(*n as i128),
RefValue::U64(n) => Some(*n as i128),
RefValue::I32(n) => Some(*n as i128),
RefValue::I64(n) => Some(*n as i128),
_ => None,
}
};
let cmp_result = match (lhs, rhs) {
(RefValue::Symbol(a), RefValue::Symbol(b)) => match op {
CompOp::Eq => Some(a == b),
CompOp::Ne => Some(a != b),
CompOp::Lt => Some(a < b),
CompOp::Le => Some(a <= b),
CompOp::Gt => Some(a > b),
CompOp::Ge => Some(a >= b),
},
(RefValue::Bool(a), RefValue::Bool(b)) => match op {
CompOp::Eq => Some(a == b),
CompOp::Ne => Some(a != b),
CompOp::Lt => Some(!a & b),
CompOp::Le => Some(*a <= *b),
CompOp::Gt => Some(*a & !b),
CompOp::Ge => Some(*a >= *b),
},
_ => match (to_i128(lhs), to_i128(rhs)) {
(Some(l), Some(r)) => Some(match op {
CompOp::Eq => l == r,
CompOp::Ne => l != r,
CompOp::Lt => l < r,
CompOp::Le => l <= r,
CompOp::Gt => l > r,
CompOp::Ge => l >= r,
}),
_ => None,
},
};
cmp_result.ok_or_else(|| RefEvalError::ComparisonTypeMismatch {
left: format!("{lhs:?}"),
right: format!("{rhs:?}"),
op,
})
}