use std::collections::{BTreeMap, HashMap};
use std::hash::{Hash, Hasher};
use xlog_core::{Result, XlogError};
use xlog_logic::ast::{
AggExpr, AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Evidence, ProbQuery, Program, Rule, Term,
};
use xlog_logic::stratify::{
analyze_stratification, build_dependency_graph, find_sccs_for_lowering, stratify,
};
use crate::wfs::{evaluate_wfs_rules, WfsAtom, WfsConfig, WfsLiteral, WfsRule};
use crate::aggregates::AggState;
use crate::pir::{ChoiceVarId, LeafId, PirGraph, PirNodeId};
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Value {
I64(i64),
F64(u64),
Symbol(u32),
String(String),
}
impl From<i64> for Value {
fn from(v: i64) -> Self {
Self::I64(v)
}
}
impl From<u32> for Value {
fn from(v: u32) -> Self {
Self::Symbol(v)
}
}
impl From<String> for Value {
fn from(v: String) -> Self {
Self::String(v)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct GroundAtom {
pub predicate: String,
pub args: Vec<Value>,
}
impl GroundAtom {
pub fn new(predicate: impl Into<String>, args: Vec<Value>) -> Self {
Self {
predicate: predicate.into(),
args,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChoiceSource {
pub choices: Vec<(GroundAtom, f64)>,
pub choice_index: usize,
pub source_id: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregateLiftStatus {
Fired,
FallbackExactEnumeration,
Declined,
}
impl AggregateLiftStatus {
pub fn as_str(self) -> &'static str {
match self {
AggregateLiftStatus::Fired => "fired",
AggregateLiftStatus::FallbackExactEnumeration => "fallback_exact_enumeration",
AggregateLiftStatus::Declined => "declined",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AggregateLiftReport {
pub predicate: String,
pub group_key: Vec<Value>,
pub operator: String,
pub finite_domain_source: String,
pub deterministic_rows: usize,
pub uncertain_rows: usize,
pub domain_size: usize,
pub cap: usize,
pub status: AggregateLiftStatus,
pub reason: String,
pub naive_outcomes: u128,
pub dynamic_programming_states: usize,
}
#[derive(Debug, Clone)]
struct Relation {
tuples: BTreeMap<Vec<Value>, PirNodeId>,
}
impl Relation {
fn new() -> Self {
Self {
tuples: BTreeMap::new(),
}
}
fn get(&self, tuple: &[Value]) -> Option<PirNodeId> {
self.tuples.get(tuple).copied()
}
fn is_empty(&self) -> bool {
self.tuples.is_empty()
}
fn insert_or(&mut self, tuple: Vec<Value>, formula: PirNodeId, builder: &mut PirBuilder) {
let entry = self
.tuples
.entry(tuple)
.or_insert_with(|| builder.const_false());
*entry = builder.or(vec![*entry, formula]);
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum PirKey {
Const(bool),
Lit(LeafId),
NegLit(LeafId),
And(Vec<PirNodeId>),
Or(Vec<PirNodeId>),
Decision {
var: ChoiceVarId,
child_false: PirNodeId,
child_true: PirNodeId,
},
}
impl Hash for PirKey {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
PirKey::Const(b) => {
0u8.hash(state);
b.hash(state);
}
PirKey::Lit(l) => {
1u8.hash(state);
l.hash(state);
}
PirKey::NegLit(l) => {
5u8.hash(state);
l.hash(state);
}
PirKey::And(children) => {
2u8.hash(state);
children.hash(state);
}
PirKey::Or(children) => {
3u8.hash(state);
children.hash(state);
}
PirKey::Decision {
var,
child_false,
child_true,
} => {
4u8.hash(state);
var.hash(state);
child_false.hash(state);
child_true.hash(state);
}
}
}
}
#[derive(Debug)]
struct PirBuilder {
pir: PirGraph,
intern: HashMap<PirKey, PirNodeId>,
const_true: PirNodeId,
const_false: PirNodeId,
}
impl PirBuilder {
fn new() -> Self {
let mut pir = PirGraph::new();
let const_true = pir.const_true();
let const_false = pir.const_false();
let mut intern = HashMap::new();
intern.insert(PirKey::Const(true), const_true);
intern.insert(PirKey::Const(false), const_false);
Self {
pir,
intern,
const_true,
const_false,
}
}
fn finish(self) -> PirGraph {
self.pir
}
fn const_true(&self) -> PirNodeId {
self.const_true
}
fn const_false(&self) -> PirNodeId {
self.const_false
}
fn lit(&mut self, leaf: LeafId) -> PirNodeId {
let key = PirKey::Lit(leaf);
if let Some(&id) = self.intern.get(&key) {
return id;
}
let id = self.pir.lit(leaf);
self.intern.insert(key, id);
id
}
fn neg_lit(&mut self, leaf: LeafId) -> PirNodeId {
let key = PirKey::NegLit(leaf);
if let Some(&id) = self.intern.get(&key) {
return id;
}
let id = self.pir.neg_lit(leaf);
self.intern.insert(key, id);
id
}
fn and(&mut self, mut children: Vec<PirNodeId>) -> PirNodeId {
children.retain(|&c| c != self.const_true);
if children.contains(&self.const_false) {
return self.const_false;
}
if children.is_empty() {
return self.const_true;
}
if children.len() == 1 {
return children[0];
}
children.sort_by_key(|id| id.as_u32());
children.dedup();
if children.len() == 1 {
return children[0];
}
let key = PirKey::And(children.clone());
if let Some(&id) = self.intern.get(&key) {
return id;
}
let id = self.pir.and(children);
self.intern.insert(key, id);
id
}
fn or(&mut self, mut children: Vec<PirNodeId>) -> PirNodeId {
children.retain(|&c| c != self.const_false);
if children.contains(&self.const_true) {
return self.const_true;
}
if children.is_empty() {
return self.const_false;
}
if children.len() == 1 {
return children[0];
}
children.sort_by_key(|id| id.as_u32());
children.dedup();
if children.len() == 1 {
return children[0];
}
let key = PirKey::Or(children.clone());
if let Some(&id) = self.intern.get(&key) {
return id;
}
let id = self.pir.or(children);
self.intern.insert(key, id);
id
}
fn decision(
&mut self,
var: ChoiceVarId,
child_false: PirNodeId,
child_true: PirNodeId,
) -> PirNodeId {
if child_false == child_true {
return child_true;
}
let key = PirKey::Decision {
var,
child_false,
child_true,
};
if let Some(&id) = self.intern.get(&key) {
return id;
}
let id = self.pir.decision(var, child_false, child_true);
self.intern.insert(key, id);
id
}
fn choice_lit(&mut self, var: ChoiceVarId, is_true: bool) -> PirNodeId {
if is_true {
self.decision(var, self.const_false(), self.const_true())
} else {
self.decision(var, self.const_true(), self.const_false())
}
}
}
#[derive(Debug)]
pub struct Provenance {
pub pir: PirGraph,
pub leaf_probs: BTreeMap<LeafId, f64>,
pub choice_probs: BTreeMap<ChoiceVarId, (f64, f64)>,
tuple_formulas: BTreeMap<GroundAtom, PirNodeId>,
pub queries: Vec<GroundAtom>,
pub evidence: Vec<(GroundAtom, bool)>,
pub leaf_atoms: BTreeMap<LeafId, GroundAtom>,
pub choice_sources: BTreeMap<ChoiceVarId, ChoiceSource>,
pub aggregate_lifting: Vec<AggregateLiftReport>,
}
impl Provenance {
pub fn query_formula(&self, predicate: &str, args: &[Value]) -> Option<PirNodeId> {
self.tuple_formulas
.get(&GroundAtom::new(predicate, args.to_vec()))
.copied()
}
pub fn leaf_atom(&self, leaf: LeafId) -> Option<&GroundAtom> {
self.leaf_atoms.get(&leaf)
}
pub fn choice_source(&self, var: ChoiceVarId) -> Option<&ChoiceSource> {
self.choice_sources.get(&var)
}
pub fn atoms_with_formulas(&self) -> impl Iterator<Item = (&GroundAtom, PirNodeId)> + '_ {
self.tuple_formulas.iter().map(|(atom, &id)| (atom, id))
}
}
pub fn extract_from_source(source: &str) -> Result<Provenance> {
let program = xlog_logic::parse_program(source)?;
extract_from_program(&program)
}
pub fn extract_from_program(program: &Program) -> Result<Provenance> {
let _ = stratify(program)?;
let mut builder = PirBuilder::new();
let mut leaf_probs: BTreeMap<LeafId, f64> = BTreeMap::new();
let mut choice_probs: BTreeMap<ChoiceVarId, (f64, f64)> = BTreeMap::new();
let mut leaf_atoms: BTreeMap<LeafId, GroundAtom> = BTreeMap::new();
let mut choice_sources: BTreeMap<ChoiceVarId, ChoiceSource> = BTreeMap::new();
let mut aggregate_lifting: Vec<AggregateLiftReport> = Vec::new();
let mut store: BTreeMap<String, Relation> = BTreeMap::new();
for fact in program.facts() {
let key = atom_key_from_ground_atom(&fact.head)?;
let rel = store
.entry(key.predicate.clone())
.or_insert_with(Relation::new);
rel.insert_or(key.args.clone(), builder.const_true(), &mut builder);
}
let mut next_leaf: u32 = 0;
for pf in &program.prob_facts {
validate_prob(pf.prob, "probabilistic fact")?;
let key = atom_key_from_ground_atom(&pf.atom)?;
let leaf = LeafId::new(next_leaf);
next_leaf = next_leaf.checked_add(1).ok_or_else(|| {
XlogError::Compilation("probabilistic fact leaf id overflow".to_string())
})?;
leaf_probs.insert(leaf, pf.prob);
leaf_atoms.insert(leaf, key.clone());
let rel = store
.entry(key.predicate.clone())
.or_insert_with(Relation::new);
rel.insert_or(key.args.clone(), builder.lit(leaf), &mut builder);
}
let mut next_choice: u32 = 0;
for ad in &program.annotated_disjunctions {
if ad.choices.is_empty() {
return Err(XlogError::Compilation(
"Annotated disjunction must contain at least one choice".to_string(),
));
}
let (vars, outcome_formulas) = compile_annotated_disjunction(
ad,
&mut next_choice,
&mut choice_probs,
&mut choice_sources,
&mut builder,
)?;
let _ = vars;
for (pf, formula) in ad.choices.iter().zip(outcome_formulas) {
let key = atom_key_from_ground_atom(&pf.atom)?;
let rel = store
.entry(key.predicate.clone())
.or_insert_with(Relation::new);
rel.insert_or(key.args.clone(), formula, &mut builder);
}
}
let graph = build_dependency_graph(program);
for pred in &graph.predicates {
store.entry(pred.clone()).or_insert_with(Relation::new);
}
let strat_result = analyze_stratification(program);
let sccs = find_sccs_for_lowering(&graph);
let non_monotone_scc_preds: std::collections::HashSet<String> = strat_result
.sccs
.iter()
.enumerate()
.filter(|(i, _)| strat_result.non_monotone_sccs.contains(i))
.flat_map(|(_, scc)| scc.iter().cloned())
.collect();
let mut rules_by_head: BTreeMap<String, Vec<Rule>> = BTreeMap::new();
for rule in program.proper_rules() {
rules_by_head
.entry(rule.head.predicate.clone())
.or_default()
.push(rule.clone());
}
for scc in sccs {
let mut scc_rules: Vec<Rule> = Vec::new();
for pred in &scc {
if let Some(rules) = rules_by_head.get(pred) {
scc_rules.extend(rules.iter().cloned());
}
}
if scc_rules.is_empty() {
continue;
}
let is_non_monotone = scc.iter().any(|p| non_monotone_scc_preds.contains(p));
if is_non_monotone {
eval_non_monotone_scc_with_wfs(&scc, &scc_rules, &mut store, &mut builder)?;
} else {
let recursive = is_recursive_scc(&scc, &scc_rules);
if recursive {
eval_recursive_scc(
&scc,
&scc_rules,
&mut store,
&mut builder,
&mut aggregate_lifting,
)?;
} else {
eval_non_recursive_scc(
&scc_rules,
&mut store,
&mut builder,
&mut aggregate_lifting,
)?;
}
}
}
let mut tuple_formulas: BTreeMap<GroundAtom, PirNodeId> = BTreeMap::new();
for (pred, rel) in &store {
for (tuple, formula) in &rel.tuples {
tuple_formulas.insert(GroundAtom::new(pred.clone(), tuple.clone()), *formula);
}
}
let mut queries: Vec<GroundAtom> = Vec::new();
for ProbQuery { atom } in &program.prob_queries {
queries.push(atom_key_from_ground_atom(atom)?);
}
let mut evidence: Vec<(GroundAtom, bool)> = Vec::new();
for Evidence { atom, value } in &program.evidence {
evidence.push((atom_key_from_ground_atom(atom)?, *value));
}
Ok(Provenance {
pir: builder.finish(),
leaf_probs,
choice_probs,
tuple_formulas,
queries,
evidence,
leaf_atoms,
choice_sources,
aggregate_lifting,
})
}
pub(crate) fn validate_prob(p: f64, what: &str) -> Result<()> {
if !(0.0..=1.0).contains(&p) || p.is_nan() {
return Err(XlogError::Compilation(format!(
"Invalid probability {} for {} (expected 0<=p<=1)",
p, what
)));
}
Ok(())
}
pub(crate) fn atom_key_from_ground_atom(atom: &Atom) -> Result<GroundAtom> {
let mut args = Vec::with_capacity(atom.terms.len());
for term in &atom.terms {
if !term.is_constant() {
return Err(XlogError::Compilation(format!(
"Expected ground atom, found non-constant term in {}",
atom.predicate
)));
}
args.push(value_from_term(term)?);
}
Ok(GroundAtom::new(atom.predicate.clone(), args))
}
pub(crate) fn value_from_term(term: &Term) -> Result<Value> {
match term {
Term::Integer(i) => Ok(Value::I64(*i)),
Term::Float(f) => Ok(Value::F64(f.to_bits())),
Term::String(s) => Ok(Value::String(s.clone())),
Term::Symbol(id) => Ok(Value::Symbol(*id)),
Term::Variable(_) | Term::Anonymous | Term::Aggregate(_) => Err(XlogError::Compilation(
"Non-constant term cannot be converted to a value".to_string(),
)),
Term::List(_) => Err(v085_prob_term_error("value conversion", "list")),
Term::Cons { .. } => Err(v085_prob_term_error("value conversion", "cons")),
Term::Compound { .. } => Err(v085_prob_term_error("value conversion", "compound")),
Term::PredRef(_) => Err(v085_prob_term_error("value conversion", "predref")),
}
}
fn v085_prob_term_error(context: &str, kind: &str) -> XlogError {
XlogError::Compilation(format!(
"v0.8.5 term form '{}' is parsed but not supported in probabilistic {} before its G085 implementation node",
kind, context
))
}
fn compile_annotated_disjunction(
ad: &xlog_logic::ast::AnnotatedDisjunction,
next_choice: &mut u32,
choice_probs: &mut BTreeMap<ChoiceVarId, (f64, f64)>,
choice_sources: &mut BTreeMap<ChoiceVarId, ChoiceSource>,
builder: &mut PirBuilder,
) -> Result<(Vec<ChoiceVarId>, Vec<PirNodeId>)> {
for pf in &ad.choices {
validate_prob(pf.prob, "annotated disjunction choice")?;
let _ = atom_key_from_ground_atom(&pf.atom)?;
}
let explicit_choices: Vec<(GroundAtom, f64)> = ad
.choices
.iter()
.map(|pf| {
let atom = atom_key_from_ground_atom(&pf.atom).unwrap();
(atom, pf.prob)
})
.collect();
let mut probs: Vec<f64> = ad.choices.iter().map(|pf| pf.prob).collect();
let sum: f64 = probs.iter().copied().sum();
let eps = 1e-12;
if sum > 1.0 + eps {
return Err(XlogError::Compilation(format!(
"Annotated disjunction probabilities sum to {} (> 1.0)",
sum
)));
}
let mut has_none = false;
let none_prob = (1.0 - sum).max(0.0);
if none_prob > eps {
probs.push(none_prob);
has_none = true;
}
let m = probs.len();
if m == 1 {
return Ok((Vec::new(), vec![builder.const_true()]));
}
let mut vars: Vec<ChoiceVarId> = Vec::with_capacity(m.saturating_sub(1));
let mut remaining = 1.0f64;
for (i, &p_i) in probs.iter().enumerate().take(m - 1) {
let cond_true = if remaining <= 0.0 {
0.0
} else {
p_i / remaining
};
validate_prob(cond_true, "annotated disjunction conditional")?;
let cond_false = 1.0 - cond_true;
let var = ChoiceVarId::new(*next_choice);
*next_choice = (*next_choice).checked_add(1).ok_or_else(|| {
XlogError::Compilation("annotated disjunction choice id overflow".to_string())
})?;
vars.push(var);
choice_probs.insert(var, (cond_true, cond_false));
choice_sources.insert(
var,
ChoiceSource {
choices: explicit_choices.clone(),
choice_index: i,
source_id: None,
},
);
remaining -= p_i;
}
let mut outcome_formulas: Vec<PirNodeId> = Vec::new();
for i in 0..ad.choices.len() {
let mut conds: Vec<PirNodeId> = Vec::new();
for (j, &var) in vars.iter().enumerate() {
if j < i {
conds.push(builder.choice_lit(var, false));
} else if j == i {
conds.push(builder.choice_lit(var, true));
break;
}
}
outcome_formulas.push(builder.and(conds));
}
if has_none {
}
Ok((vars, outcome_formulas))
}
fn is_recursive_scc(scc: &[String], rules: &[Rule]) -> bool {
if scc.len() > 1 {
return true;
}
let Some(only) = scc.first() else {
return false;
};
for rule in rules {
for lit in &rule.body {
if let BodyLiteral::Positive(atom) = lit {
if &atom.predicate == only {
return true;
}
}
}
}
false
}
fn eval_non_recursive_scc(
rules: &[Rule],
store: &mut BTreeMap<String, Relation>,
builder: &mut PirBuilder,
aggregate_lifting: &mut Vec<AggregateLiftReport>,
) -> Result<()> {
for rule in rules {
let derived = eval_rule(
rule,
store,
&BTreeMap::new(),
None,
builder,
aggregate_lifting,
)?;
let rel = store
.entry(rule.head.predicate.clone())
.or_insert_with(Relation::new);
for (tuple, formula) in derived {
rel.insert_or(tuple, formula, builder);
}
}
Ok(())
}
const MAX_PROVENANCE_ITERATIONS: usize = 1024;
fn eval_recursive_scc(
scc: &[String],
rules: &[Rule],
store: &mut BTreeMap<String, Relation>,
builder: &mut PirBuilder,
aggregate_lifting: &mut Vec<AggregateLiftReport>,
) -> Result<()> {
let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
let mut full: BTreeMap<String, Relation> = BTreeMap::new();
for pred in scc {
let rel = store.get(pred).cloned().unwrap_or_else(Relation::new);
full.insert(pred.clone(), rel);
}
let mut delta: BTreeMap<String, Relation> = BTreeMap::new();
for rule in rules {
let derived = eval_rule(rule, store, &full, None, builder, aggregate_lifting)?;
if derived.is_empty() {
continue;
}
let head = rule.head.predicate.clone();
let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
let full_rel = full.entry(head).or_insert_with(Relation::new);
for (tuple, proof) in derived {
let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
let combined = builder.or(vec![old, proof]);
if combined != old {
full_rel.tuples.insert(tuple.clone(), combined);
delta_rel.insert_or(tuple, proof, builder);
}
}
}
let mut reached_fixpoint = false;
for _ in 0..MAX_PROVENANCE_ITERATIONS {
let any_delta = delta.values().any(|r| !r.is_empty());
if !any_delta {
reached_fixpoint = true;
break;
}
let full_prev = full.clone();
let delta_prev = delta.clone();
delta.clear();
for rule in rules {
let body_indices: Vec<usize> = rule
.body
.iter()
.enumerate()
.filter_map(|(i, lit)| match lit {
BodyLiteral::Positive(atom) if scc_set.contains(atom.predicate.as_str()) => {
let pred = &atom.predicate;
let non_empty =
delta_prev.get(pred).map(|r| !r.is_empty()).unwrap_or(false);
non_empty.then_some(i)
}
_ => None,
})
.collect();
if body_indices.is_empty() {
continue;
}
let mut derived_all: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
for idx in body_indices {
let derived = eval_rule(
rule,
store,
&full_prev,
Some((idx, &delta_prev)),
builder,
aggregate_lifting,
)?;
for (tuple, proof) in derived {
let entry = derived_all
.entry(tuple)
.or_insert_with(|| builder.const_false());
*entry = builder.or(vec![*entry, proof]);
}
}
if derived_all.is_empty() {
continue;
}
let head = rule.head.predicate.clone();
let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
let full_rel = full.entry(head).or_insert_with(Relation::new);
for (tuple, proof) in derived_all {
let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
let combined = builder.or(vec![old, proof]);
if combined != old {
full_rel.tuples.insert(tuple.clone(), combined);
delta_rel.insert_or(tuple, proof, builder);
}
}
}
}
if !reached_fixpoint {
return Err(XlogError::Compilation(format!(
"Provenance iteration limit ({}) exceeded for SCC {:?}",
MAX_PROVENANCE_ITERATIONS, scc
)));
}
for (pred, rel) in full {
store.insert(pred, rel);
}
Ok(())
}
fn eval_non_monotone_scc_with_wfs(
scc: &[String],
rules: &[Rule],
store: &mut BTreeMap<String, Relation>,
builder: &mut PirBuilder,
) -> Result<()> {
let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
let mut wfs_rules: Vec<WfsRule> = Vec::new();
for rule in rules {
let grounded = ground_rule_for_wfs(rule, store, &scc_set, builder)?;
wfs_rules.extend(grounded);
}
if wfs_rules.is_empty() {
return Ok(());
}
let wfs_result = evaluate_wfs_rules(&wfs_rules, &mut builder.pir, &WfsConfig::default())?;
for (wfs_atom, prov) in wfs_result.true_set {
let rel = store
.entry(wfs_atom.predicate.clone())
.or_insert_with(Relation::new);
rel.insert_or(wfs_atom.args, prov, builder);
}
Ok(())
}
fn ground_rule_for_wfs(
rule: &Rule,
store: &BTreeMap<String, Relation>,
scc_set: &std::collections::HashSet<&str>,
builder: &mut PirBuilder,
) -> Result<Vec<WfsRule>> {
let mut bindings: Vec<(HashMap<String, Value>, PirNodeId)> =
vec![(HashMap::new(), builder.const_true())];
let mut wfs_body_template: Vec<(usize, bool)> = Vec::new();
for (idx, lit) in rule.body.iter().enumerate() {
match lit {
BodyLiteral::Positive(atom) => {
if scc_set.contains(atom.predicate.as_str()) {
wfs_body_template.push((idx, true));
} else {
let rel = store.get(&atom.predicate);
let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
for (binding, prov) in bindings {
if let Some(rel) = rel {
for (tuple, tuple_prov) in &rel.tuples {
let mut new_binding = binding.clone();
if unify_atom(atom, tuple, &mut new_binding)? {
let new_prov = builder.and(vec![prov, *tuple_prov]);
next_bindings.push((new_binding, new_prov));
}
}
}
}
bindings = next_bindings;
if bindings.is_empty() {
return Ok(Vec::new());
}
}
}
BodyLiteral::Negated(atom) => {
if scc_set.contains(atom.predicate.as_str()) {
wfs_body_template.push((idx, false));
} else {
let rel = store.get(&atom.predicate);
let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
for (binding, prov) in bindings {
let all_bound = atom.terms.iter().all(|t| match t {
Term::Variable(v) => binding.contains_key(v),
_ => true,
});
if !all_bound {
continue;
}
if let Some(rel) = rel {
let mut matching_provs: Vec<PirNodeId> = Vec::new();
for (tuple, tuple_prov) in &rel.tuples {
let mut test_binding = binding.clone();
if unify_atom(atom, tuple, &mut test_binding)? {
matching_provs.push(*tuple_prov);
}
}
if matching_provs.is_empty() {
next_bindings.push((binding, prov));
} else {
let combined = builder.or(matching_provs);
let neg_prov = negate_provenance(combined, builder);
let new_prov = builder.and(vec![prov, neg_prov]);
next_bindings.push((binding, new_prov));
}
} else {
next_bindings.push((binding, prov));
}
}
bindings = next_bindings;
if bindings.is_empty() {
return Ok(Vec::new());
}
}
}
BodyLiteral::Epistemic(lit) => {
return Err(XlogError::UnsupportedEpistemicConstruct {
construct: "probabilistic WFS grounding".to_string(),
context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
});
}
BodyLiteral::Comparison(cmp) => {
let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
for (binding, prov) in bindings {
if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
next_bindings.push((binding, prov));
}
}
bindings = next_bindings;
if bindings.is_empty() {
return Ok(Vec::new());
}
}
BodyLiteral::IsExpr(is_expr) => {
let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
for (mut binding, prov) in bindings {
if binding.contains_key(&is_expr.target) {
return Err(XlogError::Compilation(format!(
"Is-expression target {} is already bound",
is_expr.target
)));
}
let v = eval_arith_expr(&is_expr.expr, &binding)?;
binding.insert(is_expr.target.clone(), v);
next_bindings.push((binding, prov));
}
bindings = next_bindings;
if bindings.is_empty() {
return Ok(Vec::new());
}
}
BodyLiteral::Univ(_) => {
return Err(XlogError::Compilation(
"v0.8.5 meta error: univ literal was not normalized before provenance extraction"
.to_string(),
));
}
}
}
let mut result: Vec<WfsRule> = Vec::new();
for (binding, external_prov) in bindings {
let mut wfs_body: Vec<WfsLiteral> = Vec::new();
for &(idx, is_positive) in &wfs_body_template {
let atom = match &rule.body[idx] {
BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a,
_ => continue,
};
let mut args: Vec<Value> = Vec::new();
for term in &atom.terms {
match term {
Term::Variable(name) => {
if let Some(v) = binding.get(name) {
args.push(v.clone());
} else {
continue;
}
}
_ => {
args.push(value_from_term(term)?);
}
}
}
let wfs_atom = WfsAtom::new(atom.predicate.clone(), args);
if is_positive {
wfs_body.push(WfsLiteral::Positive(wfs_atom));
} else {
wfs_body.push(WfsLiteral::Negative(wfs_atom));
}
}
let mut head_args: Vec<Value> = Vec::new();
for term in &rule.head.terms {
match term {
Term::Variable(name) => {
if let Some(v) = binding.get(name) {
head_args.push(v.clone());
} else {
continue;
}
}
_ => {
head_args.push(value_from_term(term)?);
}
}
}
let wfs_head = WfsAtom::new(rule.head.predicate.clone(), head_args);
result.push(WfsRule::new(wfs_head, wfs_body, external_prov));
}
Ok(result)
}
fn negate_provenance(prov: PirNodeId, builder: &mut PirBuilder) -> PirNodeId {
use crate::pir::PirNode;
match builder.pir.node(prov).cloned() {
Some(PirNode::Const(b)) => {
if b {
builder.const_false()
} else {
builder.const_true()
}
}
Some(PirNode::Lit { leaf }) => builder.neg_lit(leaf),
Some(PirNode::NegLit { leaf }) => builder.lit(leaf), Some(PirNode::And { children }) => {
let neg_children: Vec<PirNodeId> = children
.iter()
.map(|&c| negate_provenance(c, builder))
.collect();
builder.or(neg_children)
}
Some(PirNode::Or { children }) => {
let neg_children: Vec<PirNodeId> = children
.iter()
.map(|&c| negate_provenance(c, builder))
.collect();
builder.and(neg_children)
}
Some(PirNode::Decision {
var,
child_false,
child_true,
}) => {
let neg_false = negate_provenance(child_false, builder);
let neg_true = negate_provenance(child_true, builder);
builder.decision(var, neg_false, neg_true)
}
None => prov,
}
}
fn eval_rule(
rule: &Rule,
global: &BTreeMap<String, Relation>,
full_scc: &BTreeMap<String, Relation>,
delta_scc: Option<(usize, &BTreeMap<String, Relation>)>,
builder: &mut PirBuilder,
aggregate_lifting: &mut Vec<AggregateLiftReport>,
) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
let mut states: Vec<(HashMap<String, Value>, PirNodeId)> =
vec![(HashMap::new(), builder.const_true())];
for (idx, lit) in rule.body.iter().enumerate() {
let mut next_states: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
match lit {
BodyLiteral::Positive(atom) => {
let rel = select_relation(atom, idx, global, full_scc, delta_scc)?;
for (binding, prov) in states {
for (tuple, tuple_prov) in &rel.tuples {
let mut binding2 = binding.clone();
if unify_atom(atom, tuple, &mut binding2)? {
let prov2 = builder.and(vec![prov, *tuple_prov]);
next_states.push((binding2, prov2));
}
}
}
}
BodyLiteral::Comparison(cmp) => {
for (binding, prov) in states {
if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
next_states.push((binding, prov));
}
}
}
BodyLiteral::IsExpr(is_expr) => {
for (mut binding, prov) in states {
if binding.contains_key(&is_expr.target) {
return Err(XlogError::Compilation(format!(
"Is-expression target {} is already bound",
is_expr.target
)));
}
let v = eval_arith_expr(&is_expr.expr, &binding)?;
binding.insert(is_expr.target.clone(), v);
next_states.push((binding, prov));
}
}
BodyLiteral::Negated(atom) => {
let rel = if let Some(r) = full_scc.get(&atom.predicate) {
r
} else if let Some(r) = global.get(&atom.predicate) {
r
} else {
for (binding, prov) in states {
let all_bound = atom.terms.iter().all(|t| match t {
Term::Variable(v) => binding.contains_key(v),
_ => true,
});
if all_bound {
next_states.push((binding, prov));
}
}
states = next_states;
if states.is_empty() {
break;
}
continue;
};
for (binding, prov) in states {
let all_bound = atom.terms.iter().all(|t| match t {
Term::Variable(v) => binding.contains_key(v),
_ => true,
});
if !all_bound {
continue;
}
let mut matching_provs: Vec<PirNodeId> = Vec::new();
for (tuple, tuple_prov) in &rel.tuples {
let mut binding2 = binding.clone();
if unify_atom(atom, tuple, &mut binding2)? {
matching_provs.push(*tuple_prov);
}
}
if matching_provs.is_empty() {
next_states.push((binding, prov));
} else {
let combined_tuple_prov = builder.or(matching_provs);
let neg_prov = negate_provenance(combined_tuple_prov, builder);
let new_prov = builder.and(vec![prov, neg_prov]);
next_states.push((binding, new_prov));
}
}
}
BodyLiteral::Epistemic(lit) => {
return Err(XlogError::UnsupportedEpistemicConstruct {
construct: "probabilistic provenance evaluation".to_string(),
context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
});
}
BodyLiteral::Univ(_) => {
return Err(XlogError::Compilation(
"v0.8.5 meta error: univ literal was not normalized before provenance extraction"
.to_string(),
));
}
}
states = next_states;
if states.is_empty() {
break;
}
}
if rule.has_aggregation() {
eval_aggregate_head_provenance(&rule.head, states, builder, aggregate_lifting)
} else {
let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
for (binding, prov) in states {
let head_tuple = materialize_head(&rule.head, &binding)?;
let entry = out
.entry(head_tuple)
.or_insert_with(|| builder.const_false());
*entry = builder.or(vec![*entry, prov]);
}
Ok(out)
}
}
const MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS: usize = 16;
const MAX_EXACT_PROB_COUNT_LIFT_ROWS: usize = 64;
#[derive(Debug, Clone)]
struct AggregateProvRow {
binding: HashMap<String, Value>,
prov: PirNodeId,
}
fn eval_aggregate_head_provenance(
head: &Atom,
states: Vec<(HashMap<String, Value>, PirNodeId)>,
builder: &mut PirBuilder,
aggregate_lifting: &mut Vec<AggregateLiftReport>,
) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
let (key_vars, key_var_to_pos, agg_specs, agg_to_pos) = aggregate_head_plan(head)?;
let mut deduped_states: BTreeMap<Vec<(String, Value)>, AggregateProvRow> = BTreeMap::new();
for (binding, prov) in states {
let key = canonical_binding_key(&binding);
match deduped_states.get_mut(&key) {
Some(row) => {
row.prov = builder.or(vec![row.prov, prov]);
}
None => {
deduped_states.insert(key, AggregateProvRow { binding, prov });
}
}
}
#[derive(Debug)]
struct GroupRows {
key: Vec<Value>,
rows: Vec<AggregateProvRow>,
}
let mut groups: BTreeMap<Vec<Value>, GroupRows> = BTreeMap::new();
for row in deduped_states.into_values() {
let mut key: Vec<Value> = Vec::with_capacity(key_vars.len());
for name in &key_vars {
let v = row
.binding
.get(name)
.ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
key.push(v.clone());
}
groups
.entry(key.clone())
.or_insert_with(|| GroupRows {
key,
rows: Vec::new(),
})
.rows
.push(row);
}
let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
let count_only = agg_specs.iter().all(|(op, _)| *op == AggOp::Count);
for group in groups.into_values() {
let mut always_rows: Vec<AggregateProvRow> = Vec::new();
let mut uncertain_rows: Vec<AggregateProvRow> = Vec::new();
for row in group.rows {
match pir_const_value(builder, row.prov) {
Some(true) => always_rows.push(row),
Some(false) => {}
None => uncertain_rows.push(row),
}
}
if always_rows.is_empty() && uncertain_rows.is_empty() {
continue;
}
if count_only {
if uncertain_rows.len() > MAX_EXACT_PROB_COUNT_LIFT_ROWS {
return Err(XlogError::Compilation(format!(
"v0.8.5 agg_lift error: count lift finite domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
head.predicate,
group.key,
uncertain_rows.len(),
MAX_EXACT_PROB_COUNT_LIFT_ROWS
)));
}
validate_count_lift_rows(&agg_specs, &always_rows, &uncertain_rows)?;
record_aggregate_lift_reports(
aggregate_lifting,
head,
&group.key,
&agg_specs,
always_rows.len(),
uncertain_rows.len(),
AggregateLiftStatus::Fired,
"finite count domain lifted with exact cardinality dynamic programming",
MAX_EXACT_PROB_COUNT_LIFT_ROWS,
count_lift_dp_states(uncertain_rows.len()),
);
let count_formulas = count_lift_formulas(&uncertain_rows, builder);
for (selected_uncertain_rows, proof) in count_formulas.into_iter().enumerate() {
if always_rows.is_empty() && selected_uncertain_rows == 0 {
continue;
}
let count_value = always_rows.len() + selected_uncertain_rows;
let tuple =
materialize_count_lift_tuple(head, &group.key, &key_var_to_pos, count_value)?;
let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
*entry = builder.or(vec![*entry, proof]);
}
continue;
}
if uncertain_rows.len() > MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS {
return Err(XlogError::Compilation(format!(
"v0.8.5 prob_aggregate error: exact aggregate domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
head.predicate,
group.key,
uncertain_rows.len(),
MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS
)));
}
record_aggregate_lift_reports(
aggregate_lifting,
head,
&group.key,
&agg_specs,
always_rows.len(),
uncertain_rows.len(),
AggregateLiftStatus::FallbackExactEnumeration,
"operator uses exact finite outcome enumeration; lifted implementation is not selected for this aggregate head",
MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS,
0,
);
let mask_count = 1usize << uncertain_rows.len();
for mask in 0..mask_count {
if always_rows.is_empty() && mask == 0 {
continue;
}
let mut agg_states: Vec<AggState> =
agg_specs.iter().map(|(op, _)| AggState::new(*op)).collect();
for row in &always_rows {
update_aggregate_states(&mut agg_states, &agg_specs, row)?;
}
let mut proof_terms: Vec<PirNodeId> = Vec::with_capacity(uncertain_rows.len());
for (idx, row) in uncertain_rows.iter().enumerate() {
if (mask & (1usize << idx)) != 0 {
proof_terms.push(row.prov);
update_aggregate_states(&mut agg_states, &agg_specs, row)?;
} else {
proof_terms.push(negate_provenance(row.prov, builder));
}
}
let tuple = materialize_aggregate_tuple(
head,
&group.key,
&key_var_to_pos,
&agg_specs,
&agg_to_pos,
&agg_states,
)?;
let proof = builder.and(proof_terms);
let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
*entry = builder.or(vec![*entry, proof]);
}
}
Ok(out)
}
fn validate_count_lift_rows(
agg_specs: &[(AggOp, String)],
always_rows: &[AggregateProvRow],
uncertain_rows: &[AggregateProvRow],
) -> Result<()> {
for (_, var) in agg_specs {
for row in always_rows.iter().chain(uncertain_rows.iter()) {
if !row.binding.contains_key(var) {
return Err(XlogError::UnsafeVariable(var.clone()));
}
}
}
Ok(())
}
fn count_lift_formulas(
uncertain_rows: &[AggregateProvRow],
builder: &mut PirBuilder,
) -> Vec<PirNodeId> {
let n = uncertain_rows.len();
let mut dp = vec![builder.const_false(); n + 1];
dp[0] = builder.const_true();
for (idx, row) in uncertain_rows.iter().enumerate() {
let mut next = vec![builder.const_false(); n + 1];
let present = row.prov;
let absent = negate_provenance(row.prov, builder);
for selected in 0..=idx {
let absent_case = builder.and(vec![dp[selected], absent]);
next[selected] = builder.or(vec![next[selected], absent_case]);
let present_case = builder.and(vec![dp[selected], present]);
next[selected + 1] = builder.or(vec![next[selected + 1], present_case]);
}
dp = next;
}
dp
}
fn materialize_count_lift_tuple(
head: &Atom,
group_key: &[Value],
key_var_to_pos: &HashMap<String, usize>,
count_value: usize,
) -> Result<Vec<Value>> {
let count_value: i64 = count_value
.try_into()
.map_err(|_| XlogError::Compilation("count() overflowed i64".to_string()))?;
let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
for term in &head.terms {
match term {
Term::Variable(name) => {
let pos = *key_var_to_pos.get(name).ok_or_else(|| {
XlogError::Compilation(format!(
"Aggregate head variable {} is not a group key",
name
))
})?;
tuple.push(group_key[pos].clone());
}
Term::Aggregate(AggExpr {
op: AggOp::Count, ..
}) => tuple.push(Value::I64(count_value)),
Term::Aggregate(AggExpr { op, .. }) => {
return Err(XlogError::Compilation(format!(
"Internal aggregate lift state mismatch for {}",
agg_op_label(*op)
)));
}
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
tuple.push(value_from_term(term)?);
}
Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
Term::List(_) => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"list",
));
}
Term::Cons { .. } => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"cons",
));
}
Term::Compound { .. } => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"compound",
));
}
Term::PredRef(_) => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"predref",
));
}
}
}
Ok(tuple)
}
#[allow(clippy::too_many_arguments)]
fn record_aggregate_lift_reports(
aggregate_lifting: &mut Vec<AggregateLiftReport>,
head: &Atom,
group_key: &[Value],
agg_specs: &[(AggOp, String)],
deterministic_rows: usize,
uncertain_rows: usize,
status: AggregateLiftStatus,
reason: &str,
cap: usize,
dynamic_programming_states: usize,
) {
for (op, _) in agg_specs {
aggregate_lifting.push(AggregateLiftReport {
predicate: head.predicate.clone(),
group_key: group_key.to_vec(),
operator: agg_op_label(*op).to_string(),
finite_domain_source: "grounded body rows".to_string(),
deterministic_rows,
uncertain_rows,
domain_size: deterministic_rows + uncertain_rows,
cap,
status,
reason: reason.to_string(),
naive_outcomes: naive_outcome_count(uncertain_rows),
dynamic_programming_states,
});
}
}
fn agg_op_label(op: AggOp) -> &'static str {
match op {
AggOp::Count => "count",
AggOp::Sum => "sum",
AggOp::Min => "min",
AggOp::Max => "max",
AggOp::LogSumExp => "logsumexp",
}
}
fn naive_outcome_count(uncertain_rows: usize) -> u128 {
if uncertain_rows >= u128::BITS as usize {
u128::MAX
} else {
1u128 << uncertain_rows
}
}
fn count_lift_dp_states(uncertain_rows: usize) -> usize {
(uncertain_rows + 1) * (uncertain_rows + 2) / 2
}
type AggregatePlan = (
Vec<String>,
HashMap<String, usize>,
Vec<(AggOp, String)>,
HashMap<(AggOp, String), usize>,
);
fn aggregate_head_plan(head: &Atom) -> Result<AggregatePlan> {
let mut key_vars: Vec<String> = Vec::new();
let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
for term in &head.terms {
match term {
Term::Variable(name) => {
if !key_var_to_pos.contains_key(name) {
let pos = key_vars.len();
key_vars.push(name.clone());
key_var_to_pos.insert(name.clone(), pos);
}
}
Term::Aggregate(agg) => {
let key = (agg.op, agg.variable.clone());
if let std::collections::hash_map::Entry::Vacant(entry) =
agg_to_pos.entry(key.clone())
{
let pos = agg_specs.len();
agg_specs.push(key);
entry.insert(pos);
}
}
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {}
Term::Anonymous => {
return Err(XlogError::Compilation(format!(
"Anonymous variable in aggregate head of {} is not supported",
head.predicate
)));
}
Term::List(_) => return Err(v085_prob_term_error("aggregate head planning", "list")),
Term::Cons { .. } => {
return Err(v085_prob_term_error("aggregate head planning", "cons"));
}
Term::Compound { .. } => {
return Err(v085_prob_term_error("aggregate head planning", "compound"));
}
Term::PredRef(_) => {
return Err(v085_prob_term_error("aggregate head planning", "predref"));
}
}
}
Ok((key_vars, key_var_to_pos, agg_specs, agg_to_pos))
}
fn canonical_binding_key(binding: &HashMap<String, Value>) -> Vec<(String, Value)> {
let mut key: Vec<(String, Value)> = binding
.iter()
.map(|(name, value)| (name.clone(), value.clone()))
.collect();
key.sort();
key
}
fn pir_const_value(builder: &PirBuilder, node: PirNodeId) -> Option<bool> {
match builder.pir.node(node) {
Some(crate::pir::PirNode::Const(value)) => Some(*value),
_ => None,
}
}
fn update_aggregate_states(
states: &mut [AggState],
agg_specs: &[(AggOp, String)],
row: &AggregateProvRow,
) -> Result<()> {
for (idx, (op, var)) in agg_specs.iter().enumerate() {
let v = row
.binding
.get(var)
.ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
states[idx].update(*op, v)?;
}
Ok(())
}
fn materialize_aggregate_tuple(
head: &Atom,
group_key: &[Value],
key_var_to_pos: &HashMap<String, usize>,
agg_specs: &[(AggOp, String)],
agg_to_pos: &HashMap<(AggOp, String), usize>,
agg_states: &[AggState],
) -> Result<Vec<Value>> {
let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
for term in &head.terms {
match term {
Term::Variable(name) => {
let pos = *key_var_to_pos.get(name).ok_or_else(|| {
XlogError::Compilation(format!(
"Aggregate head variable {} is not a group key",
name
))
})?;
tuple.push(group_key[pos].clone());
}
Term::Aggregate(AggExpr { op, variable }) => {
let idx = *agg_to_pos
.get(&(*op, variable.clone()))
.expect("agg_to_pos missing");
let spec = agg_specs
.get(idx)
.expect("aggregate state index should have a spec");
tuple.push(agg_states[idx].finish(spec.0)?);
}
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
tuple.push(value_from_term(term)?);
}
Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
Term::List(_) => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"list",
));
}
Term::Cons { .. } => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"cons",
));
}
Term::Compound { .. } => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"compound",
));
}
Term::PredRef(_) => {
return Err(v085_prob_term_error(
"aggregate head materialization",
"predref",
));
}
}
}
Ok(tuple)
}
fn select_relation<'a>(
atom: &Atom,
body_index: usize,
global: &'a BTreeMap<String, Relation>,
full_scc: &'a BTreeMap<String, Relation>,
delta_scc: Option<(usize, &'a BTreeMap<String, Relation>)>,
) -> Result<&'a Relation> {
if let Some((delta_index, delta_map)) = delta_scc {
if delta_index == body_index {
return delta_map.get(&atom.predicate).ok_or_else(|| {
XlogError::Compilation(format!(
"Missing delta relation for predicate {}",
atom.predicate
))
});
}
}
if let Some(rel) = full_scc.get(&atom.predicate) {
return Ok(rel);
}
global
.get(&atom.predicate)
.ok_or_else(|| XlogError::Compilation(format!("Unknown predicate {}", atom.predicate)))
}
pub(crate) fn unify_atom(
atom: &Atom,
tuple: &[Value],
binding: &mut HashMap<String, Value>,
) -> Result<bool> {
if atom.terms.len() != tuple.len() {
return Err(XlogError::Compilation(format!(
"Arity mismatch for {}: atom has {}, tuple has {}",
atom.predicate,
atom.terms.len(),
tuple.len()
)));
}
for (term, value) in atom.terms.iter().zip(tuple.iter()) {
match term {
Term::Variable(name) => match binding.get(name) {
Some(existing) => {
if existing != value {
return Ok(false);
}
}
None => {
binding.insert(name.clone(), value.clone());
}
},
Term::Anonymous => {}
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
if &value_from_term(term)? != value {
return Ok(false);
}
}
Term::Aggregate(AggExpr { op: _, variable: _ }) => {
return Err(XlogError::Compilation(
"Aggregation not supported in provenance extraction".to_string(),
));
}
Term::List(_) => return Err(v085_prob_term_error("unification", "list")),
Term::Cons { .. } => return Err(v085_prob_term_error("unification", "cons")),
Term::Compound { .. } => {
return Err(v085_prob_term_error("unification", "compound"));
}
Term::PredRef(_) => return Err(v085_prob_term_error("unification", "predref")),
}
}
Ok(true)
}
fn materialize_head(head: &Atom, binding: &HashMap<String, Value>) -> Result<Vec<Value>> {
let mut out = Vec::with_capacity(head.terms.len());
for term in &head.terms {
match term {
Term::Variable(name) => {
let v = binding.get(name).ok_or_else(|| {
XlogError::Compilation(format!(
"Unbound head variable {} in {}",
name, head.predicate
))
})?;
out.push(v.clone());
}
Term::Anonymous => {
return Err(XlogError::Compilation(format!(
"Anonymous variable in head of {} is not supported",
head.predicate
)));
}
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
out.push(value_from_term(term)?);
}
Term::Aggregate(AggExpr {
op: AggOp::Count,
variable: _,
})
| Term::Aggregate(AggExpr {
op: AggOp::Sum,
variable: _,
})
| Term::Aggregate(AggExpr {
op: AggOp::Min,
variable: _,
})
| Term::Aggregate(AggExpr {
op: AggOp::Max,
variable: _,
})
| Term::Aggregate(AggExpr {
op: AggOp::LogSumExp,
variable: _,
}) => {
return Err(XlogError::Compilation(
"Aggregation not supported in provenance extraction".to_string(),
));
}
Term::List(_) => return Err(v085_prob_term_error("head materialization", "list")),
Term::Cons { .. } => return Err(v085_prob_term_error("head materialization", "cons")),
Term::Compound { .. } => {
return Err(v085_prob_term_error("head materialization", "compound"));
}
Term::PredRef(_) => {
return Err(v085_prob_term_error("head materialization", "predref"));
}
}
}
Ok(out)
}
pub(crate) fn eval_comparison(
op: CompOp,
left: &Term,
right: &Term,
binding: &HashMap<String, Value>,
) -> Result<bool> {
let l = resolve_term(left, binding)?;
let r = resolve_term(right, binding)?;
match (l, r) {
(Value::I64(a), Value::I64(b)) => Ok(compare_ord(op, a.cmp(&b))),
(Value::F64(a_bits), Value::F64(b_bits)) => {
let a = f64::from_bits(a_bits);
let b = f64::from_bits(b_bits);
match op {
CompOp::Eq => Ok(a == b),
CompOp::Ne => Ok(a != b),
CompOp::Lt => Ok(a < b),
CompOp::Le => Ok(a <= b),
CompOp::Gt => Ok(a > b),
CompOp::Ge => Ok(a >= b),
}
}
(Value::Symbol(a), Value::Symbol(b)) => Ok(compare_ord(op, a.cmp(&b))),
(Value::String(a), Value::String(b)) => Ok(compare_ord(op, a.cmp(&b))),
_ => Err(XlogError::Compilation(
"Comparison between differing types is not supported".to_string(),
)),
}
}
pub(crate) fn compare_ord(op: CompOp, ord: std::cmp::Ordering) -> bool {
use std::cmp::Ordering;
match op {
CompOp::Eq => ord == Ordering::Equal,
CompOp::Ne => ord != Ordering::Equal,
CompOp::Lt => ord == Ordering::Less,
CompOp::Le => ord == Ordering::Less || ord == Ordering::Equal,
CompOp::Gt => ord == Ordering::Greater,
CompOp::Ge => ord == Ordering::Greater || ord == Ordering::Equal,
}
}
pub(crate) fn resolve_term(term: &Term, binding: &HashMap<String, Value>) -> Result<Value> {
match term {
Term::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
XlogError::Compilation(format!("Unbound variable {} in comparison", name))
}),
Term::Anonymous => Err(XlogError::Compilation(
"Anonymous variable not allowed in comparison".to_string(),
)),
Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
value_from_term(term)
}
Term::Aggregate(_) => Err(XlogError::Compilation(
"Aggregation not supported in provenance extraction".to_string(),
)),
Term::List(_) => Err(v085_prob_term_error("comparison", "list")),
Term::Cons { .. } => Err(v085_prob_term_error("comparison", "cons")),
Term::Compound { .. } => Err(v085_prob_term_error("comparison", "compound")),
Term::PredRef(_) => Err(v085_prob_term_error("comparison", "predref")),
}
}
pub(crate) fn eval_arith_expr(expr: &ArithExpr, binding: &HashMap<String, Value>) -> Result<Value> {
match expr {
ArithExpr::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
XlogError::Compilation(format!("Unbound variable {} in arithmetic", name))
}),
ArithExpr::Integer(i) => Ok(Value::I64(*i)),
ArithExpr::Float(f) => Ok(Value::F64(f.to_bits())),
ArithExpr::Add(l, r) => eval_bin_op(l, r, binding, |a, b| a + b, |a, b| a + b),
ArithExpr::Sub(l, r) => eval_bin_op(l, r, binding, |a, b| a - b, |a, b| a - b),
ArithExpr::Mul(l, r) => eval_bin_op(l, r, binding, |a, b| a * b, |a, b| a * b),
ArithExpr::Div(l, r) => eval_bin_op(l, r, binding, |a, b| a / b, |a, b| a / b),
ArithExpr::Mod(l, r) => eval_bin_op(l, r, binding, |a, b| a % b, |a, b| a % b),
ArithExpr::Abs(e) => match eval_arith_expr(e, binding)? {
Value::I64(i) => Ok(Value::I64(i.abs())),
Value::F64(bits) => {
let f = f64::from_bits(bits).abs();
Ok(Value::F64(f.to_bits()))
}
_ => Err(XlogError::Compilation(
"abs() requires numeric input".to_string(),
)),
},
ArithExpr::Min(l, r) => eval_bin_op(l, r, binding, |a, b| a.min(b), |a, b| a.min(b)),
ArithExpr::Max(l, r) => eval_bin_op(l, r, binding, |a, b| a.max(b), |a, b| a.max(b)),
ArithExpr::Pow(l, r) => {
let a = eval_arith_expr(l, binding)?;
let b = eval_arith_expr(r, binding)?;
match (a, b) {
(Value::I64(a), Value::I64(b)) => {
Ok(Value::I64(a.pow(u32::try_from(b).map_err(|_| {
XlogError::Compilation("pow exponent must fit in u32".to_string())
})?)))
}
(Value::F64(a), Value::F64(b)) => Ok(Value::F64(
f64::from_bits(a).powf(f64::from_bits(b)).to_bits(),
)),
_ => Err(XlogError::Compilation(
"pow requires numeric inputs of same type".to_string(),
)),
}
}
ArithExpr::Cast(e, _ty) => {
eval_arith_expr(e, binding)
}
ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
"Function call `{}` must be expanded before provenance extraction",
name
))),
ArithExpr::Conditional { .. } => Err(XlogError::Compilation(
"Conditional expressions must be expanded before provenance extraction".to_string(),
)),
}
}
pub(crate) fn eval_bin_op<FInt, FFloat>(
l: &ArithExpr,
r: &ArithExpr,
binding: &HashMap<String, Value>,
op_int: FInt,
op_float: FFloat,
) -> Result<Value>
where
FInt: FnOnce(i64, i64) -> i64,
FFloat: FnOnce(f64, f64) -> f64,
{
let a = eval_arith_expr(l, binding)?;
let b = eval_arith_expr(r, binding)?;
match (a, b) {
(Value::I64(a), Value::I64(b)) => Ok(Value::I64(op_int(a, b))),
(Value::F64(a), Value::F64(b)) => Ok(Value::F64(
op_float(f64::from_bits(a), f64::from_bits(b)).to_bits(),
)),
_ => Err(XlogError::Compilation(
"Arithmetic operation requires matching numeric types".to_string(),
)),
}
}