1use std::collections::HashMap;
22use std::time::Duration;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum Status {
29 Success,
31 Failure,
33 Running,
35}
36
37impl Status {
38 pub fn is_success(self) -> bool { matches!(self, Status::Success) }
39 pub fn is_failure(self) -> bool { matches!(self, Status::Failure) }
40 pub fn is_running(self) -> bool { matches!(self, Status::Running) }
41}
42
43#[derive(Debug, Clone, Default)]
47pub struct Blackboard {
48 floats: HashMap<String, f32>,
49 ints: HashMap<String, i32>,
50 bools: HashMap<String, bool>,
51 strings: HashMap<String, String>,
52 vecs: HashMap<String, [f32; 3]>,
53}
54
55impl Blackboard {
56 pub fn new() -> Self { Self::default() }
57
58 pub fn set_float(&mut self, key: &str, val: f32) { self.floats.insert(key.to_string(), val); }
60 pub fn get_float(&self, key: &str) -> f32 { self.floats.get(key).copied().unwrap_or(0.0) }
61 pub fn get_float_opt(&self, key: &str) -> Option<f32> { self.floats.get(key).copied() }
62
63 pub fn set_int(&mut self, key: &str, val: i32) { self.ints.insert(key.to_string(), val); }
65 pub fn get_int(&self, key: &str) -> i32 { self.ints.get(key).copied().unwrap_or(0) }
66
67 pub fn set_bool(&mut self, key: &str, val: bool) { self.bools.insert(key.to_string(), val); }
69 pub fn get_bool(&self, key: &str) -> bool { self.bools.get(key).copied().unwrap_or(false) }
70
71 pub fn set_string(&mut self, key: &str, val: &str) { self.strings.insert(key.to_string(), val.to_string()); }
73 pub fn get_string(&self, key: &str) -> &str { self.strings.get(key).map(|s| s.as_str()).unwrap_or("") }
74
75 pub fn set_vec3(&mut self, key: &str, val: [f32; 3]) { self.vecs.insert(key.to_string(), val); }
77 pub fn get_vec3(&self, key: &str) -> [f32; 3] { self.vecs.get(key).copied().unwrap_or([0.0; 3]) }
78
79 pub fn has(&self, key: &str) -> bool {
81 self.floats.contains_key(key)
82 || self.ints.contains_key(key)
83 || self.bools.contains_key(key)
84 || self.strings.contains_key(key)
85 || self.vecs.contains_key(key)
86 }
87
88 pub fn clear_key(&mut self, key: &str) {
89 self.floats.remove(key);
90 self.ints.remove(key);
91 self.bools.remove(key);
92 self.strings.remove(key);
93 self.vecs.remove(key);
94 }
95
96 pub fn clear(&mut self) {
97 self.floats.clear();
98 self.ints.clear();
99 self.bools.clear();
100 self.strings.clear();
101 self.vecs.clear();
102 }
103}
104
105pub struct TickContext<'a, T> {
109 pub blackboard: &'a mut Blackboard,
110 pub entity: &'a mut T,
111 pub dt: f32,
112 pub elapsed: f32,
113 pub fired_events: Vec<String>,
114}
115
116impl<'a, T> TickContext<'a, T> {
117 pub fn new(bb: &'a mut Blackboard, entity: &'a mut T, dt: f32, elapsed: f32) -> Self {
118 Self { blackboard: bb, entity, dt, elapsed, fired_events: Vec::new() }
119 }
120
121 pub fn fire_event(&mut self, name: &str) {
122 self.fired_events.push(name.to_string());
123 }
124}
125
126pub trait Node<T>: std::fmt::Debug + Send + Sync {
130 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status;
131 fn abort(&mut self) {}
133 fn name(&self) -> &str { "Node" }
135}
136
137#[derive(Debug)]
142pub struct Sequence<T> {
143 pub name: String,
144 children: Vec<Box<dyn Node<T>>>,
145 current_idx: usize,
146}
147
148impl<T> Sequence<T> {
149 pub fn new(name: &str, children: Vec<Box<dyn Node<T>>>) -> Self {
150 Self { name: name.to_string(), children, current_idx: 0 }
151 }
152}
153
154impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Sequence<T> {
155 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
156 while self.current_idx < self.children.len() {
157 match self.children[self.current_idx].tick(ctx) {
158 Status::Success => self.current_idx += 1,
159 Status::Failure => {
160 self.current_idx = 0;
161 return Status::Failure;
162 }
163 Status::Running => return Status::Running,
164 }
165 }
166 self.current_idx = 0;
167 Status::Success
168 }
169
170 fn abort(&mut self) {
171 for child in &mut self.children {
172 child.abort();
173 }
174 self.current_idx = 0;
175 }
176
177 fn name(&self) -> &str { &self.name }
178}
179
180#[derive(Debug)]
185pub struct Selector<T> {
186 pub name: String,
187 children: Vec<Box<dyn Node<T>>>,
188 current_idx: usize,
189}
190
191impl<T> Selector<T> {
192 pub fn new(name: &str, children: Vec<Box<dyn Node<T>>>) -> Self {
193 Self { name: name.to_string(), children, current_idx: 0 }
194 }
195}
196
197impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Selector<T> {
198 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
199 while self.current_idx < self.children.len() {
200 match self.children[self.current_idx].tick(ctx) {
201 Status::Success => {
202 self.current_idx = 0;
203 return Status::Success;
204 }
205 Status::Failure => self.current_idx += 1,
206 Status::Running => return Status::Running,
207 }
208 }
209 self.current_idx = 0;
210 Status::Failure
211 }
212
213 fn abort(&mut self) {
214 for child in &mut self.children {
215 child.abort();
216 }
217 self.current_idx = 0;
218 }
219
220 fn name(&self) -> &str { &self.name }
221}
222
223#[derive(Debug, Clone, Copy, PartialEq)]
228pub enum ParallelPolicy {
229 SucceedOnN(usize),
231 FailOnN(usize),
233}
234
235#[derive(Debug)]
236pub struct Parallel<T> {
237 pub name: String,
238 children: Vec<Box<dyn Node<T>>>,
239 pub success_policy: ParallelPolicy,
240 pub failure_policy: ParallelPolicy,
241 statuses: Vec<Option<Status>>,
242}
243
244impl<T> Parallel<T> {
245 pub fn new(name: &str, children: Vec<Box<dyn Node<T>>>) -> Self {
246 let n = children.len();
247 Self {
248 name: name.to_string(),
249 children,
250 success_policy: ParallelPolicy::SucceedOnN(1),
251 failure_policy: ParallelPolicy::FailOnN(1),
252 statuses: vec![None; n],
253 }
254 }
255}
256
257impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Parallel<T> {
258 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
259 let mut successes = 0usize;
260 let mut failures = 0usize;
261
262 for (i, child) in self.children.iter_mut().enumerate() {
263 if self.statuses[i].map(|s| s != Status::Running).unwrap_or(false) {
264 match self.statuses[i] {
266 Some(Status::Success) => successes += 1,
267 Some(Status::Failure) => failures += 1,
268 _ => {}
269 }
270 continue;
271 }
272 let s = child.tick(ctx);
273 self.statuses[i] = Some(s);
274 match s {
275 Status::Success => successes += 1,
276 Status::Failure => failures += 1,
277 Status::Running => {}
278 }
279 }
280
281 let succeed_n = match self.success_policy { ParallelPolicy::SucceedOnN(n) => n, ParallelPolicy::FailOnN(_) => usize::MAX };
282 let fail_n = match self.failure_policy { ParallelPolicy::FailOnN(n) => n, ParallelPolicy::SucceedOnN(_) => usize::MAX };
283
284 if successes >= succeed_n {
285 self.statuses.iter_mut().for_each(|s| *s = None);
286 Status::Success
287 } else if failures >= fail_n {
288 self.statuses.iter_mut().for_each(|s| *s = None);
289 Status::Failure
290 } else {
291 Status::Running
292 }
293 }
294
295 fn abort(&mut self) {
296 for child in &mut self.children {
297 child.abort();
298 }
299 self.statuses.iter_mut().for_each(|s| *s = None);
300 }
301
302 fn name(&self) -> &str { &self.name }
303}
304
305#[derive(Debug)]
309pub struct RandomSelector<T> {
310 pub name: String,
311 children: Vec<Box<dyn Node<T>>>,
312 order: Vec<usize>,
313 current_idx: usize,
314 seed: u32,
315}
316
317impl<T> RandomSelector<T> {
318 pub fn new(name: &str, children: Vec<Box<dyn Node<T>>>) -> Self {
319 let n = children.len();
320 let order: Vec<usize> = (0..n).collect();
321 Self { name: name.to_string(), children, order, current_idx: 0, seed: 12345 }
322 }
323
324 fn shuffle(&mut self) {
325 let n = self.order.len();
327 for i in (1..n).rev() {
328 self.seed = self.seed.wrapping_mul(1664525).wrapping_add(1013904223);
329 let j = (self.seed as usize) % (i + 1);
330 self.order.swap(i, j);
331 }
332 }
333}
334
335impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for RandomSelector<T> {
336 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
337 if self.current_idx == 0 {
338 self.shuffle();
339 }
340 while self.current_idx < self.order.len() {
341 let child_idx = self.order[self.current_idx];
342 match self.children[child_idx].tick(ctx) {
343 Status::Success => {
344 self.current_idx = 0;
345 return Status::Success;
346 }
347 Status::Failure => self.current_idx += 1,
348 Status::Running => return Status::Running,
349 }
350 }
351 self.current_idx = 0;
352 Status::Failure
353 }
354
355 fn abort(&mut self) {
356 for child in &mut self.children {
357 child.abort();
358 }
359 self.current_idx = 0;
360 }
361
362 fn name(&self) -> &str { &self.name }
363}
364
365#[derive(Debug)]
369pub struct Inverter<T> { pub child: Box<dyn Node<T>> }
370
371impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Inverter<T> {
372 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
373 match self.child.tick(ctx) {
374 Status::Success => Status::Failure,
375 Status::Failure => Status::Success,
376 Status::Running => Status::Running,
377 }
378 }
379 fn abort(&mut self) { self.child.abort(); }
380 fn name(&self) -> &str { "Inverter" }
381}
382
383#[derive(Debug)]
385pub struct Succeeder<T> { pub child: Box<dyn Node<T>> }
386
387impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Succeeder<T> {
388 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
389 self.child.tick(ctx);
390 Status::Success
391 }
392 fn abort(&mut self) { self.child.abort(); }
393 fn name(&self) -> &str { "Succeeder" }
394}
395
396#[derive(Debug)]
398pub struct Failer<T> { pub child: Box<dyn Node<T>> }
399
400impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Failer<T> {
401 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
402 self.child.tick(ctx);
403 Status::Failure
404 }
405 fn abort(&mut self) { self.child.abort(); }
406 fn name(&self) -> &str { "Failer" }
407}
408
409#[derive(Debug)]
411pub struct Repeater<T> {
412 pub child: Box<dyn Node<T>>,
413 pub max_repeats: Option<u32>,
414 pub stop_on_fail: bool,
415 count: u32,
416}
417
418impl<T> Repeater<T> {
419 pub fn infinite(child: Box<dyn Node<T>>) -> Self {
420 Self { child, max_repeats: None, stop_on_fail: false, count: 0 }
421 }
422
423 pub fn n_times(child: Box<dyn Node<T>>, n: u32) -> Self {
424 Self { child, max_repeats: Some(n), stop_on_fail: false, count: 0 }
425 }
426}
427
428impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Repeater<T> {
429 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
430 loop {
431 if let Some(max) = self.max_repeats {
432 if self.count >= max {
433 self.count = 0;
434 return Status::Success;
435 }
436 }
437 match self.child.tick(ctx) {
438 Status::Running => return Status::Running,
439 Status::Failure if self.stop_on_fail => {
440 self.count = 0;
441 return Status::Failure;
442 }
443 _ => {
444 self.count += 1;
445 self.child.abort();
446 }
447 }
448 }
449 }
450 fn abort(&mut self) { self.child.abort(); self.count = 0; }
451 fn name(&self) -> &str { "Repeater" }
452}
453
454#[derive(Debug)]
456pub struct RetryUntilFail<T> {
457 pub child: Box<dyn Node<T>>,
458 count: u32,
459}
460
461impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for RetryUntilFail<T> {
462 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
463 loop {
464 match self.child.tick(ctx) {
465 Status::Failure => { self.count = 0; return Status::Success; }
466 Status::Running => return Status::Running,
467 Status::Success => { self.count += 1; self.child.abort(); }
468 }
469 }
470 }
471 fn abort(&mut self) { self.child.abort(); self.count = 0; }
472 fn name(&self) -> &str { "RetryUntilFail" }
473}
474
475#[derive(Debug)]
477pub struct TimeLimit<T> {
478 pub child: Box<dyn Node<T>>,
479 pub limit_secs: f32,
480 pub elapsed: f32,
481 pub fail_on_limit: bool,
482}
483
484impl<T> TimeLimit<T> {
485 pub fn new(child: Box<dyn Node<T>>, limit_secs: f32) -> Self {
486 Self { child, limit_secs, elapsed: 0.0, fail_on_limit: true }
487 }
488}
489
490impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for TimeLimit<T> {
491 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
492 self.elapsed += ctx.dt;
493 if self.elapsed >= self.limit_secs {
494 self.child.abort();
495 self.elapsed = 0.0;
496 return if self.fail_on_limit { Status::Failure } else { Status::Success };
497 }
498 let s = self.child.tick(ctx);
499 if s != Status::Running { self.elapsed = 0.0; }
500 s
501 }
502 fn abort(&mut self) { self.child.abort(); self.elapsed = 0.0; }
503 fn name(&self) -> &str { "TimeLimit" }
504}
505
506#[derive(Debug)]
508pub struct Cooldown<T> {
509 pub child: Box<dyn Node<T>>,
510 pub cooldown_secs: f32,
511 remaining: f32,
512 pub fail_during: bool,
513}
514
515impl<T> Cooldown<T> {
516 pub fn new(child: Box<dyn Node<T>>, cooldown_secs: f32) -> Self {
517 Self { child, cooldown_secs, remaining: 0.0, fail_during: true }
518 }
519}
520
521impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Cooldown<T> {
522 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
523 self.remaining = (self.remaining - ctx.dt).max(0.0);
524 if self.remaining > 0.0 {
525 return if self.fail_during { Status::Failure } else { Status::Running };
526 }
527 let s = self.child.tick(ctx);
528 if s == Status::Success {
529 self.remaining = self.cooldown_secs;
530 }
531 s
532 }
533 fn abort(&mut self) { self.child.abort(); }
534 fn name(&self) -> &str { "Cooldown" }
535}
536
537pub struct Guard<T> {
539 pub name: String,
540 pub condition: Box<dyn Fn(&Blackboard) -> bool + Send + Sync>,
541 pub child: Box<dyn Node<T>>,
542}
543
544impl<T: std::fmt::Debug> std::fmt::Debug for Guard<T> {
545 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
546 f.debug_struct("Guard").field("name", &self.name).finish()
547 }
548}
549
550impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Guard<T> {
551 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
552 if (self.condition)(ctx.blackboard) {
553 self.child.tick(ctx)
554 } else {
555 self.child.abort();
556 Status::Failure
557 }
558 }
559 fn abort(&mut self) { self.child.abort(); }
560 fn name(&self) -> &str { &self.name }
561}
562
563pub struct ConditionNode<T> {
567 pub name: String,
568 pub func: Box<dyn Fn(&TickContext<T>) -> bool + Send + Sync>,
569}
570
571impl<T> std::fmt::Debug for ConditionNode<T> {
572 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
573 write!(f, "ConditionNode({})", self.name)
574 }
575}
576
577impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for ConditionNode<T> {
578 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
579 if (self.func)(ctx) { Status::Success } else { Status::Failure }
580 }
581 fn name(&self) -> &str { &self.name }
582}
583
584pub struct ActionNode<T> {
588 pub name: String,
589 pub func: Box<dyn FnMut(&mut TickContext<T>) -> Status + Send + Sync>,
590}
591
592impl<T> std::fmt::Debug for ActionNode<T> {
593 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
594 write!(f, "ActionNode({})", self.name)
595 }
596}
597
598impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for ActionNode<T> {
599 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
600 (self.func)(ctx)
601 }
602 fn name(&self) -> &str { &self.name }
603}
604
605#[derive(Debug)]
609pub struct Wait<T: std::fmt::Debug> {
610 pub duration: f32,
611 elapsed: f32,
612 _phantom: std::marker::PhantomData<T>,
613}
614
615impl<T: std::fmt::Debug> Wait<T> {
616 pub fn new(duration: f32) -> Self {
617 Self { duration, elapsed: 0.0, _phantom: std::marker::PhantomData }
618 }
619}
620
621impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for Wait<T> {
622 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
623 self.elapsed += ctx.dt;
624 if self.elapsed >= self.duration {
625 self.elapsed = 0.0;
626 Status::Success
627 } else {
628 Status::Running
629 }
630 }
631 fn abort(&mut self) { self.elapsed = 0.0; }
632 fn name(&self) -> &str { "Wait" }
633}
634
635#[derive(Debug)]
637pub struct WaitRandom<T: std::fmt::Debug> {
638 pub min_secs: f32,
639 pub max_secs: f32,
640 elapsed: f32,
641 target: f32,
642 seed: u32,
643 _phantom: std::marker::PhantomData<T>,
644}
645
646impl<T: std::fmt::Debug> WaitRandom<T> {
647 pub fn new(min_secs: f32, max_secs: f32) -> Self {
648 let mut s = Self {
649 min_secs, max_secs, elapsed: 0.0, target: 0.0, seed: 54321,
650 _phantom: std::marker::PhantomData,
651 };
652 s.reset_target();
653 s
654 }
655
656 fn reset_target(&mut self) {
657 self.seed = self.seed.wrapping_mul(1664525_u32).wrapping_add(1013904223_u32);
658 let t01 = (self.seed >> 16) as f32 / u16::MAX as f32;
659 self.target = self.min_secs + t01 * (self.max_secs - self.min_secs);
660 }
661}
662
663impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for WaitRandom<T> {
664 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
665 self.elapsed += ctx.dt;
666 if self.elapsed >= self.target {
667 self.elapsed = 0.0;
668 self.reset_target();
669 Status::Success
670 } else {
671 Status::Running
672 }
673 }
674 fn abort(&mut self) { self.elapsed = 0.0; self.reset_target(); }
675 fn name(&self) -> &str { "WaitRandom" }
676}
677
678#[derive(Debug)]
682pub struct SetFloat<T: std::fmt::Debug> {
683 pub key: String,
684 pub value: f32,
685 _phantom: std::marker::PhantomData<T>,
686}
687
688impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for SetFloat<T> {
689 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
690 ctx.blackboard.set_float(&self.key, self.value);
691 Status::Success
692 }
693 fn name(&self) -> &str { "SetFloat" }
694}
695
696#[derive(Debug)]
698pub struct SetBool<T: std::fmt::Debug> {
699 pub key: String,
700 pub value: bool,
701 _phantom: std::marker::PhantomData<T>,
702}
703
704impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for SetBool<T> {
705 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
706 ctx.blackboard.set_bool(&self.key, self.value);
707 Status::Success
708 }
709 fn name(&self) -> &str { "SetBool" }
710}
711
712#[derive(Debug)]
714pub struct BlackboardHas<T: std::fmt::Debug> {
715 pub key: String,
716 _phantom: std::marker::PhantomData<T>,
717}
718
719impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for BlackboardHas<T> {
720 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
721 if ctx.blackboard.has(&self.key) { Status::Success } else { Status::Failure }
722 }
723 fn name(&self) -> &str { "BlackboardHas" }
724}
725
726#[derive(Debug, Clone, Copy, PartialEq)]
728pub enum FloatComparison { Greater, GreaterEq, Less, LessEq, Equal, NotEqual }
729
730#[derive(Debug)]
731pub struct CheckFloat<T: std::fmt::Debug> {
732 pub key: String,
733 pub threshold: f32,
734 pub op: FloatComparison,
735 _phantom: std::marker::PhantomData<T>,
736}
737
738impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for CheckFloat<T> {
739 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
740 let v = ctx.blackboard.get_float(&self.key);
741 let pass = match self.op {
742 FloatComparison::Greater => v > self.threshold,
743 FloatComparison::GreaterEq => v >= self.threshold,
744 FloatComparison::Less => v < self.threshold,
745 FloatComparison::LessEq => v <= self.threshold,
746 FloatComparison::Equal => (v - self.threshold).abs() < 1e-5,
747 FloatComparison::NotEqual => (v - self.threshold).abs() >= 1e-5,
748 };
749 if pass { Status::Success } else { Status::Failure }
750 }
751 fn name(&self) -> &str { "CheckFloat" }
752}
753
754#[derive(Debug)]
756pub struct FireEvent<T: std::fmt::Debug> {
757 pub event_name: String,
758 _phantom: std::marker::PhantomData<T>,
759}
760
761impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for FireEvent<T> {
762 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
763 ctx.fire_event(&self.event_name);
764 Status::Success
765 }
766 fn name(&self) -> &str { "FireEvent" }
767}
768
769pub struct BehaviorTree<T> {
773 pub name: String,
774 root: Box<dyn Node<T>>,
775}
776
777impl<T: std::fmt::Debug + Send + Sync + 'static> BehaviorTree<T> {
778 pub fn new(name: &str, root: Box<dyn Node<T>>) -> Self {
779 Self { name: name.to_string(), root }
780 }
781
782 pub fn tick(&mut self, blackboard: &mut Blackboard, entity: &mut T, dt: f32, elapsed: f32) -> Status {
784 let mut ctx = TickContext::new(blackboard, entity, dt, elapsed);
785 self.root.tick(&mut ctx)
786 }
787
788 pub fn abort(&mut self) {
790 self.root.abort();
791 }
792}
793
794impl<T> std::fmt::Debug for BehaviorTree<T> {
795 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
796 write!(f, "BehaviorTree({})", self.name)
797 }
798}
799
800#[derive(Debug)]
806pub struct SubTree<T> {
807 pub name: String,
808 pub root: Box<dyn Node<T>>,
809}
810
811impl<T: std::fmt::Debug + Send + Sync + 'static> Node<T> for SubTree<T> {
812 fn tick(&mut self, ctx: &mut TickContext<T>) -> Status {
813 self.root.tick(ctx)
814 }
815 fn abort(&mut self) { self.root.abort(); }
816 fn name(&self) -> &str { &self.name }
817}
818
819pub struct TreeRunner<T> {
823 pub tree: BehaviorTree<T>,
824 pub bb: Blackboard,
825 elapsed: f32,
826 last_status: Status,
827 event_queue: Vec<String>,
828}
829
830impl<T: std::fmt::Debug + Send + Sync + 'static> TreeRunner<T> {
831 pub fn new(tree: BehaviorTree<T>) -> Self {
832 Self {
833 tree,
834 bb: Blackboard::new(),
835 elapsed: 0.0,
836 last_status: Status::Running,
837 event_queue: Vec::new(),
838 }
839 }
840
841 pub fn update(&mut self, entity: &mut T, dt: f32) -> Status {
843 self.elapsed += dt;
844 let mut ctx = TickContext::new(&mut self.bb, entity, dt, self.elapsed);
845 let status = self.tree.root.tick(&mut ctx);
846 self.event_queue.extend(ctx.fired_events);
847 self.last_status = status;
848 status
849 }
850
851 pub fn drain_events(&mut self) -> Vec<String> {
853 std::mem::take(&mut self.event_queue)
854 }
855
856 pub fn last_status(&self) -> Status { self.last_status }
857 pub fn elapsed(&self) -> f32 { self.elapsed }
858
859 pub fn reset(&mut self) {
861 self.tree.abort();
862 self.elapsed = 0.0;
863 self.last_status = Status::Running;
864 self.event_queue.clear();
865 }
866}
867
868#[derive(Debug, Default, Clone)]
872pub struct AiEntity {
873 pub position: [f32; 3],
874 pub health: f32,
875 pub is_dead: bool,
876 pub target_id: Option<u32>,
877}
878
879impl AiEntity {
880 pub fn new(health: f32) -> Self {
881 Self { position: [0.0; 3], health, is_dead: false, target_id: None }
882 }
883}
884
885pub struct CommonBehaviors;
887
888impl CommonBehaviors {
889 pub fn combat_ai() -> BehaviorTree<AiEntity> {
891 let is_alive = Box::new(ConditionNode {
893 name: "IsAlive".to_string(),
894 func: Box::new(|ctx: &TickContext<AiEntity>| !ctx.entity.is_dead),
895 });
896
897 let has_target = Box::new(ConditionNode {
899 name: "HasTarget".to_string(),
900 func: Box::new(|ctx: &TickContext<AiEntity>| ctx.entity.target_id.is_some()),
901 });
902
903 let in_attack_range = Box::new(CheckFloat::<AiEntity> {
905 key: "target_dist".to_string(),
906 threshold: 2.0,
907 op: FloatComparison::Less,
908 _phantom: std::marker::PhantomData,
909 });
910
911 let attack = Box::new(ActionNode {
913 name: "Attack".to_string(),
914 func: Box::new(|ctx: &mut TickContext<AiEntity>| {
915 ctx.fire_event("attack");
916 ctx.blackboard.set_float("attack_timer",
917 ctx.blackboard.get_float("attack_timer") + ctx.dt);
918 if ctx.blackboard.get_float("attack_timer") >= 0.5 {
919 ctx.blackboard.set_float("attack_timer", 0.0);
920 Status::Success
921 } else {
922 Status::Running
923 }
924 }),
925 });
926
927 let chase = Box::new(ActionNode {
929 name: "Chase".to_string(),
930 func: Box::new(|ctx: &mut TickContext<AiEntity>| {
931 let target_pos = ctx.blackboard.get_vec3("target_pos");
932 let pos = &mut ctx.entity.position;
933 let dx = target_pos[0] - pos[0];
934 let dz = target_pos[2] - pos[2];
935 let len = (dx * dx + dz * dz).sqrt().max(1e-6);
936 let speed = 3.0 * ctx.dt;
937 pos[0] += dx / len * speed;
938 pos[2] += dz / len * speed;
939 let dist = (dx * dx + dz * dz).sqrt();
940 ctx.blackboard.set_float("target_dist", dist);
941 Status::Running
942 }),
943 });
944
945 let idle = Box::new(Wait::<AiEntity>::new(2.0));
947
948 let attack_seq = Box::new(Sequence::new("attack_seq", vec![in_attack_range, attack]));
950 let combat = Box::new(Selector::new("combat", vec![attack_seq, chase]));
952 let combat_seq = Box::new(Sequence::new("combat_seq", vec![has_target, combat]));
954 let root = Box::new(Selector::new("root", vec![
956 Box::new(Sequence::new("alive_gate", vec![is_alive, Box::new(Selector::new("ai", vec![combat_seq, idle]))])),
957 ]));
958
959 BehaviorTree::new("combat_ai", root)
960 }
961
962 pub fn patrol_ai(waypoints: Vec<[f32; 3]>) -> BehaviorTree<AiEntity> {
964 let wp_count = waypoints.len().max(1);
965 let waypoints_data = waypoints;
966
967 let patrol = Box::new(ActionNode {
968 name: "Patrol".to_string(),
969 func: Box::new(move |ctx: &mut TickContext<AiEntity>| {
970 let wp_idx = ctx.blackboard.get_int("wp_idx") as usize % wp_count;
971 let target = waypoints_data[wp_idx];
972 let pos = &mut ctx.entity.position;
973 let dx = target[0] - pos[0];
974 let dz = target[2] - pos[2];
975 let dist = (dx * dx + dz * dz).sqrt();
976 if dist < 0.3 {
977 ctx.blackboard.set_int("wp_idx", (wp_idx as i32 + 1) % wp_count as i32);
978 return Status::Success;
979 }
980 let speed = 2.0 * ctx.dt;
981 pos[0] += dx / dist * speed;
982 pos[2] += dz / dist * speed;
983 Status::Running
984 }),
985 });
986
987 let root = Box::new(Repeater::infinite(patrol));
988 BehaviorTree::new("patrol", root)
989 }
990
991 pub fn flee_ai(flee_threshold: f32) -> BehaviorTree<AiEntity> {
993 let health_low = Box::new(ConditionNode {
994 name: "HealthLow".to_string(),
995 func: Box::new(move |ctx: &TickContext<AiEntity>| {
996 ctx.entity.health < flee_threshold
997 }),
998 });
999
1000 let flee = Box::new(ActionNode {
1001 name: "Flee".to_string(),
1002 func: Box::new(|ctx: &mut TickContext<AiEntity>| {
1003 let threat = ctx.blackboard.get_vec3("threat_pos");
1004 let pos = &mut ctx.entity.position;
1005 let dx = pos[0] - threat[0];
1006 let dz = pos[2] - threat[2];
1007 let len = (dx * dx + dz * dz).sqrt().max(1e-6);
1008 pos[0] += dx / len * 5.0 * ctx.dt;
1009 pos[2] += dz / len * 5.0 * ctx.dt;
1010 ctx.fire_event("fleeing");
1011 Status::Running
1012 }),
1013 });
1014
1015 let idle = Box::new(Wait::<AiEntity>::new(1.0));
1016 let flee_seq = Box::new(Sequence::new("flee_seq", vec![health_low, flee]));
1017 let root = Box::new(Selector::new("root", vec![flee_seq, idle]));
1018 BehaviorTree::new("flee", root)
1019 }
1020
1021 pub fn guard_post(post: [f32; 3], alert_radius: f32) -> BehaviorTree<AiEntity> {
1023 let intruder_near = Box::new(CheckFloat::<AiEntity> {
1024 key: "target_dist".to_string(),
1025 threshold: alert_radius,
1026 op: FloatComparison::Less,
1027 _phantom: std::marker::PhantomData,
1028 });
1029
1030 let alert = Box::new(ActionNode {
1031 name: "Alert".to_string(),
1032 func: Box::new(|ctx: &mut TickContext<AiEntity>| {
1033 ctx.fire_event("intruder_detected");
1034 ctx.blackboard.set_bool("alerted", true);
1035 Status::Success
1036 }),
1037 });
1038
1039 let return_to_post = Box::new(ActionNode {
1040 name: "ReturnToPost".to_string(),
1041 func: Box::new(move |ctx: &mut TickContext<AiEntity>| {
1042 let pos = &mut ctx.entity.position;
1043 let dx = post[0] - pos[0];
1044 let dz = post[2] - pos[2];
1045 let dist = (dx * dx + dz * dz).sqrt();
1046 if dist < 0.2 { return Status::Success; }
1047 let speed = 2.0 * ctx.dt;
1048 pos[0] += dx / dist.max(1e-6) * speed;
1049 pos[2] += dz / dist.max(1e-6) * speed;
1050 Status::Running
1051 }),
1052 });
1053
1054 let idle_anim = Box::new(ActionNode {
1055 name: "IdleAnim".to_string(),
1056 func: Box::new(|ctx: &mut TickContext<AiEntity>| {
1057 ctx.blackboard.set_string("anim", "idle");
1058 Status::Success
1059 }),
1060 });
1061
1062 let alert_seq = Box::new(Sequence::new("alert_seq", vec![intruder_near, alert]));
1063 let idle_patrol = Box::new(Sequence::new("idle_patrol", vec![return_to_post, idle_anim]));
1064 let root = Box::new(Selector::new("root", vec![alert_seq, idle_patrol]));
1065 BehaviorTree::new("guard_post", root)
1066 }
1067}
1068
1069#[cfg(test)]
1072mod tests {
1073 use super::*;
1074
1075 fn make_entity() -> AiEntity { AiEntity::new(100.0) }
1076
1077 #[test]
1078 fn test_sequence_all_success() {
1079 let s1: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1080 name: "s1".to_string(),
1081 func: Box::new(|_| Status::Success),
1082 });
1083 let s2: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1084 name: "s2".to_string(),
1085 func: Box::new(|_| Status::Success),
1086 });
1087 let mut seq = Sequence::new("test", vec![s1, s2]);
1088 let mut bb = Blackboard::new();
1089 let mut e = make_entity();
1090 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1091 assert_eq!(seq.tick(&mut ctx), Status::Success);
1092 }
1093
1094 #[test]
1095 fn test_sequence_early_failure() {
1096 let f: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1097 name: "f".to_string(),
1098 func: Box::new(|_| Status::Failure),
1099 });
1100 let s: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1101 name: "s".to_string(),
1102 func: Box::new(|_| Status::Success),
1103 });
1104 let mut seq = Sequence::new("test", vec![f, s]);
1105 let mut bb = Blackboard::new();
1106 let mut e = make_entity();
1107 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1108 assert_eq!(seq.tick(&mut ctx), Status::Failure);
1109 }
1110
1111 #[test]
1112 fn test_selector_first_success() {
1113 let f: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1114 name: "f".to_string(),
1115 func: Box::new(|_| Status::Failure),
1116 });
1117 let s: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1118 name: "s".to_string(),
1119 func: Box::new(|_| Status::Success),
1120 });
1121 let mut sel = Selector::new("test", vec![f, s]);
1122 let mut bb = Blackboard::new();
1123 let mut e = make_entity();
1124 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1125 assert_eq!(sel.tick(&mut ctx), Status::Success);
1126 }
1127
1128 #[test]
1129 fn test_inverter() {
1130 let f: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1131 name: "f".to_string(),
1132 func: Box::new(|_| Status::Failure),
1133 });
1134 let mut inv = Inverter { child: f };
1135 let mut bb = Blackboard::new();
1136 let mut e = make_entity();
1137 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1138 assert_eq!(inv.tick(&mut ctx), Status::Success);
1139 }
1140
1141 #[test]
1142 fn test_wait_completes() {
1143 let mut w = Wait::<AiEntity>::new(0.1);
1144 let mut bb = Blackboard::new();
1145 let mut e = make_entity();
1146 let mut ctx = TickContext::new(&mut bb, &mut e, 0.05, 0.05);
1148 assert_eq!(w.tick(&mut ctx), Status::Running);
1149 let mut ctx2 = TickContext::new(&mut bb, &mut e, 0.06, 0.11);
1151 assert_eq!(w.tick(&mut ctx2), Status::Success);
1152 }
1153
1154 #[test]
1155 fn test_repeater_n_times() {
1156 let mut count = 0;
1157 let inner: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1158 name: "inner".to_string(),
1159 func: Box::new(move |_ctx| {
1160 count += 1;
1161 Status::Success
1162 }),
1163 });
1164 let mut rep = Repeater::n_times(inner, 3);
1165 let mut bb = Blackboard::new();
1166 let mut e = make_entity();
1167 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1168 assert_eq!(rep.tick(&mut ctx), Status::Success);
1169 }
1170
1171 #[test]
1172 fn test_cooldown() {
1173 let inner: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1174 name: "inner".to_string(),
1175 func: Box::new(|_| Status::Success),
1176 });
1177 let mut cd = Cooldown::new(inner, 1.0);
1178 let mut bb = Blackboard::new();
1179 let mut e = make_entity();
1180 {
1181 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.016);
1182 assert_eq!(cd.tick(&mut ctx), Status::Success);
1183 }
1184 {
1186 let mut ctx2 = TickContext::new(&mut bb, &mut e, 0.016, 0.032);
1187 assert_eq!(cd.tick(&mut ctx2), Status::Failure);
1188 }
1189 }
1190
1191 #[test]
1192 fn test_blackboard_set_check() {
1193 let mut bb = Blackboard::new();
1194 bb.set_float("hp", 50.0);
1195 assert_eq!(bb.get_float("hp"), 50.0);
1196 bb.set_bool("alive", true);
1197 assert!(bb.get_bool("alive"));
1198 bb.set_string("state", "patrol");
1199 assert_eq!(bb.get_string("state"), "patrol");
1200 }
1201
1202 #[test]
1203 fn test_check_float() {
1204 let mut chk = CheckFloat::<AiEntity> {
1205 key: "hp".to_string(),
1206 threshold: 50.0,
1207 op: FloatComparison::Less,
1208 _phantom: std::marker::PhantomData,
1209 };
1210 let mut bb = Blackboard::new();
1211 let mut e = make_entity();
1212 bb.set_float("hp", 30.0);
1213 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1214 assert_eq!(chk.tick(&mut ctx), Status::Success);
1215 bb.set_float("hp", 80.0);
1216 let mut ctx2 = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1217 assert_eq!(chk.tick(&mut ctx2), Status::Failure);
1218 }
1219
1220 #[test]
1221 fn test_combat_ai_no_crash() {
1222 let mut tree = CommonBehaviors::combat_ai();
1223 let mut bb = Blackboard::new();
1224 let mut e = AiEntity::new(100.0);
1225 e.target_id = Some(1);
1226 bb.set_float("target_dist", 5.0);
1227 bb.set_vec3("target_pos", [10.0, 0.0, 10.0]);
1228 for _ in 0..10 {
1229 tree.tick(&mut bb, &mut e, 0.016, 0.0);
1230 }
1231 }
1232
1233 #[test]
1234 fn test_patrol_ai_moves() {
1235 let waypoints = vec![[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]];
1236 let mut tree = CommonBehaviors::patrol_ai(waypoints);
1237 let mut bb = Blackboard::new();
1238 let mut e = AiEntity::new(100.0);
1239 for _ in 0..100 {
1240 tree.tick(&mut bb, &mut e, 0.016, 0.0);
1241 }
1242 assert!(e.position[0] > 0.1 || e.position[2] != 0.0);
1244 }
1245
1246 #[test]
1247 fn test_fire_event() {
1248 let mut ev = FireEvent::<AiEntity> {
1249 event_name: "test_event".to_string(),
1250 _phantom: std::marker::PhantomData,
1251 };
1252 let mut bb = Blackboard::new();
1253 let mut e = make_entity();
1254 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1255 ev.tick(&mut ctx);
1256 assert!(ctx.fired_events.contains(&"test_event".to_string()));
1257 }
1258
1259 #[test]
1260 fn test_tree_runner_events() {
1261 let ev_node: Box<dyn Node<AiEntity>> = Box::new(FireEvent {
1262 event_name: "tick_event".to_string(),
1263 _phantom: std::marker::PhantomData,
1264 });
1265 let tree = BehaviorTree::new("ev_tree", ev_node);
1266 let mut runner = TreeRunner::new(tree);
1267 let mut e = make_entity();
1268 runner.update(&mut e, 0.016);
1269 let events = runner.drain_events();
1270 assert!(events.contains(&"tick_event".to_string()));
1271 }
1272
1273 #[test]
1274 fn test_parallel_succeed_on_one() {
1275 let s: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1276 name: "s".to_string(),
1277 func: Box::new(|_| Status::Success),
1278 });
1279 let r: Box<dyn Node<AiEntity>> = Box::new(ActionNode {
1280 name: "r".to_string(),
1281 func: Box::new(|_| Status::Running),
1282 });
1283 let mut par = Parallel::new("test", vec![s, r]);
1284 par.success_policy = ParallelPolicy::SucceedOnN(1);
1285 par.failure_policy = ParallelPolicy::FailOnN(2);
1286 let mut bb = Blackboard::new();
1287 let mut e = make_entity();
1288 let mut ctx = TickContext::new(&mut bb, &mut e, 0.016, 0.0);
1289 assert_eq!(par.tick(&mut ctx), Status::Success);
1290 }
1291}