autd3_driver/datagram/gain/
boxed.rs

1use std::{collections::HashMap, mem::MaybeUninit};
2
3use super::{Gain, GainCalculatorGenerator};
4
5pub use crate::geometry::{Device, Geometry};
6
7use autd3_core::derive::*;
8
9#[cfg(not(feature = "lightweight"))]
10pub trait DGainCalculatorGenerator {
11    #[must_use]
12    fn dyn_generate(&mut self, device: &Device) -> Box<dyn GainCalculator>;
13}
14#[cfg(feature = "lightweight")]
15pub trait DGainCalculatorGenerator: Send + Sync {
16    #[must_use]
17    fn dyn_generate(&mut self, device: &Device) -> Box<dyn GainCalculator>;
18}
19
20pub struct DynGainCalculatorGenerator {
21    g: Box<dyn DGainCalculatorGenerator>,
22}
23
24impl GainCalculatorGenerator for DynGainCalculatorGenerator {
25    type Calculator = Box<dyn GainCalculator>;
26
27    fn generate(&mut self, device: &Device) -> Box<dyn GainCalculator> {
28        self.g.dyn_generate(device)
29    }
30}
31
32impl<
33    Calculator: GainCalculator + 'static,
34    #[cfg(not(feature = "lightweight"))] G: GainCalculatorGenerator<Calculator = Calculator>,
35    #[cfg(feature = "lightweight")] G: GainCalculatorGenerator<Calculator = Calculator> + Send + Sync,
36> DGainCalculatorGenerator for G
37{
38    fn dyn_generate(&mut self, device: &Device) -> Box<dyn GainCalculator> {
39        Box::new(GainCalculatorGenerator::generate(self, device))
40    }
41}
42
43/// A dyn-compatible version of [`Gain`].
44trait DGain {
45    fn dyn_init(
46        &mut self,
47        geometry: &Geometry,
48        filter: Option<&HashMap<usize, BitVec>>,
49        parallel: bool,
50    ) -> Result<Box<dyn DGainCalculatorGenerator>, GainError>;
51    fn dyn_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
52}
53
54impl<
55    G: DGainCalculatorGenerator + 'static,
56    #[cfg(not(feature = "lightweight"))] T: Gain<G = G>,
57    #[cfg(feature = "lightweight")] T: Gain<G = G> + Send + Sync,
58> DGain for MaybeUninit<T>
59{
60    fn dyn_init(
61        &mut self,
62        geometry: &Geometry,
63        filter: Option<&HashMap<usize, BitVec>>,
64        parallel: bool,
65    ) -> Result<Box<dyn DGainCalculatorGenerator>, GainError> {
66        let mut tmp: MaybeUninit<T> = MaybeUninit::uninit();
67        std::mem::swap(&mut tmp, self);
68        // SAFETY: This function is called only once from `Gain::init`.
69        let g = unsafe { tmp.assume_init() };
70        Ok(Box::new(g.init_full(geometry, filter, parallel)?) as _)
71    }
72
73    fn dyn_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        // SAFETY: This function is never called after `dyn_init`.
75        unsafe { self.assume_init_ref() }.fmt(f)
76    }
77}
78
79/// Boxed [`Gain`].
80///
81/// Because [`Gain`] traits can have different associated types, it cannot simply be wrapped in a [`Box`] like `Box<dyn Gain>`.
82/// [`BoxedGain`] provides the ability to wrap any [`Gain`] in a common type.
83#[derive(Gain)]
84pub struct BoxedGain {
85    g: Box<dyn DGain>,
86}
87
88#[cfg(feature = "lightweight")]
89unsafe impl Send for BoxedGain {}
90#[cfg(feature = "lightweight")]
91unsafe impl Sync for BoxedGain {}
92
93impl std::fmt::Debug for BoxedGain {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        self.g.as_ref().dyn_fmt(f)
96    }
97}
98
99impl Gain for BoxedGain {
100    type G = DynGainCalculatorGenerator;
101
102    fn init_full(
103        self,
104        geometry: &Geometry,
105        filter: Option<&HashMap<usize, BitVec>>,
106        parallel: bool,
107    ) -> Result<Self::G, GainError> {
108        let Self { mut g, .. } = self;
109        Ok(DynGainCalculatorGenerator {
110            g: g.dyn_init(geometry, filter, parallel)?,
111        })
112    }
113
114    // GRCOV_EXCL_START
115    fn init(self) -> Result<Self::G, GainError> {
116        unimplemented!()
117    }
118    // GRCOV_EXCL_STOP
119}
120
121/// Trait to convert [`Gain`] to [`BoxedGain`].
122pub trait IntoBoxedGain {
123    /// Convert [`Gain`] to [`BoxedGain`].
124    #[must_use]
125    fn into_boxed(self) -> BoxedGain;
126}
127
128impl<
129    #[cfg(feature = "lightweight")] GG: GainCalculatorGenerator + Send + Sync + 'static,
130    #[cfg(not(feature = "lightweight"))] G: Gain + 'static,
131    #[cfg(feature = "lightweight")] G: Gain<G = GG> + Send + Sync + 'static,
132> IntoBoxedGain for G
133{
134    fn into_boxed(self) -> BoxedGain {
135        BoxedGain {
136            g: Box::new(MaybeUninit::new(self)),
137        }
138    }
139}
140
141#[cfg(test)]
142pub mod tests {
143    use autd3_core::gain::Drive;
144
145    use super::*;
146    use crate::datagram::gain::tests::TestGain;
147
148    use crate::firmware::fpga::{EmitIntensity, Phase};
149
150    const NUM_TRANSDUCERS: usize = 2;
151
152    #[rstest::rstest]
153    #[test]
154    #[case::serial(
155        [
156            (0, vec![Drive { phase: Phase(0x01), intensity: EmitIntensity(0x01) }; NUM_TRANSDUCERS]),
157            (1, vec![Drive { phase: Phase(0x02), intensity: EmitIntensity(0x02) }; NUM_TRANSDUCERS])
158        ].into_iter().collect(),
159        vec![true; 2],
160        2)]
161    #[case::parallel(
162        [
163            (0, vec![Drive { phase: Phase(0x01), intensity: EmitIntensity(0x01) }; NUM_TRANSDUCERS]),
164            (1, vec![Drive { phase: Phase(0x02), intensity: EmitIntensity(0x02) }; NUM_TRANSDUCERS]),
165            (2, vec![Drive { phase: Phase(0x03), intensity: EmitIntensity(0x03) }; NUM_TRANSDUCERS]),
166            (3, vec![Drive { phase: Phase(0x04), intensity: EmitIntensity(0x04) }; NUM_TRANSDUCERS]),
167            (4, vec![Drive { phase: Phase(0x05), intensity: EmitIntensity(0x05) }; NUM_TRANSDUCERS]),
168        ].into_iter().collect(),
169        vec![true; 5],
170        5)]
171    #[case::enabled(
172        [
173            (0, vec![Drive { phase: Phase(0x01), intensity: EmitIntensity(0x01) }; NUM_TRANSDUCERS]),
174        ].into_iter().collect(),
175        vec![true, false],
176        2)]
177    fn boxed_gain_unsafe(
178        #[case] expect: HashMap<usize, Vec<Drive>>,
179        #[case] enabled: Vec<bool>,
180        #[case] n: u16,
181    ) -> anyhow::Result<()> {
182        use crate::datagram::tests::create_geometry;
183
184        let mut geometry = create_geometry(n, NUM_TRANSDUCERS as _);
185        geometry
186            .iter_mut()
187            .zip(enabled.iter())
188            .for_each(|(dev, &e)| dev.enable = e);
189        let g = TestGain::new(
190            |dev| {
191                let dev_idx = dev.idx();
192                move |_| Drive {
193                    phase: Phase(dev_idx as u8 + 1),
194                    intensity: EmitIntensity(dev_idx as u8 + 1),
195                }
196            },
197            &geometry,
198        )
199        .into_boxed();
200
201        let mut f = g.init_full(&geometry, None, false)?;
202        assert_eq!(
203            expect,
204            geometry
205                .devices()
206                .map(|dev| {
207                    let f = GainCalculatorGenerator::generate(&mut f, dev);
208                    (dev.idx(), dev.iter().map(|tr| f.calc(tr)).collect())
209                })
210                .collect()
211        );
212
213        Ok(())
214    }
215
216    #[test]
217    fn boxed_gain_dbg_unsafe() {
218        let g = TestGain::null();
219        assert_eq!(format!("{:?}", g), format!("{:?}", g.into_boxed()));
220    }
221}