autd3_modulation_audio_file/
wav.rs

1use autd3_core::{common::Hz, derive::*};
2use hound::SampleFormat;
3
4use std::{fmt::Debug, path::Path};
5
6use crate::error::AudioFileError;
7
8/// [`Modulation`] from Wav data.
9#[derive(Modulation, Debug, Clone)]
10pub struct Wav {
11    spec: hound::WavSpec,
12    buffer: Vec<u8>,
13}
14
15impl Wav {
16    /// Create a new [`Wav`].
17    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, AudioFileError> {
18        let path = path.as_ref().to_path_buf();
19        let mut reader = hound::WavReader::open(&path)?;
20        let spec = reader.spec();
21        if spec.channels != 1 {
22            return Err(AudioFileError::Wav(hound::Error::Unsupported));
23        }
24        let buffer = match spec.sample_format {
25            SampleFormat::Int => {
26                let raw_buffer = reader.samples::<i32>().collect::<Result<Vec<_>, _>>()?;
27                match spec.bits_per_sample {
28                    8 => raw_buffer
29                        .iter()
30                        .map(|i| (i - i8::MIN as i32) as _)
31                        .collect(),
32                    16 => raw_buffer
33                        .iter()
34                        .map(|i| ((i - i16::MIN as i32) as f32 / 257.).round() as _)
35                        .collect(),
36                    24 => raw_buffer
37                        .iter()
38                        .map(|i| ((i + 8388608i32) as f32 / 65793.).round() as _)
39                        .collect(),
40                    32 => raw_buffer
41                        .iter()
42                        .map(|&i| ((i as i64 - i32::MIN as i64) as f32 / 16843009.).round() as _)
43                        .collect(),
44                    _ => return Err(AudioFileError::Wav(hound::Error::Unsupported)), // GRCOV_EXCL_LINE
45                }
46            }
47            SampleFormat::Float => {
48                let raw_buffer = reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?;
49                match spec.bits_per_sample {
50                    32 => raw_buffer
51                        .iter()
52                        .map(|&i| ((i + 1.0) / 2. * 255.).round() as _)
53                        .collect(),
54                    _ => return Err(AudioFileError::Wav(hound::Error::Unsupported)), // GRCOV_EXCL_LINE
55                }
56            }
57        };
58
59        Ok(Self { spec, buffer })
60    }
61
62    /// Encode a [`Modulation`] into a mono 8-bit PCM WAV file.
63    ///
64    /// This writes the provided modulation's data into a wav file with:
65    /// - `channels = 1`
66    /// - `bits_per_sample = 8`
67    /// - `sample_format = Int`
68    /// - `sample_rate = sampling frequency of the modulation`
69    ///
70    /// The sample rate must be an integer number of hertz; otherwise this returns error.
71    pub fn encode<P: AsRef<Path>, M: Modulation>(m: M, path: P) -> Result<(), AudioFileError> {
72        let sample_rate = m.sampling_config().freq()?.hz();
73        if !autd3_core::utils::float::is_integer(sample_rate as f64) {
74            return Err(AudioFileError::Wav(hound::Error::Unsupported));
75        }
76        let sample_rate = sample_rate as u32;
77        let buffer = m.calc(&FirmwareLimits {
78            mod_buf_size_max: u32::MAX,
79            ..FirmwareLimits::unused()
80        })?;
81
82        let spec = hound::WavSpec {
83            channels: 1,
84            sample_rate,
85            bits_per_sample: 8,
86            sample_format: SampleFormat::Int,
87        };
88        let mut writer = hound::WavWriter::create(&path, spec)?;
89        buffer
90            .into_iter()
91            .try_for_each(|b| writer.write_sample(b.wrapping_add(128) as i8))?;
92        writer.finalize()?;
93        Ok(())
94    }
95}
96
97impl Modulation for Wav {
98    fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
99        Ok(self.buffer)
100    }
101
102    fn sampling_config(&self) -> SamplingConfig {
103        SamplingConfig::new(self.spec.sample_rate as f32 * Hz)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    fn create_wav(
112        path: impl AsRef<Path>,
113        spec: hound::WavSpec,
114        data: &[impl hound::Sample + Copy],
115    ) -> Result<(), Box<dyn std::error::Error>> {
116        let mut writer = hound::WavWriter::create(path, spec)?;
117        data.iter().try_for_each(|&s| writer.write_sample(s))?;
118        writer.finalize()?;
119        Ok(())
120    }
121
122    #[rstest::rstest]
123    #[case::i8(
124        vec![
125            0xFF,
126            0x80,
127            0x00
128        ],
129        hound::WavSpec {
130            channels: 1,
131            sample_rate: 4000,
132            bits_per_sample: 8,
133            sample_format: hound::SampleFormat::Int,
134        },
135        &[i8::MAX, 0, i8::MIN]
136    )]
137    #[case::i16(
138        vec![
139            0xFF,
140            0x80,
141            0x00
142        ],
143        hound::WavSpec {
144            channels: 1,
145            sample_rate: 4000,
146            bits_per_sample: 16,
147            sample_format: hound::SampleFormat::Int,
148        },
149        &[i16::MAX, 0, i16::MIN]
150    )]
151    #[case::i24(
152        vec![
153            0xFF,
154            0x80,
155            0x00
156        ],
157        hound::WavSpec {
158            channels: 1,
159            sample_rate: 4000,
160            bits_per_sample: 24,
161            sample_format: hound::SampleFormat::Int,
162        },
163        &[8388607, 0, -8388608]
164    )]
165    #[case::i32(
166        vec![
167            0xFF,
168            0x80,
169            0x00
170        ],
171        hound::WavSpec {
172            channels: 1,
173            sample_rate: 4000,
174            bits_per_sample: 32,
175            sample_format: hound::SampleFormat::Int,
176        },
177        &[i32::MAX, 0, i32::MIN]
178    )]
179    #[case::f32(
180        vec![
181            0xFF,
182            0x80,
183            0x00
184        ],
185        hound::WavSpec {
186            channels: 1,
187            sample_rate: 4000,
188            bits_per_sample: 32,
189            sample_format: hound::SampleFormat::Float,
190        },
191        &[1., 0., -1.]
192    )]
193    fn wav(
194        #[case] expect: Vec<u8>,
195        #[case] spec: hound::WavSpec,
196        #[case] data: &[impl hound::Sample + Copy],
197    ) -> Result<(), Box<dyn std::error::Error>> {
198        let dir = tempfile::tempdir()?;
199        let path = dir.path().join("tmp.wav");
200        create_wav(&path, spec, data)?;
201        let m = Wav::new(path)?;
202        assert_eq!(spec.sample_rate, m.sampling_config().freq()?.hz() as u32);
203        assert_eq!(Ok(expect), m.calc(&FirmwareLimits::unused()));
204
205        Ok(())
206    }
207
208    #[test]
209    fn wav_new_unsupported() -> Result<(), Box<dyn std::error::Error>> {
210        let dir = tempfile::tempdir()?;
211        let path = dir.path().join("tmp.wav");
212        create_wav(
213            &path,
214            hound::WavSpec {
215                channels: 2,
216                sample_rate: 4000,
217                bits_per_sample: 32,
218                sample_format: hound::SampleFormat::Int,
219            },
220            &[0, 0],
221        )?;
222        assert!(Wav::new(path).is_err());
223        Ok(())
224    }
225
226    #[test]
227    fn encode_writes_expected_wav() -> Result<(), Box<dyn std::error::Error>> {
228        #[derive(Clone)]
229        struct TestMod {
230            data: Vec<u8>,
231            rate: f32,
232        }
233        impl Modulation for TestMod {
234            fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
235                Ok(self.data)
236            }
237            fn sampling_config(&self) -> SamplingConfig {
238                SamplingConfig::new(self.rate * Hz)
239            }
240        }
241
242        let dir = tempfile::tempdir()?;
243        let path = dir.path().join("enc.wav");
244        let data = vec![0u8, 128u8, 255u8];
245        let m = TestMod {
246            data: data.clone(),
247            rate: 4000.0,
248        };
249        Wav::encode(m, &path)?;
250
251        let mut reader = hound::WavReader::open(&path)?;
252        let spec = reader.spec();
253        assert_eq!(spec.channels, 1);
254        assert_eq!(spec.bits_per_sample, 8);
255        assert_eq!(spec.sample_format, hound::SampleFormat::Int);
256        assert_eq!(spec.sample_rate, 4000);
257
258        let samples = reader.samples::<i8>().collect::<Result<Vec<_>, _>>()?;
259        assert_eq!(samples, vec![-128, 0, 127]);
260
261        let decoded = Wav::new(&path)?;
262        assert_eq!(decoded.calc(&FirmwareLimits::unused())?, data);
263
264        Ok(())
265    }
266
267    #[test]
268    fn encode_rejects_non_integer_rate() -> Result<(), Box<dyn std::error::Error>> {
269        struct TestMod;
270        impl Modulation for TestMod {
271            // GRCOV_EXCL_START
272            fn calc(self, _: &FirmwareLimits) -> Result<Vec<u8>, ModulationError> {
273                unreachable!()
274            }
275            // GRCOV_EXCL_STOP
276            fn sampling_config(&self) -> SamplingConfig {
277                SamplingConfig::new(std::num::NonZeroU16::new(3).unwrap())
278            }
279        }
280        let dir = tempfile::tempdir()?;
281        let path = dir.path().join("enc_err.wav");
282        let err = Wav::encode(TestMod, &path);
283        match err {
284            Err(AudioFileError::Wav(hound::Error::Unsupported)) => {}
285            _ => panic!("unexpected error: {err:?}"), // GRCOV_EXCL_LINE
286        }
287        Ok(())
288    }
289}