use xlog_core::ScalarType;
#[derive(Debug, Clone, PartialEq)]
pub enum Term {
Variable(String),
Anonymous,
Integer(i64),
Float(f64),
String(String),
Symbol(u32),
List(Vec<Term>),
Cons {
head: Box<Term>,
tail: Box<Term>,
},
Compound {
functor: String,
args: Vec<Term>,
},
PredRef(String),
Aggregate(AggExpr),
}
impl Term {
pub fn is_variable(&self) -> bool {
matches!(self, Term::Variable(_))
}
pub fn is_anonymous(&self) -> bool {
matches!(self, Term::Anonymous)
}
pub fn is_any_variable(&self) -> bool {
matches!(self, Term::Variable(_) | Term::Anonymous)
}
pub fn is_constant(&self) -> bool {
!self.is_any_variable()
&& !matches!(
self,
Term::Aggregate(_)
| Term::List(_)
| Term::Cons { .. }
| Term::Compound { .. }
| Term::PredRef(_)
)
}
pub fn variable_name(&self) -> Option<&str> {
match self {
Term::Variable(name) => Some(name),
_ => None,
}
}
pub fn variables(&self) -> Vec<&str> {
match self {
Term::Variable(name) => vec![name.as_str()],
Term::List(items) => items.iter().flat_map(Term::variables).collect(),
Term::Cons { head, tail } => {
let mut vars = head.variables();
vars.extend(tail.variables());
vars
}
Term::Compound { args, .. } => args.iter().flat_map(Term::variables).collect(),
Term::Anonymous
| Term::Integer(_)
| Term::Float(_)
| Term::String(_)
| Term::Symbol(_)
| Term::PredRef(_)
| Term::Aggregate(_) => vec![],
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AggExpr {
pub op: AggOp,
pub variable: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AggOp {
Count,
Sum,
Min,
Max,
LogSumExp,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ArithExpr {
Variable(String),
Integer(i64),
Float(f64),
Add(Box<ArithExpr>, Box<ArithExpr>),
Sub(Box<ArithExpr>, Box<ArithExpr>),
Mul(Box<ArithExpr>, Box<ArithExpr>),
Div(Box<ArithExpr>, Box<ArithExpr>),
Mod(Box<ArithExpr>, Box<ArithExpr>),
Abs(Box<ArithExpr>),
Min(Box<ArithExpr>, Box<ArithExpr>),
Max(Box<ArithExpr>, Box<ArithExpr>),
Pow(Box<ArithExpr>, Box<ArithExpr>),
Cast(Box<ArithExpr>, ScalarType),
FuncCall {
name: String,
args: Vec<ArithExpr>,
},
Conditional {
cond_left: Box<ArithExpr>,
cond_op: CompOp,
cond_right: Box<ArithExpr>,
then_expr: Box<ArithExpr>,
else_expr: Box<ArithExpr>,
},
}
impl ArithExpr {
pub fn variables(&self) -> Vec<&str> {
match self {
ArithExpr::Variable(name) => vec![name.as_str()],
ArithExpr::Integer(_) | ArithExpr::Float(_) => vec![],
ArithExpr::Add(l, r)
| ArithExpr::Sub(l, r)
| ArithExpr::Mul(l, r)
| ArithExpr::Div(l, r)
| ArithExpr::Mod(l, r)
| ArithExpr::Min(l, r)
| ArithExpr::Max(l, r)
| ArithExpr::Pow(l, r) => {
let mut vars = l.variables();
vars.extend(r.variables());
vars
}
ArithExpr::Abs(e) | ArithExpr::Cast(e, _) => e.variables(),
ArithExpr::FuncCall { args, .. } => args.iter().flat_map(|a| a.variables()).collect(),
ArithExpr::Conditional {
cond_left,
cond_right,
then_expr,
else_expr,
..
} => {
let mut vars = cond_left.variables();
vars.extend(cond_right.variables());
vars.extend(then_expr.variables());
vars.extend(else_expr.variables());
vars
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct IsExpr {
pub target: String,
pub expr: ArithExpr,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Atom {
pub predicate: String,
pub terms: Vec<Term>,
}
impl Atom {
pub fn arity(&self) -> usize {
self.terms.len()
}
pub fn variables(&self) -> Vec<&str> {
self.terms.iter().flat_map(Term::variables).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EpistemicOp {
Know,
Possible,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EpistemicLiteral {
pub op: EpistemicOp,
pub negated: bool,
pub atom: Atom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Comparison {
pub left: Term,
pub op: CompOp,
pub right: Term,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Univ {
pub term: Term,
pub parts: Term,
}
#[derive(Debug, Clone, PartialEq)]
pub enum BodyLiteral {
Positive(Atom),
Negated(Atom),
Epistemic(EpistemicLiteral),
Comparison(Comparison),
IsExpr(IsExpr),
Univ(Univ),
}
impl BodyLiteral {
pub fn is_positive(&self) -> bool {
matches!(self, BodyLiteral::Positive(_))
}
pub fn is_negated(&self) -> bool {
matches!(self, BodyLiteral::Negated(_))
}
pub fn atom(&self) -> Option<&Atom> {
match self {
BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => Some(a),
BodyLiteral::Epistemic(lit) => Some(&lit.atom),
BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => None,
}
}
pub fn variables(&self) -> Vec<&str> {
match self {
BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a.variables(),
BodyLiteral::Epistemic(lit) => lit.atom.variables(),
BodyLiteral::Comparison(c) => {
let mut vars = vec![];
vars.extend(c.left.variables());
vars.extend(c.right.variables());
vars
}
BodyLiteral::IsExpr(is_expr) => {
let mut vars = is_expr.expr.variables();
vars.push(is_expr.target.as_str());
vars
}
BodyLiteral::Univ(univ) => {
let mut vars = univ.term.variables();
vars.extend(univ.parts.variables());
vars
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Rule {
pub head: Atom,
pub body: Vec<BodyLiteral>,
}
impl Rule {
pub fn is_fact(&self) -> bool {
self.body.is_empty()
}
pub fn has_negation(&self) -> bool {
self.body.iter().any(|l| l.is_negated())
}
pub fn has_aggregation(&self) -> bool {
self.head
.terms
.iter()
.any(|t| matches!(t, Term::Aggregate(_)))
}
pub fn body_predicates(&self) -> Vec<&str> {
self.body
.iter()
.filter_map(|l| l.atom().map(|a| a.predicate.as_str()))
.collect()
}
pub fn head_variables(&self) -> Vec<&str> {
self.head.variables()
}
pub fn body_variables(&self) -> Vec<&str> {
self.body.iter().flat_map(|l| l.variables()).collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Constraint {
pub body: Vec<BodyLiteral>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Query {
pub atom: Atom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProbEngine {
ExactDdnnf,
Mc,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProbCache {
On,
Off,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EpistemicMode {
G91,
Faeel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProbMethod {
Rejection,
EvidenceClamping,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MagicSetsMode {
Auto,
On,
Off,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Directives {
pub prob_engine: Option<ProbEngine>,
pub prob_cache: Option<ProbCache>,
pub prob_samples: Option<usize>,
pub prob_seed: Option<u64>,
pub prob_confidence: Option<f64>,
pub prob_method: Option<ProbMethod>,
pub prob_max_nonmonotone_iterations: Option<usize>,
pub max_recursion_depth: Option<u32>,
pub epistemic_mode: Option<EpistemicMode>,
pub magic_sets: Option<MagicSetsMode>,
}
impl Directives {
pub fn prob_engine_or_default(&self) -> ProbEngine {
self.prob_engine.unwrap_or(ProbEngine::ExactDdnnf)
}
pub fn max_recursion_depth_or_default(&self) -> u32 {
self.max_recursion_depth.unwrap_or(1000)
}
pub fn epistemic_mode_or_default(&self) -> EpistemicMode {
self.epistemic_mode.unwrap_or(EpistemicMode::Faeel)
}
pub fn prob_samples_or_default(&self) -> usize {
self.prob_samples.unwrap_or(10000)
}
pub fn prob_seed_or_default(&self) -> u64 {
self.prob_seed.unwrap_or(0)
}
pub fn prob_confidence_or_default(&self) -> f64 {
self.prob_confidence.unwrap_or(0.95)
}
pub fn prob_max_nonmonotone_iterations_or_default(&self) -> usize {
self.prob_max_nonmonotone_iterations.unwrap_or(1024)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProbFact {
pub prob: f64,
pub atom: Atom,
}
#[derive(Debug, Clone, PartialEq)]
pub struct NeuralPredDecl {
pub network: String,
pub inputs: Vec<String>,
pub output: String,
pub labels: Option<Vec<NeuralLabel>>,
pub predicate: Atom,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NeuralLabel {
Integer(i64),
Symbol(String),
}
#[derive(Debug, Clone)]
pub struct LearnableRule {
pub mask_name: String,
pub head: Atom,
pub body: Vec<BodyLiteral>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AnnotatedDisjunction {
pub choices: Vec<ProbFact>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Evidence {
pub atom: Atom,
pub value: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProbQuery {
pub atom: Atom,
}
#[derive(Debug, Clone, PartialEq)]
pub struct UseDecl {
pub module_path: Vec<String>,
pub imports: Option<Vec<String>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct DomainDecl {
pub name: String,
pub typ: ScalarType,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeRef {
Scalar(ScalarType),
Domain(String),
List(Box<TypeRef>),
Term,
Compound,
PredRef,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PredColumn {
pub name: Option<String>,
pub typ: TypeRef,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PredDecl {
pub name: String,
pub types: Vec<TypeRef>,
pub columns: Vec<PredColumn>,
pub is_private: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub struct FuncParam {
pub name: String,
pub typ: Option<ScalarType>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CondExpr {
pub cond_left: ArithExpr,
pub cond_op: CompOp,
pub cond_right: ArithExpr,
pub then_branch: Box<FuncBody>,
pub else_branch: Box<FuncBody>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FuncBody {
Arithmetic(ArithExpr),
Conditional(CondExpr),
Predicate {
result: String,
body: Vec<BodyLiteral>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct FuncDef {
pub name: String,
pub params: Vec<FuncParam>,
pub return_type: Option<ScalarType>,
pub body: FuncBody,
pub is_private: bool,
}
#[derive(Debug, Clone, Default)]
pub struct Program {
pub imports: Vec<UseDecl>,
pub functions: Vec<FuncDef>,
pub domains: Vec<DomainDecl>,
pub predicates: Vec<PredDecl>,
pub rules: Vec<Rule>,
pub constraints: Vec<Constraint>,
pub queries: Vec<Query>,
pub prob_facts: Vec<ProbFact>,
pub annotated_disjunctions: Vec<AnnotatedDisjunction>,
pub evidence: Vec<Evidence>,
pub prob_queries: Vec<ProbQuery>,
pub neural_predicates: Vec<NeuralPredDecl>,
pub learnable_rules: Vec<LearnableRule>,
pub directives: Directives,
}
impl Program {
pub fn new() -> Self {
Self::default()
}
pub fn facts(&self) -> impl Iterator<Item = &Rule> {
self.rules.iter().filter(|r| r.is_fact())
}
pub fn proper_rules(&self) -> impl Iterator<Item = &Rule> {
self.rules.iter().filter(|r| !r.is_fact())
}
pub fn defined_predicates(&self) -> Vec<&str> {
self.rules
.iter()
.map(|r| r.head.predicate.as_str())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
}
pub fn is_probabilistic_profile(&self) -> bool {
!self.prob_facts.is_empty()
|| !self.annotated_disjunctions.is_empty()
|| !self.evidence.is_empty()
|| !self.prob_queries.is_empty()
|| self.directives.prob_engine.is_some()
|| self.directives.prob_cache.is_some()
|| self.directives.prob_samples.is_some()
|| self.directives.prob_seed.is_some()
|| self.directives.prob_confidence.is_some()
|| self.directives.prob_method.is_some()
|| self.directives.prob_max_nonmonotone_iterations.is_some()
}
pub fn prob_engine(&self) -> ProbEngine {
self.directives.prob_engine_or_default()
}
pub fn merge_from(
&mut self,
other: &Program,
imported_items: Option<&std::collections::HashSet<String>>,
) {
use std::collections::HashSet;
let private_preds: HashSet<&str> = other
.predicates
.iter()
.filter(|p| p.is_private)
.map(|p| p.name.as_str())
.collect();
let _private_funcs: HashSet<&str> = other
.functions
.iter()
.filter(|f| f.is_private)
.map(|f| f.name.as_str())
.collect();
for pred in &other.predicates {
if pred.is_private {
continue;
}
if let Some(items) = imported_items {
if !items.contains(&pred.name) {
continue;
}
}
if !self.predicates.iter().any(|p| p.name == pred.name) {
self.predicates.push(pred.clone());
}
}
for func in &other.functions {
if func.is_private {
continue;
}
if let Some(items) = imported_items {
if !items.contains(&func.name) {
continue;
}
}
if !self.functions.iter().any(|f| f.name == func.name) {
self.functions.push(func.clone());
}
}
for rule in &other.rules {
if private_preds.contains(rule.head.predicate.as_str()) {
continue;
}
if let Some(items) = imported_items {
if !items.contains(&rule.head.predicate) {
continue;
}
}
self.rules.push(rule.clone());
}
for domain in &other.domains {
if !self.domains.iter().any(|d| d.name == domain.name) {
self.domains.push(domain.clone());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_term_variable() {
let term = Term::Variable("X".to_string());
assert!(term.is_variable());
assert!(!term.is_constant());
}
#[test]
fn test_term_constant() {
let term = Term::Integer(42);
assert!(!term.is_variable());
assert!(term.is_constant());
}
#[test]
fn test_atom_arity() {
let atom = Atom {
predicate: "edge".to_string(),
terms: vec![Term::Integer(1), Term::Integer(2)],
};
assert_eq!(atom.arity(), 2);
}
#[test]
fn test_atom_variables() {
let atom = Atom {
predicate: "edge".to_string(),
terms: vec![Term::Variable("X".to_string()), Term::Integer(2)],
};
let vars = atom.variables();
assert_eq!(vars, vec!["X"]);
}
#[test]
fn test_rule_is_fact() {
let fact = Rule {
head: Atom {
predicate: "edge".to_string(),
terms: vec![Term::Integer(1), Term::Integer(2)],
},
body: vec![],
};
assert!(fact.is_fact());
}
#[test]
fn test_rule_has_negation() {
let rule = Rule {
head: Atom {
predicate: "isolated".to_string(),
terms: vec![Term::Variable("X".to_string())],
},
body: vec![
BodyLiteral::Positive(Atom {
predicate: "node".to_string(),
terms: vec![Term::Variable("X".to_string())],
}),
BodyLiteral::Negated(Atom {
predicate: "edge".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("Y".to_string()),
],
}),
],
};
assert!(rule.has_negation());
}
#[test]
fn test_program_facts() {
let mut program = Program::new();
program.rules.push(Rule {
head: Atom {
predicate: "edge".to_string(),
terms: vec![Term::Integer(1), Term::Integer(2)],
},
body: vec![],
});
program.rules.push(Rule {
head: Atom {
predicate: "reach".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("Y".to_string()),
],
},
body: vec![BodyLiteral::Positive(Atom {
predicate: "edge".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("Y".to_string()),
],
})],
});
assert_eq!(program.facts().count(), 1);
assert_eq!(program.proper_rules().count(), 1);
}
#[test]
fn test_arith_expr_structure() {
let expr = ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Integer(1)),
);
assert!(matches!(expr, ArithExpr::Add(_, _)));
}
#[test]
fn test_is_expr_structure() {
let is_expr = IsExpr {
target: "Z".to_string(),
expr: ArithExpr::Variable("Y".to_string()),
};
assert_eq!(is_expr.target, "Z");
}
}