1use autd3_core::{common::Freq, derive::*, firmware::SamplingConfig};
2
3use super::sampling_mode::{Nearest, SamplingMode};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
7pub struct SquareOption {
8 pub low: u8,
10 pub high: u8,
12 pub duty: f32,
14 pub sampling_config: SamplingConfig,
16}
17
18impl Default for SquareOption {
19 fn default() -> Self {
20 Self {
21 low: u8::MIN,
22 high: u8::MAX,
23 duty: 0.5,
24 sampling_config: SamplingConfig::FREQ_4K,
25 }
26 }
27}
28
29#[derive(Modulation, Clone, Copy, PartialEq, Debug)]
31pub struct Square<S: Into<SamplingMode> + Clone + Copy + std::fmt::Debug> {
32 pub freq: S,
34 pub option: SquareOption,
36}
37
38impl<S: Into<SamplingMode> + Clone + Copy + std::fmt::Debug> Square<S> {
39 #[must_use]
41 pub const fn new(freq: S, option: SquareOption) -> Self {
42 Self { freq, option }
43 }
44}
45
46impl Square<Freq<f32>> {
47 #[must_use]
59 pub const fn into_nearest(self) -> Square<Nearest> {
60 Square {
61 freq: Nearest(self.freq),
62 option: self.option,
63 }
64 }
65}
66
67impl<S: Into<SamplingMode> + Clone + Copy + std::fmt::Debug> Modulation for Square<S> {
68 fn calc(self) -> Result<Vec<u8>, ModulationError> {
69 if !(0.0..=1.0).contains(&self.option.duty) {
70 return Err(ModulationError::new("duty must be in range from 0 to 1"));
71 }
72
73 let sampling_mode: SamplingMode = self.freq.into();
74 let (n, rep) = sampling_mode.validate(self.option.sampling_config)?;
75 let high = self.option.high;
76 let low = self.option.low;
77 let duty = self.option.duty;
78 Ok((0..rep)
79 .map(|i| (n + i) / rep)
80 .flat_map(|size| {
81 let n_high = (size as f32 * duty) as usize;
82 vec![high; n_high]
83 .into_iter()
84 .chain(vec![low; size as usize - n_high])
85 })
86 .collect())
87 }
88
89 fn sampling_config(&self) -> SamplingConfig {
90 self.option.sampling_config
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use autd3_driver::common::Hz;
97
98 use super::*;
99
100 #[rstest::rstest]
101 #[case(
102 Ok(vec![
103 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0,
104 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0,
105 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
106 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
107 ]),
108 150.*Hz
109 )]
110 #[case(
111 Ok(vec![
112 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0,
113 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0,
114 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
115 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
116 ]),
117 150*Hz
118 )]
119 #[case(
120 Ok(vec![255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
121 200.*Hz
122 )]
123 #[case(
124 Ok(vec![255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
125 200*Hz
126 )]
127 #[case(
128 Ok(vec![
129 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255,
130 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0,
131 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255,
132 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0,
133 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255, 0, 0, 0, 255, 255,
134 255, 0, 0, 0, 255, 255, 255, 0, 0, 0, 255, 255, 255, 0, 0, 0
135 ]),
136 781.25*Hz
137 )]
138 #[case(
139 Err(ModulationError::new("Frequency (150.01 Hz) cannot be output with the sampling config (SamplingConfig::Freq(4000 Hz)).")),
140 150.01*Hz
141 )]
142 #[case(
143 Err(ModulationError::new("Frequency (2000 Hz) is equal to or greater than the Nyquist frequency (2000 Hz)")),
144 2000.*Hz
145 )]
146 #[case(
147 Err(ModulationError::new("Frequency (2000 Hz) is equal to or greater than the Nyquist frequency (2000 Hz)")),
148 2000*Hz
149 )]
150 #[case(
151 Err(ModulationError::new("Frequency (4000 Hz) is equal to or greater than the Nyquist frequency (2000 Hz)")),
152 4000.*Hz
153 )]
154 #[case(
155 Err(ModulationError::new("Frequency (4000 Hz) is equal to or greater than the Nyquist frequency (2000 Hz)")),
156 4000*Hz
157 )]
158 #[case(
159 Err(ModulationError::new("Frequency must not be zero. If intentional, use `Static` instead.")),
160 0*Hz
161 )]
162 #[case(
163 Err(ModulationError::new("Frequency must not be zero. If intentional, use `Static` instead.")),
164 0.*Hz
165 )]
166 fn with_freq_float_exact(
167 #[case] expect: Result<Vec<u8>, ModulationError>,
168 #[case] freq: impl Into<SamplingMode> + Copy + std::fmt::Debug,
169 ) {
170 let m = Square::new(freq, SquareOption::default());
171 assert_eq!(u8::MIN, m.option.low);
172 assert_eq!(u8::MAX, m.option.high);
173 assert_eq!(0.5, m.option.duty);
174 assert_eq!(SamplingConfig::FREQ_4K, m.sampling_config());
175 assert_eq!(expect, m.calc());
176 }
177
178 #[rstest::rstest]
179 #[case(
180 Ok(vec![
181 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0,
182 0, 0, 0, 0, 0, 0, 0,
183 ]),
184 150.*Hz
185 )]
186 #[case(
187 Ok(vec![255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
188 200.*Hz
189 )]
190 fn into_nearest(#[case] expect: Result<Vec<u8>, ModulationError>, #[case] freq: Freq<f32>) {
191 let m = Square {
192 freq,
193 option: SquareOption::default(),
194 }
195 .into_nearest();
196 assert_eq!(u8::MIN, m.option.low);
197 assert_eq!(u8::MAX, m.option.high);
198 assert_eq!(0.5, m.option.duty);
199 assert_eq!(SamplingConfig::FREQ_4K, m.sampling_config());
200 assert_eq!(expect, m.calc());
201 }
202
203 #[test]
204 fn with_low() -> Result<(), Box<dyn std::error::Error>> {
205 let m = Square {
206 freq: 150. * Hz,
207 option: SquareOption {
208 low: u8::MAX,
209 ..Default::default()
210 },
211 };
212 assert!(m.calc()?.iter().all(|&x| x == u8::MAX));
213
214 Ok(())
215 }
216
217 #[test]
218 fn with_high() -> Result<(), Box<dyn std::error::Error>> {
219 let m = Square {
220 freq: 150. * Hz,
221 option: SquareOption {
222 high: u8::MIN,
223 ..Default::default()
224 },
225 };
226 assert!(m.calc()?.iter().all(|&x| x == u8::MIN));
227
228 Ok(())
229 }
230
231 #[rstest::rstest]
232 #[case(u8::MIN, 0.0)]
233 #[case(u8::MAX, 1.0)]
234 #[test]
235 fn with_duty(#[case] expect: u8, #[case] duty: f32) -> Result<(), Box<dyn std::error::Error>> {
236 let m = Square {
237 freq: 150. * Hz,
238 option: SquareOption {
239 duty,
240 ..Default::default()
241 },
242 };
243 assert!(m.calc()?.iter().all(|&x| x == expect));
244
245 Ok(())
246 }
247
248 #[rstest::rstest]
249 #[case("duty must be in range from 0 to 1", -0.1)]
250 #[case("duty must be in range from 0 to 1", 1.1)]
251 #[test]
252 fn duty_out_of_range(#[case] expect: &str, #[case] duty: f32) {
253 assert_eq!(
254 Some(ModulationError::new(expect)),
255 Square {
256 freq: 150. * Hz,
257 option: SquareOption {
258 duty,
259 ..Default::default()
260 },
261 }
262 .calc()
263 .err()
264 );
265 }
266}