Skip to main content

bonsai_mdsl/
tree.rs

1use std::collections::HashMap;
2use std::time::Instant;
3
4use crate::context::TreeContext;
5use crate::error::{BonsaiError, Result};
6use crate::nodes::{Guard, Node, NodeResult, NodeState, NodeType};
7use crate::parser::parse_mdsl;
8
9/// A behavior tree that can be executed with a context
10pub struct BehaviorTree {
11    root: Node,
12    _named_trees: HashMap<String, Node>,
13    start_time: Option<Instant>,
14}
15
16impl BehaviorTree {
17    /// Create a new behavior tree from an MDSL string
18    pub fn from_mdsl(mdsl: &str) -> Result<Self> {
19        let root = parse_mdsl(mdsl)?;
20        Ok(Self {
21            root,
22            _named_trees: HashMap::new(),
23            start_time: None,
24        })
25    }
26
27    /// Create a behavior tree from a pre-built node
28    pub fn from_node(root: Node) -> Self {
29        Self {
30            root,
31            _named_trees: HashMap::new(),
32            start_time: None,
33        }
34    }
35
36    /// Execute one tick of the behavior tree
37    pub fn tick(&mut self, context: &TreeContext) -> Result<NodeResult> {
38        self.tick_with_delta(context, 0.0)
39    }
40
41    /// Execute one tick of the behavior tree with delta time
42    pub fn tick_with_delta(
43        &mut self,
44        context: &TreeContext,
45        delta_time: f64,
46    ) -> Result<NodeResult> {
47        if self.start_time.is_none() {
48            self.start_time = Some(Instant::now());
49        }
50
51        let result = execute_node(&mut self.root, context, delta_time)?;
52
53        // Don't auto-reset - let the user decide when to reset
54        // if matches!(result, NodeResult::Success | NodeResult::Failure) {
55        //     self.reset();
56        // }
57
58        Ok(result)
59    }
60
61    /// Reset the tree to its initial state
62    pub fn reset(&mut self) {
63        self.root.reset();
64        self.start_time = None;
65    }
66
67    /// Get the current state of the tree
68    pub fn get_state(&self) -> NodeState {
69        self.root.state
70    }
71
72    /// Check if the tree is currently running
73    pub fn is_running(&self) -> bool {
74        matches!(self.root.state, NodeState::Running)
75    }
76}
77
78/// Execute a single node - made into a free function to avoid borrowing issues
79fn execute_node(node: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
80    // Check guards first
81    if let Some(guard) = &node.guard {
82        if !evaluate_guard(guard, context)? {
83            return Ok(NodeResult::Failure);
84        }
85    }
86
87    // Execute entry callback if transitioning to running
88    if node.state == NodeState::Ready {
89        if let Some(callbacks) = &node.callbacks {
90            if let Some((callback_name, args)) = &callbacks.entry {
91                context.execute_callback(callback_name, args)?;
92            }
93        }
94    }
95
96    // Execute step callback
97    if let Some(callbacks) = &node.callbacks {
98        if let Some((callback_name, args)) = &callbacks.step {
99            context.execute_callback(callback_name, args)?;
100        }
101    }
102
103    // Execute the node based on its type
104    let result = match &mut node.node_type {
105        NodeType::Root { child } => execute_node(child, context, delta_time)?,
106        NodeType::Sequence { children } => execute_sequence(children, context, delta_time)?,
107        NodeType::Selector { children } => execute_selector(children, context, delta_time)?,
108        NodeType::Parallel { children } => execute_parallel(children, context, delta_time)?,
109        NodeType::Race { children } => execute_race(children, context, delta_time)?,
110        NodeType::All { children } => execute_all(children, context, delta_time)?,
111        NodeType::Lotto { children, weights } => {
112            execute_lotto(children, weights.as_ref(), context, delta_time)?
113        }
114        NodeType::While {
115            condition,
116            args,
117            children,
118        } => execute_while(condition, args, children, context, delta_time)?,
119        NodeType::Until {
120            condition,
121            args,
122            children,
123        } => execute_until(condition, args, children, context, delta_time)?,
124        NodeType::WhileAll {
125            condition,
126            args,
127            children,
128        } => execute_while_all(condition, args, children, context, delta_time)?,
129        NodeType::Repeat { child, iterations } => {
130            execute_repeat(child, *iterations, context, delta_time)?
131        }
132        NodeType::Retry { child, attempts } => {
133            execute_retry(child, *attempts, context, delta_time)?
134        }
135        NodeType::Flip { child } => execute_flip(child, context, delta_time)?,
136        NodeType::Succeed { child } => execute_succeed(child, context, delta_time)?,
137        NodeType::Fail { child } => execute_fail(child, context, delta_time)?,
138        NodeType::Action { name, args } => context.execute_action(name, args)?,
139        NodeType::Condition { name, args } => {
140            if context.evaluate_condition(name, args)? {
141                NodeResult::Success
142            } else {
143                NodeResult::Failure
144            }
145        }
146        NodeType::Wait { duration } => {
147            if let Some(duration_ms) = duration {
148                // Add delta time to elapsed time
149                node.elapsed_time += delta_time * 1000.0; // Convert seconds to milliseconds
150
151                if node.elapsed_time >= *duration_ms as f64 {
152                    // Wait duration completed, reset for next time
153                    node.elapsed_time = 0.0;
154                    NodeResult::Success
155                } else {
156                    // Still waiting
157                    NodeResult::Running
158                }
159            } else {
160                // Wait indefinitely
161                NodeResult::Running
162            }
163        }
164        NodeType::Branch {
165            reference: _reference,
166        } => {
167            // Branch nodes reference other named behavior trees
168            // This feature requires implementing a tree registry and would be a significant architectural change
169            return Err(BonsaiError::NodeExecutionError(
170                "Branch nodes are not yet implemented - this is a planned feature for referencing named trees".to_string(),
171            ));
172        }
173    };
174
175    // Update node state
176    node.state = result.into();
177
178    // Execute exit callback if transitioning to a terminal state
179    if matches!(result, NodeResult::Success | NodeResult::Failure) {
180        if let Some(callbacks) = &node.callbacks {
181            if let Some((callback_name, args)) = &callbacks.exit {
182                context.execute_callback(callback_name, args)?;
183            }
184        }
185    }
186
187    Ok(result)
188}
189
190fn execute_sequence(
191    children: &mut [Node],
192    context: &TreeContext,
193    delta_time: f64,
194) -> Result<NodeResult> {
195    let mut processed_wait_in_this_tick = false;
196
197    for child in children {
198        // Skip children that have already succeeded
199        if child.state == NodeState::Success {
200            continue;
201        }
202
203        let is_wait_node = matches!(child.node_type, NodeType::Wait { .. });
204
205        // If we already processed a wait node and this is another wait node, stop
206        if processed_wait_in_this_tick && is_wait_node {
207            return Ok(NodeResult::Running);
208        }
209
210        // Execute this child
211        let result = execute_node(child, context, delta_time)?;
212
213        // Track if we processed a wait node
214        if is_wait_node && matches!(result, NodeResult::Success | NodeResult::Running) {
215            processed_wait_in_this_tick = true;
216        }
217
218        match result {
219            NodeResult::Success => {
220                // For instant actions (non-wait), continue in same tick
221                // For wait nodes, we'll check in the next iteration
222                if !is_wait_node {
223                    continue;
224                }
225                // Wait completed, continue to next child if it's instant
226                continue;
227            }
228            NodeResult::Running => return Ok(NodeResult::Running),
229            NodeResult::Failure => return Ok(NodeResult::Failure),
230            NodeResult::Ready => return Ok(NodeResult::Running),
231        }
232    }
233    // All children have succeeded
234    Ok(NodeResult::Success)
235}
236
237fn execute_selector(
238    children: &mut [Node],
239    context: &TreeContext,
240    delta_time: f64,
241) -> Result<NodeResult> {
242    for child in children {
243        let result = execute_node(child, context, delta_time)?;
244        match result {
245            NodeResult::Success => return Ok(NodeResult::Success),
246            NodeResult::Running => return Ok(NodeResult::Running),
247            NodeResult::Failure => continue,
248            NodeResult::Ready => continue,
249        }
250    }
251    Ok(NodeResult::Failure)
252}
253
254fn execute_parallel(
255    children: &mut [Node],
256    context: &TreeContext,
257    delta_time: f64,
258) -> Result<NodeResult> {
259    let mut has_running = false;
260    let mut has_failure = false;
261
262    for child in children {
263        let result = execute_node(child, context, delta_time)?;
264        match result {
265            NodeResult::Running => has_running = true,
266            NodeResult::Failure => has_failure = true,
267            _ => {}
268        }
269    }
270
271    if has_running {
272        Ok(NodeResult::Running)
273    } else if has_failure {
274        Ok(NodeResult::Failure)
275    } else {
276        Ok(NodeResult::Success)
277    }
278}
279
280fn execute_race(
281    children: &mut [Node],
282    context: &TreeContext,
283    delta_time: f64,
284) -> Result<NodeResult> {
285    for child in children.iter_mut() {
286        let result = execute_node(child, context, delta_time)?;
287        match result {
288            NodeResult::Success => return Ok(NodeResult::Success),
289            NodeResult::Running => continue,
290            NodeResult::Failure => continue,
291            NodeResult::Ready => continue,
292        }
293    }
294
295    // If no child succeeded, check if any are still running
296    for child in children {
297        if matches!(child.state, NodeState::Running) {
298            return Ok(NodeResult::Running);
299        }
300    }
301
302    Ok(NodeResult::Failure)
303}
304
305fn execute_all(
306    children: &mut [Node],
307    context: &TreeContext,
308    delta_time: f64,
309) -> Result<NodeResult> {
310    let mut has_running = false;
311    let mut has_success = false;
312
313    for child in children {
314        let result = execute_node(child, context, delta_time)?;
315        match result {
316            NodeResult::Running => has_running = true,
317            NodeResult::Success => has_success = true,
318            _ => {}
319        }
320    }
321
322    if has_running {
323        Ok(NodeResult::Running)
324    } else if has_success {
325        Ok(NodeResult::Success)
326    } else {
327        Ok(NodeResult::Failure)
328    }
329}
330
331fn execute_lotto(
332    children: &mut [Node],
333    weights: Option<&Vec<u32>>,
334    context: &TreeContext,
335    delta_time: f64,
336) -> Result<NodeResult> {
337    if children.is_empty() {
338        return Ok(NodeResult::Failure);
339    }
340
341    let index = if let Some(weights) = weights {
342        // Weighted random selection
343        if weights.len() != children.len() {
344            // Fallback to uniform random if weights don't match children count
345            (std::time::SystemTime::now()
346                .duration_since(std::time::UNIX_EPOCH)
347                .unwrap()
348                .as_nanos()
349                % children.len() as u128) as usize
350        } else {
351            // Calculate total weight
352            let total_weight: u32 = weights.iter().sum();
353            if total_weight == 0 {
354                // Fallback to uniform random if all weights are zero
355                (std::time::SystemTime::now()
356                    .duration_since(std::time::UNIX_EPOCH)
357                    .unwrap()
358                    .as_nanos()
359                    % children.len() as u128) as usize
360            } else {
361                // Generate random number and find weighted index
362                let mut random_value = (std::time::SystemTime::now()
363                    .duration_since(std::time::UNIX_EPOCH)
364                    .unwrap()
365                    .as_nanos()
366                    % total_weight as u128) as u32;
367
368                let mut selected_index = 0;
369                for (i, &weight) in weights.iter().enumerate() {
370                    if random_value < weight {
371                        selected_index = i;
372                        break;
373                    }
374                    random_value -= weight;
375                }
376                selected_index
377            }
378        }
379    } else {
380        // Uniform random selection
381        (std::time::SystemTime::now()
382            .duration_since(std::time::UNIX_EPOCH)
383            .unwrap()
384            .as_nanos()
385            % children.len() as u128) as usize
386    };
387
388    execute_node(&mut children[index], context, delta_time)
389}
390
391fn execute_repeat(
392    child: &mut Node,
393    iterations: Option<u32>,
394    context: &TreeContext,
395    delta_time: f64,
396) -> Result<NodeResult> {
397    let max_iterations = iterations.unwrap_or(1);
398
399    // Track iterations in elapsed_time field
400    let current_iteration = child.elapsed_time as u32;
401
402    if current_iteration >= max_iterations {
403        return Ok(NodeResult::Success);
404    }
405
406    // For instant actions, execute all iterations in one tick
407    loop {
408        let result = execute_node(child, context, delta_time)?;
409
410        match result {
411            NodeResult::Success => {
412                // Increment iteration count
413                child.elapsed_time = (child.elapsed_time as u32 + 1) as f64;
414
415                // Check if we've completed all iterations
416                if child.elapsed_time as u32 >= max_iterations {
417                    return Ok(NodeResult::Success);
418                } else {
419                    // Reset child for next iteration and continue in same tick
420                    child.state = NodeState::Ready;
421                    continue; // Execute next iteration immediately
422                }
423            }
424            NodeResult::Running => {
425                // If child is still running, we need to wait for next tick
426                return Ok(NodeResult::Running);
427            }
428            NodeResult::Failure => {
429                // Repeat stops on failure
430                return Ok(NodeResult::Failure);
431            }
432            NodeResult::Ready => {
433                // This shouldn't happen, but treat as Running
434                return Ok(NodeResult::Running);
435            }
436        }
437    }
438}
439
440fn execute_retry(
441    child: &mut Node,
442    attempts: Option<u32>,
443    context: &TreeContext,
444    delta_time: f64,
445) -> Result<NodeResult> {
446    let max_attempts = attempts.unwrap_or(1);
447
448    // Track attempts in elapsed_time field (reusing the field for different purpose)
449    let current_attempt = child.elapsed_time as u32;
450
451    if current_attempt >= max_attempts {
452        return Ok(NodeResult::Failure); // Exhausted all attempts
453    }
454
455    // For instant actions, execute all attempts in one tick
456    loop {
457        let result = execute_node(child, context, delta_time)?;
458
459        match result {
460            NodeResult::Success => {
461                // Success on any attempt
462                return Ok(NodeResult::Success);
463            }
464            NodeResult::Running => {
465                // If child is still running, we need to wait for next tick
466                return Ok(NodeResult::Running);
467            }
468            NodeResult::Failure => {
469                // Increment attempt count
470                child.elapsed_time = (child.elapsed_time as u32 + 1) as f64;
471
472                // Check if we have more attempts
473                if child.elapsed_time as u32 >= max_attempts {
474                    return Ok(NodeResult::Failure); // Exhausted all attempts
475                } else {
476                    // Reset child for next attempt and continue in same tick
477                    child.state = NodeState::Ready;
478                    continue; // Try next attempt immediately
479                }
480            }
481            NodeResult::Ready => {
482                // This shouldn't happen, but treat as Running
483                return Ok(NodeResult::Running);
484            }
485        }
486    }
487}
488
489fn execute_flip(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
490    let result = execute_node(child, context, delta_time)?;
491    Ok(match result {
492        NodeResult::Success => NodeResult::Failure,
493        NodeResult::Failure => NodeResult::Success,
494        other => other,
495    })
496}
497
498fn execute_succeed(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
499    let result = execute_node(child, context, delta_time)?;
500    Ok(match result {
501        NodeResult::Running => NodeResult::Running,
502        _ => NodeResult::Success,
503    })
504}
505
506fn execute_fail(child: &mut Node, context: &TreeContext, delta_time: f64) -> Result<NodeResult> {
507    let result = execute_node(child, context, delta_time)?;
508    Ok(match result {
509        NodeResult::Running => NodeResult::Running,
510        _ => NodeResult::Failure,
511    })
512}
513
514fn execute_while(
515    condition: &str,
516    args: &[serde_json::Value],
517    children: &mut [Node],
518    context: &TreeContext,
519    delta_time: f64,
520) -> Result<NodeResult> {
521    // Check condition first
522    if !context.evaluate_condition(condition, args)? {
523        // Condition is false, while loop exits successfully
524        return Ok(NodeResult::Success);
525    }
526
527    // Execute children as a sequence
528    let result = execute_sequence(children, context, delta_time)?;
529    match result {
530        NodeResult::Running => Ok(NodeResult::Running),
531        NodeResult::Failure => Ok(NodeResult::Failure),
532        NodeResult::Success => {
533            // Reset children for next iteration and continue while loop
534            for child in children.iter_mut() {
535                child.reset();
536            }
537            // Continue the while loop by returning Running
538            // The condition will be checked again on the next tick
539            Ok(NodeResult::Running)
540        }
541        NodeResult::Ready => Ok(NodeResult::Running),
542    }
543}
544
545fn execute_until(
546    condition: &str,
547    args: &[serde_json::Value],
548    children: &mut [Node],
549    context: &TreeContext,
550    delta_time: f64,
551) -> Result<NodeResult> {
552    // Check condition first
553    if context.evaluate_condition(condition, args)? {
554        // Condition is true, until loop exits successfully
555        return Ok(NodeResult::Success);
556    }
557
558    // Execute children as a sequence
559    let result = execute_sequence(children, context, delta_time)?;
560    match result {
561        NodeResult::Running => Ok(NodeResult::Running),
562        NodeResult::Failure => Ok(NodeResult::Failure),
563        NodeResult::Success => {
564            // Reset children for next iteration and continue until loop
565            for child in children.iter_mut() {
566                child.reset();
567            }
568            // Continue the until loop by returning Running
569            // The condition will be checked again on the next tick
570            Ok(NodeResult::Running)
571        }
572        NodeResult::Ready => Ok(NodeResult::Running),
573    }
574}
575
576fn execute_while_all(
577    condition: &str,
578    args: &[serde_json::Value],
579    children: &mut [Node],
580    context: &TreeContext,
581    delta_time: f64,
582) -> Result<NodeResult> {
583    // WhileAll semantics: Execute ALL children as a sequence first
584    let result = execute_sequence(children, context, delta_time)?;
585    match result {
586        NodeResult::Running => Ok(NodeResult::Running),
587        NodeResult::Failure => Ok(NodeResult::Failure),
588        NodeResult::Success => {
589            // All children succeeded, now check condition
590            if context.evaluate_condition(condition, args)? {
591                // Condition is true, reset children and continue loop
592                for child in children.iter_mut() {
593                    child.reset();
594                }
595                // Continue the while loop by returning Running
596                Ok(NodeResult::Running)
597            } else {
598                // Condition is false, while loop exits successfully
599                Ok(NodeResult::Success)
600            }
601        }
602        NodeResult::Ready => Ok(NodeResult::Running),
603    }
604}
605
606fn evaluate_guard(guard: &Guard, context: &TreeContext) -> Result<bool> {
607    match guard {
608        Guard::While { condition, args } => context.evaluate_condition(condition, args),
609        Guard::Until { condition, args } => {
610            let result = context.evaluate_condition(condition, args)?;
611            Ok(!result)
612        }
613    }
614}