Skip to main content

dicom_toolkit_codec/jpeg_ls/
bitstream.rs

1//! Bitstream reader and writer with JPEG-LS FF-bitstuffing.
2//!
3//! Per ISO/IEC 14495-1 §A.1: after a 0xFF byte is written/read, the next byte
4//! has its high bit reserved (stuffed to 0). This means only 7 payload bits
5//! follow an 0xFF byte, not 8.
6
7use dicom_toolkit_core::error::{DcmError, DcmResult};
8
9// ── BitReader ─────────────────────────────────────────────────────────────────
10
11/// Reads bits from a byte slice, handling JPEG-LS FF-bitstuffing.
12pub struct BitReader<'a> {
13    data: &'a [u8],
14    byte_pos: usize,
15    /// Cached bits (MSB-aligned in a u64).
16    cache: u64,
17    /// Number of valid bits in `cache`.
18    valid_bits: i32,
19    /// Position of the next 0xFF byte (for fast-path skipping).
20    next_ff: usize,
21}
22
23impl<'a> BitReader<'a> {
24    pub fn new(data: &'a [u8]) -> Self {
25        let next_ff = find_next_ff(data, 0);
26        let mut reader = Self {
27            data,
28            byte_pos: 0,
29            cache: 0,
30            valid_bits: 0,
31            next_ff,
32        };
33        reader.fill();
34        reader
35    }
36
37    /// Current byte position in the source data.
38    pub fn byte_position(&self) -> usize {
39        // Back-track: the cache may hold bytes that haven't been "consumed".
40        let mut pos = self.byte_pos;
41        let mut bits = self.valid_bits;
42        while bits > 0 && pos > 0 {
43            pos -= 1;
44            let bits_from_byte = if pos > 0 && self.data[pos - 1] == 0xFF {
45                7
46            } else {
47                8
48            };
49            bits -= bits_from_byte;
50        }
51        pos
52    }
53
54    /// Read `length` bits (up to 25) and return them right-aligned.
55    #[inline]
56    pub fn read_value(&mut self, length: i32) -> DcmResult<i32> {
57        debug_assert!(length > 0 && length <= 25);
58        if self.valid_bits < length {
59            self.fill();
60            if self.valid_bits < length {
61                return Err(DcmError::DecompressionError {
62                    reason: "JPEG-LS: unexpected end of bitstream".into(),
63                });
64            }
65        }
66        let result = (self.cache >> (64 - length)) as i32;
67        self.skip(length);
68        Ok(result)
69    }
70
71    /// Read a single bit.
72    #[inline]
73    pub fn read_bit(&mut self) -> DcmResult<bool> {
74        if self.valid_bits <= 0 {
75            self.fill();
76            if self.valid_bits <= 0 {
77                return Err(DcmError::DecompressionError {
78                    reason: "JPEG-LS: unexpected end of bitstream".into(),
79                });
80            }
81        }
82        let set = (self.cache & (1u64 << 63)) != 0;
83        self.skip(1);
84        Ok(set)
85    }
86
87    /// Peek at the top 8 bits of the cache (for lookup-table decoding).
88    #[inline]
89    pub fn peek_byte(&mut self) -> i32 {
90        if self.valid_bits < 8 {
91            self.fill();
92        }
93        (self.cache >> 56) as i32
94    }
95
96    /// Count leading zero bits (up to 16).
97    #[inline]
98    pub fn read_highbits(&mut self) -> DcmResult<i32> {
99        if self.valid_bits < 16 {
100            self.fill();
101        }
102        let mut count = 0i32;
103        let mut val = self.cache;
104        while count < 16 {
105            if (val & (1u64 << 63)) != 0 {
106                self.skip(count + 1);
107                return Ok(count);
108            }
109            val <<= 1;
110            count += 1;
111        }
112        // More than 16 leading zeros.
113        self.skip(15);
114        let mut highbits = 15i32;
115        loop {
116            if self.read_bit()? {
117                return Ok(highbits);
118            }
119            highbits += 1;
120        }
121    }
122
123    /// Skip `n` bits in the cache.
124    #[inline]
125    pub fn skip(&mut self, n: i32) {
126        self.valid_bits -= n;
127        self.cache <<= n as u32;
128    }
129
130    /// Fill the cache from the byte stream (with FF-bitstuffing handling).
131    fn fill(&mut self) {
132        // Fast path: no 0xFF nearby.
133        if self.byte_pos + 8 <= self.next_ff {
134            let bytes_to_read = ((64 - self.valid_bits) >> 3) as usize;
135            let bytes_to_read = bytes_to_read.min(self.data.len() - self.byte_pos);
136            for _ in 0..bytes_to_read {
137                self.cache |= (self.data[self.byte_pos] as u64) << (56 - self.valid_bits as u32);
138                self.byte_pos += 1;
139                self.valid_bits += 8;
140            }
141            return;
142        }
143
144        // Slow path: handle FF-bitstuffing.
145        while self.valid_bits < 56 {
146            if self.byte_pos >= self.data.len() {
147                return;
148            }
149
150            let val = self.data[self.byte_pos] as u64;
151
152            // Check if this is a marker (FF followed by >= 0x80).
153            if val == 0xFF
154                && (self.byte_pos + 1 >= self.data.len()
155                    || (self.data[self.byte_pos + 1] & 0x80) != 0)
156            {
157                return; // Don't read into markers.
158            }
159
160            self.cache |= val << (56 - self.valid_bits as u32);
161            self.byte_pos += 1;
162            self.valid_bits += 8;
163
164            // After reading 0xFF, the next byte has only 7 payload bits.
165            if val == 0xFF {
166                self.valid_bits -= 1;
167            }
168        }
169
170        self.next_ff = find_next_ff(self.data, self.byte_pos);
171    }
172}
173
174fn find_next_ff(data: &[u8], start: usize) -> usize {
175    data[start..]
176        .iter()
177        .position(|&b| b == 0xFF)
178        .map_or(data.len(), |i| start + i)
179}
180
181// ── BitWriter ─────────────────────────────────────────────────────────────────
182
183/// Writes bits to a `Vec<u8>`, handling JPEG-LS FF-bitstuffing.
184pub struct BitWriter {
185    output: Vec<u8>,
186    /// Accumulator (MSB-aligned in a u32).
187    val_current: u32,
188    /// Number of free bits remaining in `val_current` (starts at 32).
189    bit_pos: i32,
190    /// Whether the last written byte was 0xFF.
191    is_ff_written: bool,
192}
193
194impl Default for BitWriter {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl BitWriter {
201    pub fn new() -> Self {
202        Self {
203            output: Vec::with_capacity(4096),
204            val_current: 0,
205            bit_pos: 32,
206            is_ff_written: false,
207        }
208    }
209
210    /// Append `length` bits from `value` (right-aligned) to the bitstream.
211    ///
212    /// Supports any non-negative length. For lengths >= 32 the value must be 0
213    /// (only zero-padding uses lengths that large).
214    #[inline]
215    pub fn append(&mut self, value: i32, length: i32) {
216        debug_assert!(length >= 0);
217        debug_assert!(length < 32 || value == 0, "only 0-bits for length >= 32");
218
219        // Handle large zero-padding (e.g. unary prefix in Golomb escape codes).
220        if length >= 32 {
221            let mut remaining = length;
222            while remaining >= 31 {
223                self.append_short(0, 31);
224                remaining -= 31;
225            }
226            if remaining > 0 {
227                self.append_short(0, remaining);
228            }
229            return;
230        }
231
232        self.append_short(value, length);
233    }
234
235    /// Inner append for lengths 0..31.
236    #[inline]
237    fn append_short(&mut self, value: i32, length: i32) {
238        debug_assert!((0..32).contains(&length));
239        if length == 0 {
240            return;
241        }
242        self.bit_pos -= length;
243        if self.bit_pos >= 0 {
244            if self.bit_pos < 32 {
245                self.val_current |= (value as u32) << (self.bit_pos as u32);
246            }
247            return;
248        }
249        // Overflow: flush and continue.
250        self.val_current |= (value as u32).wrapping_shr((-self.bit_pos) as u32);
251        self.flush();
252        if self.bit_pos < 0 {
253            self.val_current |= (value as u32).wrapping_shr((-self.bit_pos) as u32);
254            self.flush();
255        }
256        debug_assert!(self.bit_pos >= 0);
257        if self.bit_pos < 32 {
258            self.val_current |= (value as u32) << (self.bit_pos as u32);
259        }
260    }
261
262    /// Append `length` 1-bits.
263    #[inline]
264    pub fn append_ones(&mut self, length: i32) {
265        self.append((1 << length) - 1, length);
266    }
267
268    /// Finalize the bitstream: flush remaining bits.
269    pub fn end_scan(&mut self) {
270        self.flush();
271        if self.is_ff_written {
272            self.append(0, (self.bit_pos - 1) % 8);
273        } else {
274            self.append(0, self.bit_pos % 8);
275        }
276        self.flush();
277    }
278
279    /// Get the written bytes.
280    pub fn into_bytes(self) -> Vec<u8> {
281        self.output
282    }
283
284    /// Current byte length of written data.
285    pub fn len(&self) -> usize {
286        self.output.len()
287    }
288
289    /// Whether the writer has no data.
290    pub fn is_empty(&self) -> bool {
291        self.output.is_empty()
292    }
293
294    fn flush(&mut self) {
295        for _ in 0..4 {
296            if self.bit_pos >= 32 {
297                break;
298            }
299            if self.is_ff_written {
300                // After 0xFF: write 7 bits (inserting the 0-stuffed bit).
301                self.write_byte((self.val_current >> 25) as u8);
302                self.val_current <<= 7;
303                self.bit_pos += 7;
304                self.is_ff_written = false;
305            } else {
306                let byte = (self.val_current >> 24) as u8;
307                self.write_byte(byte);
308                self.is_ff_written = byte == 0xFF;
309                self.val_current <<= 8;
310                self.bit_pos += 8;
311            }
312        }
313    }
314
315    fn write_byte(&mut self, byte: u8) {
316        self.output.push(byte);
317    }
318}
319
320// ── Tests ─────────────────────────────────────────────────────────────────────
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn roundtrip_simple_values() {
328        let mut w = BitWriter::new();
329        w.append(0b101, 3);
330        w.append(0b11110000, 8);
331        w.append(0b1, 1);
332        w.end_scan();
333        let bytes = w.into_bytes();
334
335        let mut r = BitReader::new(&bytes);
336        assert_eq!(r.read_value(3).unwrap(), 0b101);
337        assert_eq!(r.read_value(8).unwrap(), 0b11110000);
338        assert!(r.read_bit().unwrap());
339    }
340
341    #[test]
342    fn ff_bitstuffing_roundtrip() {
343        // Write values that produce 0xFF bytes and verify round-trip.
344        let mut w = BitWriter::new();
345        w.append(0xFF, 8);
346        w.append(0x01, 3);
347        w.end_scan();
348        let bytes = w.into_bytes();
349
350        // After 0xFF, a 0 bit must be stuffed by the writer.
351        let mut r = BitReader::new(&bytes);
352        assert_eq!(r.read_value(8).unwrap(), 0xFF);
353        assert_eq!(r.read_value(3).unwrap(), 0x01);
354    }
355
356    #[test]
357    fn read_highbits_counts_zeros() {
358        let mut w = BitWriter::new();
359        // 5 zeros then a 1-bit
360        w.append(0b000001, 6);
361        w.append(0b1, 1); // padding
362        w.end_scan();
363        let bytes = w.into_bytes();
364
365        let mut r = BitReader::new(&bytes);
366        assert_eq!(r.read_highbits().unwrap(), 5);
367    }
368
369    #[test]
370    fn roundtrip_many_small_values() {
371        let mut w = BitWriter::new();
372        for i in 0..100 {
373            w.append(i & 0x1F, 5);
374        }
375        w.end_scan();
376        let bytes = w.into_bytes();
377
378        let mut r = BitReader::new(&bytes);
379        for i in 0..100 {
380            assert_eq!(r.read_value(5).unwrap(), i & 0x1F);
381        }
382    }
383
384    #[test]
385    fn peek_byte_works() {
386        let mut w = BitWriter::new();
387        w.append(0b10110100, 8);
388        w.end_scan();
389        let bytes = w.into_bytes();
390
391        let mut r = BitReader::new(&bytes);
392        assert_eq!(r.peek_byte(), 0b10110100);
393        // peek shouldn't consume
394        assert_eq!(r.read_value(8).unwrap(), 0b10110100);
395    }
396}