autd3_driver/datagram/stm/foci/
implement.rs

1use std::{borrow::Borrow, sync::Arc};
2
3use crate::{error::AUTDDriverError, geometry::Device};
4
5use super::{ControlPoints, FociSTMGenerator, FociSTMIterator, FociSTMIteratorGenerator};
6
7pub struct VecFociSTMIterator<const N: usize, C, I>
8where
9    ControlPoints<N>: From<C>,
10{
11    foci: Arc<I>,
12    i: usize,
13    _phantom: std::marker::PhantomData<C>,
14}
15
16impl<const N: usize, C, I> FociSTMIterator<N> for VecFociSTMIterator<N, C, I>
17where
18    I: Borrow<[C]> + Send + Sync,
19    C: Clone + Send + Sync,
20    ControlPoints<N>: From<C>,
21{
22    fn next(&mut self) -> ControlPoints<N> {
23        let p = <I as Borrow<[C]>>::borrow(&self.foci)[self.i]
24            .clone()
25            .into();
26        self.i += 1;
27        p
28    }
29}
30
31impl<const N: usize, C, I> FociSTMIteratorGenerator<N> for VecFociSTMIterator<N, C, I>
32where
33    I: Borrow<[C]> + Send + Sync,
34    C: Clone + Send + Sync,
35    ControlPoints<N>: From<C>,
36{
37    type Iterator = VecFociSTMIterator<N, C, I>;
38
39    fn generate(&mut self, _: &Device) -> Self::Iterator {
40        Self::Iterator {
41            foci: self.foci.clone(),
42            i: 0,
43            _phantom: std::marker::PhantomData,
44        }
45    }
46}
47
48impl<const N: usize, C> FociSTMGenerator<N> for Vec<C>
49where
50    C: Clone + Send + Sync,
51    ControlPoints<N>: From<C>,
52{
53    type T = VecFociSTMIterator<N, C, Vec<C>>;
54
55    fn init(self) -> Result<Self::T, AUTDDriverError> {
56        Ok(VecFociSTMIterator {
57            foci: Arc::new(self),
58            i: 0,
59            _phantom: std::marker::PhantomData,
60        })
61    }
62
63    fn len(&self) -> usize {
64        Vec::len(self)
65    }
66}
67
68impl<const M: usize, const N: usize, C> FociSTMGenerator<N> for [C; M]
69where
70    C: Clone + Send + Sync,
71    ControlPoints<N>: From<C>,
72{
73    type T = VecFociSTMIterator<N, C, [C; M]>;
74
75    fn init(self) -> Result<Self::T, AUTDDriverError> {
76        Ok(VecFociSTMIterator {
77            foci: Arc::new(self),
78            i: 0,
79            _phantom: std::marker::PhantomData,
80        })
81    }
82
83    fn len(&self) -> usize {
84        M
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::time::Duration;
91
92    use super::super::FociSTM;
93    use crate::{
94        common::{Freq, Hz, kHz},
95        geometry::Point3,
96    };
97
98    use autd3_core::firmware::SamplingConfig;
99
100    #[rstest::rstest]
101    #[case(SamplingConfig::new(1. * Hz), 0.5*Hz, 2)]
102    #[case(SamplingConfig::new(10. * Hz), 1.*Hz, 10)]
103    #[case(SamplingConfig::new(20. * Hz), 2.*Hz, 10)]
104    fn from_freq(#[case] expect: SamplingConfig, #[case] freq: Freq<f32>, #[case] n: usize) {
105        assert_eq!(
106            Ok(expect),
107            FociSTM {
108                foci: (0..n).map(|_| Point3::origin()).collect::<Vec<_>>(),
109                config: freq,
110            }
111            .sampling_config()
112        );
113    }
114
115    #[rstest::rstest]
116    #[case(SamplingConfig::new(1. * Hz).into_nearest(), 0.5*Hz, 2)]
117    #[case(SamplingConfig::new(0.98 * Hz).into_nearest(), 0.49*Hz, 2)]
118    #[case(SamplingConfig::new(10. * Hz).into_nearest(), 1.*Hz, 10)]
119    #[case(SamplingConfig::new(20. * Hz).into_nearest(), 2.*Hz, 10)]
120    fn from_freq_nearest(
121        #[case] expect: SamplingConfig,
122        #[case] freq: Freq<f32>,
123        #[case] n: usize,
124    ) {
125        assert_eq!(
126            Ok(expect),
127            FociSTM {
128                foci: (0..n).map(|_| Point3::origin()).collect::<Vec<_>>(),
129                config: freq,
130            }
131            .into_nearest()
132            .sampling_config()
133        );
134    }
135
136    #[rstest::rstest]
137    #[case(
138        Ok(SamplingConfig::new(Duration::from_millis(1000))),
139        Duration::from_millis(2000),
140        2
141    )]
142    #[case(
143        Ok(SamplingConfig::new(Duration::from_millis(100))),
144        Duration::from_millis(1000),
145        10
146    )]
147    #[case(
148        Ok(SamplingConfig::new(Duration::from_millis(50))),
149        Duration::from_millis(500),
150        10
151    )]
152    #[case(Err(crate::error::AUTDDriverError::STMPeriodInvalid(2, Duration::from_millis(2000) + Duration::from_nanos(1))), Duration::from_millis(2000) + Duration::from_nanos(1), 2)]
153    fn from_period(
154        #[case] expect: Result<SamplingConfig, crate::error::AUTDDriverError>,
155        #[case] p: Duration,
156        #[case] n: usize,
157    ) {
158        assert_eq!(
159            expect,
160            FociSTM {
161                foci: (0..n).map(|_| Point3::origin()).collect::<Vec<_>>(),
162                config: p,
163            }
164            .sampling_config()
165        );
166    }
167
168    #[rstest::rstest]
169    #[case(
170        SamplingConfig::new(Duration::from_millis(1000)).into_nearest(),
171        Duration::from_millis(2000),
172        2
173    )]
174    #[case(
175        SamplingConfig::new(Duration::from_millis(100)).into_nearest(),
176        Duration::from_millis(1000),
177        10
178    )]
179    #[case(
180        SamplingConfig::new(Duration::from_millis(50)).into_nearest(),
181        Duration::from_millis(500),
182        10
183    )]
184    #[case(SamplingConfig::new(Duration::from_millis(1000)).into_nearest(), Duration::from_millis(2000) + Duration::from_nanos(1), 2)]
185    fn from_period_nearest(#[case] expect: SamplingConfig, #[case] p: Duration, #[case] n: usize) {
186        assert_eq!(
187            Ok(expect),
188            FociSTM {
189                foci: (0..n).map(|_| Point3::origin()).collect::<Vec<_>>(),
190                config: p,
191            }
192            .into_nearest()
193            .sampling_config()
194        );
195    }
196
197    #[rstest::rstest]
198    #[case(SamplingConfig::new(4. * kHz), 10)]
199    #[case(SamplingConfig::new(8. * kHz), 10)]
200    fn from_sampling_config(#[case] config: SamplingConfig, #[case] n: usize) {
201        assert_eq!(
202            Ok(config),
203            FociSTM {
204                foci: (0..n).map(|_| Point3::origin()).collect::<Vec<_>>(),
205                config,
206            }
207            .sampling_config()
208        );
209    }
210}