1use std::collections::HashMap;
2use std::time::Instant;
3
4use crate::context::TreeContext;
5use crate::error::{BonsaiError, Result};
6use crate::nodes::{Guard, Node, NodeResult, NodeState, NodeType};
7use crate::parser::parse_mdsl;
8
9pub struct BehaviorTree {
11 root: Node,
12 _named_trees: HashMap<String, Node>,
13 start_time: Option<Instant>,
14}
15
16impl BehaviorTree {
17 pub fn from_mdsl(mdsl: &str) -> Result<Self> {
19 let root = parse_mdsl(mdsl)?;
20 Ok(Self {
21 root,
22 _named_trees: HashMap::new(),
23 start_time: None,
24 })
25 }
26
27 pub fn from_node(root: Node) -> Self {
29 Self {
30 root,
31 _named_trees: HashMap::new(),
32 start_time: None,
33 }
34 }
35
36 pub fn tick(&mut self, context: &TreeContext) -> Result<NodeResult> {
38 self.tick_with_delta(context, 0.0)
39 }
40
41 pub fn tick_with_delta(
43 &mut self,
44 context: &TreeContext,
45 delta_time: f64,
46 ) -> Result<NodeResult> {
47 if self.start_time.is_none() {
48 self.start_time = Some(Instant::now());
49 }
50
51 let result = execute_node(&mut self.root, context, delta_time)?;
52
53 Ok(result)
59 }
60
61 pub fn reset(&mut self) {
63 self.root.reset();
64 self.start_time = None;
65 }
66
67 pub fn get_state(&self) -> NodeState {
69 self.root.state
70 }
71
72 pub fn is_running(&self) -> bool {
74 matches!(self.root.state, NodeState::Running)
75 }
76}
77
78fn execute_node(node: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
80 if let Some(guard) = &node.guard {
82 if !evaluate_guard(guard, context)? {
83 return Ok(NodeResult::Failure);
84 }
85 }
86
87 if node.state == NodeState::Ready {
89 if let Some(callbacks) = &node.callbacks {
90 if let Some((callback_name, args)) = &callbacks.entry {
91 context.execute_callback(callback_name, args)?;
92 }
93 }
94 }
95
96 if let Some(callbacks) = &node.callbacks {
98 if let Some((callback_name, args)) = &callbacks.step {
99 context.execute_callback(callback_name, args)?;
100 }
101 }
102
103 let result = match &mut node.node_type {
105 NodeType::Root { child } => execute_node(child, context, delta_time)?,
106 NodeType::Sequence { children } => execute_sequence(children, context, delta_time)?,
107 NodeType::Selector { children } => execute_selector(children, context, delta_time)?,
108 NodeType::Parallel { children } => execute_parallel(children, context, delta_time)?,
109 NodeType::Race { children } => execute_race(children, context, delta_time)?,
110 NodeType::All { children } => execute_all(children, context, delta_time)?,
111 NodeType::Lotto { children, weights } => {
112 execute_lotto(children, weights.as_ref(), context, delta_time)?
113 }
114 NodeType::While {
115 condition,
116 args,
117 children,
118 } => execute_while(condition, args, children, context, delta_time)?,
119 NodeType::Until {
120 condition,
121 args,
122 children,
123 } => execute_until(condition, args, children, context, delta_time)?,
124 NodeType::WhileAll {
125 condition,
126 args,
127 children,
128 } => execute_while_all(condition, args, children, context, delta_time)?,
129 NodeType::Repeat { child, iterations } => {
130 execute_repeat(child, *iterations, context, delta_time)?
131 }
132 NodeType::Retry { child, attempts } => {
133 execute_retry(child, *attempts, context, delta_time)?
134 }
135 NodeType::Flip { child } => execute_flip(child, context, delta_time)?,
136 NodeType::Succeed { child } => execute_succeed(child, context, delta_time)?,
137 NodeType::Fail { child } => execute_fail(child, context, delta_time)?,
138 NodeType::Action { name, args } => context.execute_action(name, args)?,
139 NodeType::Condition { name, args } => {
140 if context.evaluate_condition(name, args)? {
141 NodeResult::Success
142 } else {
143 NodeResult::Failure
144 }
145 }
146 NodeType::Wait { duration } => {
147 if let Some(duration_ms) = duration {
148 node.elapsed_time += delta_time * 1000.0; if node.elapsed_time >= *duration_ms as f64 {
152 node.elapsed_time = 0.0;
154 NodeResult::Success
155 } else {
156 NodeResult::Running
158 }
159 } else {
160 NodeResult::Running
162 }
163 }
164 NodeType::Branch {
165 reference: _reference,
166 } => {
167 return Err(BonsaiError::NodeExecutionError(
170 "Branch nodes are not yet implemented - this is a planned feature for referencing named trees".to_string(),
171 ));
172 }
173 };
174
175 node.state = result.into();
177
178 if matches!(result, NodeResult::Success | NodeResult::Failure) {
180 if let Some(callbacks) = &node.callbacks {
181 if let Some((callback_name, args)) = &callbacks.exit {
182 context.execute_callback(callback_name, args)?;
183 }
184 }
185 }
186
187 Ok(result)
188}
189
190fn execute_sequence(
191 children: &mut [Node],
192 context: &TreeContext,
193 delta_time: f64,
194) -> Result<NodeResult> {
195 let mut processed_wait_in_this_tick = false;
196
197 for child in children {
198 if child.state == NodeState::Success {
200 continue;
201 }
202
203 let is_wait_node = matches!(child.node_type, NodeType::Wait { .. });
204
205 if processed_wait_in_this_tick && is_wait_node {
207 return Ok(NodeResult::Running);
208 }
209
210 let result = execute_node(child, context, delta_time)?;
212
213 if is_wait_node && matches!(result, NodeResult::Success | NodeResult::Running) {
215 processed_wait_in_this_tick = true;
216 }
217
218 match result {
219 NodeResult::Success => {
220 if !is_wait_node {
223 continue;
224 }
225 continue;
227 }
228 NodeResult::Running => return Ok(NodeResult::Running),
229 NodeResult::Failure => return Ok(NodeResult::Failure),
230 NodeResult::Ready => return Ok(NodeResult::Running),
231 }
232 }
233 Ok(NodeResult::Success)
235}
236
237fn execute_selector(
238 children: &mut [Node],
239 context: &TreeContext,
240 delta_time: f64,
241) -> Result<NodeResult> {
242 for child in children {
243 let result = execute_node(child, context, delta_time)?;
244 match result {
245 NodeResult::Success => return Ok(NodeResult::Success),
246 NodeResult::Running => return Ok(NodeResult::Running),
247 NodeResult::Failure => continue,
248 NodeResult::Ready => continue,
249 }
250 }
251 Ok(NodeResult::Failure)
252}
253
254fn execute_parallel(
255 children: &mut [Node],
256 context: &TreeContext,
257 delta_time: f64,
258) -> Result<NodeResult> {
259 let mut has_running = false;
260 let mut has_failure = false;
261
262 for child in children {
263 let result = execute_node(child, context, delta_time)?;
264 match result {
265 NodeResult::Running => has_running = true,
266 NodeResult::Failure => has_failure = true,
267 _ => {}
268 }
269 }
270
271 if has_running {
272 Ok(NodeResult::Running)
273 } else if has_failure {
274 Ok(NodeResult::Failure)
275 } else {
276 Ok(NodeResult::Success)
277 }
278}
279
280fn execute_race(
281 children: &mut [Node],
282 context: &TreeContext,
283 delta_time: f64,
284) -> Result<NodeResult> {
285 for child in children.iter_mut() {
286 let result = execute_node(child, context, delta_time)?;
287 match result {
288 NodeResult::Success => return Ok(NodeResult::Success),
289 NodeResult::Running => continue,
290 NodeResult::Failure => continue,
291 NodeResult::Ready => continue,
292 }
293 }
294
295 for child in children {
297 if matches!(child.state, NodeState::Running) {
298 return Ok(NodeResult::Running);
299 }
300 }
301
302 Ok(NodeResult::Failure)
303}
304
305fn execute_all(
306 children: &mut [Node],
307 context: &TreeContext,
308 delta_time: f64,
309) -> Result<NodeResult> {
310 let mut has_running = false;
311 let mut has_success = false;
312
313 for child in children {
314 let result = execute_node(child, context, delta_time)?;
315 match result {
316 NodeResult::Running => has_running = true,
317 NodeResult::Success => has_success = true,
318 _ => {}
319 }
320 }
321
322 if has_running {
323 Ok(NodeResult::Running)
324 } else if has_success {
325 Ok(NodeResult::Success)
326 } else {
327 Ok(NodeResult::Failure)
328 }
329}
330
331fn execute_lotto(
332 children: &mut [Node],
333 weights: Option<&Vec<u32>>,
334 context: &TreeContext,
335 delta_time: f64,
336) -> Result<NodeResult> {
337 if children.is_empty() {
338 return Ok(NodeResult::Failure);
339 }
340
341 let index = if let Some(weights) = weights {
342 if weights.len() != children.len() {
344 (std::time::SystemTime::now()
346 .duration_since(std::time::UNIX_EPOCH)
347 .unwrap()
348 .as_nanos()
349 % children.len() as u128) as usize
350 } else {
351 let total_weight: u32 = weights.iter().sum();
353 if total_weight == 0 {
354 (std::time::SystemTime::now()
356 .duration_since(std::time::UNIX_EPOCH)
357 .unwrap()
358 .as_nanos()
359 % children.len() as u128) as usize
360 } else {
361 let mut random_value = (std::time::SystemTime::now()
363 .duration_since(std::time::UNIX_EPOCH)
364 .unwrap()
365 .as_nanos()
366 % total_weight as u128) as u32;
367
368 let mut selected_index = 0;
369 for (i, &weight) in weights.iter().enumerate() {
370 if random_value < weight {
371 selected_index = i;
372 break;
373 }
374 random_value -= weight;
375 }
376 selected_index
377 }
378 }
379 } else {
380 (std::time::SystemTime::now()
382 .duration_since(std::time::UNIX_EPOCH)
383 .unwrap()
384 .as_nanos()
385 % children.len() as u128) as usize
386 };
387
388 execute_node(&mut children[index], context, delta_time)
389}
390
391fn execute_repeat(
392 child: &mut Node,
393 iterations: Option<u32>,
394 context: &TreeContext,
395 delta_time: f64,
396) -> Result<NodeResult> {
397 let max_iterations = iterations.unwrap_or(1);
398
399 let current_iteration = child.elapsed_time as u32;
401
402 if current_iteration >= max_iterations {
403 return Ok(NodeResult::Success);
404 }
405
406 loop {
408 let result = execute_node(child, context, delta_time)?;
409
410 match result {
411 NodeResult::Success => {
412 child.elapsed_time = (child.elapsed_time as u32 + 1) as f64;
414
415 if child.elapsed_time as u32 >= max_iterations {
417 return Ok(NodeResult::Success);
418 } else {
419 child.state = NodeState::Ready;
421 continue; }
423 }
424 NodeResult::Running => {
425 return Ok(NodeResult::Running);
427 }
428 NodeResult::Failure => {
429 return Ok(NodeResult::Failure);
431 }
432 NodeResult::Ready => {
433 return Ok(NodeResult::Running);
435 }
436 }
437 }
438}
439
440fn execute_retry(
441 child: &mut Node,
442 attempts: Option<u32>,
443 context: &TreeContext,
444 delta_time: f64,
445) -> Result<NodeResult> {
446 let max_attempts = attempts.unwrap_or(1);
447
448 let current_attempt = child.elapsed_time as u32;
450
451 if current_attempt >= max_attempts {
452 return Ok(NodeResult::Failure); }
454
455 loop {
457 let result = execute_node(child, context, delta_time)?;
458
459 match result {
460 NodeResult::Success => {
461 return Ok(NodeResult::Success);
463 }
464 NodeResult::Running => {
465 return Ok(NodeResult::Running);
467 }
468 NodeResult::Failure => {
469 child.elapsed_time = (child.elapsed_time as u32 + 1) as f64;
471
472 if child.elapsed_time as u32 >= max_attempts {
474 return Ok(NodeResult::Failure); } else {
476 child.state = NodeState::Ready;
478 continue; }
480 }
481 NodeResult::Ready => {
482 return Ok(NodeResult::Running);
484 }
485 }
486 }
487}
488
489fn execute_flip(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
490 let result = execute_node(child, context, delta_time)?;
491 Ok(match result {
492 NodeResult::Success => NodeResult::Failure,
493 NodeResult::Failure => NodeResult::Success,
494 other => other,
495 })
496}
497
498fn execute_succeed(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
499 let result = execute_node(child, context, delta_time)?;
500 Ok(match result {
501 NodeResult::Running => NodeResult::Running,
502 _ => NodeResult::Success,
503 })
504}
505
506fn execute_fail(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
507 let result = execute_node(child, context, delta_time)?;
508 Ok(match result {
509 NodeResult::Running => NodeResult::Running,
510 _ => NodeResult::Failure,
511 })
512}
513
514fn execute_while(
515 condition: &str,
516 args: &[serde_json::Value],
517 children: &mut [Node],
518 context: &TreeContext,
519 delta_time: f64,
520) -> Result<NodeResult> {
521 if !context.evaluate_condition(condition, args)? {
523 return Ok(NodeResult::Success);
525 }
526
527 let result = execute_sequence(children, context, delta_time)?;
529 match result {
530 NodeResult::Running => Ok(NodeResult::Running),
531 NodeResult::Failure => Ok(NodeResult::Failure),
532 NodeResult::Success => {
533 for child in children.iter_mut() {
535 child.reset();
536 }
537 Ok(NodeResult::Running)
540 }
541 NodeResult::Ready => Ok(NodeResult::Running),
542 }
543}
544
545fn execute_until(
546 condition: &str,
547 args: &[serde_json::Value],
548 children: &mut [Node],
549 context: &TreeContext,
550 delta_time: f64,
551) -> Result<NodeResult> {
552 if context.evaluate_condition(condition, args)? {
554 return Ok(NodeResult::Success);
556 }
557
558 let result = execute_sequence(children, context, delta_time)?;
560 match result {
561 NodeResult::Running => Ok(NodeResult::Running),
562 NodeResult::Failure => Ok(NodeResult::Failure),
563 NodeResult::Success => {
564 for child in children.iter_mut() {
566 child.reset();
567 }
568 Ok(NodeResult::Running)
571 }
572 NodeResult::Ready => Ok(NodeResult::Running),
573 }
574}
575
576fn execute_while_all(
577 condition: &str,
578 args: &[serde_json::Value],
579 children: &mut [Node],
580 context: &TreeContext,
581 delta_time: f64,
582) -> Result<NodeResult> {
583 let result = execute_sequence(children, context, delta_time)?;
585 match result {
586 NodeResult::Running => Ok(NodeResult::Running),
587 NodeResult::Failure => Ok(NodeResult::Failure),
588 NodeResult::Success => {
589 if context.evaluate_condition(condition, args)? {
591 for child in children.iter_mut() {
593 child.reset();
594 }
595 Ok(NodeResult::Running)
597 } else {
598 Ok(NodeResult::Success)
600 }
601 }
602 NodeResult::Ready => Ok(NodeResult::Running),
603 }
604}
605
606fn evaluate_guard(guard: &Guard, context: &TreeContext) -> Result<bool> {
607 match guard {
608 Guard::While { condition, args } => context.evaluate_condition(condition, args),
609 Guard::Until { condition, args } => {
610 let result = context.evaluate_condition(condition, args)?;
611 Ok(!result)
612 }
613 }
614}