riff_wave/
reader.rs

1// riff-wave -- Basic support for reading and writing wave PCM files.
2// Copyright (c) 2016 Kevin Brothaler and the riff-wave project authors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// A copy of the License has been included in the root of the repository.
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use std::error;
14use std::fmt;
15use std::io;
16use std::io::{Read, Seek, SeekFrom};
17use std::result;
18
19use byteorder::{LittleEndian, ReadBytesExt};
20
21use super::{Format, PcmFormat};
22use super::{FORMAT_EXTENDED, FORMAT_UNCOMPRESSED_PCM};
23
24// MARK: Error types
25
26/// Represents an error that occurred while reading a wave file.
27#[derive(Debug)]
28pub enum ReadError {
29    /// The file format is incorrect or unsupported.
30    Format(ReadErrorKind),
31    /// An IO error occurred.
32    Io(io::Error),
33}
34
35/// Represents a result when reading a wave file.
36pub type ReadResult<T> = result::Result<T, ReadError>;
37
38/// Represents a file format error, when the wave file is incorrect or unsupported.
39#[derive(Debug)]
40pub enum ReadErrorKind {
41    /// The file does not start with a "RIFF" tag and chunk size.
42    NotARiffFile,
43    /// The file doesn't continue with "WAVE" after the RIFF chunk header.
44    NotAWaveFile,
45    /// This file is not an uncompressed PCM wave file. Only uncompressed files are supported.
46    NotAnUncompressedPcmWaveFile(u16),
47    /// This file is missing header data and can't be parsed.
48    FmtChunkTooShort,
49    /// The number of channels is zero, which is invalid.
50    NumChannelsIsZero,
51    /// The sample rate is zero, which is invalid.
52    SampleRateIsZero,
53    /// Only 8-bit, 16-bit, 24-bit and 32-bit PCM files are supported.
54    UnsupportedBitsPerSample(u16),
55    /// We don't currently support extended PCM wave files where the actual
56    /// bits per sample is less than the container size.
57    InvalidBitsPerSample(u16, u16),
58}
59
60impl ReadErrorKind {
61    fn to_string(&self) -> &str {
62        match *self {
63            ReadErrorKind::NotARiffFile => "not a RIFF file",
64            ReadErrorKind::NotAWaveFile => "not a WAVE file",
65            ReadErrorKind::NotAnUncompressedPcmWaveFile(_) => "Not an uncompressed wave file",
66            ReadErrorKind::FmtChunkTooShort => "fmt_ chunk is too short",
67            ReadErrorKind::NumChannelsIsZero => "Number of channels is zero",
68            ReadErrorKind::SampleRateIsZero => "Sample rate is zero",
69            ReadErrorKind::UnsupportedBitsPerSample(_) => "Unsupported bits per sample",
70            ReadErrorKind::InvalidBitsPerSample(_, _) => {
71                "A bits per sample of less than the container size is not currently supported"
72            }
73        }
74    }
75}
76
77impl fmt::Display for ReadErrorKind {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        write!(f, "{}", self.to_string())
80    }
81}
82
83impl fmt::Display for ReadError {
84    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85        match *self {
86            ReadError::Format(ref err_kind) => write!(f, "Format error: {}", err_kind),
87            ReadError::Io(ref err) => write!(f, "IO error: {}", err),
88        }
89    }
90}
91
92impl error::Error for ReadError {
93    fn cause(&self) -> Option<&dyn error::Error> {
94        match *self {
95            ReadError::Format(_) => None,
96            ReadError::Io(ref err) => Some(err),
97        }
98    }
99}
100
101impl From<io::Error> for ReadError {
102    fn from(err: io::Error) -> ReadError {
103        ReadError::Io(err)
104    }
105}
106
107// MARK: Validation and parsing functions
108
109fn validate_pcm_format(format: u16) -> ReadResult<Format> {
110    match format {
111        FORMAT_UNCOMPRESSED_PCM => Ok(Format::UncompressedPcm),
112        FORMAT_EXTENDED => Ok(Format::Extended),
113        _ => Err(ReadError::Format(
114            ReadErrorKind::NotAnUncompressedPcmWaveFile(format),
115        )),
116    }
117}
118
119fn validate_pcm_subformat(sub_format: u16) -> ReadResult<()> {
120    match sub_format {
121        FORMAT_UNCOMPRESSED_PCM => Ok(()),
122        _ => Err(ReadError::Format(
123            ReadErrorKind::NotAnUncompressedPcmWaveFile(sub_format),
124        )),
125    }
126}
127
128fn validate_fmt_header_is_large_enough(size: u32, min_size: u32) -> ReadResult<()> {
129    if size < min_size {
130        Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort))
131    } else {
132        Ok(())
133    }
134}
135
136trait ReadWaveExt: Read + Seek {
137    fn read_wave_header(&mut self) -> ReadResult<PcmFormat> {
138        self.validate_is_riff_file()?;
139        self.validate_is_wave_file()?;
140
141        // The fmt subchunk should be at least 14 bytes for wave files, and 16 bytes
142        // for PCM wave files. The check is done twice so an appropriate error message
143        // can be returned depending on the type of file.
144        let fmt_subchunk_size = self.skip_until_subchunk(b"fmt ")?;
145        validate_fmt_header_is_large_enough(fmt_subchunk_size, 14)?;
146        let format = validate_pcm_format(self.read_u16::<LittleEndian>()?)?;
147        validate_fmt_header_is_large_enough(fmt_subchunk_size, 16)?;
148
149        let num_channels = self.read_u16::<LittleEndian>()?;
150        let sample_rate = self.read_u32::<LittleEndian>()?;
151        let _ = self.read_u32::<LittleEndian>()?;                   // Byte rate, ignored.
152        let _ = self.read_u16::<LittleEndian>()?;                   // Block align, ignored.
153        let bits_per_sample = self.read_u16::<LittleEndian>()?;
154
155        match format {
156            Format::UncompressedPcm => self.skip_over_remainder(16, fmt_subchunk_size)?,
157            Format::Extended => self.validate_extended_format(bits_per_sample)?,
158        }
159
160        if num_channels == 0 {
161            return Err(ReadError::Format(ReadErrorKind::NumChannelsIsZero));
162        } else if sample_rate == 0 {
163            return Err(ReadError::Format(ReadErrorKind::SampleRateIsZero));
164        } else if bits_per_sample != 8 && bits_per_sample != 16 
165        	   && bits_per_sample != 24 && bits_per_sample != 32 {
166            return Err(ReadError::Format(
167            	ReadErrorKind::UnsupportedBitsPerSample(bits_per_sample)));
168        }
169
170        Ok(PcmFormat {
171            num_channels: num_channels,
172            sample_rate: sample_rate,
173            bits_per_sample: bits_per_sample,
174        })
175    }
176
177    fn validate_extended_format(&mut self, bits_per_sample: u16) -> ReadResult<()> {
178        let extra_info_size = self.read_u16::<LittleEndian>()?;
179        validate_fmt_header_is_large_enough(extra_info_size.into(), 22)?;
180
181        let sample_info = self.read_u16::<LittleEndian>()?;
182        let _ = self.read_u32::<LittleEndian>()?;                   // Channel mask, ignored.
183        validate_pcm_subformat(self.read_u16::<LittleEndian>()?)?;
184        self.skip_over_remainder(8, extra_info_size.into())?;       // Ignore the rest of the GUID.
185
186        if sample_info != bits_per_sample {
187            // We don't currently support wave files where the bits per sample
188            // doesn't entirely fill the allocated bits per sample.
189            return Err(ReadError::Format(ReadErrorKind::InvalidBitsPerSample(
190                bits_per_sample,
191                sample_info,
192            )));
193        }
194
195        Ok(())
196    }
197
198    fn skip_over_remainder(&mut self, read_so_far: u32, size: u32) -> ReadResult<()> {
199        if read_so_far < size {
200            let remainder = size - read_so_far;
201            self.seek(SeekFrom::Current(remainder.into()))?;
202        }
203        Ok(())
204    }
205
206    fn validate_is_riff_file(&mut self) -> ReadResult<()> {
207        self.validate_tag(b"RIFF", ReadErrorKind::NotARiffFile)?;
208        // The next four bytes represent the chunk size. We're not going to
209        // validate it, so that we can still try to read files that might have
210        // an incorrect chunk size, so let's skip over it.
211        let _ = self.read_chunk_size()?;
212        Ok(())
213    }
214
215    fn validate_is_wave_file(&mut self) -> ReadResult<()> {
216        self.validate_tag(b"WAVE", ReadErrorKind::NotAWaveFile)?;
217        Ok(())
218    }
219
220    fn validate_tag(&mut self, expected_tag: &[u8; 4], err_kind: ReadErrorKind) -> ReadResult<()> {
221        let tag = self.read_tag()?;
222        if &tag != expected_tag {
223            return Err(ReadError::Format(err_kind));
224        }
225        Ok(())
226    }
227
228    fn skip_until_subchunk(&mut self, matching_tag: &[u8; 4]) -> ReadResult<u32> {
229        loop {
230            let tag = self.read_tag()?;
231            let subchunk_size = self.read_chunk_size()?;
232
233            if &tag == matching_tag {
234                return Ok(subchunk_size);
235            } else {
236                self.seek(SeekFrom::Current(subchunk_size.into()))?;
237            }
238        }
239    }
240
241    fn read_tag(&mut self) -> ReadResult<[u8; 4]> {
242        let mut tag: [u8; 4] = [0; 4];
243        self.read_exact(&mut tag)?;
244        Ok(tag)
245    }
246
247    fn read_chunk_size(&mut self) -> ReadResult<u32> {
248        Ok(self.read_u32::<LittleEndian>()?)
249    }
250}
251
252impl<T> ReadWaveExt for T where T: Read + Seek {}
253
254/// Helper struct that takes ownership of a reader and can be used to read data
255/// from a PCM wave file.
256pub struct WaveReader<T>
257where
258    T: Read + Seek,
259{
260    /// Represents the PCM format for this wave file.
261    pub pcm_format: PcmFormat,
262
263    // The underlying reader that we'll use to read data.
264    reader: T,
265}
266
267// TODO what should we do if an incorrect read_* method is called? Return the
268// error in the result? Also, the read methods might need to return optionals
269// instead so we have a better way of flagging EOF.
270impl<T> WaveReader<T>
271where
272    T: Read + Seek,
273{
274    /// Returns a new wave reader for the given reader.
275    pub fn new(mut reader: T) -> ReadResult<WaveReader<T>> {
276        let pcm_format = reader.read_wave_header()?;
277        let _ = reader.skip_until_subchunk(b"data")?;
278
279        Ok(WaveReader {
280            pcm_format: pcm_format,
281            reader: reader,
282        })
283    }
284
285    /// Reads a single sample as an unsigned 8-bit value.
286    pub fn read_sample_u8(&mut self) -> io::Result<u8> {
287        self.read_sample(|reader| reader.read_u8())
288    }
289
290    /// Reads a single sample as a signed 16-bit value.
291    pub fn read_sample_i16(&mut self) -> io::Result<i16> {
292        self.read_sample(|reader| reader.read_i16::<LittleEndian>())
293    }
294
295    /// Reads a single sample as a signed 24-bit value. The value will be padded
296    /// to fit in a 32-bit buffer.
297    pub fn read_sample_i24(&mut self) -> io::Result<i32> {
298        self.read_sample(|reader| reader.read_int::<LittleEndian>(3))
299            .map(|x| x as i32)
300    }
301
302    /// Reads a single sample as a signed 32-bit value.
303    pub fn read_sample_i32(&mut self) -> io::Result<i32> {
304        self.read_sample(|reader| reader.read_i32::<LittleEndian>())
305    }
306
307    fn read_sample<F, S>(&mut self, read_data: F) -> io::Result<S>
308    where
309        F: Fn(&mut T) -> io::Result<S>,
310    {
311        Ok(read_data(&mut self.reader)?)
312    }
313
314    /// Consumes this reader, returning the underlying value.
315    pub fn into_inner(self) -> T {
316        self.reader
317    }
318}
319
320// MARK: Tests
321
322#[cfg(test)]
323mod tests {
324    use std::fmt::Debug;
325    use std::io;
326    use std::io::{Cursor, Read};
327
328    use byteorder::{ByteOrder, LittleEndian};
329
330    use super::super::{Format, PcmFormat};
331    use super::super::{FORMAT_EXTENDED, FORMAT_UNCOMPRESSED_PCM};
332    use super::{validate_fmt_header_is_large_enough, validate_pcm_format, validate_pcm_subformat};
333    use super::{ReadError, ReadErrorKind, ReadWaveExt, WaveReader};
334
335    // RIFF header tests
336
337    #[test]
338    fn test_validate_is_riff_file_ok() {
339        let mut data = Cursor::new(b"RIFF    ");
340        assert_matches!(Ok(()), data.validate_is_riff_file());
341    }
342
343    #[test]
344    fn test_validate_is_riff_file_err_incomplete() {
345        let mut data = Cursor::new(b"RIF     ");
346        assert_matches!(
347            Err(ReadError::Format(ReadErrorKind::NotARiffFile)),
348            data.validate_is_riff_file()
349        );
350    }
351
352    #[test]
353    fn test_validate_is_riff_file_err_something_else() {
354        let mut data = Cursor::new(b"JPEG     ");
355        assert_matches!(
356            Err(ReadError::Format(ReadErrorKind::NotARiffFile)),
357            data.validate_is_riff_file()
358        );
359    }
360
361    // Wave tag tests
362
363    #[test]
364    fn test_validate_is_wave_file_ok() {
365        let mut data = Cursor::new(b"WAVE");
366        assert_matches!(Ok(()), data.validate_is_wave_file());
367    }
368
369    #[test]
370    fn test_validate_is_wave_file_err_incomplete() {
371        let mut data = Cursor::new(b"WAV ");
372        assert_matches!(
373            Err(ReadError::Format(ReadErrorKind::NotAWaveFile)),
374            data.validate_is_wave_file()
375        );
376    }
377
378    #[test]
379    fn test_validate_is_wave_file_err_something_else() {
380        let mut data = Cursor::new(b"JPEG");
381        assert_matches!(
382            Err(ReadError::Format(ReadErrorKind::NotAWaveFile)),
383            data.validate_is_wave_file()
384        );
385    }
386
387    // Skipping to subchunk tests
388    // After reading in the file header, we also need to read in the "fmt " subchunk.
389    // The file might contain other subchunks that we don't currently support, so
390    // we'll need to skip over them.
391
392    #[test]
393    fn test_skip_until_subchunk() {
394        // A size of 0.
395        let mut data = Cursor::new(b"RIFF    WAVEfmt \x00\x00\x00\x00");
396        let _ = data.validate_is_riff_file();
397        let _ = data.validate_is_wave_file();
398        let size = data.skip_until_subchunk(b"fmt ");
399        assert_eq!(0, size.unwrap());
400    }
401
402    #[test]
403    fn test_skip_until_second_subchunk() {
404        // A size of 0.
405        let mut data = Cursor::new(b"RIFF    WAVEfmt \x00\x00\x00\x00data\x00\x00\x00\x00");
406        let _ = data.validate_is_riff_file();
407        let _ = data.validate_is_wave_file();
408        let _ = data.skip_until_subchunk(b"fmt ");
409        let size = data.skip_until_subchunk(b"data");
410        assert_eq!(0, size.unwrap());
411    }
412
413    #[test]
414    #[should_panic]
415    fn test_cant_read_first_subchunk_after_second() {
416        // A size of 0.
417        let mut data = Cursor::new(b"RIFF    WAVEdata\x00\x00\x00\x00fmt \x00\x00\x00\x00");
418        let _ = data.validate_is_riff_file();
419        let _ = data.validate_is_wave_file();
420        let _ = data.skip_until_subchunk(b"fmt ");
421        let size = data.skip_until_subchunk(b"data");
422        assert_eq!(0, size.unwrap());
423    }
424
425    // Wave format validation tests. We only support uncompressed PCM files,
426    // which can be in the "canonical" format or an "extended" format.
427
428    #[test]
429    fn test_validate_pcm_format_ok_uncompressed() {
430        assert_matches!(
431            Ok(Format::UncompressedPcm),
432            validate_pcm_format(FORMAT_UNCOMPRESSED_PCM)
433        );
434    }
435
436    #[test]
437    fn test_validate_pcm_format_ok_extended() {
438        assert_matches!(Ok(Format::Extended), validate_pcm_format(FORMAT_EXTENDED));
439    }
440
441    #[test]
442    fn test_validate_pcm_format_err_not_uncompressed() {
443        assert_matches!(
444            Err(ReadError::Format(
445                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
446            )),
447            validate_pcm_format(12345)
448        );
449    }
450
451    // Wave subformat validation tests. We only support uncompressed PCM files.
452
453    #[test]
454    fn test_validate_pcm_subformat_ok_uncompressed() {
455        assert_matches!(Ok(()), validate_pcm_subformat(FORMAT_UNCOMPRESSED_PCM));
456    }
457
458    #[test]
459    fn test_validate_pcm_subformat_err_extended_format_value_not_valid_for_subformat() {
460        assert_matches!(
461            Err(ReadError::Format(
462                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
463            )),
464            validate_pcm_subformat(FORMAT_EXTENDED)
465        );
466    }
467
468    #[test]
469    fn test_validate_pcm_subformat_err_not_uncompressed() {
470        assert_matches!(
471            Err(ReadError::Format(
472                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
473            )),
474            validate_pcm_subformat(12345)
475        );
476    }
477
478    // Validation tests for ensuring the header is large enough to read in the data we need.
479
480    #[test]
481    fn test_validate_fmt_header_is_large_enough_matches() {
482        assert_matches!(Ok(()), validate_fmt_header_is_large_enough(16, 16));
483    }
484
485    #[test]
486    fn test_validate_fmt_header_is_large_enough_more_than_we_need() {
487        assert_matches!(Ok(()), validate_fmt_header_is_large_enough(22, 16));
488    }
489
490    #[test]
491    fn test_validate_fmt_header_is_large_enough_too_small() {
492        assert_matches!(
493            Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort)),
494            validate_fmt_header_is_large_enough(14, 16)
495        );
496    }
497
498    // Wave header validation tests.
499
500    #[test]
501    fn test_validate_pcm_header_missing_fmt_chunk() {
502        let mut data = Cursor::new(b"RIFF    WAVE");
503        assert_matches!(Err(ReadError::Io(ref err)) if err.kind() == io::ErrorKind::UnexpectedEof,
504            			data.read_wave_header());
505    }
506
507    #[test]
508    fn test_validate_pcm_header_fmt_chunk_too_small() {
509        let mut data = Cursor::new(
510            b"RIFF    WAVE\
511                                     fmt \x0C\x00\x00\x00",
512        );
513        assert_matches!(
514            Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort)),
515            data.read_wave_header()
516        );
517    }
518
519    #[test]
520    fn test_validate_pcm_header_fmt_chunk_too_small_pcm() {
521        let mut data = Cursor::new(
522            b"RIFF    WAVE\
523                                     fmt \x0E\x00\x00\x00\
524                                     \x01\x00",
525        );
526        assert_matches!(
527            Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort)),
528            data.read_wave_header()
529        );
530    }
531
532    #[test]
533    fn test_validate_pcm_header_not_pcm_format() {
534        let mut data = Cursor::new(
535            b"RIFF    WAVE\
536                                     fmt \x0E\x00\x00\x00\
537                                     \x02\x00",
538        );
539        assert_matches!(
540            Err(ReadError::Format(
541                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
542            )),
543            data.read_wave_header()
544        );
545    }
546
547    #[test]
548    fn test_validate_pcm_header_dont_accept_zero_channels() {
549        let mut data = Cursor::new(
550            b"RIFF    WAVE\
551                                     fmt \x10\x00\x00\x00\
552                                     \x01\x00\
553                                     \x00\x00\
554                                     \x00\x00\x00\x00\
555                                     \x00\x00\x00\x00\
556                                     \x00\x00\
557                                     \x00\x00" as &[u8],
558        );
559        assert_matches!(
560            Err(ReadError::Format(ReadErrorKind::NumChannelsIsZero)),
561            data.read_wave_header()
562        );
563    }
564
565    #[test]
566    fn test_validate_pcm_header_dont_accept_zero_sample_rate() {
567        let mut data = Cursor::new(
568            b"RIFF    WAVE\
569                                     fmt \x10\x00\x00\x00\
570                                     \x01\x00\
571                                     \x01\x00\
572                                     \x00\x00\x00\x00\
573                                     \x00\x00\x00\x00\
574                                     \x00\x00\
575                                     \x00\x00" as &[u8],
576        );
577        assert_matches!(
578            Err(ReadError::Format(ReadErrorKind::SampleRateIsZero)),
579            data.read_wave_header()
580        );
581    }
582
583    // Standard wave files
584
585    #[test]
586    fn test_validate_pcm_header_validate_bits_per_sample_standard() {
587        let mut vec = Vec::new();
588        vec.extend_from_slice(
589            b"RIFF    WAVE\
590	                            fmt \x10\x00\x00\x00\
591	                            \x01\x00\
592	                            \x01\x00\
593	                            \x44\xAC\x00\x00\
594	                            \x00\x00\x00\x00\
595	                            \x00\x00\
596	                            \x08\x00",
597        );
598
599        let mut cursor = Cursor::new(vec.clone());
600        assert_matches!(Ok(_), cursor.read_wave_header());
601
602        vec[34] = 16;
603        let mut cursor = Cursor::new(vec.clone());
604        assert_matches!(Ok(_), cursor.read_wave_header());
605
606        vec[34] = 24;
607        let mut cursor = Cursor::new(vec.clone());
608        assert_matches!(Ok(_), cursor.read_wave_header());
609
610        vec[34] = 32;
611        let mut cursor = Cursor::new(vec.clone());
612        assert_matches!(Ok(_), cursor.read_wave_header());
613
614        vec[34] = 48;
615        let mut cursor = Cursor::new(vec.clone());
616        assert_matches!(
617            Err(ReadError::Format(ReadErrorKind::UnsupportedBitsPerSample(
618                _
619            ))),
620            cursor.read_wave_header()
621        );
622
623        vec[34] = 0;
624        let mut cursor = Cursor::new(vec.clone());
625        assert_matches!(
626            Err(ReadError::Format(ReadErrorKind::UnsupportedBitsPerSample(
627                _
628            ))),
629            cursor.read_wave_header()
630        );
631    }
632
633    #[test]
634    fn test_validate_pcm_header_8bit_mono_example_standard() {
635        let mut vec = Vec::new();
636        vec.extend_from_slice(
637            b"RIFF    WAVE\
638	                            fmt \x10\x00\x00\x00\
639	                            \x01\x00\
640	                            \x01\x00\
641	                            \x44\xAC\x00\x00\
642	                            \x00\x00\x00\x00\
643	                            \x00\x00\
644	                            \x08\x00",
645        );
646        let mut cursor = Cursor::new(vec.clone());
647
648        assert_matches!(
649            Ok(PcmFormat {
650                num_channels: 1,
651                sample_rate: 44100,
652                bits_per_sample: 8,
653            }),
654            cursor.read_wave_header()
655        );
656    }
657
658    #[test]
659    fn test_validate_pcm_header_8bit_mono_example_standard_with_extra_cb_data() {
660        let mut vec = Vec::new();
661        vec.extend_from_slice(
662            b"RIFF    WAVE\
663	                            fmt \x10\x00\x00\x00\
664	                            \x01\x00\
665	                            \x01\x00\
666	                            \x44\xAC\x00\x00\
667	                            \x00\x00\x00\x00\
668	                            \x00\x00\
669	                            \x08\x00\
670	                            \x00\x00\x00\x00",
671        );
672        let mut cursor = Cursor::new(vec.clone());
673
674        assert_matches!(
675            Ok(PcmFormat {
676                num_channels: 1,
677                sample_rate: 44100,
678                bits_per_sample: 8,
679            }),
680            cursor.read_wave_header()
681        );
682    }
683
684    // Extended format
685
686    #[test]
687    fn test_validate_pcm_header_extended_format_too_small() {
688        let mut vec = Vec::new();
689        vec.extend_from_slice(
690            b"RIFF    WAVE\
691		                        fmt \x10\x00\x00\x00\
692		                        \xFE\xFF\
693		                        \x01\x00\
694		                        \x44\xAC\x00\x00\
695		                        \x00\x00\x00\x00\
696		                        \x00\x00\
697		                        \x08\x00\
698		                        \x02\x00\x00\x00",
699        );
700        let mut cursor = Cursor::new(vec.clone());
701
702        assert_matches!(
703            Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort)),
704            cursor.read_wave_header()
705        );
706    }
707
708    #[test]
709    fn test_validate_pcm_header_extended_format_not_pcm_format() {
710        let mut vec = Vec::new();
711        vec.extend_from_slice(
712            b"RIFF    WAVE\
713		                        fmt \x10\x00\x00\x00\
714		                        \xFE\xFF\
715		                        \x01\x00\
716		                        \x44\xAC\x00\x00\
717		                        \x00\x00\x00\x00\
718		                        \x00\x00\
719		                        \x08\x00\
720		                        \x16\x00\
721		                        \x08\x00\
722		                        \x00\x00\x00\x00\
723		                        \x09\x00\x00\x00\x00\x00\x10\x00\x80\x00\x00\xAA\x00\x38\x9B\x71",
724        );
725        let mut cursor = Cursor::new(vec.clone());
726
727        assert_matches!(
728            Err(ReadError::Format(
729                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
730            )),
731            cursor.read_wave_header()
732        );
733    }
734
735    #[test]
736    fn test_validate_pcm_header_extended_format_sample_rates_dont_match() {
737        let mut vec = Vec::new();
738        vec.extend_from_slice(
739            b"RIFF    WAVE\
740		                        fmt \x10\x00\x00\x00\
741		                        \xFE\xFF\
742		                        \x01\x00\
743		                        \x44\xAC\x00\x00\
744		                        \x00\x00\x00\x00\
745		                        \x00\x00\
746		                        \x08\x00\
747		                        \x16\x00\
748		                        \x10\x00\
749		                        \x00\x00\x00\x00\
750		                        \x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
751        );
752        let mut cursor = Cursor::new(vec.clone());
753
754        assert_matches!(
755            Err(ReadError::Format(ReadErrorKind::InvalidBitsPerSample(_, _))),
756            cursor.read_wave_header()
757        );
758    }
759
760    #[test]
761    fn test_validate_pcm_header_extended_format_sample_rates_ok() {
762        let mut vec = Vec::new();
763        vec.extend_from_slice(
764            b"RIFF    WAVE\
765	                            fmt \x10\x00\x00\x00\
766	                            \xFE\xFF\
767	                            \x01\x00\
768	                            \x44\xAC\x00\x00\
769	                            \x00\x00\x00\x00\
770	                            \x00\x00\
771	                            \x08\x00\
772	                            \x16\x00\
773	                            \x08\x00\
774	                            \x00\x00\x00\x00\
775	                            \x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
776        );
777        let mut cursor = Cursor::new(vec.clone());
778
779        assert_matches!(Ok(_), cursor.read_wave_header());
780    }
781
782    #[test]
783    fn test_validate_pcm_header_8bit_mono_example_extended() {
784        let mut vec = Vec::new();
785        vec.extend_from_slice(
786            b"RIFF    WAVE\
787		                        fmt \x10\x00\x00\x00\
788		                        \xFE\xFF\
789		                        \x01\x00\
790		                        \x44\xAC\x00\x00\
791		                        \x00\x00\x00\x00\
792		                        \x00\x00\
793		                        \x08\x00\
794		                        \x16\x00\
795		                        \x08\x00\
796		                        \x00\x00\x00\x00\
797		                        \x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
798        );
799        let mut cursor = Cursor::new(vec.clone());
800
801        assert_matches!(
802            Ok(PcmFormat {
803                num_channels: 1,
804                sample_rate: 44100,
805                bits_per_sample: 8,
806            }),
807            cursor.read_wave_header()
808        );
809    }
810
811    #[test]
812    fn test_validate_extended_format_too_short() {
813        // Extended size is less than 22 -- should fail.
814        let mut data = Cursor::new(b"\x0F\x00\x00\x00");
815        assert_matches!(
816            Err(ReadError::Format(ReadErrorKind::FmtChunkTooShort)),
817            data.validate_extended_format(16)
818        );
819    }
820
821    #[test]
822    fn test_validate_extended_format_not_pcm() {
823        let mut data = Cursor::new(
824            b"\x16\x00\
825                                     \x10\x00\
826                                     \x00\x00\x00\x00\
827                                     \xFF\xFF\x00\x00\x00\x00\x00\x00\
828                                     \x00\x00\x00\x00\x00\x00\x00\x00",
829        );
830        assert_matches!(
831            Err(ReadError::Format(
832                ReadErrorKind::NotAnUncompressedPcmWaveFile(_)
833            )),
834            data.validate_extended_format(16)
835        );
836    }
837
838    #[test]
839    fn test_validate_extended_format_sample_rate_doesnt_match() {
840        let mut data = Cursor::new(
841            b"\x16\x00\
842                                     \x0F\x00\
843                                     \x00\x00\x00\x00\
844                                     \x01\x00\x00\x00\x00\x00\x00\x00\
845                                     \x00\x00\x00\x00\x00\x00\x00\x00",
846        );
847        assert_matches!(
848            Err(ReadError::Format(ReadErrorKind::InvalidBitsPerSample(_, _))),
849            data.validate_extended_format(16)
850        );
851    }
852
853    #[test]
854    fn test_validate_extended_format_sample_rate_ok() {
855        let mut data = Cursor::new(
856            b"\x16\x00\
857                                     \x10\x00\
858                                     \x00\x00\x00\x00\
859                                     \x01\x00\x00\x00\x00\x00\x00\x00\
860                                     \x00\x00\x00\x00\x00\x00\x00\x00",
861        );
862        assert_matches!(Ok(()), data.validate_extended_format(16));
863    }
864
865    // Misc tests
866
867    #[test]
868    fn test_skip_over_remainder() {
869        let mut data = Cursor::new(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ");
870        let mut buf = [0u8; 4];
871
872        let _ = data.skip_over_remainder(0, 0);
873        let _ = data.read(&mut buf);
874        assert_eq!(b"ABCD", &buf);
875
876        let _ = data.skip_over_remainder(4, 4);
877        let _ = data.read(&mut buf);
878        assert_eq!(b"EFGH", &buf);
879
880        let _ = data.skip_over_remainder(0, 4);
881        let _ = data.read(&mut buf);
882        assert_eq!(b"MNOP", &buf);
883
884        let _ = data.skip_over_remainder(4, 8);
885        let _ = data.read(&mut buf);
886        assert_eq!(b"UVWX", &buf);
887    }
888
889    // Wave reader tests
890
891    #[test]
892    fn test_reading_data_from_data_chunk_u8() {
893        let raw_data = b"\x00\x01\x02\x03\
894                         \x04\x05\x06\x07\
895                         \x08\x09\x0A\x0B\
896                         \x0C\x0D\x0E\x0F";
897
898        let expected_results = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
899
900        test_reading_data_from_data_chunk(
901            raw_data,
902            &expected_results,
903            1,
904            WaveReader::read_sample_u8,
905        );
906    }
907
908    #[test]
909    fn test_reading_data_from_data_chunk_i16() {
910        let raw_data = b"\x00\x01\x01\x01\
911                         \x02\x01\x03\x01\
912                         \x04\x01\x05\x01\
913                         \x06\x01\x07\x01";
914        let expected_results = [256, 257, 258, 259, 260, 261, 262, 263];
915
916        test_reading_data_from_data_chunk(
917            raw_data,
918            &expected_results,
919            2,
920            WaveReader::read_sample_i16,
921        );
922    }
923
924    #[test]
925    fn test_reading_data_from_data_chunk_i24() {
926        let raw_data = b"\x01\x01\x02\
927                         \x02\x01\x02\
928                         \x03\x01\x02\
929                         \x04\x01\x02\
930                         \x05\x01\x02";
931        let expected_results = [
932            65536 * 2 + 256 + 1 + 0,
933            65536 * 2 + 256 + 1 + 1,
934            65536 * 2 + 256 + 1 + 2,
935            65536 * 2 + 256 + 1 + 3,
936            65536 * 2 + 256 + 1 + 4,
937        ];
938
939        test_reading_data_from_data_chunk(
940            raw_data,
941            &expected_results,
942            3,
943            WaveReader::read_sample_i24,
944        );
945    }
946
947    #[test]
948    fn test_reading_data_from_data_chunk_i32() {
949        let raw_data = b"\x00\x01\x02\x03\
950                         \x04\x05\x06\x07\
951                         \x08\x09\x0A\x0B\
952                         \x0C\x0D\x0E\x0F";
953        let expected_results = [50462976, 117835012, 185207048, 252579084];
954
955        test_reading_data_from_data_chunk(
956            raw_data,
957            &expected_results,
958            4,
959            WaveReader::read_sample_i32,
960        );
961    }
962
963    fn test_reading_data_from_data_chunk<S, F>(
964        raw_data: &[u8],
965        expected_results: &[S],
966        bytes_per_num: u16,
967        read_sample: F,
968    ) where
969        S: PartialEq + Debug,
970        F: Fn(&mut WaveReader<Cursor<Vec<u8>>>) -> io::Result<S>,
971    {
972        let vec = create_standard_in_memory_riff_wave(1, 44100, bytes_per_num * 8, raw_data);
973        let cursor = Cursor::new(vec.clone());
974        let mut wave_reader = WaveReader::new(cursor).unwrap();
975
976        for expected in expected_results {
977            let next_sample = read_sample(&mut wave_reader).unwrap();
978            assert_eq!(*expected, next_sample);
979        }
980    }
981
982    trait VecExt {
983        fn extend_with_header_for_standard_wave(&mut self, data_size: usize);
984
985        fn extend_with_standard_fmt_subchunk(
986            &mut self,
987            num_channels: u16,
988            sample_rate: u32,
989            bits_per_sample: u16,
990        );
991
992        fn extend_with_data_subchunk(&mut self, raw_data: &[u8]);
993
994        fn extend_with_u16(&mut self, val: u16);
995
996        fn extend_with_u32(&mut self, val: u32);
997    }
998
999    impl VecExt for Vec<u8> {
1000        fn extend_with_header_for_standard_wave(&mut self, data_size: usize) {
1001            self.extend_from_slice(b"RIFF");                    // Identifier
1002            self.extend_with_u32(36 + data_size as u32);        // File size minus eight bytes
1003            self.extend_from_slice(b"WAVE");                    // Identifier
1004        }
1005
1006        fn extend_with_standard_fmt_subchunk(
1007            &mut self,
1008            num_channels: u16,
1009            sample_rate: u32,
1010            bits_per_sample: u16,
1011        ) {
1012            self.extend_from_slice(b"fmt \x10\x00\x00\x00");    // "fmt " chunk and size
1013            self.extend_from_slice(b"\x01\x00");                // PCM Format
1014            self.extend_with_u16(num_channels);                 // Number of channels
1015            self.extend_with_u32(sample_rate);                  // Sample rate
1016
1017            let bytes_per_sample = bits_per_sample / 8;
1018            let block_align = num_channels * bytes_per_sample;
1019            let byte_rate = block_align as u32 * sample_rate;
1020
1021            self.extend_with_u32(byte_rate);                    // Byte rate
1022            self.extend_with_u16(block_align);                  // Block align
1023            self.extend_with_u16(bits_per_sample);              // Bits per sample
1024        }
1025
1026        fn extend_with_data_subchunk(&mut self, raw_data: &[u8]) {
1027            self.extend_from_slice(b"data");                    // Start of "data" subchunk.
1028            self.extend_with_u32(raw_data.len() as u32);        // Size of data subchunk.
1029            self.extend_from_slice(raw_data);                   // The actual data, as bytes.
1030        }
1031
1032        fn extend_with_u16(&mut self, val: u16) {
1033            let mut buf_16: [u8; 2] = [0; 2];
1034            LittleEndian::write_u16(&mut buf_16, val);
1035            self.extend_from_slice(&buf_16);
1036        }
1037
1038        fn extend_with_u32(&mut self, val: u32) {
1039            let mut buf_32: [u8; 4] = [0; 4];
1040            LittleEndian::write_u32(&mut buf_32, val);
1041            self.extend_from_slice(&buf_32);
1042        }
1043    }
1044
1045    fn create_standard_in_memory_riff_wave(
1046        num_channels: u16,
1047        sample_rate: u32,
1048        bits_per_sample: u16,
1049        data: &[u8],
1050    ) -> Vec<u8> {
1051        let mut vec = Vec::new();
1052
1053        vec.extend_with_header_for_standard_wave(data.len());
1054        vec.extend_with_standard_fmt_subchunk(num_channels, sample_rate, bits_per_sample);
1055        vec.extend_with_data_subchunk(data);
1056
1057        vec
1058    }
1059}