oximedia_codec/jpegxl/
bitreader.rs1use crate::error::{CodecError, CodecResult};
8
9pub struct BitReader<'a> {
14 data: &'a [u8],
15 byte_pos: usize,
17 bit_pos: u8,
19}
20
21impl<'a> BitReader<'a> {
22 pub fn new(data: &'a [u8]) -> Self {
24 Self {
25 data,
26 byte_pos: 0,
27 bit_pos: 0,
28 }
29 }
30
31 pub fn read_bits(&mut self, n: u8) -> CodecResult<u32> {
37 if n == 0 {
38 return Ok(0);
39 }
40 if n > 32 {
41 return Err(CodecError::InvalidBitstream(
42 "Cannot read more than 32 bits at once".into(),
43 ));
44 }
45 if self.remaining_bits() < n as usize {
46 return Err(CodecError::InvalidBitstream(
47 "Not enough bits remaining in stream".into(),
48 ));
49 }
50
51 let mut result: u32 = 0;
52 let mut bits_read: u8 = 0;
53
54 while bits_read < n {
55 let bits_available_in_byte = 8 - self.bit_pos;
56 let bits_needed = n - bits_read;
57 let bits_to_read = bits_available_in_byte.min(bits_needed);
58
59 let byte_val = self.data[self.byte_pos] as u32;
60 let mask = (1u32 << bits_to_read) - 1;
61 let extracted = (byte_val >> self.bit_pos) & mask;
62
63 result |= extracted << bits_read;
64 bits_read += bits_to_read;
65
66 self.bit_pos += bits_to_read;
67 if self.bit_pos >= 8 {
68 self.bit_pos = 0;
69 self.byte_pos += 1;
70 }
71 }
72
73 Ok(result)
74 }
75
76 pub fn read_bool(&mut self) -> CodecResult<bool> {
78 Ok(self.read_bits(1)? != 0)
79 }
80
81 pub fn read_u8(&mut self, n: u8) -> CodecResult<u8> {
83 if n > 8 {
84 return Err(CodecError::InvalidBitstream(
85 "Cannot read more than 8 bits into u8".into(),
86 ));
87 }
88 self.read_bits(n).map(|v| v as u8)
89 }
90
91 pub fn read_u16(&mut self, n: u8) -> CodecResult<u16> {
93 if n > 16 {
94 return Err(CodecError::InvalidBitstream(
95 "Cannot read more than 16 bits into u16".into(),
96 ));
97 }
98 self.read_bits(n).map(|v| v as u16)
99 }
100
101 pub fn read_u32(&mut self, n: u8) -> CodecResult<u32> {
103 self.read_bits(n)
104 }
105
106 pub fn read_u64(&mut self) -> CodecResult<u64> {
114 let selector = self.read_bits(2)?;
115 match selector {
116 0 => Ok(0),
117 1 => {
118 let extra = self.read_bits(4)? as u64;
119 Ok(1 + extra)
120 }
121 2 => {
122 let extra = self.read_bits(8)? as u64;
123 Ok(17 + extra)
124 }
125 3 => {
126 let mut value = self.read_bits(12)? as u64;
128 let mut shift = 12u32;
129 while shift < 60 {
130 let more = self.read_bool()?;
131 if more {
132 let chunk = self.read_bits(8)? as u64;
133 value |= chunk << shift;
134 shift += 8;
135 } else {
136 break;
137 }
138 }
139 if shift >= 60 {
141 let chunk = self.read_bits(4)? as u64;
142 value |= chunk << shift;
143 }
144 Ok(273 + value)
145 }
146 _ => Err(CodecError::InvalidBitstream("Invalid U64 selector".into())),
147 }
148 }
149
150 pub fn remaining_bits(&self) -> usize {
152 if self.byte_pos >= self.data.len() {
153 return 0;
154 }
155 (self.data.len() - self.byte_pos) * 8 - self.bit_pos as usize
156 }
157
158 pub fn align_to_byte(&mut self) {
162 if self.bit_pos != 0 {
163 self.bit_pos = 0;
164 self.byte_pos += 1;
165 }
166 }
167
168 pub fn byte_position(&self) -> usize {
170 self.byte_pos
171 }
172
173 pub fn is_empty(&self) -> bool {
175 self.remaining_bits() == 0
176 }
177}
178
179pub struct BitWriter {
183 data: Vec<u8>,
184 current_byte: u8,
185 bit_pos: u8,
186}
187
188impl BitWriter {
189 pub fn new() -> Self {
191 Self {
192 data: Vec::new(),
193 current_byte: 0,
194 bit_pos: 0,
195 }
196 }
197
198 pub fn with_capacity(bytes: usize) -> Self {
200 Self {
201 data: Vec::with_capacity(bytes),
202 current_byte: 0,
203 bit_pos: 0,
204 }
205 }
206
207 pub fn write_bits(&mut self, value: u32, n: u8) {
209 if n == 0 {
210 return;
211 }
212 let mut remaining = n;
213 let mut val = value;
214 let mut written: u8 = 0;
215
216 while written < n {
217 let space_in_byte = 8 - self.bit_pos;
218 let bits_to_write = space_in_byte.min(remaining);
219 let mask = (1u32 << bits_to_write) - 1;
220 let bits = (val & mask) as u8;
221
222 self.current_byte |= bits << self.bit_pos;
223 self.bit_pos += bits_to_write;
224 val >>= bits_to_write;
225 written += bits_to_write;
226 remaining -= bits_to_write;
227
228 if self.bit_pos >= 8 {
229 self.data.push(self.current_byte);
230 self.current_byte = 0;
231 self.bit_pos = 0;
232 }
233 }
234 }
235
236 pub fn write_bool(&mut self, v: bool) {
238 self.write_bits(u32::from(v), 1);
239 }
240
241 pub fn write_u64(&mut self, value: u64) {
243 if value == 0 {
244 self.write_bits(0, 2); } else if value <= 16 {
246 self.write_bits(1, 2); self.write_bits((value - 1) as u32, 4);
248 } else if value <= 272 {
249 self.write_bits(2, 2); self.write_bits((value - 17) as u32, 8);
251 } else {
252 self.write_bits(3, 2); let mut remaining = value - 273;
254 self.write_bits((remaining & 0xFFF) as u32, 12);
256 remaining >>= 12;
257 let mut shift = 12u32;
258 while shift < 60 && remaining > 0 {
259 self.write_bool(true); self.write_bits((remaining & 0xFF) as u32, 8);
261 remaining >>= 8;
262 shift += 8;
263 }
264 if shift < 60 {
265 self.write_bool(false); } else if shift >= 60 {
267 self.write_bits((remaining & 0xF) as u32, 4);
269 }
270 }
271 }
272
273 pub fn align_to_byte(&mut self) {
275 if self.bit_pos != 0 {
276 self.data.push(self.current_byte);
277 self.current_byte = 0;
278 self.bit_pos = 0;
279 }
280 }
281
282 pub fn finish(mut self) -> Vec<u8> {
286 self.align_to_byte();
287 self.data
288 }
289
290 pub fn bytes_written(&self) -> usize {
292 self.data.len()
293 }
294
295 pub fn bits_written(&self) -> usize {
297 self.data.len() * 8 + self.bit_pos as usize
298 }
299}
300
301impl Default for BitWriter {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 #[ignore]
313 fn test_bitreader_basic() {
314 let data = [0b1010_0110u8, 0b1100_0011];
315 let mut reader = BitReader::new(&data);
316
317 assert_eq!(reader.read_bits(4).expect("ok"), 0b0110);
319 assert_eq!(reader.read_bits(4).expect("ok"), 0b1010);
321 assert_eq!(reader.read_bits(8).expect("ok"), 0b1100_0011);
323 }
324
325 #[test]
326 #[ignore]
327 fn test_bitreader_cross_byte() {
328 let data = [0xFF, 0x00];
329 let mut reader = BitReader::new(&data);
330
331 assert_eq!(reader.read_bits(4).expect("ok"), 0xF);
333 assert_eq!(reader.read_bits(8).expect("ok"), 0x0F);
335 }
336
337 #[test]
338 #[ignore]
339 fn test_bitreader_bool() {
340 let data = [0b0000_0101];
341 let mut reader = BitReader::new(&data);
342
343 assert!(reader.read_bool().expect("ok")); assert!(!reader.read_bool().expect("ok")); assert!(reader.read_bool().expect("ok")); }
347
348 #[test]
349 #[ignore]
350 fn test_bitreader_eof() {
351 let data = [0xFF];
352 let mut reader = BitReader::new(&data);
353 let _ = reader.read_bits(8).expect("ok");
354 assert!(reader.read_bits(1).is_err());
355 }
356
357 #[test]
358 #[ignore]
359 fn test_bitreader_remaining() {
360 let data = [0xFF, 0xFF];
361 let mut reader = BitReader::new(&data);
362 assert_eq!(reader.remaining_bits(), 16);
363 let _ = reader.read_bits(3).expect("ok");
364 assert_eq!(reader.remaining_bits(), 13);
365 }
366
367 #[test]
368 #[ignore]
369 fn test_bitreader_align() {
370 let data = [0xFF, 0xAA];
371 let mut reader = BitReader::new(&data);
372 let _ = reader.read_bits(3).expect("ok");
373 reader.align_to_byte();
374 assert_eq!(reader.read_bits(8).expect("ok"), 0xAA);
376 }
377
378 #[test]
379 #[ignore]
380 fn test_bitwriter_basic() {
381 let mut writer = BitWriter::new();
382 writer.write_bits(0b0110, 4);
383 writer.write_bits(0b1010, 4);
384 let data = writer.finish();
385 assert_eq!(data, vec![0b1010_0110]);
386 }
387
388 #[test]
389 #[ignore]
390 fn test_bitwriter_cross_byte() {
391 let mut writer = BitWriter::new();
392 writer.write_bits(0xF, 4);
393 writer.write_bits(0x0F, 8);
394 let data = writer.finish();
395 assert_eq!(data, vec![0xFF, 0x00]);
396 }
397
398 #[test]
399 #[ignore]
400 fn test_bitwriter_bool() {
401 let mut writer = BitWriter::new();
402 writer.write_bool(true);
403 writer.write_bool(false);
404 writer.write_bool(true);
405 writer.write_bool(false);
406 writer.write_bool(false);
407 writer.write_bool(false);
408 writer.write_bool(false);
409 writer.write_bool(false);
410 let data = writer.finish();
411 assert_eq!(data, vec![0b0000_0101]);
412 }
413
414 #[test]
415 #[ignore]
416 fn test_roundtrip_bits() {
417 let mut writer = BitWriter::new();
418 writer.write_bits(42, 7);
419 writer.write_bits(1023, 10);
420 writer.write_bits(0, 3);
421 writer.write_bits(255, 8);
422 let data = writer.finish();
423
424 let mut reader = BitReader::new(&data);
425 assert_eq!(reader.read_bits(7).expect("ok"), 42);
426 assert_eq!(reader.read_bits(10).expect("ok"), 1023);
427 assert_eq!(reader.read_bits(3).expect("ok"), 0);
428 assert_eq!(reader.read_bits(8).expect("ok"), 255);
429 }
430
431 #[test]
432 #[ignore]
433 fn test_roundtrip_u64() {
434 for value in [0u64, 1, 5, 16, 17, 100, 272, 273, 1000, 65535, 1_000_000] {
435 let mut writer = BitWriter::new();
436 writer.write_u64(value);
437 let data = writer.finish();
438
439 let mut reader = BitReader::new(&data);
440 let decoded = reader.read_u64().expect("ok");
441 assert_eq!(decoded, value, "U64 roundtrip failed for {value}");
442 }
443 }
444
445 #[test]
446 #[ignore]
447 fn test_bitwriter_align() {
448 let mut writer = BitWriter::new();
449 writer.write_bits(0b101, 3);
450 writer.align_to_byte();
451 writer.write_bits(0xAA, 8);
452 let data = writer.finish();
453 assert_eq!(data.len(), 2);
454 assert_eq!(data[0], 0b0000_0101);
455 assert_eq!(data[1], 0xAA);
456 }
457}