gollum-ir 0.4.0

Intermediate Representation for the Gollum language
Documentation
//! IR program type.

use std::collections::HashSet;

use crate::action::{ActionValidationError, IrAction};
use crate::clause::IrClause;
use crate::query::IrQuery;

/// A complete IR program: a collection of clauses, queries, and planning actions.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct IrProgram {
    /// All clauses (facts and rules) in the program.
    pub clauses: Vec<IrClause>,
    /// All queries in the program.
    pub queries: Vec<IrQuery>,
    /// Predicates declared with `:- table functor/arity.`
    #[serde(default)]
    pub tabled_predicates: HashSet<(String, usize)>,
    /// Predicates declared with `:- differentiable neural functor/arity using "model".`
    #[serde(default)]
    pub diff_neural_predicates: HashSet<(String, usize, String)>,
    /// Generative neural predicates declared with `:- neural_gen functor/arity.`
    #[serde(default)]
    pub neural_gen_predicates: HashSet<(String, usize)>,
    /// Mapping from neural predicate (name/arity) to a named model (from `:- neural_model`).
    #[serde(default)]
    pub neural_models: std::collections::HashMap<String, String>,
    /// STRIPS-style planning actions registered in this program.
    #[serde(default)]
    pub actions: Vec<IrAction>,
    /// Cosine-similarity threshold used by the `~=` (neural unification) operator.
    ///
    /// Set via `:- neural_unify threshold(<value>).`; defaults to `0.85`.
    #[serde(default = "default_neural_unify_threshold")]
    pub neural_unify_threshold: f64,
    /// PDDL type hierarchy: each entry is `(child_types, parent)`.
    ///
    /// Preserves the full type hierarchy from a PDDL domain so that
    /// round-tripping through IR does not flatten the type tree.
    /// For example, `city location - object` and `truck car - vehicle`
    /// become `[  (["city", "location"], Some("object")),
    ///             (["truck", "car"],    Some("vehicle")) ]`.
    ///
    /// Union (`either`) types in parameter annotations are stored
    /// with `|`-separated names (e.g. `"brick|window"`).
    #[serde(default)]
    pub type_hierarchy: Vec<(Vec<String>, Option<String>)>,
}

impl IrProgram {
    /// Create an empty `IrProgram`.
    pub fn new() -> Self {
        Self {
            clauses: vec![],
            queries: vec![],
            tabled_predicates: HashSet::new(),
            diff_neural_predicates: HashSet::new(),
            neural_gen_predicates: HashSet::new(),
            neural_models: std::collections::HashMap::new(),
            actions: vec![],
            neural_unify_threshold: default_neural_unify_threshold(),
            type_hierarchy: vec![],
        }
    }
}

fn default_neural_unify_threshold() -> f64 {
    0.85
}

impl IrProgram {
    /// Validate all actions in the program, returning the index of the first
    /// invalid action together with its [`ActionValidationError`], or `Ok(())`
    /// if all actions are well-formed.
    ///
    /// This is a convenience wrapper around [`IrAction::validate`] that makes
    /// it easy to gate the planning pipeline on structural correctness.
    ///
    /// # Example
    ///
    /// ```rust
    /// use gollum_ir::{IrAction, IrProgram, IrTerm};
    ///
    /// let mut prog = IrProgram::new();
    /// prog.actions.push(IrAction {
    ///     name: "move".into(),
    ///     parameters: vec!["X".into(), "Y".into()],
    ///     preconditions: vec![
    ///         IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Var("X".into())] },
    ///     ],
    ///     effects: vec![
    ///         IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Var("Y".into())] },
    ///     ],
    ///     metadata: None,
    /// });
    /// assert!(prog.validate_actions().is_ok());
    /// ```
    pub fn validate_actions(&self) -> Result<(), (usize, ActionValidationError)> {
        for (i, action) in self.actions.iter().enumerate() {
            if let Err(e) = action.validate() {
                return Err((i, e));
            }
        }
        Ok(())
    }
}

impl Default for IrProgram {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::action::IrAction;
    use crate::clause::IrClause;
    use crate::metadata::IrMetadata;
    use crate::query::IrQuery;
    use crate::term::IrTerm;

    fn atom(s: &str) -> IrTerm {
        IrTerm::Atom(s.into())
    }

    fn var(s: &str) -> IrTerm {
        IrTerm::Var(s.into())
    }

