use crate::ast::{ArithExpr, Atom, BodyLiteral, Comparison, FuncBody, FuncDef, IsExpr, Term};
use crate::function::{FunctionError, FunctionRegistry};
use std::collections::HashMap;
pub struct ExpansionContext<'a> {
registry: &'a FunctionRegistry,
depth: u32,
max_depth: u32,
}
impl<'a> ExpansionContext<'a> {
pub fn new(registry: &'a FunctionRegistry, max_depth: u32) -> Self {
Self {
registry,
depth: 0,
max_depth,
}
}
pub fn expand_call(
&mut self,
name: &str,
args: &[ArithExpr],
) -> Result<ArithExpr, FunctionError> {
if self.depth >= self.max_depth {
return Err(FunctionError::MaxRecursionDepth {
name: name.to_string(),
depth: self.max_depth,
});
}
let func = self
.registry
.get(name)
.ok_or_else(|| FunctionError::UndefinedFunction {
name: name.to_string(),
})?;
let mut subst: HashMap<String, ArithExpr> = HashMap::new();
for (param, arg) in func.params.iter().zip(args.iter()) {
subst.insert(param.name.clone(), arg.clone());
}
self.depth += 1;
let result = self.expand_body(&func.body, &subst)?;
self.depth -= 1;
Ok(result)
}
fn expand_body(
&mut self,
body: &FuncBody,
subst: &HashMap<String, ArithExpr>,
) -> Result<ArithExpr, FunctionError> {
match body {
FuncBody::Arithmetic(expr) => self.expand_expr(expr, subst),
FuncBody::Conditional(cond) => {
let cond_left = self.expand_expr(&cond.cond_left, subst)?;
let cond_right = self.expand_expr(&cond.cond_right, subst)?;
let then_expr = self.expand_body(&cond.then_branch, subst)?;
let else_expr = self.expand_body(&cond.else_branch, subst)?;
Ok(ArithExpr::Conditional {
cond_left: Box::new(cond_left),
cond_op: cond.cond_op,
cond_right: Box::new(cond_right),
then_expr: Box::new(then_expr),
else_expr: Box::new(else_expr),
})
}
FuncBody::Predicate { result, .. } => {
let result_var = subst
.get(result)
.cloned()
.unwrap_or_else(|| ArithExpr::Variable(result.clone()));
Ok(result_var)
}
}
}
fn expand_expr(
&mut self,
expr: &ArithExpr,
subst: &HashMap<String, ArithExpr>,
) -> Result<ArithExpr, FunctionError> {
match expr {
ArithExpr::Variable(name) => {
Ok(subst.get(name).cloned().unwrap_or_else(|| expr.clone()))
}
ArithExpr::Integer(_) | ArithExpr::Float(_) => Ok(expr.clone()),
ArithExpr::FuncCall { name, args } => {
let expanded_args: Result<Vec<_>, _> =
args.iter().map(|a| self.expand_expr(a, subst)).collect();
let expanded_args = expanded_args?;
if self.registry.contains(name) {
self.expand_call(name, &expanded_args)
} else {
Ok(ArithExpr::FuncCall {
name: name.clone(),
args: expanded_args,
})
}
}
ArithExpr::Add(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Add(Box::new(el), Box::new(er)))
}
ArithExpr::Sub(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Sub(Box::new(el), Box::new(er)))
}
ArithExpr::Mul(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Mul(Box::new(el), Box::new(er)))
}
ArithExpr::Div(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Div(Box::new(el), Box::new(er)))
}
ArithExpr::Mod(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Mod(Box::new(el), Box::new(er)))
}
ArithExpr::Abs(e) => {
let ee = self.expand_expr(e, subst)?;
Ok(ArithExpr::Abs(Box::new(ee)))
}
ArithExpr::Min(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Min(Box::new(el), Box::new(er)))
}
ArithExpr::Max(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Max(Box::new(el), Box::new(er)))
}
ArithExpr::Pow(l, r) => {
let el = self.expand_expr(l, subst)?;
let er = self.expand_expr(r, subst)?;
Ok(ArithExpr::Pow(Box::new(el), Box::new(er)))
}
ArithExpr::Cast(e, t) => {
let ee = self.expand_expr(e, subst)?;
Ok(ArithExpr::Cast(Box::new(ee), *t))
}
ArithExpr::Conditional {
cond_left,
cond_op,
cond_right,
then_expr,
else_expr,
} => {
let cl = self.expand_expr(cond_left, subst)?;
let cr = self.expand_expr(cond_right, subst)?;
let te = self.expand_expr(then_expr, subst)?;
let ee = self.expand_expr(else_expr, subst)?;
Ok(ArithExpr::Conditional {
cond_left: Box::new(cl),
cond_op: *cond_op,
cond_right: Box::new(cr),
then_expr: Box::new(te),
else_expr: Box::new(ee),
})
}
}
}
#[allow(dead_code)] pub(crate) fn expand_predicate_func(
&self,
func: &FuncDef,
args: &[ArithExpr],
) -> Result<(Vec<BodyLiteral>, String), FunctionError> {
match &func.body {
FuncBody::Predicate { result, body } => {
let mut subst: HashMap<String, ArithExpr> = HashMap::new();
for (param, arg) in func.params.iter().zip(args.iter()) {
subst.insert(param.name.clone(), arg.clone());
}
let expanded_body: Vec<BodyLiteral> = body
.iter()
.map(|lit| self.substitute_literal(lit, &subst))
.collect();
let result_var = self.substitute_var(result, &subst);
Ok((expanded_body, result_var))
}
_ => Err(FunctionError::UndefinedFunction {
name: func.name.clone(),
}),
}
}
fn substitute_literal(
&self,
lit: &BodyLiteral,
subst: &HashMap<String, ArithExpr>,
) -> BodyLiteral {
match lit {
BodyLiteral::Positive(atom) => BodyLiteral::Positive(self.substitute_atom(atom, subst)),
BodyLiteral::Negated(atom) => BodyLiteral::Negated(self.substitute_atom(atom, subst)),
BodyLiteral::Comparison(cmp) => BodyLiteral::Comparison(Comparison {
left: self.substitute_term(&cmp.left, subst),
op: cmp.op,
right: self.substitute_term(&cmp.right, subst),
}),
BodyLiteral::IsExpr(is_expr) => {
let target = self.substitute_var(&is_expr.target, subst);
let expr = self.substitute_arith_expr(&is_expr.expr, subst);
BodyLiteral::IsExpr(IsExpr { target, expr })
}
}
}
fn substitute_atom(&self, atom: &Atom, subst: &HashMap<String, ArithExpr>) -> Atom {
Atom {
predicate: atom.predicate.clone(),
terms: atom
.terms
.iter()
.map(|t| self.substitute_term(t, subst))
.collect(),
}
}
fn substitute_term(&self, term: &Term, subst: &HashMap<String, ArithExpr>) -> Term {
match term {
Term::Variable(name) => {
if let Some(replacement) = subst.get(name) {
match replacement {
ArithExpr::Variable(new_name) => Term::Variable(new_name.clone()),
ArithExpr::Integer(n) => Term::Integer(*n),
ArithExpr::Float(f) => Term::Float(*f),
_ => term.clone(),
}
} else {
term.clone()
}
}
_ => term.clone(),
}
}
fn substitute_arith_expr(
&self,
expr: &ArithExpr,
subst: &HashMap<String, ArithExpr>,
) -> ArithExpr {
match expr {
ArithExpr::Variable(name) => subst.get(name).cloned().unwrap_or_else(|| expr.clone()),
ArithExpr::Integer(_) | ArithExpr::Float(_) => expr.clone(),
ArithExpr::Add(l, r) => ArithExpr::Add(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Sub(l, r) => ArithExpr::Sub(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Mul(l, r) => ArithExpr::Mul(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Div(l, r) => ArithExpr::Div(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Mod(l, r) => ArithExpr::Mod(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Abs(e) => ArithExpr::Abs(Box::new(self.substitute_arith_expr(e, subst))),
ArithExpr::Min(l, r) => ArithExpr::Min(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Max(l, r) => ArithExpr::Max(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Pow(l, r) => ArithExpr::Pow(
Box::new(self.substitute_arith_expr(l, subst)),
Box::new(self.substitute_arith_expr(r, subst)),
),
ArithExpr::Cast(e, t) => {
ArithExpr::Cast(Box::new(self.substitute_arith_expr(e, subst)), *t)
}
ArithExpr::FuncCall { name, args } => ArithExpr::FuncCall {
name: name.clone(),
args: args
.iter()
.map(|a| self.substitute_arith_expr(a, subst))
.collect(),
},
ArithExpr::Conditional {
cond_left,
cond_op,
cond_right,
then_expr,
else_expr,
} => ArithExpr::Conditional {
cond_left: Box::new(self.substitute_arith_expr(cond_left, subst)),
cond_op: *cond_op,
cond_right: Box::new(self.substitute_arith_expr(cond_right, subst)),
then_expr: Box::new(self.substitute_arith_expr(then_expr, subst)),
else_expr: Box::new(self.substitute_arith_expr(else_expr, subst)),
},
}
}
fn substitute_var(&self, var: &str, subst: &HashMap<String, ArithExpr>) -> String {
if let Some(ArithExpr::Variable(new_name)) = subst.get(var) {
new_name.clone()
} else {
var.to_string()
}
}
#[allow(dead_code)] pub(crate) fn is_predicate_func(&self, name: &str) -> bool {
self.registry
.get(name)
.map(|f| matches!(f.body, FuncBody::Predicate { .. }))
.unwrap_or(false)
}
pub(crate) fn expand_expr_fully(
&mut self,
expr: &ArithExpr,
) -> Result<ArithExpr, FunctionError> {
match expr {
ArithExpr::Variable(_) | ArithExpr::Integer(_) | ArithExpr::Float(_) => {
Ok(expr.clone())
}
ArithExpr::FuncCall { name, args } => {
let expanded_args: Result<Vec<_>, _> =
args.iter().map(|a| self.expand_expr_fully(a)).collect();
let expanded_args = expanded_args?;
if self.registry.contains(name) {
self.expand_call(name, &expanded_args)
} else {
Ok(ArithExpr::FuncCall {
name: name.clone(),
args: expanded_args,
})
}
}
ArithExpr::Add(l, r) => Ok(ArithExpr::Add(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Sub(l, r) => Ok(ArithExpr::Sub(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Mul(l, r) => Ok(ArithExpr::Mul(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Div(l, r) => Ok(ArithExpr::Div(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Mod(l, r) => Ok(ArithExpr::Mod(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Abs(e) => Ok(ArithExpr::Abs(Box::new(self.expand_expr_fully(e)?))),
ArithExpr::Min(l, r) => Ok(ArithExpr::Min(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Max(l, r) => Ok(ArithExpr::Max(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Pow(l, r) => Ok(ArithExpr::Pow(
Box::new(self.expand_expr_fully(l)?),
Box::new(self.expand_expr_fully(r)?),
)),
ArithExpr::Cast(e, t) => Ok(ArithExpr::Cast(Box::new(self.expand_expr_fully(e)?), *t)),
ArithExpr::Conditional {
cond_left,
cond_op,
cond_right,
then_expr,
else_expr,
} => Ok(ArithExpr::Conditional {
cond_left: Box::new(self.expand_expr_fully(cond_left)?),
cond_op: *cond_op,
cond_right: Box::new(self.expand_expr_fully(cond_right)?),
then_expr: Box::new(self.expand_expr_fully(then_expr)?),
else_expr: Box::new(self.expand_expr_fully(else_expr)?),
}),
}
}
}
use crate::ast::{Program, Rule};
pub fn expand_program_functions(
program: &Program,
max_depth: u32,
) -> Result<Program, FunctionError> {
let mut registry = FunctionRegistry::new();
for func in &program.functions {
registry.register(func.clone())?;
}
if program.functions.is_empty() {
return Ok(program.clone());
}
let mut ctx = ExpansionContext::new(®istry, max_depth);
let expanded_rules: Result<Vec<Rule>, FunctionError> = program
.rules
.iter()
.map(|rule| expand_rule_functions(&mut ctx, rule))
.collect();
Ok(Program {
rules: expanded_rules?,
directives: program.directives.clone(),
queries: program.queries.clone(),
predicates: program.predicates.clone(),
constraints: program.constraints.clone(),
imports: program.imports.clone(),
functions: program.functions.clone(),
domains: program.domains.clone(),
prob_facts: program.prob_facts.clone(),
annotated_disjunctions: program.annotated_disjunctions.clone(),
evidence: program.evidence.clone(),
prob_queries: program.prob_queries.clone(),
neural_predicates: program.neural_predicates.clone(),
learnable_rules: program.learnable_rules.clone(),
})
}
fn expand_rule_functions(ctx: &mut ExpansionContext, rule: &Rule) -> Result<Rule, FunctionError> {
let expanded_body: Result<Vec<BodyLiteral>, FunctionError> = rule
.body
.iter()
.map(|lit| expand_literal_functions(ctx, lit))
.collect();
Ok(Rule {
head: rule.head.clone(),
body: expanded_body?,
})
}
fn expand_literal_functions(
ctx: &mut ExpansionContext,
lit: &BodyLiteral,
) -> Result<BodyLiteral, FunctionError> {
match lit {
BodyLiteral::Positive(atom) => Ok(BodyLiteral::Positive(atom.clone())),
BodyLiteral::Negated(atom) => Ok(BodyLiteral::Negated(atom.clone())),
BodyLiteral::Comparison(cmp) => Ok(BodyLiteral::Comparison(cmp.clone())),
BodyLiteral::IsExpr(is_expr) => {
let expanded_expr = ctx.expand_expr_fully(&is_expr.expr)?;
Ok(BodyLiteral::IsExpr(IsExpr {
target: is_expr.target.clone(),
expr: expanded_expr,
}))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::{FuncDef, FuncParam};
#[test]
fn test_simple_expansion() {
let mut reg = FunctionRegistry::new();
let double = FuncDef {
name: "double".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Variable("X".to_string())),
)),
is_private: false,
};
reg.register(double).unwrap();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx.expand_call("double", &[ArithExpr::Integer(5)]).unwrap();
match result {
ArithExpr::Add(l, r) => {
assert!(matches!(*l, ArithExpr::Integer(5)));
assert!(matches!(*r, ArithExpr::Integer(5)));
}
_ => panic!("Expected Add expression"),
}
}
#[test]
fn test_nested_expansion() {
let mut reg = FunctionRegistry::new();
let double = FuncDef {
name: "double".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Variable("X".to_string())),
)),
is_private: false,
};
let quadruple = FuncDef {
name: "quadruple".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::FuncCall {
name: "double".to_string(),
args: vec![ArithExpr::FuncCall {
name: "double".to_string(),
args: vec![ArithExpr::Variable("X".to_string())],
}],
}),
is_private: false,
};
reg.register(double).unwrap();
reg.register(quadruple).unwrap();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx
.expand_call("quadruple", &[ArithExpr::Integer(2)])
.unwrap();
match &result {
ArithExpr::Add(l, r) => {
assert!(matches!(l.as_ref(), ArithExpr::Add(_, _)));
assert!(matches!(r.as_ref(), ArithExpr::Add(_, _)));
}
_ => panic!("Expected nested Add expression, got {:?}", result),
}
}
#[test]
fn test_max_recursion_depth() {
let mut reg = FunctionRegistry::new();
let infinite = FuncDef {
name: "infinite".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::FuncCall {
name: "infinite".to_string(),
args: vec![ArithExpr::Variable("X".to_string())],
}),
is_private: false,
};
reg.register(infinite).unwrap();
let mut ctx = ExpansionContext::new(®, 10);
let result = ctx.expand_call("infinite", &[ArithExpr::Integer(1)]);
assert!(matches!(
result,
Err(FunctionError::MaxRecursionDepth { .. })
));
}
#[test]
fn test_undefined_function() {
let reg = FunctionRegistry::new();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx.expand_call("undefined", &[ArithExpr::Integer(1)]);
assert!(matches!(
result,
Err(FunctionError::UndefinedFunction { .. })
));
}
#[test]
fn test_builtin_function_passthrough() {
let mut reg = FunctionRegistry::new();
let abs_x = FuncDef {
name: "abs_x".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::FuncCall {
name: "abs".to_string(),
args: vec![ArithExpr::Variable("X".to_string())],
}),
is_private: false,
};
reg.register(abs_x).unwrap();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx.expand_call("abs_x", &[ArithExpr::Integer(-5)]).unwrap();
match result {
ArithExpr::FuncCall { name, args } => {
assert_eq!(name, "abs");
assert_eq!(args.len(), 1);
assert!(matches!(args[0], ArithExpr::Integer(-5)));
}
_ => panic!("Expected FuncCall for builtin"),
}
}
#[test]
fn test_variable_substitution() {
let mut reg = FunctionRegistry::new();
let add = FuncDef {
name: "add".to_string(),
params: vec![
FuncParam {
name: "X".to_string(),
typ: None,
},
FuncParam {
name: "Y".to_string(),
typ: None,
},
],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Variable("Y".to_string())),
)),
is_private: false,
};
reg.register(add).unwrap();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx
.expand_call("add", &[ArithExpr::Integer(3), ArithExpr::Integer(7)])
.unwrap();
match result {
ArithExpr::Add(l, r) => {
assert!(matches!(*l, ArithExpr::Integer(3)));
assert!(matches!(*r, ArithExpr::Integer(7)));
}
_ => panic!("Expected Add expression"),
}
}
#[test]
fn test_expansion_with_variable_args() {
let mut reg = FunctionRegistry::new();
let double = FuncDef {
name: "double".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Variable("X".to_string())),
)),
is_private: false,
};
reg.register(double).unwrap();
let mut ctx = ExpansionContext::new(®, 100);
let result = ctx
.expand_call("double", &[ArithExpr::Variable("Y".to_string())])
.unwrap();
match result {
ArithExpr::Add(l, r) => {
assert!(matches!(l.as_ref(), ArithExpr::Variable(n) if n == "Y"));
assert!(matches!(r.as_ref(), ArithExpr::Variable(n) if n == "Y"));
}
_ => panic!("Expected Add expression"),
}
}
#[test]
fn test_predicate_func_expansion() {
let func = FuncDef {
name: "get_parent".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Predicate {
result: "P".to_string(),
body: vec![BodyLiteral::Positive(Atom {
predicate: "parent".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("P".to_string()),
],
})],
},
is_private: false,
};
let mut reg = FunctionRegistry::new();
reg.register(func).unwrap();
let ctx = ExpansionContext::new(®, 100);
let args = vec![ArithExpr::Variable("alice".to_string())];
let func_def = reg.get("get_parent").unwrap();
let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
assert_eq!(result, "P");
assert_eq!(body.len(), 1);
if let BodyLiteral::Positive(atom) = &body[0] {
assert_eq!(atom.predicate, "parent");
assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
} else {
panic!("Expected Positive literal");
}
}
#[test]
fn test_predicate_func_with_constant_arg() {
let func = FuncDef {
name: "get_child".to_string(),
params: vec![FuncParam {
name: "P".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Predicate {
result: "C".to_string(),
body: vec![BodyLiteral::Positive(Atom {
predicate: "parent".to_string(),
terms: vec![
Term::Variable("C".to_string()),
Term::Variable("P".to_string()),
],
})],
},
is_private: false,
};
let mut reg = FunctionRegistry::new();
reg.register(func).unwrap();
let ctx = ExpansionContext::new(®, 100);
let args = vec![ArithExpr::Integer(42)];
let func_def = reg.get("get_child").unwrap();
let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
assert_eq!(result, "C");
assert_eq!(body.len(), 1);
if let BodyLiteral::Positive(atom) = &body[0] {
assert_eq!(atom.predicate, "parent");
assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "C"));
assert!(matches!(&atom.terms[1], Term::Integer(42)));
} else {
panic!("Expected Positive literal");
}
}
#[test]
fn test_predicate_func_multiple_body_literals() {
let func = FuncDef {
name: "get_grandparent".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Predicate {
result: "G".to_string(),
body: vec![
BodyLiteral::Positive(Atom {
predicate: "parent".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("P".to_string()),
],
}),
BodyLiteral::Positive(Atom {
predicate: "parent".to_string(),
terms: vec![
Term::Variable("P".to_string()),
Term::Variable("G".to_string()),
],
}),
],
},
is_private: false,
};
let mut reg = FunctionRegistry::new();
reg.register(func).unwrap();
let ctx = ExpansionContext::new(®, 100);
let args = vec![ArithExpr::Variable("alice".to_string())];
let func_def = reg.get("get_grandparent").unwrap();
let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
assert_eq!(result, "G");
assert_eq!(body.len(), 2);
if let BodyLiteral::Positive(atom) = &body[0] {
assert_eq!(atom.predicate, "parent");
assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
} else {
panic!("Expected Positive literal for first body");
}
if let BodyLiteral::Positive(atom) = &body[1] {
assert_eq!(atom.predicate, "parent");
assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "P"));
assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "G"));
} else {
panic!("Expected Positive literal for second body");
}
}
#[test]
fn test_is_predicate_func() {
let mut reg = FunctionRegistry::new();
let arith_func = FuncDef {
name: "double".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Arithmetic(ArithExpr::Add(
Box::new(ArithExpr::Variable("X".to_string())),
Box::new(ArithExpr::Variable("X".to_string())),
)),
is_private: false,
};
let pred_func = FuncDef {
name: "get_parent".to_string(),
params: vec![FuncParam {
name: "X".to_string(),
typ: None,
}],
return_type: None,
body: FuncBody::Predicate {
result: "P".to_string(),
body: vec![BodyLiteral::Positive(Atom {
predicate: "parent".to_string(),
terms: vec![
Term::Variable("X".to_string()),
Term::Variable("P".to_string()),
],
})],
},
is_private: false,
};
reg.register(arith_func).unwrap();
reg.register(pred_func).unwrap();
let ctx = ExpansionContext::new(®, 100);
assert!(!ctx.is_predicate_func("double"));
assert!(ctx.is_predicate_func("get_parent"));
assert!(!ctx.is_predicate_func("nonexistent"));
}
}