1use std::fs::File;
4use std::io::{Cursor, Read};
5use crate::header::*;
6
7use thiserror::Error;
8
9#[derive(Error, Debug)]
10pub enum DecodeError {
11    #[error("Invalid chunk tag, expected '{expected:?}', found '{found:?}'")]
12    InvalidTag {
13        expected: &'static str,
14        found: String,
15    },
16    #[error("Invalid chunk attribute, attribute {attribute:?} must be greater than {expected:?}, found {found:?} instead")]
17    InvalidChunkAttribute {
18        attribute: &'static str,
19        expected: u32,
20        found: u32,
21    },
22    #[error("Invalid chunk attribute, attribute {attribute} must be on of {expected:?}, found 0x{found:02x}")]
23    InvalidChunkAttributeRange {
24        attribute: &'static str,
25        expected: &'static [u32],
26        found: u32,
27    },
28    #[error("Unsupported wav-format, attribute {attribute} must be one of {expected:?}, found 0x{found:02x}")]
29    UnsupportedWav {
30        attribute: &'static str,
31        expected: &'static [u32],
32        found: u32,
33    },
34    #[error("Unsupported wav encoding, module only supports PCM data")]
35    UnsupportedEncoding,
36    #[error("Failed to open file")]
37    FileOpen {
38        #[source]
39        source: std::io::Error,
40    },
41    #[error("Unsupported system, please use a 32-bit system or higher")]
42    UnsupportedSystem,
43    #[error("Unable to read data")]
44    ReadFail {
45        #[source]
46        source: std::io::Error,
47    },
48}
49
50pub fn from_file(file: File) -> Result<WavData, DecodeError> {
52    let mut r: Reader = match Reader::from_file(file) {
54        Err(err) => return Err(err),
55        Ok(r) => r,
56    };
57    let header = match r.read_header() {
59        Err(err) => return Err(err),
60        Ok(head) => head,
61    };
62    let samples = match r.get_samples_f32() {
64        Err(err) => return Err(err),
65        Ok(samples) => samples,
66    };
67    Ok(WavData{header, samples})
68}
69
70pub fn from_file_str(file_path: &str) -> Result<WavData, DecodeError> {
72    let f = match File::open(file_path) {
73        Ok(f) => f,
74        Err(err) => return Err(DecodeError::FileOpen { source: err }),
75    };
76    from_file(f)
77}
78
79pub struct Reader {
81    pub cur: Cursor<Vec<u8>>,
82    pub header: Option<WavHeader>,
83}
84impl Reader {    
85    pub fn from_file(file: File) -> Result<Reader, DecodeError> {
87        let mut data: Vec<u8> = Vec::new();
88        let mut f = file;
89        match f.read_to_end(&mut data) {
90            Ok(_) => {},
91            Err(err) => return Err(DecodeError::ReadFail { source: err }),
92        };
93        Self::from_vec(data)
94    }
95    pub fn from_vec(data: Vec<u8>) -> Result<Reader, DecodeError> {
97        let reader = Reader {
98            cur: Cursor::new(data),
99            header: None,
100        };
101        Ok(reader)
102    }
103    pub fn read_header(&mut self) -> Result<WavHeader, DecodeError> {
105        let mut header = WavHeader::new();
106        let riff_tag = self.read_str4();
108        if riff_tag != "RIFF" {
109            return Err(DecodeError::InvalidTag { expected: "RIFF", found: riff_tag.to_string() });
110        }
111        let chunk_size = self.read_u32().unwrap_or(0);
113        if chunk_size < 8 {
114            return Err(DecodeError::InvalidChunkAttribute {
115                attribute: "chunk size",
116                expected: 7,
117                found: chunk_size,
118            });
119        }
120        let wave_tag = self.read_str4();
122        if  wave_tag != "WAVE" {
123            return Err(DecodeError::InvalidTag { expected: "WAVE", found: riff_tag.to_string() });
124        }
125
126        let _ = self.read_list_chunk(&mut header);
129
130        let fmt_tag = self.read_str4();
132        if fmt_tag != "fmt " {
133            return Err(DecodeError::InvalidTag { expected: "fmt ", found: riff_tag.to_string() });
134        }
135        let chunk_size = self.read_u32().unwrap_or(0);
136        let format_tag = self.read_u16().unwrap_or(0);
138        match format_tag {
139            0x0001 => header.sample_format = SampleFormat::Int,
140            0x0003 => header.sample_format = SampleFormat::Float,
141            0x0006 => header.sample_format = SampleFormat::WaveFromatALaw,
142            0x0007 => header.sample_format = SampleFormat::WaveFormatMuLaw,
143            0xFFFE => header.sample_format = SampleFormat::SubFormat,
144            0x0055 => return Err(DecodeError::UnsupportedWav {
145                attribute: "format tag (0x0055: MP3)",
146                expected: &[0x0001, 0x0003, 0x0006, 0x0007, 0xFFFE],
147                found: format_tag as u32,
148            }),
149            _ => return Err(DecodeError::InvalidChunkAttributeRange {
150                attribute: "format tag",
151                expected: &[0x0001, 0x0003, 0x0006, 0x0007, 0xFFFE],
152                found: format_tag as u32,
153            }),
154        }
155        let ch = self.read_u16().unwrap_or(0);
157        if 1 <= ch && ch <= 2 {
158            header.channels = ch;
159        }
160        header.sample_rate = self.read_u32().unwrap_or(0);
162        if header.sample_rate < 32 {
163            return Err(DecodeError::InvalidChunkAttribute {
164                attribute: "sample rate",
165                expected: 31,
166                found: header.sample_rate,
167            });
168        }
169        let bytes_per_sec = self.read_u32().unwrap_or(0);
171        if bytes_per_sec < 8 {
172            return Err(DecodeError::InvalidChunkAttribute {
173                attribute: "bytes per second",
174                expected: 7,
175                found: header.sample_rate,
176            });
177        }
178        let _data_block_size = self.read_u16().unwrap_or(0);
180        let bits_per_sample = self.read_u16().unwrap_or(0);
182        if bits_per_sample < 8 {
183            return Err(DecodeError::InvalidChunkAttribute {
184                attribute: "bits per sample",
185                expected: 7,
186                found: header.sample_rate,
187            });
188        }
189        header.bits_per_sample = bits_per_sample;
190        let pos = self.cur.position() + chunk_size as u64 - 16;
192        self.cur.set_position(pos);
193
194        let _ = self.read_list_chunk(&mut header);
197
198        self.header = Some(header.clone());
200        Ok(header)
201    }
202
203    pub fn read_list_chunk(&mut self, header: &mut WavHeader) -> Result<usize, DecodeError> {
207        let begin_position = self.cur.position();
209
210        let info_tag = self.read_str4();
212        if info_tag != "LIST" {
213            self.cur.set_position(begin_position);
214            return Err(DecodeError::InvalidTag {
215                expected: "LIST",
216                found: info_tag.to_string()
217            });
218        }
219        let Some(read_size) = self.read_u32() else {
221            self.cur.set_position(begin_position);
222            return Err(DecodeError::ReadFail { source: std::io::Error::new(std::io::ErrorKind::Other, "Unable to read u32") })
223        };
224        let Ok(read_size) = read_size.try_into() else {
225            self.cur.set_position(begin_position);
226            return Err(DecodeError::UnsupportedSystem)
227        };
228
229        let mut data = vec![0; read_size];
231        match self.cur.read_exact(&mut data) {
232            Ok(_) => {
233                Ok(self.analize_list_chunk(data, header))
234            },
235            Err(err) => {
236                self.cur.set_position(begin_position);
237                Err(DecodeError::ReadFail { source: err })
238            }
239        }
240    }
241
242    pub fn analize_list_chunk(&mut self, data: Vec<u8>, header: &mut WavHeader) -> usize {
243        let data_len = data.len() as u64;
244        let mut cur = Cursor::new(data);
245        let mut chunk_tag = [0u8; 4];
247        let chunk_tag = match cur.read_exact(&mut chunk_tag) {
248            Ok(_) => String::from_utf8_lossy(&chunk_tag),
249            Err(_) => return 0, };
251        if chunk_tag != "INFO" {
252            return 0;
253        }
254        let mut items = vec![];
267        let mut result = 0;
268        while cur.position() < data_len {
269            let mut chunk_tag = [0u8; 4];
271            let chunk_tag = match cur.read_exact(&mut chunk_tag) {
272                Ok(_) => String::from_utf8_lossy(&chunk_tag),
273                Err(_) => break,
274            };
275            let mut chunk_size = [0u8; 4];
277            let chunk_size = match cur.read_exact(&mut chunk_size) {
278                Ok(_) => u32::from_le_bytes(chunk_size),
279                Err(_) => break,
280            };
281            let mut data = vec![0; chunk_size as usize];
282            let data = match cur.read_exact(&mut data) {
283                Ok(_) => String::from_utf8_lossy(&data),
284                Err(_) => break,
285            };
286            let item = ListChunkItem {
288                id: chunk_tag.trim_end_matches('\0').to_string(),
289                value: data.trim_end_matches('\0').to_string(),
290            };
291            items.push(item);
292            result += 1;
293        }
294        header.list_chunk = Some(ListChunk{items});
295        result
296    }
297
298    pub fn get_samples_f32(&mut self) -> Result<Vec<f32>, DecodeError> {
299        let mut result:Vec<f32> = Vec::new();
300        loop {
301            let chunk_tag = self.read_str4();
303            if chunk_tag == "" { break; }
304            let size = self.read_u32().unwrap_or(0) as u64;
305            if size == 0 { continue }
308            if chunk_tag != "data" {
310                self.cur.set_position(self.cur.position() + size);
311                continue;
312            }
313            let h = &self.header.clone().unwrap();
315            let num_sample = (size / (h.bits_per_sample / 8) as u64) as u64;
316            match h.sample_format {
317                SampleFormat::Float => {
319                    match h.bits_per_sample {
320                        32 => {
321                            for _ in 0..num_sample {
322                                let lv = self.read_f32().unwrap_or(0.0);
323                                result.push(lv);
324                            }
325                        },
326                        64 => {
327                            for _ in 0..num_sample {
328                                let lv = self.read_f64().unwrap_or(0.0);
329                                result.push(lv as f32); }
331                        },
332                        _ => return Err(DecodeError::UnsupportedWav {
333                            attribute: "bits per float sample",
334                            expected: &[32, 64],
335                            found: h.bits_per_sample as u32,
336                        }),
337                    }
338                },
339                SampleFormat::Int => {
341                    match h.bits_per_sample {
342                        8 => {
343                            for _ in 0..num_sample {
344                                let lv = self.read_u8().unwrap_or(0);
346                                let fv = lv.wrapping_sub(128) as i8 as f32 / (0xFF as f32 / 2.0);
347                                result.push(fv);
348                            }
349                        },
350                        16 => {
351                            for _ in 0..num_sample {
352                                let lv = self.read_i16().unwrap_or(0);
353                                let fv = lv as f32 / (0xFFFF as f32 / 2.0);
354                                result.push(fv);
355                            }
356                        },
357                        24 => {
358                            for _ in 0..num_sample {
359                                let lv = self.read_i24().unwrap_or(0);
360                                let fv = lv as f32 / (0xFFFFFF as f32 / 2.0);
361                                result.push(fv);
362                            }
363                        },
364                        32 => {
365                            for _ in 0..num_sample {
366                                let lv = self.read_i32().unwrap_or(0);
367                                let fv = lv as f32 / (0xFFFFFFFFu32 as f32 / 2.0);
368                                result.push(fv);
369                            }
370                        },
371                        _ => return Err(DecodeError::UnsupportedWav {
372                            attribute: "bits per integer sample",
373                            expected: &[8, 16, 24, 32],
374                            found: h.bits_per_sample as u32,
375                        }),
376                    }
377                },
378                _ => return Err(DecodeError::UnsupportedEncoding),
379            }
380        }
381        Ok(result)
382    }
383
384    pub fn read_str4(&mut self) -> String {
385        let mut buf = [0u8; 4];
386        match self.cur.read(&mut buf) {
387            Ok(sz) => {
388                if sz < 4 {
389                    return String::from("");
390                }
391            },
392            Err(_) => return String::from(""),
393        }
394        let s = String::from_utf8_lossy(&buf);
395        String::from(s)
396    }
397
398    pub fn read_f32(&mut self) -> Option<f32> {
399        match self.read_u32() {
400            Some(v) => Some(f32::from_bits(v)),
401            None => None,
402        }
403    }
404
405    pub fn read_f64(&mut self) -> Option<f64> {
406        match self.read_u64() {
407            Some(v) => Some(f64::from_bits(v)),
408            None => None,
409        }
410    }
411
412    pub fn read_u64(&mut self) -> Option<u64> {
413        let mut buf = [0u8; 8];
414        match self.cur.read(&mut buf) {
415            Ok(v) => v,
416            Err(_) => return None,
417        };
418        Some(u64::from_le_bytes(buf))
419    }
420
421    pub fn read_u32(&mut self) -> Option<u32> {
422        let mut buf = [0u8; 4];
423        match self.cur.read(&mut buf) {
424            Ok(v) => v,
425            Err(_) => return None,
426        };
427        Some(u32::from_le_bytes(buf))
428    }
429
430    pub fn read_i32(&mut self) -> Option<i32> {
431        let mut buf = [0u8; 4];
432        match self.cur.read(&mut buf) {
433            Ok(v) => v,
434            Err(_) => return None,
435        };
436        Some(i32::from_le_bytes(buf))
437    }
438
439    pub fn read_u24(&mut self) -> Option<u32> {
440        let mut buf = [0u8; 3];
441        match self.cur.read(&mut buf) {
442            Ok(v) => v,
443            Err(_) => return None,
444        };
445        let result = 
446            (buf[0] as u32) << 0 | 
447            (buf[1] as u32) << 8 |
448            (buf[2] as u32) << 16;
449        Some(result)
450    }
451
452    pub fn read_i24(&mut self) -> Option<i32> {
453        let mut buf = [0u8; 3];
454        match self.cur.read(&mut buf) {
455            Ok(v) => v,
456            Err(_) => return None,
457        };
458        let buf4 = [0, buf[0], buf[1], buf[2]];
459        Some(i32::from_le_bytes(buf4) >> 8)
460    }
461
462    pub fn read_u16(&mut self) -> Option<u16> {
463        let mut buf = [0u8; 2];
464        match self.cur.read(&mut buf) {
465            Ok(v) => v,
466            Err(_) => return None,
467        };
468        let result = 
469            (buf[0] as u16) << 0 |
470            (buf[1] as u16) << 8;
471        Some(result)
472    }
473
474    pub fn read_i16(&mut self) -> Option<i16> {
475        let mut buf = [0u8; 2];
476        match self.cur.read(&mut buf) {
477            Ok(v) => v,
478            Err(_) => return None,
479        };
480        Some(i16::from_le_bytes(buf))
481    }
482
483    pub fn read_u8(&mut self) -> Option<u8> {
484        let mut buf = [0u8; 1];
485        match self.cur.read(&mut buf) {
486            Ok(_) => Some(buf[0]),
487            Err(_) => None,
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    #[test]
496    
497    fn read_it() {
498        let mut r = Reader::from_vec(vec![1,0]).unwrap();
499        assert_eq!(Some(1), r.read_i16());
500        let mut r = Reader::from_vec(vec![0,1]).unwrap();
501        assert_eq!(Some(0x100), r.read_i16());
502
503        let mut r = Reader::from_vec(vec![0xFF,0xFF]).unwrap();
504        assert_eq!(Some(-1), r.read_i16());
505        let mut r = Reader::from_vec(vec![0xFE,0xFF]).unwrap();
506        assert_eq!(Some(-2), r.read_i16());
507        let mut r = Reader::from_vec(vec![0xFD,0xFF]).unwrap();
508        assert_eq!(Some(-3), r.read_i16());
509
510        let mut r = Reader::from_vec(vec![0xFF,0xFF, 0xFF]).unwrap();
511        assert_eq!(Some(-1), r.read_i24());
512    }
513}
514