Skip to main content

behaviortree_rs/nodes/control/
parallel_all.rs

1use std::collections::HashSet;
2
3use behaviortree_rs_derive::bt_node;
4
5use crate::{
6    basic_types::NodeStatus,
7    macros::{define_ports, input_port},
8    nodes::{NodeError, NodeResult},
9};
10
11/// The ParallelAllNode execute all its children
12/// __concurrently__, but not in separate threads!
13///
14/// It differs in the way ParallelNode works because the latter may stop
15/// and halt other children if a certain number of SUCCESS/FAILURES is reached,
16/// whilst this one will always complete the execution of ALL its children.
17///
18/// Note that threshold indexes work as in Python:
19/// https://www.i2tutorials.com/what-are-negative-indexes-and-why-are-they-used/
20///
21/// Therefore -1 is equivalent to the number of children.
22#[bt_node(ControlNode)]
23pub struct ParallelAllNode {
24    #[bt(default = "-1")]
25    failure_threshold: i32,
26    #[bt(default)]
27    completed_list: HashSet<usize>,
28    #[bt(default = "0")]
29    failure_count: usize,
30}
31
32#[bt_node(ControlNode)]
33impl ParallelAllNode {
34    fn failure_threshold(&self, n_children: i32) -> usize {
35        if self.failure_threshold < 0 {
36            (n_children + self.failure_threshold + 1).max(0) as usize
37        } else {
38            self.failure_threshold as usize
39        }
40    }
41
42    async fn tick(&mut self) -> NodeResult {
43        self.failure_threshold = node_.config.get_input("max_failures")?;
44
45        let children_count = node_.children.len();
46
47        if (children_count as i32) < self.failure_threshold {
48            return Err(NodeError::NodeStructureError(
49                "Number of children is less than the threshold. Can never fail.".to_string(),
50            ));
51        }
52
53        let mut skipped_count = 0;
54
55        for i in 0..children_count {
56            // Skip completed node
57            if self.completed_list.contains(&i) {
58                continue;
59            }
60
61            let status = node_.children[i].execute_tick().await?;
62            match status {
63                NodeStatus::Success => {
64                    self.completed_list.insert(i);
65                }
66                NodeStatus::Failure => {
67                    self.completed_list.insert(i);
68                    self.failure_count += 1;
69                }
70                NodeStatus::Skipped => skipped_count += 1,
71                NodeStatus::Running => {}
72                // Throw error, should never happen
73                NodeStatus::Idle => {
74                    return Err(NodeError::StatusError(
75                        "ParallelAllNode".to_string(),
76                        "Idle".to_string(),
77                    ))
78                }
79            }
80        }
81
82        if skipped_count == children_count {
83            return Ok(NodeStatus::Skipped);
84        }
85
86        if skipped_count + self.completed_list.len() >= children_count {
87            // Done!
88            node_.reset_children().await;
89            self.completed_list.clear();
90
91            let status =
92                if self.failure_count >= self.failure_threshold(node_.children.len() as i32) {
93                    NodeStatus::Failure
94                } else {
95                    NodeStatus::Success
96                };
97
98            // Reset failure_count after using it
99            self.failure_count = 0;
100
101            return Ok(status);
102        }
103
104        Ok(NodeStatus::Running)
105    }
106
107    fn ports() -> crate::basic_types::PortsList {
108        define_ports!(input_port!("max_failures", 1))
109    }
110
111    async fn halt(&mut self) {
112        node_.reset_children().await;
113    }
114}