    fn structure(name: &str, args: Vec<IrTerm>) -> IrTerm {
        IrTerm::Structure { name: name.into(), args }
    }

    fn fact_clause(name: &str, args: Vec<IrTerm>) -> IrClause {
        IrClause { head: structure(name, args), body: vec![], metadata: None }
    }

    fn query(name: &str, args: Vec<IrTerm>) -> IrQuery {
        IrQuery { goal: structure(name, args), metadata: None }
    }

    fn move_action() -> IrAction {
        IrAction {
            name: "move".into(),
            parameters: vec!["X".into(), "Y".into()],
            preconditions: vec![
                structure("at", vec![var("X")]),
                structure("connected", vec![var("X"), var("Y")]),
            ],
            effects: vec![
                structure("at", vec![var("Y")]),
                structure("not", vec![structure("at", vec![var("X")])]),
            ],
            metadata: None,
        }
    }

    #[test]
    fn test_new_program_is_empty() {
        let prog = IrProgram::new();
        assert!(prog.clauses.is_empty());
        assert!(prog.queries.is_empty());
        assert!(prog.actions.is_empty());
        assert!(prog.tabled_predicates.is_empty());
        assert!(prog.diff_neural_predicates.is_empty());
    }

    #[test]
    fn test_push_action_to_program() {
        let mut prog = IrProgram::new();
        prog.actions.push(move_action());
        assert_eq!(prog.actions.len(), 1);
        assert_eq!(prog.actions[0].name, "move");
        assert_eq!(prog.actions[0].parameters, vec!["X", "Y"]);
    }

    #[test]
    fn test_program_with_clauses_queries_and_actions() {
        let mut prog = IrProgram::new();
        prog.clauses.push(fact_clause("at", vec![atom("a")]));
        prog.clauses.push(fact_clause("connected", vec![atom("a"), atom("b")]));
        prog.queries.push(query("at", vec![var("X")]));
        prog.actions.push(move_action());

        assert_eq!(prog.clauses.len(), 2);
        assert_eq!(prog.queries.len(), 1);
        assert_eq!(prog.actions.len(), 1);
    }

    #[test]
    fn test_multiple_actions_in_program() {
        let mut prog = IrProgram::new();
        prog.actions.push(move_action());
        prog.actions.push(IrAction {
            name: "pickup".into(),
            parameters: vec!["X".into()],
            preconditions: vec![structure("clear", vec![var("X")])],
            effects: vec![structure("holding", vec![var("X")])],
            metadata: None,
        });
        assert_eq!(prog.actions.len(), 2);
        assert_eq!(prog.actions[1].name, "pickup");
    }

    #[test]
    fn test_action_with_probability_metadata() {
        let mut prog = IrProgram::new();
        prog.actions.push(IrAction {
            name: "risky_move".into(),
            parameters: vec!["X".into()],
            preconditions: vec![structure("at", vec![var("X")])],
            effects: vec![],
            metadata: Some(IrMetadata {
                probability: Some(0.75),
                ..IrMetadata::default()
            }),
        });
        let prob = prog.actions[0]
            .metadata
            .as_ref()
            .and_then(|m| m.probability);
        assert_eq!(prob, Some(0.75));
    }

    #[test]
    fn test_serde_roundtrip_empty() {
        let prog = IrProgram::new();
        let s = ron::to_string(&prog).unwrap();
        let back: IrProgram = ron::from_str(&s).unwrap();
        assert_eq!(prog, back);
    }

    #[test]
    fn test_serde_roundtrip_with_actions() {
        let mut prog = IrProgram::new();
        prog.clauses.push(fact_clause("at", vec![atom("a")]));
        prog.actions.push(move_action());

        let s = ron::to_string(&prog).unwrap();
        let back: IrProgram = ron::from_str(&s).unwrap();
        assert_eq!(prog.clauses, back.clauses);
        assert_eq!(prog.actions, back.actions);
    }

    #[test]
    fn test_clone_and_eq() {
        let mut prog = IrProgram::new();
        prog.clauses.push(fact_clause("at", vec![atom("a")]));
        prog.actions.push(move_action());
        let copy = prog.clone();
        assert_eq!(prog, copy);
    }

    // ── validate_actions() tests ──────────────────────────────────────────────

    #[test]
    fn test_validate_actions_empty_program_is_ok() {
        assert!(IrProgram::new().validate_actions().is_ok());
    }

