Skip to main content

behaviortree_rs/nodes/control/
parallel.rs

1use std::collections::HashSet;
2
3use behaviortree_rs_derive::bt_node;
4
5use crate::{
6    basic_types::NodeStatus,
7    macros::{define_ports, input_port},
8    nodes::{NodeError, NodeResult},
9};
10
11/// The ParallelNode execute all its children
12/// __concurrently__, but not in separate threads!
13///
14/// Even if this may look similar to ReactiveSequence,
15/// this Control Node is the __only__ one that can have
16/// multiple children RUNNING at the same time.
17///
18/// The Node is completed either when the THRESHOLD_SUCCESS
19/// or THRESHOLD_FAILURE number is reached (both configured using ports).
20///
21/// If any of the thresholds is reached, and other children are still running,
22/// they will be halted.
23///
24/// Note that threshold indexes work as in Python:
25/// https://www.i2tutorials.com/what-are-negative-indexes-and-why-are-they-used/
26///
27/// Therefore -1 is equivalent to the number of children.
28#[bt_node(ControlNode)]
29pub struct ParallelNode {
30    #[bt(default = "-1")]
31    success_threshold: i32,
32    #[bt(default = "-1")]
33    failure_threshold: i32,
34    #[bt(default)]
35    completed_list: HashSet<usize>,
36    #[bt(default = "0")]
37    success_count: usize,
38    #[bt(default = "0")]
39    failure_count: usize,
40}
41
42#[bt_node(ControlNode)]
43impl ParallelNode {
44    fn success_threshold(&self, n_children: i32) -> usize {
45        if self.success_threshold < 0 {
46            (n_children + self.success_threshold + 1).max(0) as usize
47        } else {
48            self.success_threshold as usize
49        }
50    }
51
52    fn failure_threshold(&self, n_children: i32) -> usize {
53        if self.failure_threshold < 0 {
54            (n_children + self.failure_threshold + 1).max(0) as usize
55        } else {
56            self.failure_threshold as usize
57        }
58    }
59
60    fn clear(&mut self) {
61        self.completed_list.clear();
62        self.success_count = 0;
63        self.failure_count = 0;
64    }
65
66    async fn tick(&mut self) -> NodeResult {
67        self.success_threshold = node_.config.get_input("success_count").unwrap();
68        self.failure_threshold = node_.config.get_input("failure_count").unwrap();
69
70        let children_count = node_.children.len();
71
72        if children_count < self.success_threshold(node_.children.len() as i32) {
73            return Err(NodeError::NodeStructureError(
74                "Number of children is less than the threshold. Can never succeed.".to_string(),
75            ));
76        }
77
78        if children_count < self.failure_threshold(node_.children.len() as i32) {
79            return Err(NodeError::NodeStructureError(
80                "Number of children is less than the threshold. Can never fail.".to_string(),
81            ));
82        }
83
84        let mut skipped_count = 0;
85
86        for i in 0..children_count {
87            if !self.completed_list.contains(&i) {
88                let child = &mut node_.children[i];
89                match child.execute_tick().await? {
90                    NodeStatus::Skipped => skipped_count += 1,
91                    NodeStatus::Success => {
92                        self.completed_list.insert(i);
93                        self.success_count += 1;
94                    }
95                    NodeStatus::Failure => {
96                        self.completed_list.insert(i);
97                        self.failure_count += 1;
98                    }
99                    NodeStatus::Running => {}
100                    // Throw error, should never happen
101                    NodeStatus::Idle => {}
102                }
103            }
104
105            let required_success_count = self.success_threshold(node_.children.len() as i32);
106
107            // Check if success condition has been met
108            if self.success_count >= required_success_count
109                || (self.success_threshold < 0
110                    && (self.success_count + skipped_count) >= required_success_count)
111            {
112                self.clear();
113                node_.reset_children().await;
114                return Ok(NodeStatus::Success);
115            }
116
117            if (children_count - self.failure_count) < required_success_count
118                || self.failure_count == self.failure_threshold(node_.children.len() as i32)
119            {
120                self.clear();
121                node_.reset_children().await;
122                return Ok(NodeStatus::Failure);
123            }
124        }
125
126        // If all children were skipped, return Skipped
127        // Otherwise return Running
128        match skipped_count == children_count {
129            true => Ok(NodeStatus::Skipped),
130            false => Ok(NodeStatus::Running),
131        }
132    }
133
134    fn ports() -> crate::basic_types::PortsList {
135        define_ports!(
136            input_port!("success_count", -1),
137            input_port!("failure_count", 1)
138        )
139    }
140
141    async fn halt(&mut self) {
142        node_.reset_children().await;
143    }
144}