autd3_driver/datagram/
group.rs

1use std::{collections::HashMap, fmt::Debug, hash::Hash, time::Duration};
2
3use crate::datagram::*;
4
5use autd3_core::{
6    datagram::DatagramOption,
7    derive::{Inspectable, InspectionResult},
8    gain::BitVec,
9};
10use derive_more::Debug as DeriveDebug;
11use itertools::Itertools;
12
13/// [`Datagram`] that divide the devices into groups by given function and send different data to each group.
14///
15/// If the key is `None`, nothing is done for the devices corresponding to the key.
16///
17/// # Example
18///
19/// ```
20/// use std::collections::HashMap;
21/// # use autd3_driver::datagram::*;
22/// use autd3_driver::datagram::IntoBoxedDatagram;
23///
24/// Group {
25///     key_map: |dev| match dev.idx() {
26///         0 => Some("clear"),
27///         2 => Some("force-fan"),
28///         _ => None,
29///     },
30///     datagram_map: HashMap::from([
31///         ("clear", Clear::default().into_boxed()),
32///         ("force-fan", ForceFan { f: |_| false }.into_boxed()),
33///     ]),
34/// };
35/// ```
36#[derive(Default, DeriveDebug)]
37pub struct Group<K, D, F>
38where
39    K: Hash + Eq + Debug,
40    D: Datagram,
41    F: Fn(&Device) -> Option<K>,
42    D::G: OperationGenerator,
43    AUTDDriverError: From<<D as Datagram>::Error>,
44{
45    /// Mapping function from device to group key.
46    #[debug(ignore)]
47    pub key_map: F,
48    /// Map from group key to [`Datagram`].
49    #[debug(ignore)]
50    pub datagram_map: HashMap<K, D>,
51}
52
53impl<K, D, F> Group<K, D, F>
54where
55    K: Hash + Eq + Debug,
56    D: Datagram,
57    F: Fn(&Device) -> Option<K>,
58    D::G: OperationGenerator,
59    AUTDDriverError: From<<D as Datagram>::Error>,
60{
61    /// Creates a new [`Group`].
62    #[must_use]
63    pub const fn new(key_map: F, datagram_map: HashMap<K, D>) -> Self {
64        Self {
65            key_map,
66            datagram_map,
67        }
68    }
69
70    fn generate_filter(key_map: F, geometry: &Geometry) -> HashMap<K, BitVec> {
71        let num_devices = geometry.iter().len();
72        let mut filters: HashMap<K, BitVec> = HashMap::new();
73        geometry.devices().for_each(|dev| {
74            if let Some(key) = key_map(dev) {
75                if let Some(v) = filters.get_mut(&key) {
76                    v.set(dev.idx(), true);
77                } else {
78                    filters.insert(key, BitVec::from_fn(num_devices, |i| i == dev.idx()));
79                }
80            }
81        });
82        filters
83    }
84}
85
86pub struct GroupOpGenerator<D>
87where
88    D: Datagram,
89    D::G: OperationGenerator,
90{
91    #[allow(clippy::type_complexity)]
92    operations: Vec<
93        Option<(
94            <D::G as OperationGenerator>::O1,
95            <D::G as OperationGenerator>::O2,
96        )>,
97    >,
98}
99
100impl<D> OperationGenerator for GroupOpGenerator<D>
101where
102    D: Datagram,
103    D::G: OperationGenerator,
104{
105    type O1 = <D::G as OperationGenerator>::O1;
106    type O2 = <D::G as OperationGenerator>::O2;
107
108    fn generate(&mut self, dev: &Device) -> Option<(Self::O1, Self::O2)> {
109        self.operations[dev.idx()].take()
110    }
111}
112
113impl<K, D, F> Datagram for Group<K, D, F>
114where
115    K: Hash + Eq + Debug,
116    D: Datagram,
117    F: Fn(&Device) -> Option<K>,
118    D::G: OperationGenerator,
119    AUTDDriverError: From<<D as Datagram>::Error>,
120{
121    type G = GroupOpGenerator<D>;
122    type Error = AUTDDriverError;
123
124    fn operation_generator(self, geometry: &mut Geometry) -> Result<Self::G, Self::Error> {
125        let Self {
126            key_map,
127            mut datagram_map,
128        } = self;
129
130        let filters = Self::generate_filter(key_map, geometry);
131
132        let enable_store = geometry.iter().map(|dev| dev.enable).collect::<Vec<_>>();
133
134        let mut operations: Vec<_> = geometry.iter().map(|_| None).collect();
135
136        filters
137            .into_iter()
138            .try_for_each(|(k, filter)| -> Result<(), AUTDDriverError> {
139                {
140                    let datagram = datagram_map
141                        .remove(&k)
142                        .ok_or(AUTDDriverError::UnknownKey(format!("{:?}", k)))?;
143
144                    // set enable flag for each device
145                    // This is not required for the operation except `Gain`s which cannot be calculated independently for each device, such as `autd3-gain-holo`.
146                    geometry.devices_mut().for_each(|dev| {
147                        dev.enable = filter[dev.idx()];
148                    });
149
150                    let mut generator = datagram
151                        .operation_generator(geometry)
152                        .map_err(AUTDDriverError::from)?;
153
154                    // restore enable flag
155                    geometry
156                        .iter_mut()
157                        .zip(enable_store.iter())
158                        .for_each(|(dev, &enable)| {
159                            dev.enable = enable;
160                        });
161
162                    operations
163                        .iter_mut()
164                        .zip(geometry.iter())
165                        .filter(|(_, dev)| dev.enable && filter[dev.idx()])
166                        .for_each(|(op, dev)| {
167                            tracing::debug!("Generate operation for device {}", dev.idx());
168                            *op = generator.generate(dev);
169                        });
170                    Ok(())
171                }
172            })?;
173
174        if !datagram_map.is_empty() {
175            return Err(AUTDDriverError::UnusedKey(
176                datagram_map.keys().map(|k| format!("{:?}", k)).join(", "),
177            ));
178        }
179
180        Ok(GroupOpGenerator { operations })
181    }
182
183    fn option(&self) -> DatagramOption {
184        self.datagram_map.values().map(|d| d.option()).fold(
185            DatagramOption {
186                timeout: Duration::ZERO,
187                parallel_threshold: usize::MAX,
188            },
189            DatagramOption::merge,
190        )
191    }
192}
193
194impl<K, D, F> Inspectable for Group<K, D, F>
195where
196    K: Hash + Eq + Debug,
197    D: Datagram + Inspectable,
198    F: Fn(&Device) -> Option<K>,
199    D::G: OperationGenerator,
200    AUTDDriverError: From<<D as Datagram>::Error>,
201{
202    type Result = D::Result;
203
204    fn inspect(
205        self,
206        geometry: &mut Geometry,
207    ) -> Result<InspectionResult<Self::Result>, AUTDDriverError> {
208        let Self {
209            key_map,
210            mut datagram_map,
211        } = self;
212
213        let filters = Self::generate_filter(key_map, geometry);
214
215        let enable_store = geometry.iter().map(|dev| dev.enable).collect::<Vec<_>>();
216
217        let results = filters
218            .into_iter()
219            .map(
220                |(k, filter)| -> Result<Vec<Option<Self::Result>>, AUTDDriverError> {
221                    {
222                        let datagram = datagram_map
223                            .remove(&k)
224                            .ok_or(AUTDDriverError::UnknownKey(format!("{:?}", k)))?;
225
226                        geometry.devices_mut().for_each(|dev| {
227                            dev.enable = filter[dev.idx()];
228                        });
229
230                        let r = datagram.inspect(geometry).map_err(AUTDDriverError::from)?;
231
232                        // restore enable flag
233                        geometry
234                            .iter_mut()
235                            .zip(enable_store.iter())
236                            .for_each(|(dev, &enable)| {
237                                dev.enable = enable;
238                            });
239
240                        Ok(r.result)
241                    }
242                },
243            )
244            .collect::<Result<Vec<_>, _>>()?;
245
246        Ok(InspectionResult {
247            result: results
248                .into_iter()
249                .reduce(|a, b| a.into_iter().zip(b).map(|(a, b)| a.or(b)).collect())
250                .unwrap(),
251        })
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::datagram::tests::create_geometry;
258
259    use super::*;
260
261    use std::{
262        convert::Infallible,
263        sync::{Arc, Mutex},
264    };
265
266    pub struct NullOperationGenerator;
267
268    impl OperationGenerator for NullOperationGenerator {
269        type O1 = NullOp;
270        type O2 = NullOp;
271
272        fn generate(&mut self, _: &Device) -> Option<(Self::O1, Self::O2)> {
273            Some((NullOp, NullOp))
274        }
275    }
276
277    #[test]
278    fn group() -> anyhow::Result<()> {
279        #[derive(Debug)]
280        pub struct TestDatagram;
281
282        impl Datagram for TestDatagram {
283            type G = NullOperationGenerator;
284            type Error = Infallible;
285
286            fn operation_generator(self, _: &mut Geometry) -> Result<Self::G, Self::Error> {
287                Ok(NullOperationGenerator)
288            }
289        }
290
291        let mut geometry = create_geometry(3, 1);
292        geometry[0].enable = false;
293
294        let mut g = Group::new(
295            |dev| match dev.idx() {
296                0 => Some(0), // GRCOV_EXCL_LINE
297                1 => Some(1),
298                _ => None,
299            },
300            HashMap::from([(1, TestDatagram)]),
301        )
302        .operation_generator(&mut geometry)?;
303
304        assert!(g.generate(&geometry[0]).is_none());
305        assert!(g.generate(&geometry[1]).is_some());
306        assert!(g.generate(&geometry[2]).is_none());
307
308        Ok(())
309    }
310
311    #[test]
312    fn group_option() -> anyhow::Result<()> {
313        #[derive(Debug)]
314        pub struct TestDatagram {
315            pub option: DatagramOption,
316        }
317
318        impl Datagram for TestDatagram {
319            type G = NullOperationGenerator;
320            type Error = Infallible;
321
322            // GRCOV_EXCL_START
323            fn operation_generator(self, _: &mut Geometry) -> Result<Self::G, Self::Error> {
324                Ok(NullOperationGenerator)
325            }
326            // GRCOV_EXCL_STOP
327
328            fn option(&self) -> DatagramOption {
329                self.option
330            }
331        }
332
333        let option1 = DatagramOption {
334            timeout: Duration::from_secs(1),
335            parallel_threshold: 10,
336        };
337        let option2 = DatagramOption {
338            timeout: Duration::from_secs(2),
339            parallel_threshold: 5,
340        };
341
342        assert_eq!(
343            option1.merge(option2),
344            Group::new(
345                |dev| Some(dev.idx()),
346                HashMap::from([
347                    (0, TestDatagram { option: option1 }),
348                    (1, TestDatagram { option: option2 }),
349                ]),
350            )
351            .option()
352        );
353
354        Ok(())
355    }
356
357    #[test]
358    fn test_group_only_for_enabled() -> anyhow::Result<()> {
359        #[derive(Debug)]
360        pub struct TestDatagram;
361
362        impl Datagram for TestDatagram {
363            type G = NullOperationGenerator;
364            type Error = Infallible;
365
366            fn operation_generator(self, _: &mut Geometry) -> Result<Self::G, Self::Error> {
367                Ok(NullOperationGenerator)
368            }
369        }
370
371        let mut geometry = create_geometry(2, 1);
372
373        geometry[0].enable = false;
374
375        let check = Arc::new(Mutex::new([false; 2]));
376        Group::new(
377            |dev| {
378                check.lock().unwrap()[dev.idx()] = true;
379                Some(())
380            },
381            HashMap::from([((), TestDatagram)]),
382        )
383        .operation_generator(&mut geometry)?;
384
385        assert!(!check.lock().unwrap()[0]);
386        assert!(check.lock().unwrap()[1]);
387
388        Ok(())
389    }
390
391    #[test]
392    fn test_group_only_for_set() -> anyhow::Result<()> {
393        #[derive(Debug)]
394        pub struct TestDatagram {
395            pub test: Arc<Mutex<Vec<bool>>>,
396        }
397
398        impl Datagram for TestDatagram {
399            type G = NullOperationGenerator;
400            type Error = Infallible;
401
402            fn operation_generator(self, geometry: &mut Geometry) -> Result<Self::G, Self::Error> {
403                geometry.iter().for_each(|dev| {
404                    self.test.lock().unwrap()[dev.idx()] = dev.enable;
405                });
406                Ok(NullOperationGenerator)
407            }
408        }
409
410        let mut geometry = create_geometry(3, 1);
411
412        let test = Arc::new(Mutex::new(vec![false; 3]));
413        Group::new(
414            |dev| match dev.idx() {
415                0 | 2 => Some(()),
416                _ => None,
417            },
418            HashMap::from([((), TestDatagram { test: test.clone() })]),
419        )
420        .operation_generator(&mut geometry)?;
421
422        assert!(test.lock().unwrap()[0]);
423        assert!(!test.lock().unwrap()[1]);
424        assert!(test.lock().unwrap()[2]);
425
426        Ok(())
427    }
428
429    #[test]
430    fn unknown_key() -> anyhow::Result<()> {
431        let mut geometry = create_geometry(2, 1);
432
433        assert_eq!(
434            Some(AUTDDriverError::UnknownKey("1".to_owned())),
435            Group::new(|dev| Some(dev.idx()), HashMap::from([(0, Clear {})]))
436                .operation_generator(&mut geometry)
437                .err()
438        );
439
440        Ok(())
441    }
442
443    #[test]
444    fn unused_key() -> anyhow::Result<()> {
445        let mut geometry = create_geometry(2, 1);
446        assert_eq!(
447            Some(AUTDDriverError::UnusedKey("2".to_owned())),
448            Group::new(
449                |dev| Some(dev.idx()),
450                HashMap::from([(0, Clear {}), (1, Clear {}), (2, Clear {})])
451            )
452            .operation_generator(&mut geometry)
453            .err()
454        );
455
456        Ok(())
457    }
458
459    #[test]
460    fn inspect() -> anyhow::Result<()> {
461        #[derive(Debug)]
462        pub struct TestDatagram {}
463
464        impl Datagram for TestDatagram {
465            type G = NullOperationGenerator;
466            type Error = Infallible;
467
468            // GRCOV_EXCL_START
469            fn operation_generator(self, _: &mut Geometry) -> Result<Self::G, Self::Error> {
470                Ok(NullOperationGenerator)
471            }
472            // GRCOV_EXCL_STOP
473        }
474
475        impl Inspectable for TestDatagram {
476            type Result = ();
477
478            fn inspect(
479                self,
480                geometry: &mut Geometry,
481            ) -> Result<InspectionResult<Self::Result>, Self::Error> {
482                Ok(InspectionResult::new(geometry, |_| ()))
483            }
484        }
485
486        let mut geometry = create_geometry(4, 1);
487
488        geometry[3].enable = false;
489
490        let r = Group::new(
491            |dev| match dev.idx() {
492                1 => None,
493                _ => Some(()),
494            },
495            HashMap::from([((), TestDatagram {})]),
496        )
497        .inspect(&mut geometry)?;
498
499        assert!(r[0].is_some());
500        assert!(r[1].is_none());
501        assert!(r[2].is_some());
502        assert!(r[3].is_none());
503
504        Ok(())
505    }
506}