1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
pub use crate::{
    derive::*,
    error::AUTDInternalError,
    firmware::fpga::{Drive, Segment},
    firmware::operation::{GainOp, NullOp},
    geometry::{Device, Geometry, Transducer},
};
pub use autd3_derive::Gain;

use super::GainCalcResult;

#[derive(Gain)]
#[no_gain_transform]
pub struct Transform<
    G: Gain,
    FT: Fn(&Transducer, Drive) -> Drive + Send + Sync + 'static,
    F: Fn(&Device) -> FT + 'static,
> {
    gain: G,
    f: F,
}

pub trait IntoTransform<G: Gain> {
    fn with_transform<FT: Fn(&Transducer, Drive) -> Drive + Send + Sync, F: Fn(&Device) -> FT>(
        self,
        f: F,
    ) -> Transform<G, FT, F>;
}

impl<G: Gain, FT: Fn(&Transducer, Drive) -> Drive + Send + Sync, F: Fn(&Device) -> FT>
    Transform<G, FT, F>
{
    #[doc(hidden)]
    pub fn new(gain: G, f: F) -> Self {
        Self { gain, f }
    }
}

impl<
        G: Gain,
        FT: Fn(&Transducer, Drive) -> Drive + Send + Sync + 'static,
        F: Fn(&Device) -> FT + 'static,
    > Gain for Transform<G, FT, F>
{
    fn calc(&self, geometry: &Geometry) -> GainCalcResult {
        let src = self.gain.calc(geometry)?;
        let f = &self.f;
        Ok(Box::new(move |dev| {
            let f = f(dev);
            let src = src(dev);
            Box::new(move |tr| f(tr, src(tr)))
        }))
    }
}

#[cfg(test)]
mod tests {
    use rand::Rng;

    use super::{super::tests::TestGain, *};

    use crate::{defined::FREQ_40K, geometry::tests::create_geometry};

    #[test]
    fn test() {
        let geometry = create_geometry(1, 249, FREQ_40K);

        let mut rng = rand::thread_rng();
        let d: Drive = Drive::new(Phase::new(rng.gen()), EmitIntensity::new(rng.gen()));

        let gain = TestGain::null(&geometry).with_transform(move |_| move |_, _| d);

        assert_eq!(
            geometry
                .devices()
                .map(|dev| (dev.idx(), vec![d; dev.num_transducers()]))
                .collect::<HashMap<_, _>>(),
            geometry
                .devices()
                .map(|dev| (
                    dev.idx(),
                    dev.iter().map(gain.calc(&geometry).unwrap()(dev)).collect()
                ))
                .collect()
        );
    }
}