use std::collections::HashMap;
use gollum_ast::{
BinOpKind, BodyGoal, Expr, Fact, Interval as AstInterval, Item, PlainClause, Rule, Term,
};
use crate::action::IrAction;
use crate::clause::IrClause;
use crate::metadata::IrMetadata;
use crate::program::IrProgram;
use crate::query::IrQuery;
use crate::term::IrTerm;
use crate::timestamp::{Interval as IrInterval, Timestamp};
fn ast_interval_to_ir(interval: AstInterval) -> IrInterval {
IrInterval::new(
Timestamp::from_ns(interval.start_ns).unwrap(),
Timestamp::from_ns(interval.end_ns).unwrap(),
)
.unwrap()
}
pub fn lower(items: &[Item]) -> IrProgram {
let mut program = IrProgram::new();
for item in items {
match item {
Item::Fact(f) => program.clauses.push(lower_fact(f, None)),
Item::Rule(r) => program.clauses.push(lower_rule(r, None)),
Item::Query(q) => program.queries.push(lower_query(q)),
Item::Directive(gollum_ast::Directive::Table { functor, arity }) => {
program
.tabled_predicates
.insert((functor.clone(), *arity as usize));
}
Item::Directive(gollum_ast::Directive::DiffNeural {
functor,
arity,
model_name,
}) => {
program.diff_neural_predicates.insert((
functor.clone(),
*arity as usize,
model_name.clone(),
));
}
Item::Directive(gollum_ast::Directive::NeuralGen { functor, arity }) => {
program.neural_gen_predicates.insert((functor.clone(), *arity as usize));
}
Item::Directive(gollum_ast::Directive::Neural { functor, arg_types: _arg_types, arity: _arity, options }) => {
for (k, v) in options {
if k == "model" {
match v {
Term::Str(s) => {
program.neural_models.insert(functor.clone(), s.clone());
}
Term::Atom(s) => {
program.neural_models.insert(functor.clone(), s.clone());
}
other => {
eprintln!("Warning: neural model value for {} must be a string, got {}", functor, other);
}
}
} else {
eprintln!("Warning: unknown neural option '{}' for predicate {}", k, functor);
}
}
}
Item::Directive(gollum_ast::Directive::NeuralModel { predicate, model_name }) => {
program.neural_models.insert(predicate.clone(), model_name.clone());
}
Item::Directive(gollum_ast::Directive::NeuralUnify { threshold }) => {
program.neural_unify_threshold = *threshold;
}
Item::Directive(_) => {}
Item::Probabilistic(p) => {
let meta = IrMetadata {
probability: Some(p.probability),
..Default::default()
};
match &p.clause {
PlainClause::Fact(f) => {
let mut m = meta.clone();
if let Some(interval) = f.temporal_interval {
m.temporal_interval = Some(ast_interval_to_ir(interval));
}
program.clauses.push(lower_fact(f, Some(m)));
}
PlainClause::Rule(r) => {
let mut m = meta.clone();
if let Some(interval) = r.temporal_interval {
m.temporal_interval = Some(ast_interval_to_ir(interval));
}
program.clauses.push(lower_rule(r, Some(m)));
}
}
}
}
}
program.actions = extract_strip_actions(&program);
program
}
fn lower_fact(f: &Fact, metadata: Option<IrMetadata>) -> IrClause {
let mut anon = 0u32;
let mut meta = metadata.unwrap_or_default();
if let Some(interval) = f.temporal_interval {
meta.temporal_interval = Some(ast_interval_to_ir(interval));
}
if let Some(modal) = &f.modality {
meta.modality = Some(modal.clone());
}
if let Some((ref model_id, ref grad_id)) = f.diff_neural_ref {
meta.diff_neural_ref = Some((model_id.clone(), grad_id.clone()));
}
IrClause {
head: IrTerm::Structure {
name: f.name.clone(),
args: lower_terms_a(&f.args, &mut anon),
},
body: vec![],
metadata: if meta == IrMetadata::default() {
None
} else {
Some(meta)
},
}
}
fn lower_rule(r: &Rule, metadata: Option<IrMetadata>) -> IrClause {
let mut anon = 0u32;
let mut meta = metadata.unwrap_or_default();
if let Some(interval) = r.temporal_interval {
meta.temporal_interval = Some(ast_interval_to_ir(interval));
}
if let Some(modal) = &r.modality {
meta.modality = Some(modal.clone());
}
IrClause {
head: IrTerm::Structure {
name: r.head.name.clone(),
args: lower_terms_a(&r.head.args, &mut anon),
},
body: r
.body
.iter()
.map(|g| lower_body_goal_a(g, &mut anon))
.collect(),
metadata: if meta == IrMetadata::default() {
None
} else {
Some(meta)
},
}
}
fn lower_query(q: &gollum_ast::Query) -> IrQuery {
let goal = match q.goals.len() {
0 => IrTerm::Atom("true".into()),
1 => lower_body_goal(&q.goals[0]),
_ => {
let goals: Vec<IrTerm> = q.goals.iter().map(lower_body_goal).collect();
build_conjunction(goals)
}
};
let metadata = q.temporal_interval.map(|interval| IrMetadata {
temporal_interval: Some(ast_interval_to_ir(interval)),
..Default::default()
});
IrQuery { goal, metadata }
}
fn build_conjunction(mut goals: Vec<IrTerm>) -> IrTerm {
debug_assert!(!goals.is_empty());
if goals.len() == 1 {
return goals.remove(0);
}
let last = goals.pop().unwrap();
let init = build_conjunction(goals);
IrTerm::Structure {
name: ",".into(),
args: vec![init, last],
}
}
fn lower_terms_a(terms: &[Term], anon: &mut u32) -> Vec<IrTerm> {
terms.iter().map(|t| lower_term_a(t, anon)).collect()
}
fn lower_term_a(term: &Term, anon: &mut u32) -> IrTerm {
match term {
Term::Variable(s) => IrTerm::Var(s.clone()),
Term::Atom(s) => IrTerm::Atom(s.clone()),
Term::Integer(n) => IrTerm::Number(*n),
Term::Float(f) => IrTerm::Float(*f),
Term::Tensor(v) => IrTerm::Tensor(v.clone()),
Term::Str(s) => IrTerm::Atom(s.clone()),
Term::Anon => {
let name = format!("_G{}", anon);
*anon += 1;
IrTerm::Var(name)
}
Term::Compound(name, args) => IrTerm::Structure {
name: name.clone(),
args: lower_terms_a(args, anon),
},
Term::List(terms) => lower_list_a(terms, anon),
Term::ListCons(heads, tail) => lower_list_cons_a(heads, tail, anon),
Term::TypeAnnotated {
term: inner,
type_name,
} => IrTerm::Typed {
term: Box::new(lower_term_a(inner, anon)),
ty: type_name.clone(),
},
Term::NeuralGradient {
term: inner,
model_id,
grad_id,
} => IrTerm::DiffNeural {
term: Box::new(lower_term_a(inner, anon)),
model_id: model_id.clone(),
grad_id: grad_id.clone(),
},
}
}
fn lower_list_a(terms: &[Term], anon: &mut u32) -> IrTerm {
terms
.iter()
.rfold(IrTerm::Atom("[]".into()), |acc, t| IrTerm::Structure {
name: ".".into(),
args: vec![lower_term_a(t, anon), acc],
})
}
fn lower_list_cons_a(heads: &[Term], tail: &Term, anon: &mut u32) -> IrTerm {
let tail_ir = lower_term_a(tail, anon);
heads.iter().rfold(tail_ir, |acc, h| IrTerm::Structure {
name: ".".into(),
args: vec![lower_term_a(h, anon), acc],
})
}
fn lower_body_goal(goal: &BodyGoal) -> IrTerm {
let mut anon = 0u32;
lower_body_goal_a(goal, &mut anon)
}
fn lower_body_goal_a(goal: &BodyGoal, anon: &mut u32) -> IrTerm {
match goal {
BodyGoal::Call(name, args) => IrTerm::Structure {
name: name.clone(),
args: lower_terms_a(args, anon),
},
BodyGoal::Expr(e) => lower_expr_a(e, anon),
BodyGoal::Not(inner) => IrTerm::Structure {
name: "not".into(),
args: vec![lower_body_goal_a(inner, anon)],
},
BodyGoal::Cut => IrTerm::Cut,
}
}
fn lower_expr_a(expr: &Expr, anon: &mut u32) -> IrTerm {
match expr {
Expr::Term(t) => lower_term_a(t, anon),
Expr::BinOp(op, lhs, rhs) => IrTerm::Structure {
name: binop_name(op).into(),
args: vec![lower_expr_a(lhs, anon), lower_expr_a(rhs, anon)],
},
}
}
pub fn extract_strip_actions(program: &IrProgram) -> Vec<IrAction> {
type ActionData = (Vec<String>, Vec<IrTerm>, Vec<IrTerm>, Option<IrMetadata>);
let mut map: HashMap<(String, usize), ActionData> = HashMap::new();
for clause in &program.clauses {
if !clause.body.is_empty() {
continue; }
if let IrTerm::Structure { name, args } = &clause.head
&& (name == "precond" || name == "effect") && args.len() == 2 {
let action_term = &args[0];
let payload = &args[1];
let (action_name, params) = action_name_and_params(action_term);
let key = (action_name, params.len());
let entry = map.entry(key).or_insert_with(|| {
(params, vec![], vec![], clause.metadata.clone())
});
if name == "precond" {
entry.1.push(payload.clone());
} else {
entry.2.push(payload.clone());
}
}
}
let mut actions: Vec<IrAction> = map
.into_iter()
.map(|((name, _), (parameters, preconditions, effects, metadata))| IrAction {
name,
parameters,
preconditions,
effects,
metadata,
})
.collect();
actions.sort_by(|a, b| a.name.cmp(&b.name));
actions
}
fn action_name_and_params(term: &IrTerm) -> (String, Vec<String>) {
match term {
IrTerm::Structure { name, args } => {
let params = args
.iter()
.filter_map(|a| match a {
IrTerm::Var(v) => Some(v.clone()),
_ => None,
})
.collect();
(name.clone(), params)
}
IrTerm::Atom(name) => (name.clone(), vec![]),
_ => ("_unknown".into(), vec![]),
}
}
fn binop_name(op: &BinOpKind) -> &'static str {
match op {
BinOpKind::Add => "+",
BinOpKind::Sub => "-",
BinOpKind::Mul => "*",
BinOpKind::Div => "/",
BinOpKind::Mod => "mod",
BinOpKind::Is => "is",
BinOpKind::Eq => "==",
BinOpKind::Neq => "\\==",
BinOpKind::Lt => "<",
BinOpKind::Gt => ">",
BinOpKind::Lte => "=<",
BinOpKind::Gte => ">=",
BinOpKind::Unify => "=",
BinOpKind::NotUnify => "\\=",
BinOpKind::ArithEq => "=:=",
BinOpKind::ArithNeq => "=\\=",
BinOpKind::And => ",",
BinOpKind::Or => ";",
BinOpKind::ClpEq => "#=",
BinOpKind::ClpNeq => "#\\=",
BinOpKind::ClpLt => "#<",
BinOpKind::ClpLte => "#=<",
BinOpKind::ClpGt => "#>",
BinOpKind::ClpGte => "#>=",
BinOpKind::ClpIn => "in",
BinOpKind::NeuralUnify => "~=",
}
}
#[cfg(test)]
mod tests {
use super::*;
use gollum_ast::{BinOpKind, Predicate, Query, Term};
use gollum_ast::{BodyGoal, Expr, Fact, Item, PlainClause, Probabilistic, Rule};
fn fact(name: &str, args: Vec<Term>) -> Item {
Item::Fact(Fact {
name: name.into(),
args,
temporal_interval: None,
modality: None,
diff_neural_ref: None,
})
}
fn rule(head_name: &str, head_args: Vec<Term>, body: Vec<BodyGoal>) -> Item {
Item::Rule(Rule {
head: Predicate {
name: head_name.into(),
args: head_args,
},
body,
temporal_interval: None,
modality: None,
})
}
fn query(goals: Vec<BodyGoal>) -> Item {
Item::Query(Query {
goals,
temporal_interval: None,
})
}
fn call(name: &str, args: Vec<Term>) -> BodyGoal {
BodyGoal::Call(name.into(), args)
}
#[test]
fn test_lower_fact() {
let items = vec![fact(
"parent",
vec![Term::Atom("alice".into()), Term::Atom("bob".into())],
)];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 1);
assert_eq!(prog.queries.len(), 0);
let c = &prog.clauses[0];
assert!(c.body.is_empty());
assert_eq!(
c.head,
IrTerm::Structure {
name: "parent".into(),
args: vec![IrTerm::Atom("alice".into()), IrTerm::Atom("bob".into())]
}
);
}
#[test]
fn test_lower_arity0_fact() {
let items = vec![fact("rain", vec![])];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 1);
assert_eq!(
prog.clauses[0].head,
IrTerm::Structure {
name: "rain".into(),
args: vec![]
}
);
}
#[test]
fn test_lower_rule() {
let items = vec![rule(
"grandparent",
vec![Term::Variable("X".into()), Term::Variable("Y".into())],
vec![
call(
"parent",
vec![Term::Variable("X".into()), Term::Variable("Z".into())],
),
call(
"parent",
vec![Term::Variable("Z".into()), Term::Variable("Y".into())],
),
],
)];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 1);
let c = &prog.clauses[0];
assert_eq!(c.body.len(), 2);
assert_eq!(
c.head,
IrTerm::Structure {
name: "grandparent".into(),
args: vec![IrTerm::Var("X".into()), IrTerm::Var("Y".into())]
}
);
}
#[test]
fn test_lower_prob_fact() {
let items = vec![Item::Probabilistic(Probabilistic {
probability: 0.8,
clause: PlainClause::Fact(Fact {
name: "rain".into(),
args: vec![],
temporal_interval: None,
modality: None,
diff_neural_ref: None,
}),
})];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 1);
let meta = prog.clauses[0].metadata.as_ref().unwrap();
assert_eq!(meta.probability, Some(0.8));
}
#[test]
fn test_lower_prob_rule() {
let items = vec![Item::Probabilistic(Probabilistic {
probability: 0.9,
clause: PlainClause::Rule(Rule {
head: Predicate {
name: "wet_grass".into(),
args: vec![],
},
body: vec![call("rain", vec![])],
temporal_interval: None,
modality: None,
}),
})];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 1);
let meta = prog.clauses[0].metadata.as_ref().unwrap();
assert_eq!(meta.probability, Some(0.9));
assert_eq!(prog.clauses[0].body.len(), 1);
}
#[test]
fn test_lower_query_single() {
let items = vec![query(vec![call(
"grandparent",
vec![Term::Atom("alice".into()), Term::Variable("Y".into())],
)])];
let prog = lower(&items);
assert_eq!(prog.queries.len(), 1);
assert_eq!(
prog.queries[0].goal,
IrTerm::Structure {
name: "grandparent".into(),
args: vec![IrTerm::Atom("alice".into()), IrTerm::Var("Y".into())]
}
);
}
#[test]
fn test_lower_query_multi() {
let items = vec![query(vec![
call(
"parent",
vec![Term::Atom("alice".into()), Term::Variable("X".into())],
),
call(
"parent",
vec![Term::Variable("X".into()), Term::Variable("Y".into())],
),
])];
let prog = lower(&items);
assert_eq!(prog.queries.len(), 1);
if let IrTerm::Structure { name, args } = &prog.queries[0].goal {
assert_eq!(name, ",");
assert_eq!(args.len(), 2);
} else {
panic!("expected conjunction structure");
}
}
#[test]
fn test_lower_negation() {
let items = vec![rule(
"foo",
vec![],
vec![BodyGoal::Not(Box::new(call("bar", vec![])))],
)];
let prog = lower(&items);
let body_goal = &prog.clauses[0].body[0];
assert_eq!(
body_goal,
&IrTerm::Structure {
name: "not".into(),
args: vec![IrTerm::Structure {
name: "bar".into(),
args: vec![]
}]
}
);
}
#[test]
fn test_lower_arithmetic() {
let expr = Expr::BinOp(
BinOpKind::Is,
Box::new(Expr::Term(Term::Variable("X".into()))),
Box::new(Expr::BinOp(
BinOpKind::Add,
Box::new(Expr::Term(Term::Integer(1))),
Box::new(Expr::Term(Term::Integer(2))),
)),
);
let items = vec![rule(
"result",
vec![Term::Variable("X".into())],
vec![BodyGoal::Expr(Box::new(expr))],
)];
let prog = lower(&items);
let body_goal = &prog.clauses[0].body[0];
if let IrTerm::Structure { name, args } = body_goal {
assert_eq!(name, "is");
assert_eq!(args.len(), 2);
assert_eq!(args[0], IrTerm::Var("X".into()));
if let IrTerm::Structure {
name: inner_name, ..
} = &args[1]
{
assert_eq!(inner_name, "+");
} else {
panic!("expected + structure");
}
} else {
panic!("expected is structure");
}
}
#[test]
fn test_lower_program() {
let items = vec![
fact(
"parent",
vec![Term::Atom("alice".into()), Term::Atom("bob".into())],
),
rule(
"grandparent",
vec![Term::Variable("X".into()), Term::Variable("Y".into())],
vec![
call(
"parent",
vec![Term::Variable("X".into()), Term::Variable("Z".into())],
),
call(
"parent",
vec![Term::Variable("Z".into()), Term::Variable("Y".into())],
),
],
),
query(vec![call(
"grandparent",
vec![Term::Atom("alice".into()), Term::Variable("Y".into())],
)]),
];
let prog = lower(&items);
assert_eq!(prog.clauses.len(), 2);
assert_eq!(prog.queries.len(), 1);
}
fn ir_fact(name: &str, args: Vec<IrTerm>) -> IrClause {
IrClause {
head: IrTerm::Structure { name: name.into(), args },
body: vec![],
metadata: None,
}
}
fn ir_structure(name: &str, args: Vec<IrTerm>) -> IrTerm {
IrTerm::Structure { name: name.into(), args }
}
fn ir_var(s: &str) -> IrTerm {
IrTerm::Var(s.into())
}
fn ir_atom(s: &str) -> IrTerm {
IrTerm::Atom(s.into())
}
#[test]
fn test_extract_strip_actions_empty_program() {
let prog = IrProgram::new();
assert!(extract_strip_actions(&prog).is_empty());
}
#[test]
fn test_extract_strip_actions_no_precond_effect() {
let mut prog = IrProgram::new();
prog.clauses.push(ir_fact("at", vec![ir_atom("a")]));
assert!(extract_strip_actions(&prog).is_empty());
}
#[test]
fn test_extract_strip_actions_single_action() {
let move_xy =
ir_structure("move", vec![ir_var("X"), ir_var("Y")]);
let mut prog = IrProgram::new();
prog.clauses.push(ir_fact(
"precond",
vec![move_xy.clone(), ir_structure("at", vec![ir_var("X")])],
));
prog.clauses.push(ir_fact(
"precond",
vec![
move_xy.clone(),
ir_structure("connected", vec![ir_var("X"), ir_var("Y")]),
],
));
prog.clauses.push(ir_fact(
"effect",
vec![move_xy.clone(), ir_structure("at", vec![ir_var("Y")])],
));
prog.clauses.push(ir_fact(
"effect",
vec![
move_xy.clone(),
ir_structure("not", vec![ir_structure("at", vec![ir_var("X")])]),
],
));
let actions = extract_strip_actions(&prog);
assert_eq!(actions.len(), 1);
let action = &actions[0];
assert_eq!(action.name, "move");
assert_eq!(action.parameters, vec!["X", "Y"]);
assert_eq!(action.preconditions.len(), 2);
assert_eq!(action.effects.len(), 2);
}
#[test]
fn test_extract_strip_actions_ground_atom_action() {
let mut prog = IrProgram::new();
prog.clauses.push(ir_fact(
"precond",
vec![
ir_atom("unlock"),
ir_structure("state", vec![ir_atom("locked")]),
],
));
prog.clauses.push(ir_fact(
"effect",
vec![
ir_atom("unlock"),
ir_structure("state", vec![ir_atom("unlocked")]),
],
));
prog.clauses.push(ir_fact(
"effect",
vec![
ir_atom("unlock"),
ir_structure(
"not",
vec![ir_structure("state", vec![ir_atom("locked")])],
),
],
));
let actions = extract_strip_actions(&prog);
assert_eq!(actions.len(), 1);
assert_eq!(actions[0].name, "unlock");
assert!(actions[0].parameters.is_empty());
assert_eq!(actions[0].preconditions.len(), 1);
assert_eq!(actions[0].effects.len(), 2);
}
#[test]
fn test_extract_strip_actions_multiple_actions() {
let move_xy = ir_structure("move", vec![ir_var("X"), ir_var("Y")]);
let pickup_x = ir_structure("pickup", vec![ir_var("X")]);
let mut prog = IrProgram::new();
prog.clauses
.push(ir_fact("precond", vec![move_xy.clone(), ir_structure("at", vec![ir_var("X")])]));
prog.clauses
.push(ir_fact("effect", vec![move_xy, ir_structure("at", vec![ir_var("Y")])]));
prog.clauses
.push(ir_fact("precond", vec![pickup_x.clone(), ir_structure("clear", vec![ir_var("X")])]));
prog.clauses
.push(ir_fact("effect", vec![pickup_x, ir_structure("holding", vec![ir_var("X")])]));
let actions = extract_strip_actions(&prog);
assert_eq!(actions.len(), 2);
assert_eq!(actions[0].name, "move");
assert_eq!(actions[1].name, "pickup");
}
#[test]
fn test_extract_strip_actions_ignores_rules() {
let move_xy = ir_structure("move", vec![ir_var("X"), ir_var("Y")]);
let mut prog = IrProgram::new();
prog.clauses.push(IrClause {
head: ir_structure(
"precond",
vec![move_xy, ir_structure("at", vec![ir_var("X")])],
),
body: vec![ir_structure("edge", vec![ir_var("X"), ir_var("Y")])], metadata: None,
});
assert!(extract_strip_actions(&prog).is_empty());
}
#[test]
fn test_extract_strip_actions_coexists_with_other_clauses() {
let move_xy = ir_structure("move", vec![ir_var("X"), ir_var("Y")]);
let mut prog = IrProgram::new();
prog.clauses.push(ir_fact("at", vec![ir_atom("a")]));
prog.clauses
.push(ir_fact("precond", vec![move_xy.clone(), ir_structure("at", vec![ir_var("X")])]));
prog.clauses
.push(ir_fact("effect", vec![move_xy, ir_structure("at", vec![ir_var("Y")])]));
let actions = extract_strip_actions(&prog);
assert_eq!(actions.len(), 1);
assert_eq!(prog.clauses.len(), 3); }
}