autd3_driver/datagram/gain/
boxed.rs

1use std::mem::MaybeUninit;
2
3use super::{Gain, GainCalculatorGenerator};
4
5pub use crate::geometry::{Device, Geometry};
6
7use autd3_core::{derive::*, gain::TransducerFilter};
8
9pub trait DGainCalculatorGenerator {
10    #[must_use]
11    fn dyn_generate(&mut self, device: &Device) -> Box<dyn GainCalculator>;
12}
13
14pub struct DynGainCalculatorGenerator {
15    g: Box<dyn DGainCalculatorGenerator>,
16}
17
18impl GainCalculatorGenerator for DynGainCalculatorGenerator {
19    type Calculator = Box<dyn GainCalculator>;
20
21    fn generate(&mut self, device: &Device) -> Box<dyn GainCalculator> {
22        self.g.dyn_generate(device)
23    }
24}
25
26impl<Calculator: GainCalculator + 'static, G: GainCalculatorGenerator<Calculator = Calculator>>
27    DGainCalculatorGenerator for G
28{
29    fn dyn_generate(&mut self, device: &Device) -> Box<dyn GainCalculator> {
30        Box::new(GainCalculatorGenerator::generate(self, device))
31    }
32}
33
34/// A dyn-compatible version of [`Gain`].
35trait DGain {
36    fn dyn_init(
37        &mut self,
38        geometry: &Geometry,
39        env: &Environment,
40        filter: &TransducerFilter,
41    ) -> Result<Box<dyn DGainCalculatorGenerator>, GainError>;
42    fn dyn_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
43}
44
45impl<G: DGainCalculatorGenerator + 'static, T: Gain<G = G>> DGain for MaybeUninit<T> {
46    fn dyn_init(
47        &mut self,
48        geometry: &Geometry,
49        env: &Environment,
50        filter: &TransducerFilter,
51    ) -> Result<Box<dyn DGainCalculatorGenerator>, GainError> {
52        let mut tmp: MaybeUninit<T> = MaybeUninit::uninit();
53        std::mem::swap(&mut tmp, self);
54        // SAFETY: This function is called only once from `Gain::init`.
55        let g = unsafe { tmp.assume_init() };
56        Ok(Box::new(g.init(geometry, env, filter)?) as _)
57    }
58
59    fn dyn_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        // SAFETY: This function is never called after `dyn_init`.
61        unsafe { self.assume_init_ref() }.fmt(f)
62    }
63}
64
65/// Boxed [`Gain`].
66///
67/// Because [`Gain`] traits can have different associated types, it cannot simply be wrapped in a [`Box`] like `Box<dyn Gain>`.
68/// [`BoxedGain`] provides the ability to wrap any [`Gain`] in a common type.
69#[derive(Gain)]
70pub struct BoxedGain {
71    g: Box<dyn DGain>,
72}
73
74impl BoxedGain {
75    /// Creates a new [`BoxedGain`].
76    #[must_use]
77    pub fn new<G: Gain + 'static>(g: G) -> Self {
78        Self {
79            g: Box::new(MaybeUninit::new(g)),
80        }
81    }
82}
83
84impl std::fmt::Debug for BoxedGain {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        self.g.as_ref().dyn_fmt(f)
87    }
88}
89
90impl Gain for BoxedGain {
91    type G = DynGainCalculatorGenerator;
92
93    fn init(
94        self,
95        geometry: &Geometry,
96        env: &Environment,
97        filter: &TransducerFilter,
98    ) -> Result<Self::G, GainError> {
99        let Self { mut g, .. } = self;
100        Ok(DynGainCalculatorGenerator {
101            g: g.dyn_init(geometry, env, filter)?,
102        })
103    }
104}
105
106#[cfg(test)]
107pub mod tests {
108    use super::*;
109
110    use crate::datagram::gain::tests::TestGain;
111
112    use autd3_core::{
113        gain::Drive,
114        geometry::{Point3, UnitQuaternion},
115    };
116
117    const NUM_TRANSDUCERS: usize = 2;
118
119    #[rstest::rstest]
120    #[test]
121    #[case::serial(
122        [
123            (0, vec![Drive { phase: Phase(0x01), intensity: Intensity(0x01) }; NUM_TRANSDUCERS]),
124            (1, vec![Drive { phase: Phase(0x02), intensity: Intensity(0x02) }; NUM_TRANSDUCERS])
125        ].into_iter().collect(),
126        2)]
127    #[case::parallel(
128        [
129            (0, vec![Drive { phase: Phase(0x01), intensity: Intensity(0x01) }; NUM_TRANSDUCERS]),
130            (1, vec![Drive { phase: Phase(0x02), intensity: Intensity(0x02) }; NUM_TRANSDUCERS]),
131            (2, vec![Drive { phase: Phase(0x03), intensity: Intensity(0x03) }; NUM_TRANSDUCERS]),
132            (3, vec![Drive { phase: Phase(0x04), intensity: Intensity(0x04) }; NUM_TRANSDUCERS]),
133            (4, vec![Drive { phase: Phase(0x05), intensity: Intensity(0x05) }; NUM_TRANSDUCERS]),
134        ].into_iter().collect(),
135        5)]
136    fn boxed_gain_unsafe(
137        #[case] expect: HashMap<usize, Vec<Drive>>,
138        #[case] n: u16,
139    ) -> anyhow::Result<()> {
140        let geometry = Geometry::new(
141            (0..n)
142                .map(|_| {
143                    Device::new(
144                        UnitQuaternion::identity(),
145                        (0..NUM_TRANSDUCERS)
146                            .map(|_| Transducer::new(Point3::origin()))
147                            .collect(),
148                    )
149                })
150                .collect(),
151        );
152
153        let g = BoxedGain::new(TestGain::new(
154            |dev| {
155                let dev_idx = dev.idx();
156                move |_| Drive {
157                    phase: Phase(dev_idx as u8 + 1),
158                    intensity: Intensity(dev_idx as u8 + 1),
159                }
160            },
161            &geometry,
162        ));
163
164        let mut f = g.init(
165            &geometry,
166            &Environment::new(),
167            &TransducerFilter::all_enabled(),
168        )?;
169        assert_eq!(
170            expect,
171            geometry
172                .iter()
173                .map(|dev| {
174                    let f = GainCalculatorGenerator::generate(&mut f, dev);
175                    (dev.idx(), dev.iter().map(|tr| f.calc(tr)).collect())
176                })
177                .collect()
178        );
179
180        Ok(())
181    }
182
183    #[test]
184    fn boxed_gain_dbg_unsafe() {
185        let g = TestGain::null();
186        assert_eq!(format!("{g:?}"), format!("{:?}", BoxedGain::new(g)));
187    }
188}