beetry_node/control/
parallel.rs1use beetry_core::{Node, NonEmptyNodes, TickStatus};
2use bon::Builder;
3
4use crate::{Indices, control::RunningNodesAborter};
5
6#[derive(Debug, Clone, Copy, Builder)]
8#[cfg_attr(feature = "plugin", derive(serde::Deserialize))]
9pub struct ParallelParams {
10 pub success_count: u16,
12 pub failure_count: u16,
14}
15
16pub struct Parallel {
19 nodes: NonEmptyNodes,
20 aborter: RunningNodesAborter,
21 params: ParallelParams,
22}
23
24impl Parallel {
25 #[must_use]
26 pub fn new(nodes: impl Into<NonEmptyNodes>, params: ParallelParams) -> Self {
27 Self {
28 nodes: nodes.into(),
29 aborter: RunningNodesAborter::new(),
30 params,
31 }
32 }
33}
34
35impl Node for Parallel {
36 fn tick(&mut self) -> TickStatus {
37 let aborter: &mut RunningNodesAborter = &mut self.aborter;
38 let mut success_count = 0_u16;
39 let mut failure_count = 0_u16;
40
41 for idx in self.nodes.indices() {
42 let node = &mut self.nodes[idx];
43 match node.tick() {
44 TickStatus::Success => {
45 success_count += 1;
46 aborter.untrack(idx);
47 }
48 TickStatus::Running => {
49 aborter.track(idx);
50 }
51 TickStatus::Failure => {
52 failure_count += 1;
53 aborter.untrack(idx);
54 }
55 }
56 }
57
58 if failure_count >= self.params.failure_count {
59 aborter.abort_all(&mut self.nodes);
60 return TickStatus::Failure;
61 }
62
63 if success_count >= self.params.success_count {
64 aborter.abort_all(&mut self.nodes);
65 return TickStatus::Success;
66 }
67
68 if aborter.is_any_tracked() {
69 TickStatus::Running
70 } else {
71 TickStatus::Failure
72 }
73 }
74
75 fn abort(&mut self) {
76 self.aborter.clear();
77 for node in &mut self.nodes {
78 node.abort();
79 }
80 }
81
82 fn reset(&mut self) {
83 self.aborter.clear();
84 for node in &mut self.nodes {
85 node.reset();
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::mock_test::{boxed, mock_returns};
94
95 fn params(success_count: u16, failure_count: u16) -> ParallelParams {
96 ParallelParams::builder()
97 .success_count(success_count)
98 .failure_count(failure_count)
99 .build()
100 }
101
102 #[test]
103 fn succeeds_on_success_threshold() {
104 let nodes = NonEmptyNodes::from([
105 boxed(mock_returns([TickStatus::Success])),
106 boxed(mock_returns([TickStatus::Success])),
107 boxed(mock_returns([TickStatus::Failure])),
108 ]);
109 let mut pl = Parallel::new(nodes, params(2, 3));
110
111 assert_eq!(pl.tick(), TickStatus::Success);
112 }
113
114 #[test]
115 fn fails_on_failure_threshold() {
116 let nodes = NonEmptyNodes::from([
117 boxed(mock_returns([TickStatus::Failure])),
118 boxed(mock_returns([TickStatus::Failure])),
119 boxed(mock_returns([TickStatus::Success])),
120 ]);
121 let mut pl = Parallel::new(nodes, params(3, 2));
122
123 assert_eq!(pl.tick(), TickStatus::Failure);
124 }
125
126 #[test]
127 fn returns_running_when_no_threshold_reached_and_any_running() {
128 let nodes = NonEmptyNodes::from([
129 boxed(mock_returns([TickStatus::Success])),
130 boxed(mock_returns([TickStatus::Running])),
131 boxed(mock_returns([TickStatus::Failure])),
132 ]);
133 let mut pl = Parallel::new(nodes, params(2, 2));
134
135 assert_eq!(pl.tick(), TickStatus::Running);
136 }
137
138 #[test]
139 fn fails_when_all_terminal_and_no_threshold_reached() {
140 let nodes = NonEmptyNodes::from([
141 boxed(mock_returns([TickStatus::Success])),
142 boxed(mock_returns([TickStatus::Failure])),
143 boxed(mock_returns([TickStatus::Failure])),
144 ]);
145 let mut pl = Parallel::new(nodes, params(3, 3));
146
147 assert_eq!(pl.tick(), TickStatus::Failure);
148 }
149
150 #[test]
151 fn aborts_tracked_nodes_when_threshold_reached() {
152 let mut m1 = mock_returns([TickStatus::Running]);
153 let m2 = mock_returns([TickStatus::Success]);
154 let m3 = mock_returns([TickStatus::Success]);
155
156 m1.expect_abort().once().return_const(());
157
158 let nodes = NonEmptyNodes::from([boxed(m1), boxed(m2), boxed(m3)]);
159 let mut pl = Parallel::new(nodes, params(2, 3));
160
161 assert_eq!(pl.tick(), TickStatus::Success);
162 }
163}