1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
pub use crate::{
    common::{Drive, Segment},
    datagram::{DatagramS, Gain, GainCache, GainFilter, IntoGainCache, Modulation},
    error::AUTDInternalError,
    geometry::{Device, Geometry, Transducer},
    operation::{GainOp, NullOp, Operation},
};
pub use autd3_derive::Gain;

use std::collections::HashMap;

/// Gain to transform gain data
#[derive(Gain)]
#[no_gain_transform]
pub struct Transform<G: Gain + 'static, F: Fn(&Device, &Transducer, Drive) -> Drive + 'static> {
    gain: G,
    f: F,
}

pub trait IntoTransform<G: Gain> {
    /// transform gain data
    ///
    /// # Arguments
    ///
    /// * `f` - transform function. The first argument is the device, the second is transducer, and the third is the original drive data.
    ///
    fn with_transform<F: Fn(&Device, &Transducer, Drive) -> Drive>(self, f: F) -> Transform<G, F>;
}

impl<G: Gain + 'static, F: Fn(&Device, &Transducer, Drive) -> Drive> Transform<G, F> {
    #[doc(hidden)]
    pub fn new(gain: G, f: F) -> Self {
        Self { gain, f }
    }
}

impl<G: Gain + 'static, F: Fn(&Device, &Transducer, Drive) -> Drive + 'static> Gain
    for Transform<G, F>
{
    fn calc(
        &self,
        geometry: &Geometry,
        filter: GainFilter,
    ) -> Result<HashMap<usize, Vec<Drive>>, AUTDInternalError> {
        Ok(self
            .gain
            .calc(geometry, filter)?
            .into_iter()
            .map(|(k, v)| {
                (
                    k,
                    v.into_iter()
                        .enumerate()
                        .map(|(i, d)| (self.f)(&geometry[k], &geometry[k][i], d))
                        .collect::<Vec<_>>(),
                )
            })
            .collect())
    }
}

#[cfg(test)]
mod tests {
    use super::{super::tests::TestGain, *};

    use crate::{datagram::Datagram, geometry::tests::create_geometry};

    #[test]
    fn test_gain_transform() -> anyhow::Result<()> {
        let geometry = create_geometry(1, 249);

        let d = Drive::random();
        let gain = TestGain { d: Drive::null() }.with_transform(move |_, _, _| d);

        assert_eq!(
            geometry
                .devices()
                .map(|dev| (dev.idx(), vec![d; dev.num_transducers()]))
                .collect::<HashMap<_, _>>(),
            gain.calc(&geometry, GainFilter::All)?
        );

        Ok(())
    }

    #[cfg_attr(coverage_nightly, coverage(off))]
    fn f(_dev: &Device, _tr: &Transducer, _d: Drive) -> Drive {
        Drive::null()
    }

    #[test]
    fn test_gain_transform_derive() {
        let gain = TestGain { d: Drive::null() }.with_transform(f);
        let _ = gain.operation();
    }
}