1use std::io::{Cursor, Read, Seek};
2
3use crate::error::{Error, Result};
4
5use super::DecodedAudio;
6
7#[derive(Debug, Default, Clone, Copy)]
12pub struct WavDecoder;
13
14impl WavDecoder {
15 pub fn decode(data: &[u8]) -> Result<DecodedAudio> {
17 let cursor = Cursor::new(data);
18 Self::decode_from_reader(cursor)
19 }
20
21 pub fn decode_from_reader<R: Read + Seek>(reader: R) -> Result<DecodedAudio> {
23 let mut wav_reader = hound::WavReader::new(reader)
24 .map_err(|err| Error::InvalidInput(format!("failed to parse WAV header: {err}")))?;
25
26 let spec = wav_reader.spec();
27
28 if spec.sample_format != hound::SampleFormat::Int {
29 return Err(Error::InvalidInput(format!(
30 "unsupported WAV format: {:?} (only PCM is supported)",
31 spec.sample_format
32 )));
33 }
34
35 if spec.bits_per_sample != 16 && spec.bits_per_sample != 24 {
36 return Err(Error::InvalidInput(format!(
37 "unsupported bit depth: {} (only 16-bit and 24-bit PCM supported)",
38 spec.bits_per_sample
39 )));
40 }
41
42 if spec.channels > 2 {
43 return Err(Error::InvalidInput(format!(
44 "unsupported channel count: {} (only mono and stereo supported)",
45 spec.channels
46 )));
47 }
48
49 let samples = match spec.bits_per_sample {
50 16 => Self::decode_16bit(&mut wav_reader)?,
51 24 => Self::decode_24bit(&mut wav_reader)?,
52 _ => {
53 return Err(Error::InvalidInput(format!(
54 "internal error: unhandled bit depth {}",
55 spec.bits_per_sample
56 )));
57 }
58 };
59
60 let frame_count = samples.len() / spec.channels as usize;
61 let duration_sec = if spec.sample_rate > 0 {
62 frame_count as f64 / f64::from(spec.sample_rate)
63 } else {
64 0.0
65 };
66
67 Ok(DecodedAudio {
68 samples,
69 sample_rate: spec.sample_rate,
70 channels: spec.channels as u8,
71 bit_depth: spec.bits_per_sample,
72 duration_sec,
73 })
74 }
75
76 fn decode_16bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
77 wav_reader
78 .samples::<i16>()
79 .map(|sample_result| {
80 sample_result.map(Self::normalize_i16).map_err(|err| {
81 Error::InvalidInput(format!("failed to read 16-bit sample: {err}"))
82 })
83 })
84 .collect()
85 }
86
87 fn decode_24bit<R: Read + Seek>(wav_reader: &mut hound::WavReader<R>) -> Result<Vec<f32>> {
88 wav_reader
89 .samples::<i32>()
90 .map(|sample_result| {
91 sample_result.map(Self::normalize_i24).map_err(|err| {
92 Error::InvalidInput(format!("failed to read 24-bit sample: {err}"))
93 })
94 })
95 .collect()
96 }
97
98 #[inline]
99 fn normalize_i16(sample: i16) -> f32 {
100 f32::from(sample) / 32768.0
101 }
102
103 #[inline]
104 fn normalize_i24(sample: i32) -> f32 {
105 (sample as f32) / 8_388_608.0
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use std::io::Cursor;
112
113 use super::*;
114
115 type TestResult<T> = std::result::Result<T, String>;
116
117 #[test]
118 fn test_decode_16bit_mono_44100hz() -> TestResult<()> {
119 let wav_data = create_wav_header(44100, 1, 16, 4410)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
121
122 assert_eq!(decoded.sample_rate, 44100);
123 assert_eq!(decoded.channels, 1);
124 assert_eq!(decoded.bit_depth, 16);
125 assert_eq!(decoded.samples.len(), 4410);
126 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
127 assert!(decoded.is_normalized());
128
129 Ok(())
130 }
131
132 #[test]
133 fn test_decode_16bit_stereo_48000hz() -> TestResult<()> {
134 let wav_data = create_wav_header(48000, 2, 16, 9600)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
136
137 assert_eq!(decoded.sample_rate, 48000);
138 assert_eq!(decoded.channels, 2);
139 assert_eq!(decoded.bit_depth, 16);
140 assert_eq!(decoded.samples.len(), 9600);
141 assert_eq!(decoded.frame_count(), 4800);
142 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
143 assert!(decoded.is_normalized());
144
145 Ok(())
146 }
147
148 #[test]
149 fn test_decode_24bit_mono_96000hz() -> TestResult<()> {
150 let wav_data = create_wav_header(96000, 1, 24, 9600)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
152
153 assert_eq!(decoded.sample_rate, 96000);
154 assert_eq!(decoded.channels, 1);
155 assert_eq!(decoded.bit_depth, 24);
156 assert_eq!(decoded.samples.len(), 9600);
157 assert!((decoded.duration_sec - 0.1).abs() < 1e-6);
158 assert!(decoded.is_normalized());
159
160 Ok(())
161 }
162
163 #[test]
164 fn test_decode_24bit_stereo_192000hz() -> TestResult<()> {
165 let wav_data = create_wav_header(192000, 2, 24, 19200)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
167
168 assert_eq!(decoded.sample_rate, 192000);
169 assert_eq!(decoded.channels, 2);
170 assert_eq!(decoded.bit_depth, 24);
171 assert_eq!(decoded.samples.len(), 19200);
172 assert!(decoded.is_normalized());
173
174 Ok(())
175 }
176
177 #[test]
178 fn test_decode_sine_wave_preserves_amplitude() -> TestResult<()> {
179 let wav_data = create_sine_wave_wav(44100, 1, 16, 440.0, 0.1)?;
180 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
181
182 let max_amplitude = decoded
183 .samples
184 .iter()
185 .map(|s| s.abs())
186 .fold(0.0f32, f32::max);
187 assert!(
188 (max_amplitude - 0.8).abs() < 0.05,
189 "expected max amplitude ~0.8, got {max_amplitude}"
190 );
191
192 Ok(())
193 }
194
195 #[test]
196 fn test_reject_empty_data() {
197 let result = WavDecoder::decode(&[]);
198 assert!(result.is_err());
199 }
200
201 #[test]
202 fn test_decode_zero_samples() -> TestResult<()> {
203 let wav_data = create_wav_header(44_100, 1, 16, 0)?;
204 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
205 assert_eq!(decoded.samples.len(), 0);
206 assert_eq!(decoded.frame_count(), 0);
207 Ok(())
208 }
209
210 #[test]
211 fn test_decode_single_sample() -> TestResult<()> {
212 let wav_data = create_wav_header(44_100, 1, 16, 1)?;
213 let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
214 assert_eq!(decoded.samples.len(), 1);
215 assert_eq!(decoded.frame_count(), 1);
216 Ok(())
217 }
218
219 #[test]
220 fn test_normalization_bounds_16bit() {
221 let min_i16 = WavDecoder::normalize_i16(i16::MIN);
222 let max_i16 = WavDecoder::normalize_i16(i16::MAX);
223 let zero = WavDecoder::normalize_i16(0);
224
225 assert!((-1.0..=1.0).contains(&min_i16));
226 assert!((-1.0..=1.0).contains(&max_i16));
227 assert!(zero.abs() < f32::EPSILON);
228 }
229
230 #[test]
231 fn test_frame_count_calculation() -> TestResult<()> {
232 let mono = create_wav_header(44_100, 1, 16, 4_410)?;
233 let decoded_mono = WavDecoder::decode(&mono).map_err(|e| e.to_string())?;
234 assert_eq!(decoded_mono.frame_count(), 4_410);
235
236 let stereo = create_wav_header(44_100, 2, 16, 8_820)?;
237 let decoded_stereo = WavDecoder::decode(&stereo).map_err(|e| e.to_string())?;
238 assert_eq!(decoded_stereo.frame_count(), 4_410);
239 Ok(())
240 }
241
242 #[test]
243 fn test_duration_calculation_accuracy() -> TestResult<()> {
244 let wav_data = create_wav_header(48_000, 2, 16, 96_000)?; let decoded = WavDecoder::decode(&wav_data).map_err(|e| e.to_string())?;
246 assert!((decoded.duration_sec - 1.0).abs() < 1e-6);
247 Ok(())
248 }
249
250 fn create_wav_header(
251 sample_rate: u32,
252 channels: u16,
253 bits_per_sample: u16,
254 num_samples: usize,
255 ) -> TestResult<Vec<u8>> {
256 let spec = hound::WavSpec {
257 sample_rate,
258 channels,
259 bits_per_sample,
260 sample_format: hound::SampleFormat::Int,
261 };
262
263 let mut cursor = Cursor::new(Vec::new());
264 {
265 let mut writer = hound::WavWriter::new(&mut cursor, spec)
266 .map_err(|err| format!("failed to create WAV writer: {err}"))?;
267
268 for _ in 0..num_samples {
269 match bits_per_sample {
270 16 => writer
271 .write_sample(0i16)
272 .map_err(|err| format!("failed to write 16-bit sample: {err}"))?,
273 24 => writer
274 .write_sample(0i32)
275 .map_err(|err| format!("failed to write 24-bit sample: {err}"))?,
276 _ => {
277 return Err(format!("unsupported bit depth ({bits_per_sample})"));
278 }
279 }
280 }
281
282 writer
283 .finalize()
284 .map_err(|err| format!("failed to finalize WAV: {err}"))?;
285 }
286
287 Ok(cursor.into_inner())
288 }
289
290 fn create_sine_wave_wav(
291 sample_rate: u32,
292 channels: u16,
293 bits_per_sample: u16,
294 frequency: f32,
295 duration_sec: f32,
296 ) -> TestResult<Vec<u8>> {
297 let spec = hound::WavSpec {
298 sample_rate,
299 channels,
300 bits_per_sample,
301 sample_format: hound::SampleFormat::Int,
302 };
303
304 let mut cursor = Cursor::new(Vec::new());
305 let mut writer = hound::WavWriter::new(&mut cursor, spec)
306 .map_err(|err| format!("failed to create WAV writer for sine wave: {err}"))?;
307
308 let num_samples = (sample_rate as f32 * duration_sec) as usize;
309 let amplitude = match bits_per_sample {
310 16 => 32767.0 * 0.8,
311 24 => 8_388_607.0 * 0.8,
312 _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
313 };
314
315 for i in 0..num_samples {
316 let t = i as f32 / sample_rate as f32;
317 let sample_f32 = amplitude * (2.0 * std::f32::consts::PI * frequency * t).sin();
318
319 for _ in 0..channels {
320 match bits_per_sample {
321 16 => writer
322 .write_sample(sample_f32 as i16)
323 .map_err(|err| format!("failed to write sine sample: {err}"))?,
324 24 => writer
325 .write_sample(sample_f32 as i32)
326 .map_err(|err| format!("failed to write sine sample: {err}"))?,
327 _ => return Err(format!("unsupported bit depth ({bits_per_sample})")),
328 }
329 }
330 }
331
332 writer
333 .finalize()
334 .map_err(|err| format!("failed to finalize sine wave WAV: {err}"))?;
335 Ok(cursor.into_inner())
336 }
337}