Skip to main content

beetry_node/control/
parallel.rs

1use beetry_core::{Node, NonEmptyNodes, TickStatus};
2use bon::Builder;
3
4use crate::{Indices, control::RunningNodesAborter};
5
6/// Threshold configuration for [`Parallel`].
7#[derive(Debug, Clone, Copy, Builder)]
8#[cfg_attr(feature = "plugin", derive(serde::Deserialize))]
9pub struct ParallelParams {
10    /// Number of children that must succeed for the node to succeed.
11    pub success_count: u16,
12    /// Number of children that must fail for the node to fail.
13    pub failure_count: u16,
14}
15
16/// Ticks all children and resolves once the configured success or failure
17/// threshold is reached.
18pub struct Parallel {
19    nodes: NonEmptyNodes,
20    aborter: RunningNodesAborter,
21    params: ParallelParams,
22}
23
24impl Parallel {
25    #[must_use]
26    pub fn new(nodes: impl Into<NonEmptyNodes>, params: ParallelParams) -> Self {
27        Self {
28            nodes: nodes.into(),
29            aborter: RunningNodesAborter::new(),
30            params,
31        }
32    }
33}
34
35impl Node for Parallel {
36    fn tick(&mut self) -> TickStatus {
37        let aborter: &mut RunningNodesAborter = &mut self.aborter;
38        let mut success_count = 0_u16;
39        let mut failure_count = 0_u16;
40
41        for idx in self.nodes.indices() {
42            let node = &mut self.nodes[idx];
43            match node.tick() {
44                TickStatus::Success => {
45                    success_count += 1;
46                    aborter.untrack(idx);
47                }
48                TickStatus::Running => {
49                    aborter.track(idx);
50                }
51                TickStatus::Failure => {
52                    failure_count += 1;
53                    aborter.untrack(idx);
54                }
55            }
56        }
57
58        if failure_count >= self.params.failure_count {
59            aborter.abort_all(&mut self.nodes);
60            return TickStatus::Failure;
61        }
62
63        if success_count >= self.params.success_count {
64            aborter.abort_all(&mut self.nodes);
65            return TickStatus::Success;
66        }
67
68        if aborter.is_any_tracked() {
69            TickStatus::Running
70        } else {
71            TickStatus::Failure
72        }
73    }
74
75    fn abort(&mut self) {
76        self.aborter.clear();
77        for node in &mut self.nodes {
78            node.abort();
79        }
80    }
81
82    fn reset(&mut self) {
83        self.aborter.clear();
84        for node in &mut self.nodes {
85            node.reset();
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::mock_test::{boxed, mock_returns};
94
95    fn params(success_count: u16, failure_count: u16) -> ParallelParams {
96        ParallelParams::builder()
97            .success_count(success_count)
98            .failure_count(failure_count)
99            .build()
100    }
101
102    #[test]
103    fn succeeds_on_success_threshold() {
104        let nodes = NonEmptyNodes::from([
105            boxed(mock_returns([TickStatus::Success])),
106            boxed(mock_returns([TickStatus::Success])),
107            boxed(mock_returns([TickStatus::Failure])),
108        ]);
109        let mut pl = Parallel::new(nodes, params(2, 3));
110
111        assert_eq!(pl.tick(), TickStatus::Success);
112    }
113
114    #[test]
115    fn fails_on_failure_threshold() {
116        let nodes = NonEmptyNodes::from([
117            boxed(mock_returns([TickStatus::Failure])),
118            boxed(mock_returns([TickStatus::Failure])),
119            boxed(mock_returns([TickStatus::Success])),
120        ]);
121        let mut pl = Parallel::new(nodes, params(3, 2));
122
123        assert_eq!(pl.tick(), TickStatus::Failure);
124    }
125
126    #[test]
127    fn returns_running_when_no_threshold_reached_and_any_running() {
128        let nodes = NonEmptyNodes::from([
129            boxed(mock_returns([TickStatus::Success])),
130            boxed(mock_returns([TickStatus::Running])),
131            boxed(mock_returns([TickStatus::Failure])),
132        ]);
133        let mut pl = Parallel::new(nodes, params(2, 2));
134
135        assert_eq!(pl.tick(), TickStatus::Running);
136    }
137
138    #[test]
139    fn fails_when_all_terminal_and_no_threshold_reached() {
140        let nodes = NonEmptyNodes::from([
141            boxed(mock_returns([TickStatus::Success])),
142            boxed(mock_returns([TickStatus::Failure])),
143            boxed(mock_returns([TickStatus::Failure])),
144        ]);
145        let mut pl = Parallel::new(nodes, params(3, 3));
146
147        assert_eq!(pl.tick(), TickStatus::Failure);
148    }
149
150    #[test]
151    fn aborts_tracked_nodes_when_threshold_reached() {
152        let mut m1 = mock_returns([TickStatus::Running]);
153        let m2 = mock_returns([TickStatus::Success]);
154        let m3 = mock_returns([TickStatus::Success]);
155
156        m1.expect_abort().once().return_const(());
157
158        let nodes = NonEmptyNodes::from([boxed(m1), boxed(m2), boxed(m3)]);
159        let mut pl = Parallel::new(nodes, params(2, 3));
160
161        assert_eq!(pl.tick(), TickStatus::Success);
162    }
163}