use std::collections::HashMap;
use std::time::{Duration, Instant};
use crate::coordinates::Distance;
#[ derive( Debug, Clone, Copy, PartialEq, Eq ) ]
pub enum BehaviorStatus
{
Success,
Failure,
Running,
}
#[ derive( Debug ) ]
pub struct BehaviorContext
{
pub entity_id: Option<u32>,
pub current_time: Instant,
pub delta_time: Duration,
pub blackboard: HashMap<String, BehaviorValue>,
pub properties: HashMap<String, BehaviorValue>,
}
impl BehaviorContext
{
pub fn new() -> Self
{
Self
{
entity_id: None,
current_time: Instant::now(),
delta_time: Duration::from_secs_f32(1.0 / 60.0), blackboard: HashMap::new(),
properties: HashMap::new(),
}
}
pub fn for_entity( entity_id : u32 ) -> Self
{
let mut context = Self::new();
context.entity_id = Some(entity_id);
context
}
pub fn update(&mut self, delta_time: Duration) {
self.delta_time = delta_time;
self.current_time = Instant::now();
}
pub fn set_blackboard<T: Into<BehaviorValue>>(&mut self, key: &str, value: T) {
self.blackboard.insert(key.to_string(), value.into());
}
pub fn get_blackboard(&self, key: &str) -> Option<&BehaviorValue> {
self.blackboard.get(key)
}
pub fn set_property<T: Into<BehaviorValue>>(&mut self, key: &str, value: T) {
self.properties.insert(key.to_string(), value.into());
}
pub fn get_property(&self, key: &str) -> Option<&BehaviorValue> {
self.properties.get(key)
}
}
impl Default for BehaviorContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BehaviorValue {
Bool(bool),
Int(i32),
Float(f32),
String(String),
Position2D {
x: i32,
y: i32
},
EntityId(u32),
}
impl From<bool> for BehaviorValue {
fn from(value: bool) -> Self { BehaviorValue::Bool(value) }
}
impl From<i32> for BehaviorValue {
fn from(value: i32) -> Self { BehaviorValue::Int(value) }
}
impl From<f32> for BehaviorValue {
fn from(value: f32) -> Self { BehaviorValue::Float(value) }
}
impl From<String> for BehaviorValue {
fn from(value: String) -> Self { BehaviorValue::String(value) }
}
impl From<&str> for BehaviorValue {
fn from(value: &str) -> Self { BehaviorValue::String(value.to_string()) }
}
impl From<u32> for BehaviorValue {
fn from(value: u32) -> Self { BehaviorValue::EntityId(value) }
}
impl From<(i32, i32)> for BehaviorValue {
fn from((x, y): (i32, i32)) -> Self { BehaviorValue::Position2D { x, y } }
}
pub trait BehaviorNode: std::fmt::Debug {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus;
fn reset(&mut self) {}
fn name(&self) -> &str;
}
#[derive(Debug)]
pub struct BehaviorTree {
root: Box<dyn BehaviorNode>,
name: String,
}
impl BehaviorTree {
pub fn new(root: Box<dyn BehaviorNode>, name: String) -> Self {
Self { root, name }
}
pub fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
self.root.execute(context)
}
pub fn reset(&mut self) {
self.root.reset();
}
pub fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct SequenceNode {
children: Vec<Box<dyn BehaviorNode>>,
current_child: usize,
name: String,
}
impl SequenceNode {
pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
Self {
children,
current_child: 0,
name: "Sequence".to_string(),
}
}
pub fn named(children: Vec<Box<dyn BehaviorNode>>, name: String) -> Self {
Self { children, current_child: 0, name }
}
}
impl BehaviorNode for SequenceNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
while self.current_child < self.children.len() {
match self.children[self.current_child].execute(context) {
BehaviorStatus::Success => {
self.current_child += 1;
}
BehaviorStatus::Failure => {
self.reset();
return BehaviorStatus::Failure;
}
BehaviorStatus::Running => {
return BehaviorStatus::Running;
}
}
}
self.reset();
BehaviorStatus::Success
}
fn reset(&mut self) {
self.current_child = 0;
for child in &mut self.children {
child.reset();
}
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct SelectorNode {
children: Vec<Box<dyn BehaviorNode>>,
current_child: usize,
name: String,
}
impl SelectorNode {
pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
Self {
children,
current_child: 0,
name: "Selector".to_string(),
}
}
pub fn named(children: Vec<Box<dyn BehaviorNode>>, name: String) -> Self {
Self { children, current_child: 0, name }
}
}
impl BehaviorNode for SelectorNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
while self.current_child < self.children.len() {
match self.children[self.current_child].execute(context) {
BehaviorStatus::Success => {
self.reset();
return BehaviorStatus::Success;
}
BehaviorStatus::Failure => {
self.current_child += 1;
}
BehaviorStatus::Running => {
return BehaviorStatus::Running;
}
}
}
self.reset();
BehaviorStatus::Failure
}
fn reset(&mut self) {
self.current_child = 0;
for child in &mut self.children {
child.reset();
}
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct ParallelNode {
children: Vec<Box<dyn BehaviorNode>>,
name: String,
}
impl ParallelNode {
pub fn new(children: Vec<Box<dyn BehaviorNode>>) -> Self {
Self {
children,
name: "Parallel".to_string(),
}
}
pub fn named(children: Vec<Box<dyn BehaviorNode>>, name: String) -> Self {
Self { children, name }
}
}
impl BehaviorNode for ParallelNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
let mut running_count = 0;
let mut success_count = 0;
for child in &mut self.children {
match child.execute(context) {
BehaviorStatus::Success => success_count += 1,
BehaviorStatus::Failure => return BehaviorStatus::Failure,
BehaviorStatus::Running => running_count += 1,
}
}
if running_count > 0 {
BehaviorStatus::Running
} else if success_count == self.children.len() {
BehaviorStatus::Success
} else {
BehaviorStatus::Failure
}
}
fn reset(&mut self) {
for child in &mut self.children {
child.reset();
}
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct RepeatNode {
child: Box<dyn BehaviorNode>,
max_repeats: Option<u32>,
current_repeats: u32,
name: String,
}
impl RepeatNode {
pub fn infinite(child: Box<dyn BehaviorNode>) -> Self {
Self {
child,
max_repeats: None,
current_repeats: 0,
name: "Repeat(∞)".to_string(),
}
}
pub fn times(child: Box<dyn BehaviorNode>, count: u32) -> Self {
Self {
child,
max_repeats: Some(count),
current_repeats: 0,
name: format!("Repeat({})", count),
}
}
}
impl BehaviorNode for RepeatNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
loop {
match self.child.execute(context) {
BehaviorStatus::Running => return BehaviorStatus::Running,
BehaviorStatus::Success | BehaviorStatus::Failure => {
self.current_repeats += 1;
self.child.reset();
if let Some(max) = self.max_repeats {
if self.current_repeats >= max {
self.reset();
return BehaviorStatus::Success;
}
}
}
}
}
}
fn reset(&mut self) {
self.current_repeats = 0;
self.child.reset();
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct InvertNode {
child: Box<dyn BehaviorNode>,
name: String,
}
impl InvertNode {
pub fn new(child: Box<dyn BehaviorNode>) -> Self {
Self {
child,
name: "Invert".to_string(),
}
}
}
impl BehaviorNode for InvertNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
match self.child.execute(context) {
BehaviorStatus::Success => BehaviorStatus::Failure,
BehaviorStatus::Failure => BehaviorStatus::Success,
BehaviorStatus::Running => BehaviorStatus::Running,
}
}
fn reset(&mut self) {
self.child.reset();
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct CooldownNode {
child: Box<dyn BehaviorNode>,
cooldown_duration: Duration,
last_execution: Option<Instant>,
name: String,
}
impl CooldownNode {
pub fn new(child: Box<dyn BehaviorNode>, cooldown_duration: Duration) -> Self {
Self {
child,
cooldown_duration,
last_execution: None,
name: format!("Cooldown({:.1}s)", cooldown_duration.as_secs_f32()),
}
}
}
impl BehaviorNode for CooldownNode {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
if let Some(last) = self.last_execution {
if context.current_time.duration_since(last) < self.cooldown_duration {
return BehaviorStatus::Failure;
}
}
let result = self.child.execute(context);
if result != BehaviorStatus::Running {
self.last_execution = Some(context.current_time);
}
result
}
fn reset(&mut self) {
self.child.reset();
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct BlackboardCondition {
key: String,
expected_value: BehaviorValue,
name: String,
}
impl BlackboardCondition {
pub fn new<T: Into<BehaviorValue>>(key: &str, expected_value: T) -> Self {
let expected = expected_value.into();
Self {
key: key.to_string(),
expected_value: expected.clone(),
name: format!("Check({})", key),
}
}
}
impl BehaviorNode for BlackboardCondition {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
if let Some(value) = context.get_blackboard(&self.key) {
if *value == self.expected_value {
BehaviorStatus::Success
} else {
BehaviorStatus::Failure
}
} else {
BehaviorStatus::Failure
}
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct WaitAction {
duration: Duration,
start_time: Option<Instant>,
name: String,
}
impl WaitAction {
pub fn new(seconds: f32) -> Self {
Self {
duration: Duration::from_secs_f32(seconds),
start_time: None,
name: format!("Wait({:.1}s)", seconds),
}
}
}
impl BehaviorNode for WaitAction {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
if self.start_time.is_none() {
self.start_time = Some(context.current_time);
}
if let Some(start) = self.start_time {
if context.current_time.duration_since(start) >= self.duration {
self.reset();
BehaviorStatus::Success
} else {
BehaviorStatus::Running
}
} else {
BehaviorStatus::Running
}
}
fn reset(&mut self) {
self.start_time = None;
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct SetBlackboardAction {
key: String,
value: BehaviorValue,
name: String,
}
impl SetBlackboardAction {
pub fn new<T: Into<BehaviorValue>>(key: &str, value: T) -> Self {
Self {
key: key.to_string(),
value: value.into(),
name: format!("Set({})", key),
}
}
}
impl BehaviorNode for SetBlackboardAction {
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
context.set_blackboard(&self.key, self.value.clone());
BehaviorStatus::Success
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct MoveToAction<C> {
target: C,
tolerance: f32,
name: String,
}
impl<C> MoveToAction<C> {
pub fn new(target: C) -> Self {
Self {
target,
tolerance: 1.0,
name: "MoveTo".to_string(),
}
}
pub fn with_tolerance(mut self, tolerance: f32) -> Self {
self.tolerance = tolerance;
self
}
}
impl<C> BehaviorNode for MoveToAction<C>
where
C: Distance + Clone + std::fmt::Debug,
{
fn execute(&mut self, context: &mut BehaviorContext) -> BehaviorStatus {
context.set_blackboard("movement_target_reached", true);
context.set_blackboard("movement_tolerance", self.tolerance);
let _target_position = &self.target;
BehaviorStatus::Success
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct BehaviorTreeBuilder {
root: Option<Box<dyn BehaviorNode>>,
}
impl BehaviorTreeBuilder {
pub fn new() -> Self {
Self { root: None }
}
pub fn sequence(mut self, children: Vec<Box<dyn BehaviorNode>>) -> Self {
self.root = Some(Box::new(SequenceNode::new(children)));
self
}
pub fn selector(mut self, children: Vec<Box<dyn BehaviorNode>>) -> Self {
self.root = Some(Box::new(SelectorNode::new(children)));
self
}
pub fn parallel(mut self, children: Vec<Box<dyn BehaviorNode>>) -> Self {
self.root = Some(Box::new(ParallelNode::new(children)));
self
}
pub fn root(mut self, root: Box<dyn BehaviorNode>) -> Self {
self.root = Some(root);
self
}
pub fn build(self) -> BehaviorTree {
self.build_named("BehaviorTree".to_string())
}
pub fn build_named(self, name: String) -> BehaviorTree {
let root = self.root.expect("Root node must be set before building");
BehaviorTree::new(root, name)
}
}
impl Default for BehaviorTreeBuilder {
fn default() -> Self {
Self::new()
}
}
pub fn sequence(children: Vec<Box<dyn BehaviorNode>>) -> Box<dyn BehaviorNode> {
Box::new(SequenceNode::new(children))
}
pub fn selector(children: Vec<Box<dyn BehaviorNode>>) -> Box<dyn BehaviorNode> {
Box::new(SelectorNode::new(children))
}
pub fn parallel(children: Vec<Box<dyn BehaviorNode>>) -> Box<dyn BehaviorNode> {
Box::new(ParallelNode::new(children))
}
pub fn repeat(child: Box<dyn BehaviorNode>, count: u32) -> Box<dyn BehaviorNode> {
Box::new(RepeatNode::times(child, count))
}
pub fn repeat_forever(child: Box<dyn BehaviorNode>) -> Box<dyn BehaviorNode> {
Box::new(RepeatNode::infinite(child))
}
pub fn invert(child: Box<dyn BehaviorNode>) -> Box<dyn BehaviorNode> {
Box::new(InvertNode::new(child))
}
pub fn cooldown(child: Box<dyn BehaviorNode>, seconds: f32) -> Box<dyn BehaviorNode> {
Box::new(CooldownNode::new(child, Duration::from_secs_f32(seconds)))
}
pub fn wait(seconds: f32) -> Box<dyn BehaviorNode> {
Box::new(WaitAction::new(seconds))
}
pub fn condition<T: Into<BehaviorValue>>(key: &str, expected: T) -> Box<dyn BehaviorNode> {
Box::new(BlackboardCondition::new(key, expected))
}
pub fn set_blackboard<T: Into<BehaviorValue>>(key: &str, value: T) -> Box<dyn BehaviorNode> {
Box::new(SetBlackboardAction::new(key, value))
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use crate::coordinates::square::{Coordinate as SquareCoord, FourConnected};
#[test]
fn test_behavior_context_creation() {
let context = BehaviorContext::new();
assert!(context.entity_id.is_none());
assert!(context.blackboard.is_empty());
assert!(context.properties.is_empty());
}
#[test]
fn test_behavior_context_blackboard() {
let mut context = BehaviorContext::new();
context.set_blackboard("health", 100);
context.set_blackboard("position", (5, 10));
assert_eq!(context.get_blackboard("health"), Some(&BehaviorValue::Int(100)));
assert_eq!(context.get_blackboard("position"), Some(&BehaviorValue::Position2D { x: 5, y: 10 }));
assert_eq!(context.get_blackboard("missing"), None);
}
#[test]
fn test_sequence_node_success() {
let mut sequence = SequenceNode::new(vec![
Box::new(SetBlackboardAction::new("step1", true)),
Box::new(SetBlackboardAction::new("step2", true)),
]);
let mut context = BehaviorContext::new();
let status = sequence.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success);
assert_eq!(context.get_blackboard("step1"), Some(&BehaviorValue::Bool(true)));
assert_eq!(context.get_blackboard("step2"), Some(&BehaviorValue::Bool(true)));
}
#[test]
fn test_sequence_node_running() {
let mut sequence = SequenceNode::new(vec![
Box::new(SetBlackboardAction::new("step1", true)),
Box::new(WaitAction::new(1.0)), ]);
let mut context = BehaviorContext::new();
let status = sequence.execute(&mut context);
assert_eq!(status, BehaviorStatus::Running);
assert_eq!(context.get_blackboard("step1"), Some(&BehaviorValue::Bool(true)));
}
#[test]
fn test_selector_node() {
let mut selector = SelectorNode::new(vec![
Box::new(BlackboardCondition::new("should_fail", true)), Box::new(SetBlackboardAction::new("executed", true)), ]);
let mut context = BehaviorContext::new();
context.set_blackboard("should_fail", false);
let status = selector.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success);
assert_eq!(context.get_blackboard("executed"), Some(&BehaviorValue::Bool(true)));
}
#[test]
fn test_parallel_node() {
let mut parallel = ParallelNode::new(vec![
Box::new(SetBlackboardAction::new("action1", true)),
Box::new(SetBlackboardAction::new("action2", true)),
]);
let mut context = BehaviorContext::new();
let status = parallel.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success);
assert_eq!(context.get_blackboard("action1"), Some(&BehaviorValue::Bool(true)));
assert_eq!(context.get_blackboard("action2"), Some(&BehaviorValue::Bool(true)));
}
#[test]
fn test_repeat_node() {
let mut repeat = RepeatNode::times(
Box::new(SetBlackboardAction::new("counter", 1)),
3
);
let mut context = BehaviorContext::new();
let status = repeat.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success);
}
#[test]
fn test_invert_node() {
let mut invert = InvertNode::new(
Box::new(BlackboardCondition::new("should_succeed", true))
);
let mut context = BehaviorContext::new();
context.set_blackboard("should_succeed", false);
let status = invert.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success); }
#[test]
fn test_wait_action() {
let mut wait = WaitAction::new(0.1); let mut context = BehaviorContext::new();
let status1 = wait.execute(&mut context);
assert_eq!(status1, BehaviorStatus::Running);
std::thread::sleep(Duration::from_millis(150));
context.update(Duration::from_millis(150));
let status2 = wait.execute(&mut context);
assert_eq!(status2, BehaviorStatus::Success);
}
#[test]
fn test_blackboard_condition() {
let mut condition = BlackboardCondition::new("health_low", true);
let mut context = BehaviorContext::new();
assert_eq!(condition.execute(&mut context), BehaviorStatus::Failure);
context.set_blackboard("health_low", false);
assert_eq!(condition.execute(&mut context), BehaviorStatus::Failure);
context.set_blackboard("health_low", true);
assert_eq!(condition.execute(&mut context), BehaviorStatus::Success);
}
#[test]
fn test_behavior_tree_builder() {
let tree = BehaviorTreeBuilder::new()
.sequence(vec![
Box::new(SetBlackboardAction::new("step1", true)),
Box::new(SetBlackboardAction::new("step2", true)),
])
.build_named("TestTree".to_string());
assert_eq!(tree.name(), "TestTree");
}
#[test]
fn test_convenience_functions() {
let node = sequence(vec![
set_blackboard("init", true),
selector(vec![
condition("enemy_near", true),
wait(1.0),
]),
invert(condition("health_full", false)),
]);
let mut context = BehaviorContext::new();
context.set_blackboard("enemy_near", false);
context.set_blackboard("health_full", false);
assert_eq!(node.name(), "Sequence");
}
#[test]
fn test_move_to_action() {
let target = SquareCoord::<FourConnected>::new(10, 10);
let mut move_action = MoveToAction::new(target).with_tolerance(2.0);
let mut context = BehaviorContext::new();
let status = move_action.execute(&mut context);
assert_eq!(status, BehaviorStatus::Success);
assert_eq!(context.get_blackboard("movement_target_reached"), Some(&BehaviorValue::Bool(true)));
}
#[test]
fn test_cooldown_node() {
let mut cooldown = CooldownNode::new(
Box::new(SetBlackboardAction::new("executed", true)),
Duration::from_millis(100)
);
let mut context = BehaviorContext::new();
let status1 = cooldown.execute(&mut context);
assert_eq!(status1, BehaviorStatus::Success);
let status2 = cooldown.execute(&mut context);
assert_eq!(status2, BehaviorStatus::Failure);
std::thread::sleep(Duration::from_millis(150));
context.update(Duration::from_millis(150));
let status3 = cooldown.execute(&mut context);
assert_eq!(status3, BehaviorStatus::Success);
}
}