frame_header/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::io::{self, Read, Write};
3
4#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
5pub enum Endianness {
6    LittleEndian,
7    BigEndian,
8}
9
10#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
11pub enum EncodingFlag {
12    PCMSigned = 0,
13    PCMFloat = 1,
14    Opus = 2,
15    FLAC = 3,
16    AAC = 4,
17    H264 = 5,
18}
19
20#[derive(Serialize, Deserialize, Debug)]
21pub struct FrameHeader {
22    encoding: EncodingFlag,
23    sample_size: u16,
24    sample_rate: u32,
25    channels: u8,
26    bits_per_sample: u8,
27    endianness: Endianness,
28
29    #[cfg(not(target_arch = "wasm32"))]
30    id: Option<u64>,
31    pts: Option<u64>,
32
33    #[cfg(target_arch = "wasm32")]
34    #[serde(
35        serialize_with = "serialize_id_wasm",
36        deserialize_with = "deserialize_id_wasm"
37    )]
38    id: Option<u64>,
39}
40
41#[cfg(target_arch = "wasm32")]
42fn serialize_id_wasm<S>(id: &Option<u64>, serializer: S) -> Result<S::Ok, S::Error>
43where
44    S: serde::Serializer,
45{
46    match id {
47        Some(value) => serializer.serialize_some(&value.to_string()),
48        None => serializer.serialize_none(),
49    }
50}
51
52#[cfg(target_arch = "wasm32")]
53fn deserialize_id_wasm<'de, D>(deserializer: D) -> Result<Option<u64>, D::Error>
54where
55    D: serde::Deserializer<'de>,
56{
57    use serde::de::Error;
58    let id: Option<String> = Option::deserialize(deserializer)?;
59    match id {
60        Some(id_str) => id_str.parse::<u64>().map(Some).map_err(D::Error::custom),
61        None => Ok(None),
62    }
63}
64
65impl FrameHeader {
66    const MAGIC_WORD: u32 = 0x2A;
67    const MAGIC_SHIFT: u32 = 26;
68    const MAGIC_MASK: u32 = 0x3F << 26;
69
70    const SAMPLE_RATE_SHIFT: u32 = 24;
71    const SAMPLE_RATE_MASK: u32 = 0x3 << 24;
72
73    const BITS_SHIFT: u32 = 22;
74    const BITS_MASK: u32 = 0x3 << 22;
75
76    const PTS_SHIFT: u32 = 21;
77    const PTS_MASK: u32 = 0x1 << 21;
78
79    const ID_SHIFT: u32 = 20;
80    const ID_MASK: u32 = 0x1 << 20;
81
82    const ENCODING_SHIFT: u32 = 17;
83    const ENCODING_MASK: u32 = 0x7 << 17;
84
85    const ENDIAN_SHIFT: u32 = 16;
86    const ENDIAN_MASK: u32 = 0x1 << 16;
87
88    const CHANNELS_SHIFT: u32 = 12;
89    const CHANNELS_MASK: u32 = 0xF << 12;
90
91    const SAMPLE_SIZE_MASK: u32 = 0xFFF;
92
93    const VALID_SAMPLE_RATES: [u32; 4] = [16000, 44100, 48000, 96000];
94    const MAX_SAMPLE_SIZE: u16 = 0xFFF;
95
96    pub fn new(
97        encoding: EncodingFlag,
98        sample_size: u16,
99        sample_rate: u32,
100        channels: u8,
101        bits_per_sample: u8,
102        endianness: Endianness,
103        id: Option<u64>,
104        pts: Option<u64>,
105    ) -> Result<Self, String> {
106        if channels == 0 || channels > 16 {
107            return Err("Channel count must be between 1 and 16".to_string());
108        }
109
110        match bits_per_sample {
111            16 | 24 | 32 => {}
112            _ => return Err("Bits per sample must be 16, 24, or 32".to_string()),
113        }
114
115        if sample_size > Self::MAX_SAMPLE_SIZE {
116            return Err(format!(
117                "Sample size exceeds maximum value ({})",
118                Self::MAX_SAMPLE_SIZE
119            ));
120        }
121
122        if !Self::VALID_SAMPLE_RATES.contains(&sample_rate) {
123            return Err(format!(
124                "Invalid sample rate: {}. Must be one of: {:?}",
125                sample_rate,
126                Self::VALID_SAMPLE_RATES
127            ));
128        }
129
130        Ok(FrameHeader {
131            encoding,
132            sample_size,
133            sample_rate,
134            channels,
135            bits_per_sample,
136            endianness,
137            id,
138            pts,
139        })
140    }
141
142    pub fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
143        let mut header: u32 = Self::MAGIC_WORD << Self::MAGIC_SHIFT;
144
145        let sample_rate_code = match self.sample_rate {
146            16000 => 0,
147            44100 => 1,
148            48000 => 2,
149            96000 => 3,
150            _ => {
151                return Err(io::Error::new(
152                    io::ErrorKind::InvalidInput,
153                    "Invalid sample rate",
154                ))
155            }
156        };
157        header |= sample_rate_code << Self::SAMPLE_RATE_SHIFT;
158
159        let bits_code = match self.bits_per_sample {
160            16 => 0,
161            24 => 1,
162            32 => 2,
163            _ => {
164                return Err(io::Error::new(
165                    io::ErrorKind::InvalidInput,
166                    "Invalid bits per sample",
167                ))
168            }
169        };
170        header |= bits_code << Self::BITS_SHIFT;
171
172        header |= (self.pts.is_some() as u32) << Self::PTS_SHIFT;
173        header |= (self.id.is_some() as u32) << Self::ID_SHIFT;
174        header |= (self.encoding as u32) << Self::ENCODING_SHIFT;
175        header |= (self.endianness as u32) << Self::ENDIAN_SHIFT;
176        header |= ((self.channels - 1) as u32) << Self::CHANNELS_SHIFT;
177        header |= self.sample_size as u32;
178
179        writer.write_all(&header.to_be_bytes())?;
180
181        if let Some(id) = self.id {
182            writer.write_all(&id.to_be_bytes())?;
183        }
184
185        if let Some(pts) = self.pts {
186            writer.write_all(&pts.to_be_bytes())?;
187        }
188
189        Ok(())
190    }
191
192    pub fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
193        let mut header_bytes = [0u8; 4];
194        reader.read_exact(&mut header_bytes)?;
195        let header = u32::from_be_bytes(header_bytes);
196
197        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
198            return Err(io::Error::new(
199                io::ErrorKind::InvalidData,
200                "Invalid header magic word",
201            ));
202        }
203
204        let sample_rate = match (header & Self::SAMPLE_RATE_MASK) >> Self::SAMPLE_RATE_SHIFT {
205            0 => 16000,
206            1 => 44100,
207            2 => 48000,
208            3 => 96000,
209            _ => {
210                return Err(io::Error::new(
211                    io::ErrorKind::InvalidData,
212                    "Invalid sample rate code",
213                ))
214            }
215        };
216
217        let bits_per_sample = match (header & Self::BITS_MASK) >> Self::BITS_SHIFT {
218            0 => 16,
219            1 => 24,
220            2 => 32,
221            _ => {
222                return Err(io::Error::new(
223                    io::ErrorKind::InvalidData,
224                    "Invalid bits per sample code",
225                ))
226            }
227        };
228
229        let has_pts = (header & Self::PTS_MASK) >> Self::PTS_SHIFT == 1;
230        let has_id = (header & Self::ID_MASK) >> Self::ID_SHIFT == 1;
231
232        let encoding = match (header & Self::ENCODING_MASK) >> Self::ENCODING_SHIFT {
233            0 => EncodingFlag::PCMSigned,
234            1 => EncodingFlag::PCMFloat,
235            2 => EncodingFlag::Opus,
236            3 => EncodingFlag::FLAC,
237            4 => EncodingFlag::AAC,
238            _ => {
239                return Err(io::Error::new(
240                    io::ErrorKind::InvalidData,
241                    "Invalid encoding flag",
242                ))
243            }
244        };
245
246        let endianness = if (header & Self::ENDIAN_MASK) >> Self::ENDIAN_SHIFT == 0 {
247            Endianness::LittleEndian
248        } else {
249            Endianness::BigEndian
250        };
251
252        let channels = (((header & Self::CHANNELS_MASK) >> Self::CHANNELS_SHIFT) + 1) as u8;
253        let sample_size = (header & Self::SAMPLE_SIZE_MASK) as u16;
254
255        let id = if has_id {
256            let mut id_bytes = [0u8; 8];
257            reader.read_exact(&mut id_bytes)?;
258            Some(u64::from_be_bytes(id_bytes))
259        } else {
260            None
261        };
262
263        let pts = if has_pts {
264            let mut pts_bytes = [0u8; 8];
265            reader.read_exact(&mut pts_bytes)?;
266            Some(u64::from_be_bytes(pts_bytes))
267        } else {
268            None
269        };
270
271        Ok(FrameHeader {
272            encoding,
273            sample_size,
274            sample_rate,
275            channels,
276            bits_per_sample,
277            endianness,
278            id,
279            pts,
280        })
281    }
282
283    pub fn validate_header(header_bytes: &[u8]) -> Result<bool, String> {
284        if header_bytes.len() < 4 {
285            return Err("Header too small".to_string());
286        }
287
288        let header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
289
290        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
291            return Ok(false);
292        }
293
294        let encoding = (header & Self::ENCODING_MASK) >> Self::ENCODING_SHIFT;
295        if encoding > 4 {
296            return Ok(false);
297        }
298
299        let sample_rate_code = (header & Self::SAMPLE_RATE_MASK) >> Self::SAMPLE_RATE_SHIFT;
300        if sample_rate_code > 3 {
301            return Ok(false);
302        }
303
304        let channels = (((header & Self::CHANNELS_MASK) >> Self::CHANNELS_SHIFT) + 1) as u8;
305        if channels == 0 || channels > 16 {
306            return Ok(false);
307        }
308
309        let bits_code = (header & Self::BITS_MASK) >> Self::BITS_SHIFT;
310        if bits_code > 2 {
311            return Ok(false);
312        }
313
314        Ok(true)
315    }
316
317    pub fn size(&self) -> usize {
318        4 + // Base header
319        (self.id.is_some() as usize) * 8 + // Optional ID
320        (self.pts.is_some() as usize) * 8 // Optional PTS
321    }
322
323    // Getter methods
324    pub fn encoding(&self) -> &EncodingFlag {
325        &self.encoding
326    }
327
328    pub fn sample_size(&self) -> u16 {
329        self.sample_size
330    }
331
332    pub fn sample_rate(&self) -> u32 {
333        self.sample_rate
334    }
335
336    pub fn channels(&self) -> u8 {
337        self.channels
338    }
339
340    pub fn bits_per_sample(&self) -> u8 {
341        self.bits_per_sample
342    }
343
344    pub fn endianness(&self) -> &Endianness {
345        &self.endianness
346    }
347
348    pub fn id(&self) -> Option<u64> {
349        self.id
350    }
351
352    pub fn pts(&self) -> Option<u64> {
353        self.pts
354    }
355
356    // Extract methods
357    pub fn extract_sample_count(header_bytes: &[u8]) -> Result<u16, String> {
358        if header_bytes.len() < 4 {
359            return Err("Header too small".to_string());
360        }
361
362        let header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
363
364        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
365            return Err("Invalid magic word".to_string());
366        }
367
368        Ok((header & Self::SAMPLE_SIZE_MASK) as u16)
369    }
370
371    pub fn extract_encoding(header_bytes: &[u8]) -> Result<EncodingFlag, String> {
372        if header_bytes.len() < 4 {
373            return Err("Header too small".to_string());
374        }
375
376        let header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
377
378        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
379            return Err("Invalid magic word".to_string());
380        }
381
382        match (header & Self::ENCODING_MASK) >> Self::ENCODING_SHIFT {
383            0 => Ok(EncodingFlag::PCMSigned),
384            1 => Ok(EncodingFlag::PCMFloat),
385            2 => Ok(EncodingFlag::Opus),
386            3 => Ok(EncodingFlag::FLAC),
387            4 => Ok(EncodingFlag::AAC),
388            _ => Err("Invalid encoding flag".to_string()),
389        }
390    }
391
392    pub fn extract_id(header_bytes: &[u8]) -> Result<Option<u64>, String> {
393        if header_bytes.len() < 4 {
394            return Err("Header too small".to_string());
395        }
396
397        let header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
398
399        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
400            return Err("Invalid magic word".to_string());
401        }
402
403        if (header & Self::ID_MASK) >> Self::ID_SHIFT == 0 {
404            return Ok(None);
405        }
406
407        if header_bytes.len() < 12 {
408            return Err("Header indicates ID present but buffer too small".to_string());
409        }
410
411        Ok(Some(u64::from_be_bytes(
412            header_bytes[4..12].try_into().unwrap(),
413        )))
414    }
415
416    pub fn extract_pts(header_bytes: &[u8]) -> Result<Option<u64>, String> {
417        if header_bytes.len() < 4 {
418            return Err("Header too small".to_string());
419        }
420
421        let header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
422
423        if (header & Self::MAGIC_MASK) >> Self::MAGIC_SHIFT != Self::MAGIC_WORD {
424            return Err("Invalid magic word".to_string());
425        }
426
427        let has_pts = (header & Self::PTS_MASK) >> Self::PTS_SHIFT == 1;
428        if !has_pts {
429            return Ok(None);
430        }
431
432        let has_id = (header & Self::ID_MASK) >> Self::ID_SHIFT == 1;
433        let pts_offset = 4 + if has_id { 8 } else { 0 };
434
435        if header_bytes.len() < pts_offset + 8 {
436            return Err("Header indicates PTS present but buffer too small".to_string());
437        }
438
439        Ok(Some(u64::from_be_bytes(
440            header_bytes[pts_offset..pts_offset + 8].try_into().unwrap(),
441        )))
442    }
443
444    // Patch methods
445    pub fn patch_bits_per_sample(header_bytes: &mut [u8], bits: u8) -> Result<(), String> {
446        if !Self::validate_header(header_bytes)? {
447            return Err("Invalid header".to_string());
448        }
449
450        let bits_code = match bits {
451            16 => 0,
452            24 => 1,
453            32 => 2,
454            _ => return Err("Bits per sample must be 16, 24, or 32".to_string()),
455        };
456
457        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
458        header &= !Self::BITS_MASK;
459        header |= (bits_code << Self::BITS_SHIFT) & Self::BITS_MASK;
460        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
461        Ok(())
462    }
463
464    pub fn patch_sample_size(header_bytes: &mut [u8], new_sample_size: u16) -> Result<(), String> {
465        if !Self::validate_header(header_bytes)? {
466            return Err("Invalid header".to_string());
467        }
468
469        if new_sample_size > Self::MAX_SAMPLE_SIZE {
470            return Err(format!(
471                "Sample size exceeds maximum value ({})",
472                Self::MAX_SAMPLE_SIZE
473            ));
474        }
475
476        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
477        header &= !Self::SAMPLE_SIZE_MASK;
478        header |= new_sample_size as u32;
479        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
480        Ok(())
481    }
482
483    pub fn patch_encoding(header_bytes: &mut [u8], encoding: EncodingFlag) -> Result<(), String> {
484        if !Self::validate_header(header_bytes)? {
485            return Err("Invalid header".to_string());
486        }
487
488        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
489        header &= !Self::ENCODING_MASK;
490        header |= ((encoding as u32) << Self::ENCODING_SHIFT) & Self::ENCODING_MASK;
491        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
492        Ok(())
493    }
494
495    pub fn patch_sample_rate(header_bytes: &mut [u8], sample_rate: u32) -> Result<(), String> {
496        if !Self::validate_header(header_bytes)? {
497            return Err("Invalid header".to_string());
498        }
499
500        let rate_code = match sample_rate {
501            16000 => 0,
502            44100 => 1,
503            48000 => 2,
504            96000 => 3,
505            _ => {
506                return Err(format!(
507                    "Invalid sample rate: {}. Must be one of: {:?}",
508                    sample_rate,
509                    Self::VALID_SAMPLE_RATES
510                ))
511            }
512        };
513
514        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
515        header &= !Self::SAMPLE_RATE_MASK;
516        header |= (rate_code << Self::SAMPLE_RATE_SHIFT) & Self::SAMPLE_RATE_MASK;
517        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
518        Ok(())
519    }
520
521    pub fn patch_channels(header_bytes: &mut [u8], channels: u8) -> Result<(), String> {
522        if !Self::validate_header(header_bytes)? {
523            return Err("Invalid header".to_string());
524        }
525
526        if channels == 0 || channels > 16 {
527            return Err("Channel count must be between 1 and 16".to_string());
528        }
529
530        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
531        header &= !Self::CHANNELS_MASK;
532        header |= (((channels - 1) as u32) << Self::CHANNELS_SHIFT) & Self::CHANNELS_MASK;
533        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
534        Ok(())
535    }
536
537    pub fn patch_id(header_bytes: &mut [u8], id: Option<u64>) -> Result<(), String> {
538        if !Self::validate_header(header_bytes)? {
539            return Err("Invalid header".to_string());
540        }
541
542        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
543        header &= !Self::ID_MASK;
544        header |= ((id.is_some() as u32) << Self::ID_SHIFT) & Self::ID_MASK;
545        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
546
547        if let Some(id_value) = id {
548            if header_bytes.len() < 12 {
549                return Err("Buffer too small to add ID".to_string());
550            }
551            header_bytes[4..12].copy_from_slice(&id_value.to_be_bytes());
552        }
553
554        Ok(())
555    }
556
557    pub fn patch_pts(header_bytes: &mut [u8], pts: Option<u64>) -> Result<(), String> {
558        if !Self::validate_header(header_bytes)? {
559            return Err("Invalid header".to_string());
560        }
561
562        let mut header = u32::from_be_bytes(header_bytes[..4].try_into().unwrap());
563        header &= !Self::PTS_MASK;
564        header |= ((pts.is_some() as u32) << Self::PTS_SHIFT) & Self::PTS_MASK;
565
566        let has_id = (header & Self::ID_MASK) >> Self::ID_SHIFT == 1;
567        let pts_offset = 4 + if has_id { 8 } else { 0 };
568
569        if let Some(pts_value) = pts {
570            if header_bytes.len() < pts_offset + 8 {
571                return Err("Buffer too small to add PTS".to_string());
572            }
573            header_bytes[pts_offset..pts_offset + 8].copy_from_slice(&pts_value.to_be_bytes());
574        }
575
576        header_bytes[..4].copy_from_slice(&header.to_be_bytes());
577        Ok(())
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    fn create_test_header() -> Vec<u8> {
586        let header = FrameHeader::new(
587            EncodingFlag::PCMSigned,
588            1024,
589            48000,
590            2,
591            24,
592            Endianness::LittleEndian,
593            None,
594            None,
595        )
596        .unwrap();
597        let mut buffer = Vec::new();
598        header.encode(&mut buffer).unwrap();
599        buffer
600    }
601
602    fn create_header_with_pts() -> Vec<u8> {
603        let header = FrameHeader::new(
604            EncodingFlag::PCMSigned,
605            1024,
606            48000,
607            2,
608            24,
609            Endianness::LittleEndian,
610            None,
611            Some(0x1234567890ABCDEF),
612        )
613        .unwrap();
614        let mut buffer = Vec::new();
615        header.encode(&mut buffer).unwrap();
616        buffer
617    }
618
619    fn create_header_with_id_and_pts() -> Vec<u8> {
620        let header = FrameHeader::new(
621            EncodingFlag::PCMSigned,
622            1024,
623            48000,
624            2,
625            24,
626            Endianness::LittleEndian,
627            Some(0xDEADBEEF),
628            Some(0xFEEDFACE),
629        )
630        .unwrap();
631        let mut buffer = Vec::new();
632        header.encode(&mut buffer).unwrap();
633        buffer
634    }
635
636    #[test]
637    fn test_pts_handling() {
638        // Test header with PTS
639        let header_bytes = create_header_with_pts();
640        assert_eq!(header_bytes.len(), 12); // 4 bytes header + 8 bytes PTS
641
642        let decoded = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
643        assert_eq!(decoded.pts(), Some(0x1234567890ABCDEF));
644        assert_eq!(decoded.size(), 12);
645
646        // Test header with both ID and PTS
647        let header_bytes = create_header_with_id_and_pts();
648        assert_eq!(header_bytes.len(), 20); // 4 bytes header + 8 bytes ID + 8 bytes PTS
649
650        let decoded = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
651        assert_eq!(decoded.id(), Some(0xDEADBEEF));
652        assert_eq!(decoded.pts(), Some(0xFEEDFACE));
653        assert_eq!(decoded.size(), 20);
654
655        // Test patching PTS
656        let mut header_bytes = create_test_header();
657        assert_eq!(header_bytes.len(), 4); // No PTS initially
658
659        let mut extended_bytes = vec![0; 12];
660        extended_bytes[..4].copy_from_slice(&header_bytes);
661
662        assert!(FrameHeader::patch_pts(&mut extended_bytes, Some(0xCAFEBABE)).is_ok());
663        let updated = FrameHeader::decode(&mut &extended_bytes[..]).unwrap();
664        assert_eq!(updated.pts(), Some(0xCAFEBABE));
665    }
666
667    #[test]
668    fn test_extract_pts() {
669        // Test header with PTS
670        let header_with_pts = create_header_with_pts();
671        let pts = FrameHeader::extract_pts(&header_with_pts).unwrap();
672        assert_eq!(pts, Some(0x1234567890ABCDEF));
673
674        // Test header without PTS
675        let header_no_pts = create_test_header();
676        let pts = FrameHeader::extract_pts(&header_no_pts).unwrap();
677        assert_eq!(pts, None);
678
679        // Test invalid cases
680        let mut invalid_header = header_with_pts.clone();
681        invalid_header[0] = 0; // Corrupt magic word
682        assert!(FrameHeader::extract_pts(&invalid_header).is_err());
683
684        // Test truncated header with PTS flag set
685        let truncated = header_with_pts[..4].to_vec();
686        assert!(FrameHeader::extract_pts(&truncated).is_err());
687    }
688
689    #[test]
690    fn test_encoding_roundtrip_with_pts() {
691        let original = FrameHeader::new(
692            EncodingFlag::Opus,
693            2048,
694            48000,
695            8,
696            16,
697            Endianness::LittleEndian,
698            Some(0xDEADBEEF),
699            Some(0xCAFEBABE),
700        )
701        .unwrap();
702
703        let mut buffer = Vec::new();
704        original.encode(&mut buffer).unwrap();
705
706        let decoded = FrameHeader::decode(&mut &buffer[..]).unwrap();
707
708        assert_eq!(decoded.pts(), original.pts());
709        assert_eq!(decoded.id(), original.id());
710        assert_eq!(decoded.size(), original.size());
711        assert_eq!(buffer.len(), decoded.size());
712    }
713
714    #[test]
715    fn test_patch_operations() {
716        let mut header_bytes = create_test_header();
717
718        // Test sample size patching
719        assert!(FrameHeader::patch_sample_size(&mut header_bytes, 2048).is_ok());
720        let updated = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
721        assert_eq!(updated.sample_size(), 2048);
722
723        // Test encoding patching
724        assert!(FrameHeader::patch_encoding(&mut header_bytes, EncodingFlag::FLAC).is_ok());
725        let updated = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
726        assert_eq!(*updated.encoding(), EncodingFlag::FLAC);
727
728        // Test sample rate patching
729        assert!(FrameHeader::patch_sample_rate(&mut header_bytes, 96000).is_ok());
730        let updated = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
731        assert_eq!(updated.sample_rate(), 96000);
732
733        // Test bits per sample patching
734        assert!(FrameHeader::patch_bits_per_sample(&mut header_bytes, 32).is_ok());
735        let updated = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
736        assert_eq!(updated.bits_per_sample(), 32);
737
738        // Test channels patching
739        assert!(FrameHeader::patch_channels(&mut header_bytes, 16).is_ok());
740        let updated = FrameHeader::decode(&mut &header_bytes[..]).unwrap();
741        assert_eq!(updated.channels(), 16);
742
743        // Test PTS patching
744        let mut extended_bytes = vec![0; 20]; // Enough space for header + id + pts
745        extended_bytes[..header_bytes.len()].copy_from_slice(&header_bytes);
746        assert!(FrameHeader::patch_pts(&mut extended_bytes, Some(0xCAFEBABE)).is_ok());
747        let updated = FrameHeader::decode(&mut &extended_bytes[..]).unwrap();
748        assert_eq!(updated.pts(), Some(0xCAFEBABE));
749    }
750
751    #[test]
752    fn test_extract_operations() {
753        let header_bytes = create_header_with_id_and_pts();
754
755        assert_eq!(
756            FrameHeader::extract_sample_count(&header_bytes).unwrap(),
757            1024
758        );
759        assert_eq!(
760            FrameHeader::extract_encoding(&header_bytes).unwrap(),
761            EncodingFlag::PCMSigned
762        );
763        assert_eq!(
764            FrameHeader::extract_id(&header_bytes).unwrap(),
765            Some(0xDEADBEEF)
766        );
767        assert_eq!(
768            FrameHeader::extract_pts(&header_bytes).unwrap(),
769            Some(0xFEEDFACE)
770        );
771
772        // Test with invalid header
773        let mut invalid_header = header_bytes.clone();
774        invalid_header[0] = 0; // Corrupt magic word
775        assert!(FrameHeader::extract_sample_count(&invalid_header).is_err());
776        assert!(FrameHeader::extract_encoding(&invalid_header).is_err());
777        assert!(FrameHeader::extract_id(&invalid_header).is_err());
778        assert!(FrameHeader::extract_pts(&invalid_header).is_err());
779    }
780
781    #[test]
782    fn test_patch_validation() {
783        let mut header_bytes = create_test_header();
784
785        // Test invalid sample size
786        assert!(FrameHeader::patch_sample_size(&mut header_bytes, 5000).is_err());
787
788        // Test invalid sample rate
789        assert!(FrameHeader::patch_sample_rate(&mut header_bytes, 192000).is_err());
790
791        // Test invalid channels
792        assert!(FrameHeader::patch_channels(&mut header_bytes, 17).is_err());
793        assert!(FrameHeader::patch_channels(&mut header_bytes, 0).is_err());
794
795        // Test invalid bits per sample
796        assert!(FrameHeader::patch_bits_per_sample(&mut header_bytes, 20).is_err());
797    }
798
799    #[test]
800    fn test_sample_size_extraction() {
801        let header = FrameHeader::new(
802            EncodingFlag::PCMSigned,
803            1024,
804            48000,
805            2,
806            24,
807            Endianness::LittleEndian,
808            None,
809            None,
810        )
811        .unwrap();
812
813        let mut buffer = Vec::new();
814        header.encode(&mut buffer).unwrap();
815
816        let extracted = FrameHeader::extract_sample_count(&buffer).unwrap();
817        assert_eq!(extracted, 1024, "Sample size extraction failed");
818
819        // Now test with a decoded header to verify consistency
820        let decoded = FrameHeader::decode(&mut &buffer[..]).unwrap();
821        assert_eq!(decoded.sample_size(), 1024, "Sample size decode failed");
822    }
823
824    //This test ensures field boundaries by setting each field to its maximum value and verifying no corruption.
825    #[test]
826    fn test_bit_layout() {
827        let header = FrameHeader::new(
828            EncodingFlag::PCMSigned, // 000
829            0xFFF,                   // 111111111111
830            48000,                   // 01
831            16,                      // 1111
832            32,                      // 10
833            Endianness::BigEndian,   // 1
834            Some(1),                 // 1
835            Some(1),                 // 1
836        )
837        .unwrap();
838
839        let mut buffer = Vec::new();
840        header.encode(&mut buffer).unwrap();
841        let decoded = FrameHeader::decode(&mut &buffer[..]).unwrap();
842
843        // Verify max values are preserved
844        assert_eq!(decoded.sample_size(), 0xFFF);
845        assert_eq!(decoded.channels(), 16);
846        assert_eq!(decoded.bits_per_sample(), 32);
847        assert_eq!(decoded.endianness(), &Endianness::BigEndian);
848        assert!(decoded.id().is_some());
849        assert!(decoded.pts().is_some());
850    }
851    #[test]
852    fn test_valid_opus_and_flac_sample_sizes_with_varied_pts_and_ids() {
853        let opus_sample_sizes = [80, 160, 240, 480, 960, 1920, 2880];
854        let flac_sample_sizes = [512, 1024, 2048];
855        let sample_rates = [16000, 44100, 48000, 96000];
856        let channels_list = [1, 2, 8, 16];
857        let bits_list = [16, 24, 32];
858        let endianness_list = [Endianness::LittleEndian, Endianness::BigEndian];
859        let pts_values = [
860            1_670_000_000_000_000,
861            1_671_000_000_000_000,
862            1_672_000_000_000_000,
863            1_673_000_000_000_000,
864            1_674_000_000_000_000,
865            1_675_000_000_000_000,
866        ];
867        let id_values = [
868            0xFFFFFFFFFFFFFFFF,
869            0x0123456789ABCDEF,
870            0xDEADBEEFDEADBEEF,
871            0,
872            1,
873            42,
874        ];
875        for &encoding in &[EncodingFlag::Opus, EncodingFlag::FLAC] {
876            let sample_sizes: &[u16] = match encoding {
877                EncodingFlag::Opus => &opus_sample_sizes,
878                EncodingFlag::FLAC => &flac_sample_sizes,
879                _ => continue,
880            };
881            for &sample_size in sample_sizes {
882                for &sample_rate in &sample_rates {
883                    for &channels in &channels_list {
884                        for &bits in &bits_list {
885                            for &endianness in &endianness_list {
886                                for &id_val in &id_values {
887                                    for &pts_val in &pts_values {
888                                        let header = FrameHeader::new(
889                                            encoding,
890                                            sample_size,
891                                            sample_rate,
892                                            channels,
893                                            bits,
894                                            endianness,
895                                            Some(id_val),
896                                            Some(pts_val),
897                                        );
898
899                                        assert!(
900                                        header.is_ok(),
901                                        "Failed to create header for encoding: {:?}, sample_size: {}, sample_rate: {}, channels: {}, bits: {}, endianness: {:?}, id: {:?}, pts: {:?}",
902                                        encoding, sample_size, sample_rate, channels, bits, endianness, id_val, pts_val
903                                    );
904
905                                        let header = header.unwrap();
906                                        let mut buffer = Vec::new();
907
908                                        assert!(
909                                        header.encode(&mut buffer).is_ok(),
910                                        "Failed to encode header for encoding: {:?}, sample_size: {}, sample_rate: {}, channels: {}, bits: {}, endianness: {:?}, id: {:?}, pts: {:?}",
911                                        encoding, sample_size, sample_rate, channels, bits, endianness, id_val, pts_val
912                                    );
913
914                                        let decoded = FrameHeader::decode(&mut &buffer[..]);
915
916                                        assert!(
917                                        decoded.is_ok(),
918                                        "Failed to decode header for encoding: {:?}, sample_size: {}, sample_rate: {}, channels: {}, bits: {}, endianness: {:?}, id: {:?}, pts: {:?}",
919                                        encoding, sample_size, sample_rate, channels, bits, endianness, id_val, pts_val
920                                    );
921
922                                        let decoded = decoded.unwrap();
923                                        assert_eq!(*decoded.encoding(), encoding);
924                                        assert_eq!(decoded.sample_size(), sample_size);
925                                        assert_eq!(decoded.sample_rate(), sample_rate);
926                                        assert_eq!(decoded.channels(), channels);
927                                        assert_eq!(decoded.bits_per_sample(), bits);
928                                        assert_eq!(*decoded.endianness(), endianness);
929                                        assert_eq!(decoded.id(), Some(id_val));
930                                        assert_eq!(decoded.pts(), Some(pts_val));
931                                    }
932                                }
933                            }
934                        }
935                    }
936                }
937            }
938        }
939    }
940
941    #[test]
942    fn test_patch_field_isolation() {
943        // Create a header with known values for all fields
944        let original = FrameHeader::new(
945            EncodingFlag::PCMSigned,
946            1024,
947            48000,
948            4,
949            24,
950            Endianness::LittleEndian,
951            Some(0xDEADBEEF),
952            Some(0xCAFEBABE),
953        )
954        .unwrap();
955
956        let mut buffer = Vec::new();
957        original.encode(&mut buffer).unwrap();
958
959        // Test sample size patching
960        let mut test_buffer = buffer.clone();
961        FrameHeader::patch_sample_size(&mut test_buffer, 2048).unwrap();
962        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
963        assert_eq!(updated.sample_size(), 2048); // Changed field
964        assert_eq!(*updated.encoding(), *original.encoding());
965        assert_eq!(updated.sample_rate(), original.sample_rate());
966        assert_eq!(updated.channels(), original.channels());
967        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
968        assert_eq!(*updated.endianness(), *original.endianness());
969        assert_eq!(updated.id(), original.id());
970        assert_eq!(updated.pts(), original.pts());
971
972        // Test encoding patching
973        let mut test_buffer = buffer.clone();
974        FrameHeader::patch_encoding(&mut test_buffer, EncodingFlag::FLAC).unwrap();
975        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
976        assert_eq!(*updated.encoding(), EncodingFlag::FLAC); // Changed field
977        assert_eq!(updated.sample_size(), original.sample_size());
978        assert_eq!(updated.sample_rate(), original.sample_rate());
979        assert_eq!(updated.channels(), original.channels());
980        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
981        assert_eq!(*updated.endianness(), *original.endianness());
982        assert_eq!(updated.id(), original.id());
983        assert_eq!(updated.pts(), original.pts());
984
985        // Test sample rate patching
986        let mut test_buffer = buffer.clone();
987        FrameHeader::patch_sample_rate(&mut test_buffer, 96000).unwrap();
988        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
989        assert_eq!(updated.sample_rate(), 96000); // Changed field
990        assert_eq!(*updated.encoding(), *original.encoding());
991        assert_eq!(updated.sample_size(), original.sample_size());
992        assert_eq!(updated.channels(), original.channels());
993        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
994        assert_eq!(*updated.endianness(), *original.endianness());
995        assert_eq!(updated.id(), original.id());
996        assert_eq!(updated.pts(), original.pts());
997
998        // Test bits per sample patching
999        let mut test_buffer = buffer.clone();
1000        FrameHeader::patch_bits_per_sample(&mut test_buffer, 32).unwrap();
1001        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
1002        assert_eq!(updated.bits_per_sample(), 32); // Changed field
1003        assert_eq!(*updated.encoding(), *original.encoding());
1004        assert_eq!(updated.sample_size(), original.sample_size());
1005        assert_eq!(updated.sample_rate(), original.sample_rate());
1006        assert_eq!(updated.channels(), original.channels());
1007        assert_eq!(*updated.endianness(), *original.endianness());
1008        assert_eq!(updated.id(), original.id());
1009        assert_eq!(updated.pts(), original.pts());
1010
1011        // Test channels patching
1012        let mut test_buffer = buffer.clone();
1013        FrameHeader::patch_channels(&mut test_buffer, 8).unwrap();
1014        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
1015        assert_eq!(updated.channels(), 8); // Changed field
1016        assert_eq!(*updated.encoding(), *original.encoding());
1017        assert_eq!(updated.sample_size(), original.sample_size());
1018        assert_eq!(updated.sample_rate(), original.sample_rate());
1019        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
1020        assert_eq!(*updated.endianness(), *original.endianness());
1021        assert_eq!(updated.id(), original.id());
1022        assert_eq!(updated.pts(), original.pts());
1023
1024        // Test ID patching
1025        let mut test_buffer = buffer.clone();
1026        FrameHeader::patch_id(&mut test_buffer, Some(0xFEEDFACE)).unwrap();
1027        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
1028        assert_eq!(updated.id(), Some(0xFEEDFACE)); // Changed field
1029        assert_eq!(*updated.encoding(), *original.encoding());
1030        assert_eq!(updated.sample_size(), original.sample_size());
1031        assert_eq!(updated.sample_rate(), original.sample_rate());
1032        assert_eq!(updated.channels(), original.channels());
1033        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
1034        assert_eq!(*updated.endianness(), *original.endianness());
1035        assert_eq!(updated.pts(), original.pts());
1036
1037        // Test PTS patching
1038        let mut test_buffer = buffer.clone();
1039        FrameHeader::patch_pts(&mut test_buffer, Some(0xF00DFACE)).unwrap();
1040        let updated = FrameHeader::decode(&mut &test_buffer[..]).unwrap();
1041        assert_eq!(updated.pts(), Some(0xF00DFACE)); // Changed field
1042        assert_eq!(*updated.encoding(), *original.encoding());
1043        assert_eq!(updated.sample_size(), original.sample_size());
1044        assert_eq!(updated.sample_rate(), original.sample_rate());
1045        assert_eq!(updated.channels(), original.channels());
1046        assert_eq!(updated.bits_per_sample(), original.bits_per_sample());
1047        assert_eq!(*updated.endianness(), *original.endianness());
1048        assert_eq!(updated.id(), original.id());
1049    }
1050
1051    #[test]
1052    fn test_magic_word_off_by_one() {
1053        // Create a valid header as base
1054        let valid_header = FrameHeader::new(
1055            EncodingFlag::PCMSigned,
1056            1024,
1057            48000,
1058            2,
1059            24,
1060            Endianness::LittleEndian,
1061            None,
1062            None,
1063        )
1064        .unwrap();
1065        let mut valid_buffer = Vec::new();
1066        valid_header.encode(&mut valid_buffer).unwrap();
1067
1068        // Test magic word off by one higher
1069        let mut buffer = valid_buffer.clone();
1070        let mut header = u32::from_be_bytes(buffer[..4].try_into().unwrap());
1071        header = (header & !FrameHeader::MAGIC_MASK)
1072            | ((FrameHeader::MAGIC_WORD + 1) << FrameHeader::MAGIC_SHIFT);
1073        buffer[..4].copy_from_slice(&header.to_be_bytes());
1074        assert!(
1075            FrameHeader::decode(&mut &buffer[..]).is_err(),
1076            "Failed to detect magic word too high"
1077        );
1078        assert!(
1079            !FrameHeader::validate_header(&buffer).unwrap(),
1080            "Validation passed with magic word too high"
1081        );
1082
1083        // Test magic word off by one lower
1084        let mut buffer = valid_buffer.clone();
1085        let mut header = u32::from_be_bytes(buffer[..4].try_into().unwrap());
1086        header = (header & !FrameHeader::MAGIC_MASK)
1087            | ((FrameHeader::MAGIC_WORD - 1) << FrameHeader::MAGIC_SHIFT);
1088        buffer[..4].copy_from_slice(&header.to_be_bytes());
1089        assert!(
1090            FrameHeader::decode(&mut &buffer[..]).is_err(),
1091            "Failed to detect magic word too low"
1092        );
1093        assert!(
1094            !FrameHeader::validate_header(&buffer).unwrap(),
1095            "Validation passed with magic word too low"
1096        );
1097
1098        // Test magic word shifted right by one bit
1099        let mut buffer = valid_buffer.clone();
1100        let mut header = u32::from_be_bytes(buffer[..4].try_into().unwrap());
1101        header = (header & !FrameHeader::MAGIC_MASK)
1102            | ((FrameHeader::MAGIC_WORD >> 1) << FrameHeader::MAGIC_SHIFT);
1103        buffer[..4].copy_from_slice(&header.to_be_bytes());
1104        assert!(
1105            FrameHeader::decode(&mut &buffer[..]).is_err(),
1106            "Failed to detect magic word bit-shifted right"
1107        );
1108        assert!(
1109            !FrameHeader::validate_header(&buffer).unwrap(),
1110            "Validation passed with magic word bit-shifted right"
1111        );
1112
1113        // Test magic word shifted left by one bit
1114        let mut buffer = valid_buffer.clone();
1115        let mut header = u32::from_be_bytes(buffer[..4].try_into().unwrap());
1116        header = (header & !FrameHeader::MAGIC_MASK)
1117            | ((FrameHeader::MAGIC_WORD << 1) << FrameHeader::MAGIC_SHIFT);
1118        buffer[..4].copy_from_slice(&header.to_be_bytes());
1119        assert!(
1120            FrameHeader::decode(&mut &buffer[..]).is_err(),
1121            "Failed to detect magic word bit-shifted left"
1122        );
1123        assert!(
1124            !FrameHeader::validate_header(&buffer).unwrap(),
1125            "Validation passed with magic word bit-shifted left"
1126        );
1127
1128        // Test magic word at wrong bit position (shifted by 1 in final header)
1129        let mut buffer = valid_buffer;
1130        let mut header = u32::from_be_bytes(buffer[..4].try_into().unwrap());
1131        header = (header & !FrameHeader::MAGIC_MASK)
1132            | (FrameHeader::MAGIC_WORD << (FrameHeader::MAGIC_SHIFT + 1));
1133        buffer[..4].copy_from_slice(&header.to_be_bytes());
1134        assert!(
1135            FrameHeader::decode(&mut &buffer[..]).is_err(),
1136            "Failed to detect magic word at wrong position"
1137        );
1138        assert!(
1139            !FrameHeader::validate_header(&buffer).unwrap(),
1140            "Validation passed with magic word at wrong position"
1141        );
1142    }
1143}