Skip to main content

gollum_ir/
program.rs

1//! IR program type.
2
3use std::collections::HashSet;
4
5use crate::action::{ActionValidationError, IrAction};
6use crate::clause::IrClause;
7use crate::query::IrQuery;
8
9/// A complete IR program: a collection of clauses, queries, and planning actions.
10#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
11pub struct IrProgram {
12    /// All clauses (facts and rules) in the program.
13    pub clauses: Vec<IrClause>,
14    /// All queries in the program.
15    pub queries: Vec<IrQuery>,
16    /// Predicates declared with `:- table functor/arity.`
17    #[serde(default)]
18    pub tabled_predicates: HashSet<(String, usize)>,
19    /// Predicates declared with `:- differentiable neural functor/arity using "model".`
20    #[serde(default)]
21    pub diff_neural_predicates: HashSet<(String, usize, String)>,
22    /// Generative neural predicates declared with `:- neural_gen functor/arity.`
23    #[serde(default)]
24    pub neural_gen_predicates: HashSet<(String, usize)>,
25    /// Mapping from neural predicate (name/arity) to a named model (from `:- neural_model`).
26    #[serde(default)]
27    pub neural_models: std::collections::HashMap<String, String>,
28    /// STRIPS-style planning actions registered in this program.
29    #[serde(default)]
30    pub actions: Vec<IrAction>,
31    /// Cosine-similarity threshold used by the `~=` (neural unification) operator.
32    ///
33    /// Set via `:- neural_unify threshold(<value>).`; defaults to `0.85`.
34    #[serde(default = "default_neural_unify_threshold")]
35    pub neural_unify_threshold: f64,
36    /// PDDL type hierarchy: each entry is `(child_types, parent)`.
37    ///
38    /// Preserves the full type hierarchy from a PDDL domain so that
39    /// round-tripping through IR does not flatten the type tree.
40    /// For example, `city location - object` and `truck car - vehicle`
41    /// become `[  (["city", "location"], Some("object")),
42    ///             (["truck", "car"],    Some("vehicle")) ]`.
43    ///
44    /// Union (`either`) types in parameter annotations are stored
45    /// with `|`-separated names (e.g. `"brick|window"`).
46    #[serde(default)]
47    pub type_hierarchy: Vec<(Vec<String>, Option<String>)>,
48}
49
50impl IrProgram {
51    /// Create an empty `IrProgram`.
52    pub fn new() -> Self {
53        Self {
54            clauses: vec![],
55            queries: vec![],
56            tabled_predicates: HashSet::new(),
57            diff_neural_predicates: HashSet::new(),
58            neural_gen_predicates: HashSet::new(),
59            neural_models: std::collections::HashMap::new(),
60            actions: vec![],
61            neural_unify_threshold: default_neural_unify_threshold(),
62            type_hierarchy: vec![],
63        }
64    }
65}
66
67fn default_neural_unify_threshold() -> f64 {
68    0.85
69}
70
71impl IrProgram {
72    /// Validate all actions in the program, returning the index of the first
73    /// invalid action together with its [`ActionValidationError`], or `Ok(())`
74    /// if all actions are well-formed.
75    ///
76    /// This is a convenience wrapper around [`IrAction::validate`] that makes
77    /// it easy to gate the planning pipeline on structural correctness.
78    ///
79    /// # Example
80    ///
81    /// ```rust
82    /// use gollum_ir::{IrAction, IrProgram, IrTerm};
83    ///
84    /// let mut prog = IrProgram::new();
85    /// prog.actions.push(IrAction {
86    ///     name: "move".into(),
87    ///     parameters: vec!["X".into(), "Y".into()],
88    ///     preconditions: vec![
89    ///         IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Var("X".into())] },
90    ///     ],
91    ///     effects: vec![
92    ///         IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Var("Y".into())] },
93    ///     ],
94    ///     metadata: None,
95    /// });
96    /// assert!(prog.validate_actions().is_ok());
97    /// ```
98    pub fn validate_actions(&self) -> Result<(), (usize, ActionValidationError)> {
99        for (i, action) in self.actions.iter().enumerate() {
100            if let Err(e) = action.validate() {
101                return Err((i, e));
102            }
103        }
104        Ok(())
105    }
106}
107
108impl Default for IrProgram {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::action::IrAction;
118    use crate::clause::IrClause;
119    use crate::metadata::IrMetadata;
120    use crate::query::IrQuery;
121    use crate::term::IrTerm;
122
123    fn atom(s: &str) -> IrTerm {
124        IrTerm::Atom(s.into())
125    }
126
127    fn var(s: &str) -> IrTerm {
128        IrTerm::Var(s.into())
129    }
130
131    fn structure(name: &str, args: Vec<IrTerm>) -> IrTerm {
132        IrTerm::Structure { name: name.into(), args }
133    }
134
135    fn fact_clause(name: &str, args: Vec<IrTerm>) -> IrClause {
136        IrClause { head: structure(name, args), body: vec![], metadata: None }
137    }
138
139    fn query(name: &str, args: Vec<IrTerm>) -> IrQuery {
140        IrQuery { goal: structure(name, args), metadata: None }
141    }
142
143    fn move_action() -> IrAction {
144        IrAction {
145            name: "move".into(),
146            parameters: vec!["X".into(), "Y".into()],
147            preconditions: vec![
148                structure("at", vec![var("X")]),
149                structure("connected", vec![var("X"), var("Y")]),
150            ],
151            effects: vec![
152                structure("at", vec![var("Y")]),
153                structure("not", vec![structure("at", vec![var("X")])]),
154            ],
155            metadata: None,
156        }
157    }
158
159    #[test]
160    fn test_new_program_is_empty() {
161        let prog = IrProgram::new();
162        assert!(prog.clauses.is_empty());
163        assert!(prog.queries.is_empty());
164        assert!(prog.actions.is_empty());
165        assert!(prog.tabled_predicates.is_empty());
166        assert!(prog.diff_neural_predicates.is_empty());
167    }
168
169    #[test]
170    fn test_push_action_to_program() {
171        let mut prog = IrProgram::new();
172        prog.actions.push(move_action());
173        assert_eq!(prog.actions.len(), 1);
174        assert_eq!(prog.actions[0].name, "move");
175        assert_eq!(prog.actions[0].parameters, vec!["X", "Y"]);
176    }
177
178    #[test]
179    fn test_program_with_clauses_queries_and_actions() {
180        let mut prog = IrProgram::new();
181        prog.clauses.push(fact_clause("at", vec![atom("a")]));
182        prog.clauses.push(fact_clause("connected", vec![atom("a"), atom("b")]));
183        prog.queries.push(query("at", vec![var("X")]));
184        prog.actions.push(move_action());
185
186        assert_eq!(prog.clauses.len(), 2);
187        assert_eq!(prog.queries.len(), 1);
188        assert_eq!(prog.actions.len(), 1);
189    }
190
191    #[test]
192    fn test_multiple_actions_in_program() {
193        let mut prog = IrProgram::new();
194        prog.actions.push(move_action());
195        prog.actions.push(IrAction {
196            name: "pickup".into(),
197            parameters: vec!["X".into()],
198            preconditions: vec![structure("clear", vec![var("X")])],
199            effects: vec![structure("holding", vec![var("X")])],
200            metadata: None,
201        });
202        assert_eq!(prog.actions.len(), 2);
203        assert_eq!(prog.actions[1].name, "pickup");
204    }
205
206    #[test]
207    fn test_action_with_probability_metadata() {
208        let mut prog = IrProgram::new();
209        prog.actions.push(IrAction {
210            name: "risky_move".into(),
211            parameters: vec!["X".into()],
212            preconditions: vec![structure("at", vec![var("X")])],
213            effects: vec![],
214            metadata: Some(IrMetadata {
215                probability: Some(0.75),
216                ..IrMetadata::default()
217            }),
218        });
219        let prob = prog.actions[0]
220            .metadata
221            .as_ref()
222            .and_then(|m| m.probability);
223        assert_eq!(prob, Some(0.75));
224    }
225
226    #[test]
227    fn test_serde_roundtrip_empty() {
228        let prog = IrProgram::new();
229        let s = ron::to_string(&prog).unwrap();
230        let back: IrProgram = ron::from_str(&s).unwrap();
231        assert_eq!(prog, back);
232    }
233
234    #[test]
235    fn test_serde_roundtrip_with_actions() {
236        let mut prog = IrProgram::new();
237        prog.clauses.push(fact_clause("at", vec![atom("a")]));
238        prog.actions.push(move_action());
239
240        let s = ron::to_string(&prog).unwrap();
241        let back: IrProgram = ron::from_str(&s).unwrap();
242        assert_eq!(prog.clauses, back.clauses);
243        assert_eq!(prog.actions, back.actions);
244    }
245
246    #[test]
247    fn test_clone_and_eq() {
248        let mut prog = IrProgram::new();
249        prog.clauses.push(fact_clause("at", vec![atom("a")]));
250        prog.actions.push(move_action());
251        let copy = prog.clone();
252        assert_eq!(prog, copy);
253    }
254
255    // ── validate_actions() tests ──────────────────────────────────────────────
256
257    #[test]
258    fn test_validate_actions_empty_program_is_ok() {
259        assert!(IrProgram::new().validate_actions().is_ok());
260    }
261
262    #[test]
263    fn test_validate_actions_all_valid_is_ok() {
264        let mut prog = IrProgram::new();
265        prog.actions.push(move_action());
266        prog.actions.push(IrAction {
267            name: "pickup".into(),
268            parameters: vec!["X".into()],
269            preconditions: vec![structure("clear", vec![var("X")])],
270            effects: vec![structure("holding", vec![var("X")])],
271            metadata: None,
272        });
273        assert!(prog.validate_actions().is_ok());
274    }
275
276    #[test]
277    fn test_validate_actions_first_invalid_is_reported() {
278        use crate::action::ActionValidationError;
279        let mut prog = IrProgram::new();
280        // index 0 — valid
281        prog.actions.push(move_action());
282        // index 1 — invalid: empty name
283        prog.actions.push(IrAction {
284            name: "".into(),
285            parameters: vec![],
286            preconditions: vec![],
287            effects: vec![],
288            metadata: None,
289        });
290        // index 2 — invalid too, but we only report the first
291        prog.actions.push(IrAction {
292            name: "bad".into(),
293            parameters: vec![],
294            preconditions: vec![var("Z")], // undeclared var
295            effects: vec![],
296            metadata: None,
297        });
298        let result = prog.validate_actions();
299        assert!(result.is_err());
300        let (idx, err) = result.unwrap_err();
301        assert_eq!(idx, 1);
302        assert_eq!(err, ActionValidationError::EmptyName);
303    }
304
305    #[test]
306    fn test_validate_actions_second_invalid_index_correct() {
307        use crate::action::ActionValidationError;
308        let mut prog = IrProgram::new();
309        prog.actions.push(move_action()); // index 0 — valid
310        prog.actions.push(IrAction {      // index 1 — undeclared variable
311            name: "bad_effect".into(),
312            parameters: vec!["X".into()],
313            preconditions: vec![],
314            effects: vec![structure("at", vec![var("Z")])], // Z not in params
315            metadata: None,
316        });
317        let (idx, err) = prog.validate_actions().unwrap_err();
318        assert_eq!(idx, 1);
319        assert_eq!(
320            err,
321            ActionValidationError::UndeclaredVariable("Z".into(), "effects")
322        );
323    }
324
325    // ── lowering pipeline integration tests ───────────────────────────────────
326
327    #[test]
328    fn test_lower_populates_actions_automatically() {
329        // Verify that lower() auto-populates IrProgram.actions from
330        // precond/2 and effect/2 facts in the source.
331        let src = "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
332        let items = gollum_parser::parse(src).expect("parse failed");
333        let prog = crate::lower::lower(&items);
334        assert_eq!(prog.actions.len(), 1);
335        assert_eq!(prog.actions[0].name, "move");
336        assert_eq!(prog.actions[0].parameters, vec!["X", "Y"]);
337        assert_eq!(prog.actions[0].preconditions.len(), 1);
338        assert_eq!(prog.actions[0].effects.len(), 1);
339    }
340
341    #[test]
342    fn test_lower_actions_and_clauses_coexist() {
343        // Ordinary facts and STRIPS encoding live together in one program.
344        let src = "at(room1). precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
345        let items = gollum_parser::parse(src).expect("parse failed");
346        let prog = crate::lower::lower(&items);
347        // Three clauses (at/1, precond/2, effect/2) plus one extracted action.
348        assert_eq!(prog.clauses.len(), 3);
349        assert_eq!(prog.actions.len(), 1);
350    }
351
352    #[test]
353    fn test_lower_no_precond_effect_gives_empty_actions() {
354        let src = "parent(alice, bob). mortal(X) :- human(X).";
355        let items = gollum_parser::parse(src).expect("parse failed");
356        let prog = crate::lower::lower(&items);
357        assert!(prog.actions.is_empty());
358    }
359
360    #[test]
361    fn test_lower_multiple_actions_populated() {
362        let src = concat!(
363            "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)). ",
364            "precond(pickup(X), clear(X)). effect(pickup(X), holding(X))."
365        );
366        let items = gollum_parser::parse(src).expect("parse failed");
367        let prog = crate::lower::lower(&items);
368        assert_eq!(prog.actions.len(), 2);
369        let names: Vec<&str> = prog.actions.iter().map(|a| a.name.as_str()).collect();
370        assert!(names.contains(&"move"));
371        assert!(names.contains(&"pickup"));
372    }
373
374    #[test]
375    fn test_validate_actions_after_lower() {
376        let src = "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
377        let items = gollum_parser::parse(src).expect("parse failed");
378        let prog = crate::lower::lower(&items);
379        assert!(prog.validate_actions().is_ok());
380    }
381}