#![forbid(unsafe_code)]
#![warn(missing_docs)]
mod system;
use core::fmt;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub use system::{AiComponent, AiInputs, AiOutputs, AiSystem, BehaviorState, YamlAiBridge};
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum AiError {
#[error("No valid plan found to achieve goal")]
NoPlanFound,
#[error("Action preconditions not met: {0}")]
PreconditionsNotMet(String),
}
pub type Result<T> = core::result::Result<T, AiError>;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct WorldState {
facts: HashMap<String, bool>,
}
impl WorldState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, key: impl Into<String>, value: bool) {
let _ = self.facts.insert(key.into(), value);
}
#[must_use]
pub fn get(&self, key: &str) -> bool {
self.facts.get(key).copied().unwrap_or(false)
}
#[must_use]
pub fn satisfies(&self, conditions: &Self) -> bool {
conditions.facts.iter().all(|(k, v)| self.get(k) == *v)
}
#[cfg(test)]
#[must_use]
pub fn test() -> Self {
let mut state = Self::new();
state.set("has_weapon", false);
state.set("enemy_visible", true);
state
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Action {
pub name: String,
pub cost: f32,
pub preconditions: WorldState,
pub effects: WorldState,
}
impl Action {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
cost: 1.0,
preconditions: WorldState::new(),
effects: WorldState::new(),
}
}
#[must_use]
pub const fn with_cost(mut self, cost: f32) -> Self {
self.cost = cost;
self
}
#[must_use]
pub fn with_precondition(mut self, key: impl Into<String>, value: bool) -> Self {
self.preconditions.set(key, value);
self
}
#[must_use]
pub fn with_effect(mut self, key: impl Into<String>, value: bool) -> Self {
self.effects.set(key, value);
self
}
#[must_use]
pub fn can_run(&self, state: &WorldState) -> bool {
state.satisfies(&self.preconditions)
}
#[must_use]
pub fn apply(&self, state: &WorldState) -> WorldState {
let mut new_state = state.clone();
for (k, v) in &self.effects.facts {
new_state.set(k.clone(), *v);
}
new_state
}
}
impl PartialEq for Action {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
impl Eq for Action {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Goal {
pub name: String,
pub priority: f32,
pub desired_state: WorldState,
}
impl Goal {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
priority: 1.0,
desired_state: WorldState::new(),
}
}
#[must_use]
pub const fn with_priority(mut self, priority: f32) -> Self {
self.priority = priority;
self
}
#[must_use]
pub fn with_condition(mut self, key: impl Into<String>, value: bool) -> Self {
self.desired_state.set(key, value);
self
}
#[must_use]
pub fn is_satisfied(&self, state: &WorldState) -> bool {
state.satisfies(&self.desired_state)
}
}
pub struct Planner {
actions: Vec<Action>,
}
impl Planner {
#[must_use]
pub const fn new() -> Self {
Self {
actions: Vec::new(),
}
}
pub fn add_action(&mut self, action: Action) {
self.actions.push(action);
}
pub fn plan(&self, current_state: &WorldState, goal: &Goal) -> Result<Vec<Action>> {
if goal.is_satisfied(current_state) {
return Ok(Vec::new());
}
let mut plan = Vec::new();
let mut working_state = current_state.clone();
for _ in 0..100 {
if goal.is_satisfied(&working_state) {
return Ok(plan);
}
let mut best_action: Option<&Action> = None;
let mut best_progress = 0;
for action in &self.actions {
if !action.can_run(&working_state) {
continue;
}
let new_state = action.apply(&working_state);
let progress = count_satisfied(&new_state, &goal.desired_state)
- count_satisfied(&working_state, &goal.desired_state);
if progress > best_progress || best_action.is_none() {
best_progress = progress;
best_action = Some(action);
}
}
if let Some(action) = best_action {
working_state = action.apply(&working_state);
plan.push(action.clone());
} else {
return Err(AiError::NoPlanFound);
}
}
Err(AiError::NoPlanFound)
}
}
impl Default for Planner {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for Planner {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Planner")
.field("action_count", &self.actions.len())
.finish()
}
}
fn count_satisfied(state: &WorldState, goal: &WorldState) -> i32 {
goal.facts
.iter()
.filter(|(k, v)| state.get(k) == **v)
.count() as i32
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NodeStatus {
Running,
Success,
Failure,
}
pub trait BehaviorNode: fmt::Debug {
fn tick(&mut self, dt: f32) -> NodeStatus;
fn reset(&mut self);
}
#[derive(Debug)]
pub struct Sequence {
children: Vec<Box<dyn BehaviorNode>>,
current: usize,
}
impl Sequence {
#[must_use]
pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
Self {
children,
current: 0,
}
}
}
impl BehaviorNode for Sequence {
fn tick(&mut self, dt: f32) -> NodeStatus {
while self.current < self.children.len() {
match self.children[self.current].tick(dt) {
NodeStatus::Running => return NodeStatus::Running,
NodeStatus::Success => self.current += 1,
NodeStatus::Failure => return NodeStatus::Failure,
}
}
NodeStatus::Success
}
fn reset(&mut self) {
self.current = 0;
for child in &mut self.children {
child.reset();
}
}
}
#[derive(Debug)]
pub struct Selector {
children: Vec<Box<dyn BehaviorNode>>,
current: usize,
}
impl Selector {
#[must_use]
pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
Self {
children,
current: 0,
}
}
}
impl BehaviorNode for Selector {
fn tick(&mut self, dt: f32) -> NodeStatus {
while self.current < self.children.len() {
match self.children[self.current].tick(dt) {
NodeStatus::Running => return NodeStatus::Running,
NodeStatus::Failure => self.current += 1,
NodeStatus::Success => return NodeStatus::Success,
}
}
NodeStatus::Failure
}
fn reset(&mut self) {
self.current = 0;
for child in &mut self.children {
child.reset();
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_world_state() {
let mut state = WorldState::new();
state.set("has_weapon", true);
assert!(state.get("has_weapon"));
assert!(!state.get("nonexistent"));
}
#[test]
fn test_world_state_test_helper() {
let state = WorldState::test();
assert!(!state.get("has_weapon"));
assert!(state.get("enemy_visible"));
}
#[test]
fn test_world_state_satisfies() {
let mut state = WorldState::new();
state.set("has_weapon", true);
state.set("has_ammo", true);
let mut conditions = WorldState::new();
conditions.set("has_weapon", true);
assert!(state.satisfies(&conditions));
conditions.set("has_ammo", false);
assert!(!state.satisfies(&conditions));
}
#[test]
fn test_action_can_run() {
let action = Action::new("attack").with_precondition("has_weapon", true);
let mut state = WorldState::new();
assert!(!action.can_run(&state));
state.set("has_weapon", true);
assert!(action.can_run(&state));
}
#[test]
fn test_action_apply() {
let action = Action::new("pickup_weapon").with_effect("has_weapon", true);
let state = WorldState::new();
let new_state = action.apply(&state);
assert!(new_state.get("has_weapon"));
}
#[test]
fn test_action_with_cost() {
let action = Action::new("expensive_action").with_cost(5.0);
assert!((action.cost - 5.0).abs() < f32::EPSILON);
}
#[test]
fn test_action_equality() {
let action1 = Action::new("attack").with_cost(1.0);
let action2 = Action::new("attack").with_cost(2.0);
let action3 = Action::new("defend");
assert_eq!(action1, action2); assert_ne!(action1, action3);
}
#[test]
fn test_goal_satisfied() {
let goal = Goal::new("be_armed").with_condition("has_weapon", true);
let mut state = WorldState::new();
assert!(!goal.is_satisfied(&state));
state.set("has_weapon", true);
assert!(goal.is_satisfied(&state));
}
#[test]
fn test_goal_with_priority() {
let goal = Goal::new("high_priority").with_priority(10.0);
assert!((goal.priority - 10.0).abs() < f32::EPSILON);
}
#[test]
fn test_planner_simple_plan() {
let mut planner = Planner::new();
planner.add_action(Action::new("pickup_weapon").with_effect("has_weapon", true));
let state = WorldState::new();
let goal = Goal::new("be_armed").with_condition("has_weapon", true);
let plan = planner.plan(&state, &goal).unwrap();
assert_eq!(plan.len(), 1);
assert_eq!(plan[0].name, "pickup_weapon");
}
#[test]
fn test_planner_already_satisfied() {
let planner = Planner::new();
let mut state = WorldState::new();
state.set("has_weapon", true);
let goal = Goal::new("be_armed").with_condition("has_weapon", true);
let plan = planner.plan(&state, &goal).unwrap();
assert!(plan.is_empty());
}
#[test]
fn test_planner_no_plan_found() {
let planner = Planner::new();
let state = WorldState::new();
let goal = Goal::new("impossible").with_condition("has_magic", true);
let result = planner.plan(&state, &goal);
assert!(matches!(result, Err(AiError::NoPlanFound)));
}
#[test]
fn test_planner_default() {
let planner = Planner::default();
assert!(format!("{planner:?}").contains("action_count"));
}
#[test]
fn test_planner_multi_step_plan() {
let mut planner = Planner::new();
planner.add_action(Action::new("pickup_weapon").with_effect("has_weapon", true));
planner.add_action(
Action::new("attack")
.with_precondition("has_weapon", true)
.with_effect("enemy_dead", true),
);
let state = WorldState::new();
let goal = Goal::new("win").with_condition("enemy_dead", true);
let plan = planner.plan(&state, &goal).unwrap();
assert_eq!(plan.len(), 2);
assert_eq!(plan[0].name, "pickup_weapon");
assert_eq!(plan[1].name, "attack");
}
#[test]
fn test_node_status() {
assert_ne!(NodeStatus::Running, NodeStatus::Success);
assert_ne!(NodeStatus::Success, NodeStatus::Failure);
}
#[derive(Debug)]
struct TestNode {
ticks: usize,
max_ticks: usize,
result: NodeStatus,
}
impl TestNode {
fn new(max_ticks: usize, result: NodeStatus) -> Self {
Self {
ticks: 0,
max_ticks,
result,
}
}
fn immediate(result: NodeStatus) -> Self {
Self::new(0, result)
}
}
impl BehaviorNode for TestNode {
fn tick(&mut self, _dt: f32) -> NodeStatus {
if self.ticks < self.max_ticks {
self.ticks += 1;
NodeStatus::Running
} else {
self.result
}
}
fn reset(&mut self) {
self.ticks = 0;
}
}
#[test]
fn test_sequence_all_success() {
let mut seq = Sequence::new(vec![
Box::new(TestNode::immediate(NodeStatus::Success)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(seq.tick(0.016), NodeStatus::Success);
}
#[test]
fn test_sequence_with_failure() {
let mut seq = Sequence::new(vec![
Box::new(TestNode::immediate(NodeStatus::Success)),
Box::new(TestNode::immediate(NodeStatus::Failure)),
]);
assert_eq!(seq.tick(0.016), NodeStatus::Failure);
}
#[test]
fn test_sequence_with_running() {
let mut seq = Sequence::new(vec![
Box::new(TestNode::new(2, NodeStatus::Success)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(seq.tick(0.016), NodeStatus::Running);
assert_eq!(seq.tick(0.016), NodeStatus::Running);
assert_eq!(seq.tick(0.016), NodeStatus::Success);
}
#[test]
fn test_sequence_reset() {
let mut seq = Sequence::new(vec![
Box::new(TestNode::new(1, NodeStatus::Success)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(seq.tick(0.016), NodeStatus::Running);
seq.reset();
assert_eq!(seq.tick(0.016), NodeStatus::Running);
}
#[test]
fn test_selector_first_success() {
let mut sel = Selector::new(vec![
Box::new(TestNode::immediate(NodeStatus::Success)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(sel.tick(0.016), NodeStatus::Success);
}
#[test]
fn test_selector_fallback() {
let mut sel = Selector::new(vec![
Box::new(TestNode::immediate(NodeStatus::Failure)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(sel.tick(0.016), NodeStatus::Success);
}
#[test]
fn test_selector_all_fail() {
let mut sel = Selector::new(vec![
Box::new(TestNode::immediate(NodeStatus::Failure)),
Box::new(TestNode::immediate(NodeStatus::Failure)),
]);
assert_eq!(sel.tick(0.016), NodeStatus::Failure);
}
#[test]
fn test_selector_with_running() {
let mut sel = Selector::new(vec![
Box::new(TestNode::new(2, NodeStatus::Failure)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(sel.tick(0.016), NodeStatus::Running);
assert_eq!(sel.tick(0.016), NodeStatus::Running);
assert_eq!(sel.tick(0.016), NodeStatus::Success);
}
#[test]
fn test_selector_reset() {
let mut sel = Selector::new(vec![
Box::new(TestNode::new(1, NodeStatus::Failure)),
Box::new(TestNode::immediate(NodeStatus::Success)),
]);
assert_eq!(sel.tick(0.016), NodeStatus::Running);
sel.reset();
assert_eq!(sel.tick(0.016), NodeStatus::Running);
}
#[test]
fn test_ai_error_display() {
let err1 = AiError::NoPlanFound;
assert!(format!("{err1}").contains("No valid plan"));
let err2 = AiError::PreconditionsNotMet("test".to_string());
assert!(format!("{err2}").contains("test"));
}
}