1use std::collections::HashSet;
2
3use crate::place::{Place, PlaceRef};
4
5#[derive(Debug, Clone)]
9pub enum Out {
10 Place(PlaceRef),
12 And(Vec<Out>),
14 Xor(Vec<Out>),
16 Timeout { after_ms: u64, child: Box<Out> },
18 ForwardInput { from: PlaceRef, to: PlaceRef },
20}
21
22pub fn out_place<T: 'static>(p: &Place<T>) -> Out {
26 Out::Place(p.as_ref())
27}
28
29pub fn and(children: Vec<Out>) -> Out {
34 assert!(!children.is_empty(), "AND requires at least 1 child");
35 Out::And(children)
36}
37
38pub fn and_places(places: &[&PlaceRef]) -> Out {
40 and(places.iter().map(|p| Out::Place((*p).clone())).collect())
41}
42
43pub fn xor(children: Vec<Out>) -> Out {
48 assert!(children.len() >= 2, "XOR requires at least 2 children");
49 Out::Xor(children)
50}
51
52pub fn xor_places<T: 'static>(places: &[&Place<T>]) -> Out {
54 xor(places.iter().map(|p| out_place(*p)).collect())
55}
56
57pub fn timeout(after_ms: u64, child: Out) -> Out {
62 assert!(after_ms > 0, "Timeout must be positive: {after_ms}");
63 Out::Timeout {
64 after_ms,
65 child: Box::new(child),
66 }
67}
68
69pub fn timeout_place<T: 'static>(after_ms: u64, p: &Place<T>) -> Out {
71 timeout(after_ms, out_place(p))
72}
73
74pub fn forward_input<I: 'static, O: 'static>(from: &Place<I>, to: &Place<O>) -> Out {
76 Out::ForwardInput {
77 from: from.as_ref(),
78 to: to.as_ref(),
79 }
80}
81
82pub fn all_places(out: &Out) -> HashSet<PlaceRef> {
86 let mut result = HashSet::new();
87 collect_places(out, &mut result);
88 result
89}
90
91fn collect_places(out: &Out, result: &mut HashSet<PlaceRef>) {
92 match out {
93 Out::Place(p) => {
94 result.insert(p.clone());
95 }
96 Out::ForwardInput { to, .. } => {
97 result.insert(to.clone());
98 }
99 Out::And(children) | Out::Xor(children) => {
100 for child in children {
101 collect_places(child, result);
102 }
103 }
104 Out::Timeout { child, .. } => {
105 collect_places(child, result);
106 }
107 }
108}
109
110pub fn enumerate_branches(out: &Out) -> Vec<HashSet<PlaceRef>> {
116 match out {
117 Out::Place(p) => {
118 let mut set = HashSet::new();
119 set.insert(p.clone());
120 vec![set]
121 }
122 Out::ForwardInput { to, .. } => {
123 let mut set = HashSet::new();
124 set.insert(to.clone());
125 vec![set]
126 }
127 Out::And(children) => {
128 let mut result = vec![HashSet::new()];
129 for child in children {
130 result = cross_product(&result, &enumerate_branches(child));
131 }
132 result
133 }
134 Out::Xor(children) => {
135 let mut result = Vec::new();
136 for child in children {
137 result.extend(enumerate_branches(child));
138 }
139 result
140 }
141 Out::Timeout { child, .. } => enumerate_branches(child),
142 }
143}
144
145fn cross_product(a: &[HashSet<PlaceRef>], b: &[HashSet<PlaceRef>]) -> Vec<HashSet<PlaceRef>> {
146 let mut result = Vec::new();
147 for set_a in a {
148 for set_b in b {
149 let mut merged = set_a.clone();
150 merged.extend(set_b.iter().cloned());
151 result.push(merged);
152 }
153 }
154 result
155}
156
157pub fn find_timeout(out: &Out) -> Option<(u64, &Out)> {
159 match out {
160 Out::Timeout { after_ms, child } => Some((*after_ms, child)),
161 Out::And(children) | Out::Xor(children) => {
162 for child in children {
163 if let Some(found) = find_timeout(child) {
164 return Some(found);
165 }
166 }
167 None
168 }
169 Out::Place(_) | Out::ForwardInput { .. } => None,
170 }
171}
172
173pub fn find_forward_inputs(out: &Out) -> Vec<(PlaceRef, PlaceRef)> {
175 match out {
176 Out::ForwardInput { from, to } => vec![(from.clone(), to.clone())],
177 Out::And(children) | Out::Xor(children) => {
178 children.iter().flat_map(find_forward_inputs).collect()
179 }
180 Out::Timeout { child, .. } => find_forward_inputs(child),
181 Out::Place(_) => vec![],
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::place::Place;
189
190 #[test]
191 fn out_place_creates_leaf() {
192 let p = Place::<i32>::new("test");
193 let out = out_place(&p);
194 assert!(matches!(out, Out::Place(ref r) if r.name() == "test"));
195 }
196
197 #[test]
198 fn all_places_from_and() {
199 let a = Place::<i32>::new("a");
200 let b = Place::<i32>::new("b");
201 let out = and(vec![out_place(&a), out_place(&b)]);
202 let places = all_places(&out);
203 assert_eq!(places.len(), 2);
204 }
205
206 #[test]
207 fn all_places_from_xor() {
208 let a = Place::<i32>::new("a");
209 let b = Place::<i32>::new("b");
210 let out = xor(vec![out_place(&a), out_place(&b)]);
211 let places = all_places(&out);
212 assert_eq!(places.len(), 2);
213 }
214
215 #[test]
216 fn enumerate_branches_place() {
217 let a = Place::<i32>::new("a");
218 let branches = enumerate_branches(&out_place(&a));
219 assert_eq!(branches.len(), 1);
220 assert_eq!(branches[0].len(), 1);
221 }
222
223 #[test]
224 fn enumerate_branches_and() {
225 let a = Place::<i32>::new("a");
226 let b = Place::<i32>::new("b");
227 let out = and(vec![out_place(&a), out_place(&b)]);
228 let branches = enumerate_branches(&out);
229 assert_eq!(branches.len(), 1);
230 assert_eq!(branches[0].len(), 2);
231 }
232
233 #[test]
234 fn enumerate_branches_xor() {
235 let a = Place::<i32>::new("a");
236 let b = Place::<i32>::new("b");
237 let out = xor(vec![out_place(&a), out_place(&b)]);
238 let branches = enumerate_branches(&out);
239 assert_eq!(branches.len(), 2);
240 }
241
242 #[test]
243 fn enumerate_branches_and_of_xors() {
244 let a = Place::<i32>::new("a");
245 let b = Place::<i32>::new("b");
246 let c = Place::<i32>::new("c");
247 let d = Place::<i32>::new("d");
248 let out = and(vec![
249 xor(vec![out_place(&a), out_place(&b)]),
250 xor(vec![out_place(&c), out_place(&d)]),
251 ]);
252 let branches = enumerate_branches(&out);
253 assert_eq!(branches.len(), 4); }
255
256 #[test]
257 fn find_timeout_present() {
258 let p = Place::<i32>::new("timeout");
259 let out = timeout_place(5000, &p);
260 assert!(find_timeout(&out).is_some());
261 }
262
263 #[test]
264 fn find_timeout_absent() {
265 let p = Place::<i32>::new("a");
266 assert!(find_timeout(&out_place(&p)).is_none());
267 }
268
269 #[test]
270 #[should_panic(expected = "AND requires at least 1 child")]
271 fn and_empty_panics() {
272 and(vec![]);
273 }
274
275 #[test]
276 #[should_panic(expected = "XOR requires at least 2 children")]
277 fn xor_one_panics() {
278 let p = Place::<i32>::new("a");
279 xor(vec![out_place(&p)]);
280 }
281
282 #[test]
283 fn forward_input_spec() {
284 let from = Place::<i32>::new("from");
285 let to = Place::<i32>::new("to");
286 let out = forward_input(&from, &to);
287 let fis = find_forward_inputs(&out);
288 assert_eq!(fis.len(), 1);
289 assert_eq!(fis[0].0.name(), "from");
290 assert_eq!(fis[0].1.name(), "to");
291 }
292}