1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31mod system;
32
33use core::fmt;
34use std::collections::HashMap;
35
36use serde::{Deserialize, Serialize};
37use thiserror::Error;
38
39pub use system::{AiComponent, AiInputs, AiOutputs, AiSystem, BehaviorState, YamlAiBridge};
40
41#[derive(Error, Debug, Clone, PartialEq, Eq)]
43pub enum AiError {
44 #[error("No valid plan found to achieve goal")]
46 NoPlanFound,
47 #[error("Action preconditions not met: {0}")]
49 PreconditionsNotMet(String),
50}
51
52pub type Result<T> = core::result::Result<T, AiError>;
54
55#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
57pub struct WorldState {
58 facts: HashMap<String, bool>,
59}
60
61impl WorldState {
62 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn set(&mut self, key: impl Into<String>, value: bool) {
70 let _ = self.facts.insert(key.into(), value);
71 }
72
73 #[must_use]
75 pub fn get(&self, key: &str) -> bool {
76 self.facts.get(key).copied().unwrap_or(false)
77 }
78
79 #[must_use]
81 pub fn satisfies(&self, conditions: &Self) -> bool {
82 conditions.facts.iter().all(|(k, v)| self.get(k) == *v)
83 }
84
85 #[cfg(test)]
87 #[must_use]
88 pub fn test() -> Self {
89 let mut state = Self::new();
90 state.set("has_weapon", false);
91 state.set("enemy_visible", true);
92 state
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Action {
99 pub name: String,
101 pub cost: f32,
103 pub preconditions: WorldState,
105 pub effects: WorldState,
107}
108
109impl Action {
110 #[must_use]
112 pub fn new(name: impl Into<String>) -> Self {
113 Self {
114 name: name.into(),
115 cost: 1.0,
116 preconditions: WorldState::new(),
117 effects: WorldState::new(),
118 }
119 }
120
121 #[must_use]
123 pub const fn with_cost(mut self, cost: f32) -> Self {
124 self.cost = cost;
125 self
126 }
127
128 #[must_use]
130 pub fn with_precondition(mut self, key: impl Into<String>, value: bool) -> Self {
131 self.preconditions.set(key, value);
132 self
133 }
134
135 #[must_use]
137 pub fn with_effect(mut self, key: impl Into<String>, value: bool) -> Self {
138 self.effects.set(key, value);
139 self
140 }
141
142 #[must_use]
144 pub fn can_run(&self, state: &WorldState) -> bool {
145 state.satisfies(&self.preconditions)
146 }
147
148 #[must_use]
150 pub fn apply(&self, state: &WorldState) -> WorldState {
151 let mut new_state = state.clone();
152 for (k, v) in &self.effects.facts {
153 new_state.set(k.clone(), *v);
154 }
155 new_state
156 }
157}
158
159impl PartialEq for Action {
160 fn eq(&self, other: &Self) -> bool {
161 self.name == other.name
162 }
163}
164
165impl Eq for Action {}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct Goal {
170 pub name: String,
172 pub priority: f32,
174 pub desired_state: WorldState,
176}
177
178impl Goal {
179 #[must_use]
181 pub fn new(name: impl Into<String>) -> Self {
182 Self {
183 name: name.into(),
184 priority: 1.0,
185 desired_state: WorldState::new(),
186 }
187 }
188
189 #[must_use]
191 pub const fn with_priority(mut self, priority: f32) -> Self {
192 self.priority = priority;
193 self
194 }
195
196 #[must_use]
198 pub fn with_condition(mut self, key: impl Into<String>, value: bool) -> Self {
199 self.desired_state.set(key, value);
200 self
201 }
202
203 #[must_use]
205 pub fn is_satisfied(&self, state: &WorldState) -> bool {
206 state.satisfies(&self.desired_state)
207 }
208}
209
210pub struct Planner {
212 actions: Vec<Action>,
213}
214
215impl Planner {
216 #[must_use]
218 pub const fn new() -> Self {
219 Self {
220 actions: Vec::new(),
221 }
222 }
223
224 pub fn add_action(&mut self, action: Action) {
226 self.actions.push(action);
227 }
228
229 pub fn plan(&self, current_state: &WorldState, goal: &Goal) -> Result<Vec<Action>> {
237 if goal.is_satisfied(current_state) {
238 return Ok(Vec::new());
239 }
240
241 let mut plan = Vec::new();
243 let mut working_state = current_state.clone();
244
245 for _ in 0..100 {
246 if goal.is_satisfied(&working_state) {
248 return Ok(plan);
249 }
250
251 let mut best_action: Option<&Action> = None;
253 let mut best_progress = 0;
254
255 for action in &self.actions {
256 if !action.can_run(&working_state) {
257 continue;
258 }
259
260 let new_state = action.apply(&working_state);
261 let progress = count_satisfied(&new_state, &goal.desired_state)
262 - count_satisfied(&working_state, &goal.desired_state);
263
264 if progress > best_progress || best_action.is_none() {
265 best_progress = progress;
266 best_action = Some(action);
267 }
268 }
269
270 if let Some(action) = best_action {
271 working_state = action.apply(&working_state);
272 plan.push(action.clone());
273 } else {
274 return Err(AiError::NoPlanFound);
275 }
276 }
277
278 Err(AiError::NoPlanFound)
279 }
280}
281
282impl Default for Planner {
283 fn default() -> Self {
284 Self::new()
285 }
286}
287
288impl fmt::Debug for Planner {
289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 f.debug_struct("Planner")
291 .field("action_count", &self.actions.len())
292 .finish()
293 }
294}
295
296fn count_satisfied(state: &WorldState, goal: &WorldState) -> i32 {
297 goal.facts
298 .iter()
299 .filter(|(k, v)| state.get(k) == **v)
300 .count() as i32
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305pub enum NodeStatus {
306 Running,
308 Success,
310 Failure,
312}
313
314pub trait BehaviorNode: fmt::Debug {
316 fn tick(&mut self, dt: f32) -> NodeStatus;
318
319 fn reset(&mut self);
321}
322
323#[derive(Debug)]
325pub struct Sequence {
326 children: Vec<Box<dyn BehaviorNode>>,
327 current: usize,
328}
329
330impl Sequence {
331 #[must_use]
333 pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
334 Self {
335 children,
336 current: 0,
337 }
338 }
339}
340
341impl BehaviorNode for Sequence {
342 fn tick(&mut self, dt: f32) -> NodeStatus {
343 while self.current < self.children.len() {
344 match self.children[self.current].tick(dt) {
345 NodeStatus::Running => return NodeStatus::Running,
346 NodeStatus::Success => self.current += 1,
347 NodeStatus::Failure => return NodeStatus::Failure,
348 }
349 }
350 NodeStatus::Success
351 }
352
353 fn reset(&mut self) {
354 self.current = 0;
355 for child in &mut self.children {
356 child.reset();
357 }
358 }
359}
360
361#[derive(Debug)]
363pub struct Selector {
364 children: Vec<Box<dyn BehaviorNode>>,
365 current: usize,
366}
367
368impl Selector {
369 #[must_use]
371 pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
372 Self {
373 children,
374 current: 0,
375 }
376 }
377}
378
379impl BehaviorNode for Selector {
380 fn tick(&mut self, dt: f32) -> NodeStatus {
381 while self.current < self.children.len() {
382 match self.children[self.current].tick(dt) {
383 NodeStatus::Running => return NodeStatus::Running,
384 NodeStatus::Failure => self.current += 1,
385 NodeStatus::Success => return NodeStatus::Success,
386 }
387 }
388 NodeStatus::Failure
389 }
390
391 fn reset(&mut self) {
392 self.current = 0;
393 for child in &mut self.children {
394 child.reset();
395 }
396 }
397}
398
399#[cfg(test)]
400#[allow(clippy::unwrap_used, clippy::expect_used)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_world_state() {
406 let mut state = WorldState::new();
407 state.set("has_weapon", true);
408 assert!(state.get("has_weapon"));
409 assert!(!state.get("nonexistent"));
410 }
411
412 #[test]
413 fn test_world_state_test_helper() {
414 let state = WorldState::test();
415 assert!(!state.get("has_weapon"));
416 assert!(state.get("enemy_visible"));
417 }
418
419 #[test]
420 fn test_world_state_satisfies() {
421 let mut state = WorldState::new();
422 state.set("has_weapon", true);
423 state.set("has_ammo", true);
424
425 let mut conditions = WorldState::new();
426 conditions.set("has_weapon", true);
427
428 assert!(state.satisfies(&conditions));
429
430 conditions.set("has_ammo", false);
431 assert!(!state.satisfies(&conditions));
432 }
433
434 #[test]
435 fn test_action_can_run() {
436 let action = Action::new("attack").with_precondition("has_weapon", true);
437
438 let mut state = WorldState::new();
439 assert!(!action.can_run(&state));
440
441 state.set("has_weapon", true);
442 assert!(action.can_run(&state));
443 }
444
445 #[test]
446 fn test_action_apply() {
447 let action = Action::new("pickup_weapon").with_effect("has_weapon", true);
448
449 let state = WorldState::new();
450 let new_state = action.apply(&state);
451
452 assert!(new_state.get("has_weapon"));
453 }
454
455 #[test]
456 fn test_action_with_cost() {
457 let action = Action::new("expensive_action").with_cost(5.0);
458 assert!((action.cost - 5.0).abs() < f32::EPSILON);
459 }
460
461 #[test]
462 fn test_action_equality() {
463 let action1 = Action::new("attack").with_cost(1.0);
464 let action2 = Action::new("attack").with_cost(2.0);
465 let action3 = Action::new("defend");
466
467 assert_eq!(action1, action2); assert_ne!(action1, action3);
469 }
470
471 #[test]
472 fn test_goal_satisfied() {
473 let goal = Goal::new("be_armed").with_condition("has_weapon", true);
474
475 let mut state = WorldState::new();
476 assert!(!goal.is_satisfied(&state));
477
478 state.set("has_weapon", true);
479 assert!(goal.is_satisfied(&state));
480 }
481
482 #[test]
483 fn test_goal_with_priority() {
484 let goal = Goal::new("high_priority").with_priority(10.0);
485 assert!((goal.priority - 10.0).abs() < f32::EPSILON);
486 }
487
488 #[test]
489 fn test_planner_simple_plan() {
490 let mut planner = Planner::new();
491
492 planner.add_action(Action::new("pickup_weapon").with_effect("has_weapon", true));
493
494 let state = WorldState::new();
495 let goal = Goal::new("be_armed").with_condition("has_weapon", true);
496
497 let plan = planner.plan(&state, &goal).unwrap();
498 assert_eq!(plan.len(), 1);
499 assert_eq!(plan[0].name, "pickup_weapon");
500 }
501
502 #[test]
503 fn test_planner_already_satisfied() {
504 let planner = Planner::new();
505
506 let mut state = WorldState::new();
507 state.set("has_weapon", true);
508
509 let goal = Goal::new("be_armed").with_condition("has_weapon", true);
510
511 let plan = planner.plan(&state, &goal).unwrap();
512 assert!(plan.is_empty());
513 }
514
515 #[test]
516 fn test_planner_no_plan_found() {
517 let planner = Planner::new();
518
519 let state = WorldState::new();
520 let goal = Goal::new("impossible").with_condition("has_magic", true);
521
522 let result = planner.plan(&state, &goal);
523 assert!(matches!(result, Err(AiError::NoPlanFound)));
524 }
525
526 #[test]
527 fn test_planner_default() {
528 let planner = Planner::default();
529 assert!(format!("{planner:?}").contains("action_count"));
530 }
531
532 #[test]
533 fn test_planner_multi_step_plan() {
534 let mut planner = Planner::new();
535
536 planner.add_action(Action::new("pickup_weapon").with_effect("has_weapon", true));
537 planner.add_action(
538 Action::new("attack")
539 .with_precondition("has_weapon", true)
540 .with_effect("enemy_dead", true),
541 );
542
543 let state = WorldState::new();
544 let goal = Goal::new("win").with_condition("enemy_dead", true);
545
546 let plan = planner.plan(&state, &goal).unwrap();
547 assert_eq!(plan.len(), 2);
548 assert_eq!(plan[0].name, "pickup_weapon");
549 assert_eq!(plan[1].name, "attack");
550 }
551
552 #[test]
553 fn test_node_status() {
554 assert_ne!(NodeStatus::Running, NodeStatus::Success);
555 assert_ne!(NodeStatus::Success, NodeStatus::Failure);
556 }
557
558 #[derive(Debug)]
560 struct TestNode {
561 ticks: usize,
562 max_ticks: usize,
563 result: NodeStatus,
564 }
565
566 impl TestNode {
567 fn new(max_ticks: usize, result: NodeStatus) -> Self {
568 Self {
569 ticks: 0,
570 max_ticks,
571 result,
572 }
573 }
574
575 fn immediate(result: NodeStatus) -> Self {
576 Self::new(0, result)
577 }
578 }
579
580 impl BehaviorNode for TestNode {
581 fn tick(&mut self, _dt: f32) -> NodeStatus {
582 if self.ticks < self.max_ticks {
583 self.ticks += 1;
584 NodeStatus::Running
585 } else {
586 self.result
587 }
588 }
589
590 fn reset(&mut self) {
591 self.ticks = 0;
592 }
593 }
594
595 #[test]
596 fn test_sequence_all_success() {
597 let mut seq = Sequence::new(vec![
598 Box::new(TestNode::immediate(NodeStatus::Success)),
599 Box::new(TestNode::immediate(NodeStatus::Success)),
600 ]);
601
602 assert_eq!(seq.tick(0.016), NodeStatus::Success);
603 }
604
605 #[test]
606 fn test_sequence_with_failure() {
607 let mut seq = Sequence::new(vec![
608 Box::new(TestNode::immediate(NodeStatus::Success)),
609 Box::new(TestNode::immediate(NodeStatus::Failure)),
610 ]);
611
612 assert_eq!(seq.tick(0.016), NodeStatus::Failure);
613 }
614
615 #[test]
616 fn test_sequence_with_running() {
617 let mut seq = Sequence::new(vec![
618 Box::new(TestNode::new(2, NodeStatus::Success)),
619 Box::new(TestNode::immediate(NodeStatus::Success)),
620 ]);
621
622 assert_eq!(seq.tick(0.016), NodeStatus::Running);
623 assert_eq!(seq.tick(0.016), NodeStatus::Running);
624 assert_eq!(seq.tick(0.016), NodeStatus::Success);
625 }
626
627 #[test]
628 fn test_sequence_reset() {
629 let mut seq = Sequence::new(vec![
630 Box::new(TestNode::new(1, NodeStatus::Success)),
631 Box::new(TestNode::immediate(NodeStatus::Success)),
632 ]);
633
634 assert_eq!(seq.tick(0.016), NodeStatus::Running);
635 seq.reset();
636 assert_eq!(seq.tick(0.016), NodeStatus::Running);
637 }
638
639 #[test]
640 fn test_selector_first_success() {
641 let mut sel = Selector::new(vec![
642 Box::new(TestNode::immediate(NodeStatus::Success)),
643 Box::new(TestNode::immediate(NodeStatus::Success)),
644 ]);
645
646 assert_eq!(sel.tick(0.016), NodeStatus::Success);
647 }
648
649 #[test]
650 fn test_selector_fallback() {
651 let mut sel = Selector::new(vec![
652 Box::new(TestNode::immediate(NodeStatus::Failure)),
653 Box::new(TestNode::immediate(NodeStatus::Success)),
654 ]);
655
656 assert_eq!(sel.tick(0.016), NodeStatus::Success);
657 }
658
659 #[test]
660 fn test_selector_all_fail() {
661 let mut sel = Selector::new(vec![
662 Box::new(TestNode::immediate(NodeStatus::Failure)),
663 Box::new(TestNode::immediate(NodeStatus::Failure)),
664 ]);
665
666 assert_eq!(sel.tick(0.016), NodeStatus::Failure);
667 }
668
669 #[test]
670 fn test_selector_with_running() {
671 let mut sel = Selector::new(vec![
672 Box::new(TestNode::new(2, NodeStatus::Failure)),
673 Box::new(TestNode::immediate(NodeStatus::Success)),
674 ]);
675
676 assert_eq!(sel.tick(0.016), NodeStatus::Running);
677 assert_eq!(sel.tick(0.016), NodeStatus::Running);
678 assert_eq!(sel.tick(0.016), NodeStatus::Success);
679 }
680
681 #[test]
682 fn test_selector_reset() {
683 let mut sel = Selector::new(vec![
684 Box::new(TestNode::new(1, NodeStatus::Failure)),
685 Box::new(TestNode::immediate(NodeStatus::Success)),
686 ]);
687
688 assert_eq!(sel.tick(0.016), NodeStatus::Running);
689 sel.reset();
690 assert_eq!(sel.tick(0.016), NodeStatus::Running);
691 }
692
693 #[test]
694 fn test_ai_error_display() {
695 let err1 = AiError::NoPlanFound;
696 assert!(format!("{err1}").contains("No valid plan"));
697
698 let err2 = AiError::PreconditionsNotMet("test".to_string());
699 assert!(format!("{err2}").contains("test"));
700 }
701}