use std::collections::{BTreeMap, BTreeSet};
use xlog_core::symbol;
use crate::ast::{AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Program, Rule, Term};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RuleSourceKind {
Source,
Generated,
Mined,
Imported,
RuntimeInjected,
}
impl RuleSourceKind {
pub fn as_str(self) -> &'static str {
match self {
RuleSourceKind::Source => "source",
RuleSourceKind::Generated => "generated",
RuleSourceKind::Mined => "mined",
RuleSourceKind::Imported => "imported",
RuleSourceKind::RuntimeInjected => "runtime_injected",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RuleProvenance {
pub rule_id: String,
pub head: String,
pub source_kind: RuleSourceKind,
pub source_span: Option<String>,
pub generation_trace_hash: Option<String>,
pub support_relation_ids: Vec<String>,
pub counterexample_relation_ids: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QueryProofTrace {
pub query_id: String,
pub query: String,
pub answer_relation: String,
pub rule_ids: Vec<String>,
pub source_facts: Vec<String>,
pub rejected_alternatives: Vec<String>,
}
pub fn rule_provenance(
program: &Program,
generated_program: Option<&Program>,
) -> Vec<RuleProvenance> {
let mut out = Vec::new();
let mut source_keys = BTreeSet::new();
for (idx, rule) in program.rules.iter().enumerate() {
source_keys.insert(rule_key(rule));
out.push(rule_record(idx, rule, RuleSourceKind::Source));
}
if let Some(generated) = generated_program {
let mut generated_idx = 0usize;
for rule in &generated.rules {
if source_keys.contains(&rule_key(rule)) {
continue;
}
out.push(rule_record(generated_idx, rule, RuleSourceKind::Generated));
generated_idx += 1;
}
}
out
}
pub fn build_rule_provenance(
program: &Program,
generated_predicates: &[String],
) -> Vec<RuleProvenance> {
let mut out = rule_provenance(program, None);
for (idx, predicate) in generated_predicates.iter().enumerate() {
out.push(RuleProvenance {
rule_id: format!("rule:generated:{}:{}", idx, predicate),
head: predicate.clone(),
source_kind: RuleSourceKind::Generated,
source_span: None,
generation_trace_hash: Some(stable_hash(&format!("generated:{}", predicate))),
support_relation_ids: vec![predicate.clone()],
counterexample_relation_ids: Vec::new(),
});
}
out
}
pub fn query_proof_traces(
program: &Program,
provenance: &[RuleProvenance],
) -> Vec<QueryProofTrace> {
let mut rule_ids_by_head: BTreeMap<String, Vec<String>> = BTreeMap::new();
for entry in provenance {
rule_ids_by_head
.entry(head_predicate(&entry.head).to_string())
.or_default()
.push(entry.rule_id.clone());
}
program
.queries
.iter()
.enumerate()
.map(|(idx, query)| {
let query_pred = query.atom.predicate.clone();
let deriving_rules: Vec<&Rule> = program
.rules
.iter()
.filter(|rule| !rule.is_fact() && rule.head.predicate == query_pred)
.collect();
let rule_ids = rule_ids_by_head
.get(&query_pred)
.cloned()
.unwrap_or_default();
let source_facts = source_facts_for_rules(program, &deriving_rules);
let rejected_alternatives = deriving_rules
.iter()
.flat_map(|rule| {
rule.body.iter().filter_map(|lit| match lit {
BodyLiteral::Negated(atom) => Some(format!("not {}", format_atom(atom))),
_ => None,
})
})
.collect::<Vec<_>>();
QueryProofTrace {
query_id: format!("query:source:{}:{}", idx, format_atom(&query.atom)),
query: format_atom(&query.atom),
answer_relation: format!("__xlog_query_{}", idx),
rule_ids,
source_facts,
rejected_alternatives,
}
})
.collect()
}
pub fn build_query_proof_traces(program: &Program) -> Vec<QueryProofTrace> {
let provenance = rule_provenance(program, None);
query_proof_traces(program, &provenance)
}
fn rule_record(idx: usize, rule: &Rule, source_kind: RuleSourceKind) -> RuleProvenance {
let head = format_atom(&rule.head);
let prefix = source_kind.as_str();
RuleProvenance {
rule_id: format!("rule:{}:{}:{}", prefix, idx, stable_hash(&rule_key(rule))),
head,
source_kind,
source_span: Some(format!("rule_index:{}", idx)),
generation_trace_hash: Some(stable_hash(&rule_key(rule))),
support_relation_ids: support_relation_ids(rule),
counterexample_relation_ids: Vec::new(),
}
}
fn support_relation_ids(rule: &Rule) -> Vec<String> {
rule.body_predicates()
.into_iter()
.map(str::to_string)
.collect::<BTreeSet<_>>()
.into_iter()
.collect()
}
fn source_facts_for_rules(program: &Program, rules: &[&Rule]) -> Vec<String> {
let wanted: BTreeSet<String> = rules
.iter()
.flat_map(|rule| {
rule.body
.iter()
.filter_map(|lit| lit.atom().map(|atom| atom.predicate.clone()))
})
.collect();
let mut facts = BTreeSet::new();
for fact in program.facts() {
if wanted.contains(&fact.head.predicate) {
facts.insert(format!("{}.", format_atom(&fact.head)));
}
}
facts.into_iter().collect()
}
fn rule_key(rule: &Rule) -> String {
let mut out = format_atom(&rule.head);
if !rule.body.is_empty() {
let body = rule
.body
.iter()
.map(format_body_literal)
.collect::<Vec<_>>()
.join(", ");
out.push_str(" :- ");
out.push_str(&body);
}
out
}
fn head_predicate(head: &str) -> &str {
head.split_once('(').map(|(pred, _)| pred).unwrap_or(head)
}
pub fn format_atom(atom: &Atom) -> String {
let args = atom
.terms
.iter()
.map(format_term)
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", atom.predicate, args)
}
fn format_body_literal(lit: &BodyLiteral) -> String {
match lit {
BodyLiteral::Positive(atom) => format_atom(atom),
BodyLiteral::Negated(atom) => format!("not {}", format_atom(atom)),
BodyLiteral::Epistemic(lit) => format_epistemic_literal(lit),
BodyLiteral::Comparison(comparison) => format!(
"{} {} {}",
format_term(&comparison.left),
format_comp_op(comparison.op),
format_term(&comparison.right)
),
BodyLiteral::IsExpr(is_expr) => {
format!("{} is {}", is_expr.target, format_arith_expr(&is_expr.expr))
}
BodyLiteral::Univ(univ) => {
format!(
"{} =.. {}",
format_term(&univ.term),
format_term(&univ.parts)
)
}
}
}
fn format_epistemic_literal(lit: &crate::ast::EpistemicLiteral) -> String {
let op = match lit.op {
crate::ast::EpistemicOp::Know => "know",
crate::ast::EpistemicOp::Possible => "possible",
};
if lit.negated {
format!("not {} {}", op, format_atom(&lit.atom))
} else {
format!("{} {}", op, format_atom(&lit.atom))
}
}
fn format_term(term: &Term) -> String {
match term {
Term::Variable(name) => name.clone(),
Term::Anonymous => "_".to_string(),
Term::Integer(value) => value.to_string(),
Term::Float(value) => value.to_string(),
Term::String(value) => format!("\"{}\"", value),
Term::Symbol(id) => symbol::resolve(*id),
Term::List(items) => {
let values = items.iter().map(format_term).collect::<Vec<_>>().join(", ");
format!("[{}]", values)
}
Term::Cons { head, tail } => {
format!("[{} | {}]", format_term(head), format_term(tail))
}
Term::Compound { functor, args } => {
let values = args.iter().map(format_term).collect::<Vec<_>>().join(", ");
format!("{}({})", functor, values)
}
Term::PredRef(name) => name.clone(),
Term::Aggregate(agg) => format!("{}({})", format_agg_op(agg.op), agg.variable),
}
}
fn format_arith_expr(expr: &ArithExpr) -> String {
match expr {
ArithExpr::Variable(name) => name.clone(),
ArithExpr::Integer(value) => value.to_string(),
ArithExpr::Float(value) => value.to_string(),
ArithExpr::Add(left, right) => {
format!(
"({} + {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Sub(left, right) => {
format!(
"({} - {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Mul(left, right) => {
format!(
"({} * {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Div(left, right) => {
format!(
"({} / {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Mod(left, right) => {
format!(
"({} % {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Abs(value) => format!("abs({})", format_arith_expr(value)),
ArithExpr::Min(left, right) => {
format!(
"min({}, {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Max(left, right) => {
format!(
"max({}, {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Pow(left, right) => {
format!(
"pow({}, {})",
format_arith_expr(left),
format_arith_expr(right)
)
}
ArithExpr::Cast(value, ty) => format!("cast({}, {:?})", format_arith_expr(value), ty),
ArithExpr::FuncCall { name, args } => {
let values = args
.iter()
.map(format_arith_expr)
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", name, values)
}
ArithExpr::Conditional {
cond_left,
cond_op,
cond_right,
then_expr,
else_expr,
} => format!(
"if {} {} {} then {} else {}",
format_arith_expr(cond_left),
format_comp_op(*cond_op),
format_arith_expr(cond_right),
format_arith_expr(then_expr),
format_arith_expr(else_expr)
),
}
}
fn format_comp_op(op: CompOp) -> &'static str {
match op {
CompOp::Eq => "==",
CompOp::Ne => "!=",
CompOp::Lt => "<",
CompOp::Le => "<=",
CompOp::Gt => ">",
CompOp::Ge => ">=",
}
}
fn format_agg_op(op: AggOp) -> &'static str {
match op {
AggOp::Count => "count",
AggOp::Sum => "sum",
AggOp::Min => "min",
AggOp::Max => "max",
AggOp::LogSumExp => "logsumexp",
}
}
fn stable_hash(value: &str) -> String {
let mut hash = 0xcbf29ce484222325u64;
for byte in value.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x100000001b3);
}
format!("{:016x}", hash)
}