    #[test]
    fn test_validate_actions_all_valid_is_ok() {
        let mut prog = IrProgram::new();
        prog.actions.push(move_action());
        prog.actions.push(IrAction {
            name: "pickup".into(),
            parameters: vec!["X".into()],
            preconditions: vec![structure("clear", vec![var("X")])],
            effects: vec![structure("holding", vec![var("X")])],
            metadata: None,
        });
        assert!(prog.validate_actions().is_ok());
    }

    #[test]
    fn test_validate_actions_first_invalid_is_reported() {
        use crate::action::ActionValidationError;
        let mut prog = IrProgram::new();
        // index 0 — valid
        prog.actions.push(move_action());
        // index 1 — invalid: empty name
        prog.actions.push(IrAction {
            name: "".into(),
            parameters: vec![],
            preconditions: vec![],
            effects: vec![],
            metadata: None,
        });
        // index 2 — invalid too, but we only report the first
        prog.actions.push(IrAction {
            name: "bad".into(),
            parameters: vec![],
            preconditions: vec![var("Z")], // undeclared var
            effects: vec![],
            metadata: None,
        });
        let result = prog.validate_actions();
        assert!(result.is_err());
        let (idx, err) = result.unwrap_err();
        assert_eq!(idx, 1);
        assert_eq!(err, ActionValidationError::EmptyName);
    }

    #[test]
    fn test_validate_actions_second_invalid_index_correct() {
        use crate::action::ActionValidationError;
        let mut prog = IrProgram::new();
        prog.actions.push(move_action()); // index 0 — valid
        prog.actions.push(IrAction {      // index 1 — undeclared variable
            name: "bad_effect".into(),
            parameters: vec!["X".into()],
            preconditions: vec![],
            effects: vec![structure("at", vec![var("Z")])], // Z not in params
            metadata: None,
        });
        let (idx, err) = prog.validate_actions().unwrap_err();
        assert_eq!(idx, 1);
        assert_eq!(
            err,
            ActionValidationError::UndeclaredVariable("Z".into(), "effects")
        );
    }

    // ── lowering pipeline integration tests ───────────────────────────────────

    #[test]
    fn test_lower_populates_actions_automatically() {
        // Verify that lower() auto-populates IrProgram.actions from
        // precond/2 and effect/2 facts in the source.
        let src = "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
        let items = gollum_parser::parse(src).expect("parse failed");
        let prog = crate::lower::lower(&items);
        assert_eq!(prog.actions.len(), 1);
        assert_eq!(prog.actions[0].name, "move");
        assert_eq!(prog.actions[0].parameters, vec!["X", "Y"]);
        assert_eq!(prog.actions[0].preconditions.len(), 1);
        assert_eq!(prog.actions[0].effects.len(), 1);
    }

    #[test]
    fn test_lower_actions_and_clauses_coexist() {
        // Ordinary facts and STRIPS encoding live together in one program.
        let src = "at(room1). precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
        let items = gollum_parser::parse(src).expect("parse failed");
        let prog = crate::lower::lower(&items);
        // Three clauses (at/1, precond/2, effect/2) plus one extracted action.
        assert_eq!(prog.clauses.len(), 3);
        assert_eq!(prog.actions.len(), 1);
    }

    #[test]
    fn test_lower_no_precond_effect_gives_empty_actions() {
        let src = "parent(alice, bob). mortal(X) :- human(X).";
        let items = gollum_parser::parse(src).expect("parse failed");
        let prog = crate::lower::lower(&items);
        assert!(prog.actions.is_empty());
    }

    #[test]
    fn test_lower_multiple_actions_populated() {
        let src = concat!(
            "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)). ",
            "precond(pickup(X), clear(X)). effect(pickup(X), holding(X))."
        );
        let items = gollum_parser::parse(src).expect("parse failed");
        let prog = crate::lower::lower(&items);
        assert_eq!(prog.actions.len(), 2);
        let names: Vec<&str> = prog.actions.iter().map(|a| a.name.as_str()).collect();
        assert!(names.contains(&"move"));
        assert!(names.contains(&"pickup"));
    }

    #[test]
    fn test_validate_actions_after_lower() {
        let src = "precond(move(X,Y), at(X)). effect(move(X,Y), at(Y)).";
        let items = gollum_parser::parse(src).expect("parse failed");
        let prog = crate::lower::lower(&items);
        assert!(prog.validate_actions().is_ok());
    }
}