autd3_driver/datagram/gain/
boxed.rs

1use std::mem::MaybeUninit;
2
3use super::{Gain, GainCalculatorGenerator};
4
5pub use crate::geometry::{Device, Geometry};
6
7use autd3_core::{derive::*, gain::TransducerMask};
8
9pub trait DGainCalculatorGenerator<'a> {
10    #[must_use]
11    fn dyn_generate(&mut self, device: &'a Device) -> Box<dyn GainCalculator<'a>>;
12}
13
14pub struct DynGainCalculatorGenerator<'a> {
15    g: Box<dyn DGainCalculatorGenerator<'a>>,
16}
17
18impl<'a> GainCalculatorGenerator<'a> for DynGainCalculatorGenerator<'a> {
19    type Calculator = Box<dyn GainCalculator<'a>>;
20
21    fn generate(&mut self, device: &'a Device) -> Box<dyn GainCalculator<'a>> {
22        self.g.dyn_generate(device)
23    }
24}
25
26impl<
27    'a,
28    Calculator: GainCalculator<'a> + 'static,
29    G: GainCalculatorGenerator<'a, Calculator = Calculator>,
30> DGainCalculatorGenerator<'a> for G
31{
32    fn dyn_generate(&mut self, device: &'a Device) -> Box<dyn GainCalculator<'a>> {
33        Box::new(GainCalculatorGenerator::generate(self, device))
34    }
35}
36
37/// A dyn-compatible version of [`Gain`].
38trait DGain<'a> {
39    fn dyn_init(
40        &mut self,
41        geometry: &'a Geometry,
42        env: &Environment,
43        filter: &TransducerMask,
44    ) -> Result<Box<dyn DGainCalculatorGenerator<'a>>, GainError>;
45}
46
47impl<'a, G: DGainCalculatorGenerator<'a> + 'static, T: Gain<'a, G = G>> DGain<'a>
48    for MaybeUninit<T>
49{
50    fn dyn_init(
51        &mut self,
52        geometry: &'a Geometry,
53        env: &Environment,
54        filter: &TransducerMask,
55    ) -> Result<Box<dyn DGainCalculatorGenerator<'a>>, GainError> {
56        let mut tmp: MaybeUninit<T> = MaybeUninit::uninit();
57        std::mem::swap(&mut tmp, self);
58        // SAFETY: This function is called only once from `BoxedGain::init`.
59        let g = unsafe { tmp.assume_init() };
60        Ok(Box::new(g.init(geometry, env, filter)?) as _)
61    }
62}
63
64/// Boxed [`Gain`].
65///
66/// Because [`Gain`] traits can have different associated types, it cannot simply be wrapped in a [`Box`] like `Box<dyn Gain>`.
67/// [`BoxedGain`] provides the ability to wrap any [`Gain`] in a common type.
68#[derive(Gain)]
69pub struct BoxedGain<'geo> {
70    g: Box<dyn DGain<'geo>>,
71}
72
73impl<'a> BoxedGain<'a> {
74    /// Creates a new [`BoxedGain`].
75    #[must_use]
76    pub fn new<
77        C: GainCalculator<'a> + 'static,
78        GG: GainCalculatorGenerator<'a, Calculator = C> + 'static,
79        G: Gain<'a, G = GG> + 'static,
80    >(
81        g: G,
82    ) -> Self {
83        Self {
84            g: Box::new(MaybeUninit::new(g)),
85        }
86    }
87}
88
89impl<'a> Gain<'a> for BoxedGain<'a> {
90    type G = DynGainCalculatorGenerator<'a>;
91
92    fn init(
93        self,
94        geometry: &'a Geometry,
95        env: &Environment,
96        filter: &TransducerMask,
97    ) -> Result<Self::G, GainError> {
98        let Self { mut g, .. } = self;
99        Ok(DynGainCalculatorGenerator {
100            g: g.dyn_init(geometry, env, filter)?,
101        })
102    }
103}
104
105#[cfg(test)]
106pub mod tests {
107    use super::*;
108
109    use std::collections::HashMap;
110
111    use crate::datagram::gain::tests::TestGain;
112
113    use autd3_core::{
114        firmware::Drive,
115        geometry::{Point3, UnitQuaternion},
116    };
117
118    const NUM_TRANSDUCERS: usize = 2;
119
120    #[rstest::rstest]
121    #[case::serial(
122        [
123            (0, vec![Drive { phase: Phase(0x01), intensity: Intensity(0x01) }; NUM_TRANSDUCERS]),
124            (1, vec![Drive { phase: Phase(0x02), intensity: Intensity(0x02) }; NUM_TRANSDUCERS])
125        ].into_iter().collect(),
126        2)]
127    #[case::parallel(
128        [
129            (0, vec![Drive { phase: Phase(0x01), intensity: Intensity(0x01) }; NUM_TRANSDUCERS]),
130            (1, vec![Drive { phase: Phase(0x02), intensity: Intensity(0x02) }; NUM_TRANSDUCERS]),
131            (2, vec![Drive { phase: Phase(0x03), intensity: Intensity(0x03) }; NUM_TRANSDUCERS]),
132            (3, vec![Drive { phase: Phase(0x04), intensity: Intensity(0x04) }; NUM_TRANSDUCERS]),
133            (4, vec![Drive { phase: Phase(0x05), intensity: Intensity(0x05) }; NUM_TRANSDUCERS]),
134        ].into_iter().collect(),
135        5)]
136    fn new(
137        #[case] expect: HashMap<usize, Vec<Drive>>,
138        #[case] n: u16,
139    ) -> Result<(), Box<dyn std::error::Error>> {
140        let geometry = Geometry::new(
141            (0..n)
142                .map(|_| {
143                    Device::new(
144                        UnitQuaternion::identity(),
145                        (0..NUM_TRANSDUCERS)
146                            .map(|_| Transducer::new(Point3::origin()))
147                            .collect(),
148                    )
149                })
150                .collect(),
151        );
152
153        let g = BoxedGain::new(TestGain::new(
154            |dev| {
155                move |_| Drive {
156                    phase: Phase(dev.idx() as u8 + 1),
157                    intensity: Intensity(dev.idx() as u8 + 1),
158                }
159            },
160            &geometry,
161        ));
162
163        let mut f = g.init(&geometry, &Environment::new(), &TransducerMask::AllEnabled)?;
164        assert_eq!(
165            expect,
166            geometry
167                .iter()
168                .map(|dev| {
169                    let f = GainCalculatorGenerator::generate(&mut f, dev);
170                    (dev.idx(), dev.iter().map(|tr| f.calc(tr)).collect())
171                })
172                .collect()
173        );
174
175        Ok(())
176    }
177}