autd3_modulation_audio_file/
wav.rs1use autd3_core::{common::Hz, derive::*};
2use hound::SampleFormat;
3
4use std::{fmt::Debug, path::Path};
5
6use crate::error::AudioFileError;
7
8#[derive(Modulation, Debug, Clone)]
10pub struct Wav {
11 spec: hound::WavSpec,
12 buffer: Vec<u8>,
13}
14
15impl Wav {
16 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, AudioFileError> {
18 let path = path.as_ref().to_path_buf();
19 let mut reader = hound::WavReader::open(&path)?;
20 let spec = reader.spec();
21 if spec.channels != 1 {
22 return Err(AudioFileError::Wav(hound::Error::Unsupported));
23 }
24 let buffer = match spec.sample_format {
25 SampleFormat::Int => {
26 let raw_buffer = reader.samples::<i32>().collect::<Result<Vec<_>, _>>()?;
27 match spec.bits_per_sample {
28 8 => raw_buffer
29 .iter()
30 .map(|i| (i - i8::MIN as i32) as _)
31 .collect(),
32 16 => raw_buffer
33 .iter()
34 .map(|i| ((i - i16::MIN as i32) as f32 / 257.).round() as _)
35 .collect(),
36 24 => raw_buffer
37 .iter()
38 .map(|i| ((i + 8388608i32) as f32 / 65793.).round() as _)
39 .collect(),
40 32 => raw_buffer
41 .iter()
42 .map(|&i| ((i as i64 - i32::MIN as i64) as f32 / 16843009.).round() as _)
43 .collect(),
44 _ => return Err(AudioFileError::Wav(hound::Error::Unsupported)), }
46 }
47 SampleFormat::Float => {
48 let raw_buffer = reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?;
49 match spec.bits_per_sample {
50 32 => raw_buffer
51 .iter()
52 .map(|&i| ((i + 1.0) / 2. * 255.).round() as _)
53 .collect(),
54 _ => return Err(AudioFileError::Wav(hound::Error::Unsupported)), }
56 }
57 };
58
59 Ok(Self { spec, buffer })
60 }
61
62 pub fn encode<P: AsRef<Path>, M: Modulation>(m: M, path: P) -> Result<(), AudioFileError> {
72 let sample_rate = m.sampling_config().freq()?.hz();
73 if !autd3_core::utils::float::is_integer(sample_rate as f64) {
74 return Err(AudioFileError::Wav(hound::Error::Unsupported));
75 }
76 let sample_rate = sample_rate as u32;
77 let buffer = m.calc(&FirmwareLimits {
78 mod_buf_size_max: u32::MAX,
79 ..FirmwareLimits::unused()
80 })?;
81
82 let spec = hound::WavSpec {
83 channels: 1,
84 sample_rate,
85 bits_per_sample: 8,
86 sample_format: SampleFormat::Int,
87 };
88 let mut writer = hound::WavWriter::create(&path, spec)?;
89 buffer
90 .into_iter()
91 .try_for_each(|b| writer.write_sample(b.wrapping_add(128) as i8))?;
92 writer.finalize()?;
93 Ok(())
94 }
95}
96
97impl Modulation for Wav {
98 fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
99 Ok(self.buffer)
100 }
101
102 fn sampling_config(&self) -> SamplingConfig {
103 SamplingConfig::new(self.spec.sample_rate as f32 * Hz)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn create_wav(
112 path: impl AsRef<Path>,
113 spec: hound::WavSpec,
114 data: &[impl hound::Sample + Copy],
115 ) -> Result<(), Box<dyn std::error::Error>> {
116 let mut writer = hound::WavWriter::create(path, spec)?;
117 data.iter().try_for_each(|&s| writer.write_sample(s))?;
118 writer.finalize()?;
119 Ok(())
120 }
121
122 #[rstest::rstest]
123 #[case::i8(
124 vec![
125 0xFF,
126 0x80,
127 0x00
128 ],
129 hound::WavSpec {
130 channels: 1,
131 sample_rate: 4000,
132 bits_per_sample: 8,
133 sample_format: hound::SampleFormat::Int,
134 },
135 &[i8::MAX, 0, i8::MIN]
136 )]
137 #[case::i16(
138 vec![
139 0xFF,
140 0x80,
141 0x00
142 ],
143 hound::WavSpec {
144 channels: 1,
145 sample_rate: 4000,
146 bits_per_sample: 16,
147 sample_format: hound::SampleFormat::Int,
148 },
149 &[i16::MAX, 0, i16::MIN]
150 )]
151 #[case::i24(
152 vec![
153 0xFF,
154 0x80,
155 0x00
156 ],
157 hound::WavSpec {
158 channels: 1,
159 sample_rate: 4000,
160 bits_per_sample: 24,
161 sample_format: hound::SampleFormat::Int,
162 },
163 &[8388607, 0, -8388608]
164 )]
165 #[case::i32(
166 vec![
167 0xFF,
168 0x80,
169 0x00
170 ],
171 hound::WavSpec {
172 channels: 1,
173 sample_rate: 4000,
174 bits_per_sample: 32,
175 sample_format: hound::SampleFormat::Int,
176 },
177 &[i32::MAX, 0, i32::MIN]
178 )]
179 #[case::f32(
180 vec![
181 0xFF,
182 0x80,
183 0x00
184 ],
185 hound::WavSpec {
186 channels: 1,
187 sample_rate: 4000,
188 bits_per_sample: 32,
189 sample_format: hound::SampleFormat::Float,
190 },
191 &[1., 0., -1.]
192 )]
193 fn wav(
194 #[case] expect: Vec<u8>,
195 #[case] spec: hound::WavSpec,
196 #[case] data: &[impl hound::Sample + Copy],
197 ) -> Result<(), Box<dyn std::error::Error>> {
198 let dir = tempfile::tempdir()?;
199 let path = dir.path().join("tmp.wav");
200 create_wav(&path, spec, data)?;
201 let m = Wav::new(path)?;
202 assert_eq!(spec.sample_rate, m.sampling_config().freq()?.hz() as u32);
203 assert_eq!(Ok(expect), m.calc(&FirmwareLimits::unused()));
204
205 Ok(())
206 }
207
208 #[test]
209 fn wav_new_unsupported() -> Result<(), Box<dyn std::error::Error>> {
210 let dir = tempfile::tempdir()?;
211 let path = dir.path().join("tmp.wav");
212 create_wav(
213 &path,
214 hound::WavSpec {
215 channels: 2,
216 sample_rate: 4000,
217 bits_per_sample: 32,
218 sample_format: hound::SampleFormat::Int,
219 },
220 &[0, 0],
221 )?;
222 assert!(Wav::new(path).is_err());
223 Ok(())
224 }
225
226 #[test]
227 fn encode_writes_expected_wav() -> Result<(), Box<dyn std::error::Error>> {
228 #[derive(Clone)]
229 struct TestMod {
230 data: Vec<u8>,
231 rate: f32,
232 }
233 impl Modulation for TestMod {
234 fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
235 Ok(self.data)
236 }
237 fn sampling_config(&self) -> SamplingConfig {
238 SamplingConfig::new(self.rate * Hz)
239 }
240 }
241
242 let dir = tempfile::tempdir()?;
243 let path = dir.path().join("enc.wav");
244 let data = vec![0u8, 128u8, 255u8];
245 let m = TestMod {
246 data: data.clone(),
247 rate: 4000.0,
248 };
249 Wav::encode(m, &path)?;
250
251 let mut reader = hound::WavReader::open(&path)?;
252 let spec = reader.spec();
253 assert_eq!(spec.channels, 1);
254 assert_eq!(spec.bits_per_sample, 8);
255 assert_eq!(spec.sample_format, hound::SampleFormat::Int);
256 assert_eq!(spec.sample_rate, 4000);
257
258 let samples = reader.samples::<i8>().collect::<Result<Vec<_>, _>>()?;
259 assert_eq!(samples, vec![-128, 0, 127]);
260
261 let decoded = Wav::new(&path)?;
262 assert_eq!(decoded.calc(&FirmwareLimits::unused())?, data);
263
264 Ok(())
265 }
266
267 #[test]
268 fn encode_rejects_non_integer_rate() -> Result<(), Box<dyn std::error::Error>> {
269 struct TestMod;
270 impl Modulation for TestMod {
271 fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
273 unreachable!()
274 }
275 fn sampling_config(&self) -> SamplingConfig {
277 SamplingConfig::new(std::num::NonZeroU16::new(3).unwrap())
278 }
279 }
280 let dir = tempfile::tempdir()?;
281 let path = dir.path().join("enc_err.wav");
282 let err = Wav::encode(TestMod, &path);
283 match err {
284 Err(AudioFileError::Wav(hound::Error::Unsupported)) => {}
285 _ => panic!("unexpected error: {err:?}"), }
287 Ok(())
288 }
289}