use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use crate::metadata::IrMetadata;
use crate::term::IrTerm;
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum ActionValidationError {
#[error("action name must not be empty")]
EmptyName,
#[error("duplicate parameter name: {0:?}")]
DuplicateParameter(String),
#[error("undeclared variable {0:?} in {1}")]
UndeclaredVariable(String, &'static str),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct IrAction {
pub name: String,
pub parameters: Vec<String>,
pub preconditions: Vec<IrTerm>,
pub effects: Vec<IrTerm>,
pub metadata: Option<IrMetadata>,
}
impl IrAction {
pub fn validate(&self) -> Result<(), ActionValidationError> {
if self.name.is_empty() {
return Err(ActionValidationError::EmptyName);
}
let mut seen: HashSet<&str> = HashSet::new();
for p in &self.parameters {
if !seen.insert(p.as_str()) {
return Err(ActionValidationError::DuplicateParameter(p.clone()));
}
}
let declared: HashSet<&str> = self.parameters.iter().map(String::as_str).collect();
for term in &self.preconditions {
if let Some(v) = find_undeclared_var(term, &declared) {
return Err(ActionValidationError::UndeclaredVariable(v, "preconditions"));
}
}
for term in &self.effects {
if let Some(v) = find_undeclared_var(term, &declared) {
return Err(ActionValidationError::UndeclaredVariable(v, "effects"));
}
}
Ok(())
}
}
fn find_undeclared_var(term: &IrTerm, declared: &HashSet<&str>) -> Option<String> {
match term {
IrTerm::Var(name) => {
if !name.starts_with('_') && !declared.contains(name.as_str()) {
Some(name.clone())
} else {
None
}
}
IrTerm::Structure { args, .. } => {
args.iter().find_map(|a| find_undeclared_var(a, declared))
}
IrTerm::Typed { term: inner, .. }
| IrTerm::Neural { term: inner, .. }
| IrTerm::Differentiable { term: inner, .. } => find_undeclared_var(inner, declared),
IrTerm::DiffNeural { term: inner, .. } => find_undeclared_var(inner, declared),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn at(loc: &str) -> IrTerm {
IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Atom(loc.into())] }
}
fn at_var(v: &str) -> IrTerm {
IrTerm::Structure { name: "at".into(), args: vec![IrTerm::Var(v.into())] }
}
fn connected(x: &str, y: &str) -> IrTerm {
IrTerm::Structure {
name: "connected".into(),
args: vec![IrTerm::Var(x.into()), IrTerm::Var(y.into())],
}
}
fn not_term(inner: IrTerm) -> IrTerm {
IrTerm::Structure { name: "not".into(), args: vec![inner] }
}
#[test]
fn test_move_action_fields() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into(), "Y".into()],
preconditions: vec![at_var("X"), connected("X", "Y")],
effects: vec![at_var("Y"), not_term(at_var("X"))],
metadata: None,
};
assert_eq!(action.name, "move");
assert_eq!(action.parameters, vec!["X", "Y"]);
assert_eq!(action.preconditions.len(), 2);
assert_eq!(action.effects.len(), 2);
assert!(action.metadata.is_none());
}
#[test]
fn test_ground_action_no_parameters() {
let action = IrAction {
name: "open_door".into(),
parameters: vec![],
preconditions: vec![at("locked")],
effects: vec![at("unlocked"), not_term(at("locked"))],
metadata: None,
};
assert!(action.parameters.is_empty());
assert_eq!(action.preconditions.len(), 1);
assert_eq!(action.effects.len(), 2);
}
#[test]
fn test_action_with_metadata() {
let action = IrAction {
name: "risky_move".into(),
parameters: vec!["X".into()],
preconditions: vec![at_var("X")],
effects: vec![],
metadata: Some(IrMetadata { probability: Some(0.8), ..IrMetadata::default() }),
};
assert_eq!(
action.metadata.as_ref().and_then(|m| m.probability),
Some(0.8)
);
}
#[test]
fn test_serde_roundtrip() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into(), "Y".into()],
preconditions: vec![at_var("X"), connected("X", "Y")],
effects: vec![at_var("Y"), not_term(at_var("X"))],
metadata: None,
};
let s = ron::to_string(&action).expect("serialize failed");
let back: IrAction = ron::from_str(&s).expect("deserialize failed");
assert_eq!(action, back);
}
#[test]
fn test_serde_roundtrip_with_metadata() {
let action = IrAction {
name: "probabilistic_move".into(),
parameters: vec!["X".into()],
preconditions: vec![],
effects: vec![at_var("X")],
metadata: Some(IrMetadata { probability: Some(0.5), ..IrMetadata::default() }),
};
let s = ron::to_string(&action).expect("serialize failed");
let back: IrAction = ron::from_str(&s).expect("deserialize failed");
assert_eq!(action, back);
}
#[test]
fn test_clone_and_eq() {
let a = IrAction {
name: "jump".into(),
parameters: vec!["X".into()],
preconditions: vec![at_var("X")],
effects: vec![],
metadata: None,
};
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn test_validate_well_formed_parametric_action() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into(), "Y".into()],
preconditions: vec![at_var("X"), connected("X", "Y")],
effects: vec![at_var("Y"), not_term(at_var("X"))],
metadata: None,
};
assert!(action.validate().is_ok());
}
#[test]
fn test_validate_well_formed_ground_action() {
let action = IrAction {
name: "open_door".into(),
parameters: vec![],
preconditions: vec![at("locked")],
effects: vec![at("unlocked"), not_term(at("locked"))],
metadata: None,
};
assert!(action.validate().is_ok());
}
#[test]
fn test_validate_empty_name_is_error() {
let action = IrAction {
name: "".into(),
parameters: vec![],
preconditions: vec![],
effects: vec![],
metadata: None,
};
assert_eq!(action.validate(), Err(ActionValidationError::EmptyName));
}
#[test]
fn test_validate_duplicate_parameter_is_error() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into(), "X".into()],
preconditions: vec![at_var("X")],
effects: vec![],
metadata: None,
};
assert_eq!(
action.validate(),
Err(ActionValidationError::DuplicateParameter("X".into()))
);
}
#[test]
fn test_validate_undeclared_variable_in_preconditions() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into()],
preconditions: vec![connected("X", "Y")],
effects: vec![],
metadata: None,
};
assert_eq!(
action.validate(),
Err(ActionValidationError::UndeclaredVariable("Y".into(), "preconditions"))
);
}
#[test]
fn test_validate_undeclared_variable_in_effects() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into()],
preconditions: vec![at_var("X")],
effects: vec![IrTerm::Structure {
name: "at".into(),
args: vec![IrTerm::Var("Z".into())],
}],
metadata: None,
};
assert_eq!(
action.validate(),
Err(ActionValidationError::UndeclaredVariable("Z".into(), "effects"))
);
}
#[test]
fn test_validate_anonymous_variable_is_allowed() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into()],
preconditions: vec![IrTerm::Structure {
name: "edge".into(),
args: vec![IrTerm::Var("X".into()), IrTerm::Var("_".into())],
}],
effects: vec![],
metadata: None,
};
assert!(action.validate().is_ok());
}
#[test]
fn test_validate_nested_undeclared_variable_is_detected() {
let action = IrAction {
name: "move".into(),
parameters: vec!["X".into()],
preconditions: vec![],
effects: vec![not_term(IrTerm::Structure {
name: "at".into(),
args: vec![IrTerm::Var("Z".into())],
})],
metadata: None,
};
assert_eq!(
action.validate(),
Err(ActionValidationError::UndeclaredVariable("Z".into(), "effects"))
);
}
#[test]
fn test_validate_empty_preconditions_and_effects_is_ok() {
let action = IrAction {
name: "noop".into(),
parameters: vec![],
preconditions: vec![],
effects: vec![],
metadata: None,
};
assert!(action.validate().is_ok());
}
}