Skip to main content

bamboo_agent/agent/core/composition/
parallel.rs

1use serde::{Deserialize, Serialize};
2
3/// Controls how parallel branches should be waited for
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5#[serde(rename_all = "snake_case")]
6#[derive(Default)]
7pub enum ParallelWait {
8    /// Wait for all branches to complete
9    #[default]
10    All,
11    /// Wait for any branch to complete (first to finish)
12    Any,
13    /// Wait for at least N branches to complete
14    N(usize),
15}
16
17#[cfg(test)]
18mod tests {
19    use super::*;
20
21    #[test]
22    fn test_parallel_wait_serialization() {
23        // Test JSON serialization which uses snake_case
24        let all = ParallelWait::All;
25        let json = serde_json::to_string(&all).unwrap();
26        assert!(json.contains("\"all\""));
27
28        let any = ParallelWait::Any;
29        let json = serde_json::to_string(&any).unwrap();
30        assert!(json.contains("\"any\""));
31
32        let n = ParallelWait::N(3);
33        let json = serde_json::to_string(&n).unwrap();
34        assert!(json.contains("\"N\":3") || json.contains("\"n\":3"));
35    }
36
37    #[test]
38    fn test_parallel_wait_deserialization() {
39        let all: ParallelWait = serde_json::from_str("\"all\"").unwrap();
40        assert_eq!(all, ParallelWait::All);
41
42        let any: ParallelWait = serde_json::from_str("\"any\"").unwrap();
43        assert_eq!(any, ParallelWait::Any);
44
45        // N variant uses object format in JSON
46        let n: ParallelWait = serde_json::from_str("{\"n\": 5}").unwrap();
47        assert_eq!(n, ParallelWait::N(5));
48    }
49
50    #[test]
51    fn test_parallel_wait_roundtrip() {
52        // Test roundtrip serialization
53        let variants = vec![ParallelWait::All, ParallelWait::Any, ParallelWait::N(3)];
54
55        for original in variants {
56            let json = serde_json::to_string(&original).unwrap();
57            let deserialized: ParallelWait = serde_json::from_str(&json).unwrap();
58            assert_eq!(original, deserialized);
59        }
60    }
61}