h264_parser/
bitreader.rs

1use crate::{Error, Result};
2
3pub struct BitReader<'a> {
4    data: &'a [u8],
5    byte_pos: usize,
6    bit_pos: u8,
7}
8
9impl<'a> BitReader<'a> {
10    pub fn new(data: &'a [u8]) -> Self {
11        Self {
12            data,
13            byte_pos: 0,
14            bit_pos: 0,
15        }
16    }
17
18    pub fn position(&self) -> (usize, u8) {
19        (self.byte_pos, self.bit_pos)
20    }
21
22    pub fn seek(&mut self, byte_pos: usize, bit_pos: u8) -> Result<()> {
23        if byte_pos >= self.data.len() || (byte_pos == self.data.len() - 1 && bit_pos > 7) {
24            return Err(Error::BitstreamError("Seek position out of bounds".into()));
25        }
26        self.byte_pos = byte_pos;
27        self.bit_pos = bit_pos;
28        Ok(())
29    }
30
31    pub fn available_bits(&self) -> usize {
32        if self.byte_pos >= self.data.len() {
33            return 0;
34        }
35        (self.data.len() - self.byte_pos - 1) * 8 + (8 - self.bit_pos as usize)
36    }
37
38    pub fn read_bit(&mut self) -> Result<bool> {
39        if self.byte_pos >= self.data.len() {
40            return Err(Error::UnexpectedEof);
41        }
42
43        let bit = (self.data[self.byte_pos] >> (7 - self.bit_pos)) & 1;
44
45        self.bit_pos += 1;
46        if self.bit_pos == 8 {
47            self.bit_pos = 0;
48            self.byte_pos += 1;
49        }
50
51        Ok(bit != 0)
52    }
53
54    pub fn read_bits(&mut self, n: u32) -> Result<u32> {
55        if n > 32 {
56            return Err(Error::BitstreamError(
57                "Cannot read more than 32 bits".into(),
58            ));
59        }
60
61        let mut value = 0u32;
62        for _ in 0..n {
63            value = (value << 1) | (self.read_bit()? as u32);
64        }
65        Ok(value)
66    }
67
68    pub fn read_flag(&mut self) -> Result<bool> {
69        self.read_bit()
70    }
71
72    pub fn read_u8(&mut self) -> Result<u8> {
73        self.read_bits(8).map(|v| v as u8)
74    }
75
76    pub fn read_u16(&mut self) -> Result<u16> {
77        self.read_bits(16).map(|v| v as u16)
78    }
79
80    pub fn peek_bits(&mut self, n: u32) -> Result<u32> {
81        let saved_byte = self.byte_pos;
82        let saved_bit = self.bit_pos;
83
84        let value = self.read_bits(n)?;
85
86        self.byte_pos = saved_byte;
87        self.bit_pos = saved_bit;
88
89        Ok(value)
90    }
91
92    pub fn skip_bits(&mut self, n: u32) -> Result<()> {
93        for _ in 0..n {
94            self.read_bit()?;
95        }
96        Ok(())
97    }
98
99    pub fn byte_aligned(&self) -> bool {
100        self.bit_pos == 0
101    }
102
103    pub fn align_to_byte(&mut self) {
104        if self.bit_pos != 0 {
105            self.bit_pos = 0;
106            self.byte_pos += 1;
107        }
108    }
109
110    pub fn more_rbsp_data(&self) -> bool {
111        if self.byte_pos >= self.data.len() {
112            return false;
113        }
114
115        if self.byte_pos == self.data.len() - 1 {
116            let remaining_byte = self.data[self.byte_pos];
117            if self.bit_pos >= 8 {
118                return false;
119            }
120            let bits_left = 8 - self.bit_pos;
121            if bits_left == 0 || bits_left > 8 {
122                return false;
123            }
124
125            // Get the remaining bits from current position
126            let shift_amount = self.bit_pos;
127            let remaining_bits = remaining_byte << shift_amount;
128
129            // Check if remaining bits match the RBSP stop bit pattern
130            // The stop bit pattern is a single 1 followed by zeros
131            // In the most significant position after shifting
132            let stop_pattern = 0x80; // 10000000
133
134            return remaining_bits != stop_pattern;
135        }
136
137        true
138    }
139
140    pub fn rbsp_trailing_bits(&mut self) -> Result<()> {
141        if !self.read_flag()? {
142            return Err(Error::BitstreamError("Expected rbsp_stop_one_bit".into()));
143        }
144
145        while !self.byte_aligned() {
146            if self.read_flag()? {
147                return Err(Error::BitstreamError(
148                    "Expected rbsp_alignment_zero_bit".into(),
149                ));
150            }
151        }
152
153        Ok(())
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_read_bits() {
163        let data = vec![0b10110011, 0b01010101];
164        let mut reader = BitReader::new(&data);
165
166        assert_eq!(reader.read_bits(4).unwrap(), 0b1011);
167        assert_eq!(reader.read_bits(4).unwrap(), 0b0011);
168        assert_eq!(reader.read_bits(8).unwrap(), 0b01010101);
169    }
170
171    #[test]
172    fn test_read_flag() {
173        let data = vec![0b10000000, 0b01000000];
174        let mut reader = BitReader::new(&data);
175
176        assert_eq!(reader.read_flag().unwrap(), true);
177        assert_eq!(reader.read_flag().unwrap(), false);
178    }
179
180    #[test]
181    fn test_peek_bits() {
182        let data = vec![0b11110000];
183        let mut reader = BitReader::new(&data);
184
185        assert_eq!(reader.peek_bits(4).unwrap(), 0b1111);
186        assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
187        assert_eq!(reader.read_bits(4).unwrap(), 0b0000);
188    }
189
190    #[test]
191    fn test_byte_alignment() {
192        let data = vec![0xff, 0x00];
193        let mut reader = BitReader::new(&data);
194
195        assert!(reader.byte_aligned());
196        reader.read_bits(3).unwrap();
197        assert!(!reader.byte_aligned());
198        reader.align_to_byte();
199        assert!(reader.byte_aligned());
200        assert_eq!(reader.byte_pos, 1);
201    }
202
203    #[test]
204    fn test_more_rbsp_data() {
205        // Test case: 0x80 = 10000000
206        // This is the RBSP stop bit (1) followed by alignment zeros
207        let data = vec![0x80];
208        let reader = BitReader::new(&data);
209
210        // At the beginning with byte_pos=0, bit_pos=0
211        // We're looking at the last byte with 8 bits remaining: 10000000
212        // This exactly matches the stop bit pattern, so no more RBSP data
213        assert!(!reader.more_rbsp_data());
214
215        // Test another case: actual data before stop bit
216        let data = vec![0xC0]; // 11000000 - has actual data before stop bit
217        let reader = BitReader::new(&data);
218        assert!(reader.more_rbsp_data());
219    }
220}