Skip to main content

oxiarc_lzma/
range_coder.rs

1//! Range coder for LZMA compression.
2//!
3//! The range coder is an entropy coding method similar to arithmetic coding.
4//! LZMA uses a specific variant with:
5//! - 32-bit range tracking
6//! - Normalization when range drops below 2^24
7//! - 11-bit probability model (2048 = 50%)
8
9use oxiarc_core::error::{OxiArcError, Result};
10use std::io::Read;
11
12/// Number of bits in probability model.
13pub const PROB_BITS: u32 = 11;
14
15/// Probability representing 50% (1 << 10 = 1024, but we use 2048/2).
16pub const PROB_INIT: u16 = 1 << (PROB_BITS - 1);
17
18/// Maximum probability value.
19pub const PROB_MAX: u16 = 1 << PROB_BITS;
20
21/// Number of bits to shift for probability update.
22pub const MOVE_BITS: u32 = 5;
23
24/// Top value for range normalization.
25const TOP_VALUE: u32 = 1 << 24;
26
27/// Range decoder for LZMA decompression.
28#[derive(Debug)]
29pub struct RangeDecoder<R: Read> {
30    reader: R,
31    range: u32,
32    code: u32,
33    corrupted: bool,
34}
35
36impl<R: Read> RangeDecoder<R> {
37    /// Create a new range decoder.
38    pub fn new(mut reader: R) -> Result<Self> {
39        // Read first byte (should be 0x00)
40        let mut buf = [0u8; 1];
41        reader.read_exact(&mut buf)?;
42
43        if buf[0] != 0x00 {
44            return Err(OxiArcError::invalid_header(
45                "Invalid LZMA stream start byte",
46            ));
47        }
48
49        // Read initial code value (4 bytes, big-endian)
50        let mut code_buf = [0u8; 4];
51        reader.read_exact(&mut code_buf)?;
52        let code = u32::from_be_bytes(code_buf);
53
54        Ok(Self {
55            reader,
56            range: 0xFFFF_FFFF,
57            code,
58            corrupted: false,
59        })
60    }
61
62    /// Create a range decoder for raw LZMA2 stream (no header byte).
63    pub fn new_lzma2(mut reader: R) -> Result<Self> {
64        // Read initial code value (5 bytes for LZMA2)
65        let mut code_buf = [0u8; 5];
66        reader.read_exact(&mut code_buf)?;
67
68        // First byte should be 0
69        if code_buf[0] != 0 {
70            return Err(OxiArcError::invalid_header("Invalid LZMA2 stream"));
71        }
72
73        let code = u32::from_be_bytes([code_buf[1], code_buf[2], code_buf[3], code_buf[4]]);
74
75        Ok(Self {
76            reader,
77            range: 0xFFFF_FFFF,
78            code,
79            corrupted: false,
80        })
81    }
82
83    /// Normalize the range (refill when range gets small).
84    fn normalize(&mut self) -> Result<()> {
85        if self.range < TOP_VALUE {
86            let mut buf = [0u8; 1];
87            self.reader.read_exact(&mut buf)?;
88            self.range <<= 8;
89            self.code = (self.code << 8) | buf[0] as u32;
90        }
91        Ok(())
92    }
93
94    /// Decode a single bit with the given probability.
95    pub fn decode_bit(&mut self, prob: &mut u16) -> Result<u32> {
96        self.normalize()?;
97
98        let bound = (self.range >> PROB_BITS) * (*prob as u32);
99
100        if self.code < bound {
101            // Bit is 0
102            self.range = bound;
103            *prob += (PROB_MAX - *prob) >> MOVE_BITS;
104            Ok(0)
105        } else {
106            // Bit is 1
107            self.range -= bound;
108            self.code -= bound;
109            *prob -= *prob >> MOVE_BITS;
110            Ok(1)
111        }
112    }
113
114    /// Decode a bit with fixed 50% probability.
115    pub fn decode_direct_bit(&mut self) -> Result<u32> {
116        self.normalize()?;
117
118        self.range >>= 1;
119        self.code = self.code.wrapping_sub(self.range);
120
121        let bit = if (self.code as i32) < 0 {
122            self.code = self.code.wrapping_add(self.range);
123            0
124        } else {
125            1
126        };
127
128        Ok(bit)
129    }
130
131    /// Decode multiple bits with fixed probability.
132    pub fn decode_direct_bits(&mut self, count: u32) -> Result<u32> {
133        let mut result = 0u32;
134        for _ in 0..count {
135            result = (result << 1) | self.decode_direct_bit()?;
136        }
137        Ok(result)
138    }
139
140    /// Decode a bit tree (reverse order).
141    pub fn decode_bit_tree_reverse(&mut self, probs: &mut [u16], num_bits: u32) -> Result<u32> {
142        let mut result = 0u32;
143        let mut index = 1usize;
144
145        for i in 0..num_bits {
146            let bit = self.decode_bit(&mut probs[index])?;
147            index = (index << 1) | bit as usize;
148            result |= bit << i;
149        }
150
151        Ok(result)
152    }
153
154    /// Decode a bit tree (normal order).
155    pub fn decode_bit_tree(&mut self, probs: &mut [u16], num_bits: u32) -> Result<u32> {
156        let mut index = 1usize;
157
158        for _ in 0..num_bits {
159            let bit = self.decode_bit(&mut probs[index])?;
160            index = (index << 1) | bit as usize;
161        }
162
163        Ok((index as u32) - (1 << num_bits))
164    }
165
166    /// Check if the stream is corrupted.
167    pub fn is_corrupted(&self) -> bool {
168        self.corrupted
169    }
170
171    /// Check if decoding finished correctly.
172    pub fn is_finished_ok(&self) -> bool {
173        self.code == 0
174    }
175}
176
177/// Range encoder for LZMA compression.
178#[derive(Debug)]
179pub struct RangeEncoder {
180    /// Output buffer.
181    buffer: Vec<u8>,
182    /// Current range.
183    range: u32,
184    /// Low value.
185    low: u64,
186    /// Cache byte.
187    cache: u8,
188    /// Cache size.
189    cache_size: u64,
190}
191
192impl RangeEncoder {
193    /// Create a new range encoder.
194    pub fn new() -> Self {
195        Self {
196            buffer: Vec::new(),
197            range: 0xFFFF_FFFF,
198            low: 0,
199            cache: 0,
200            cache_size: 1,
201        }
202    }
203
204    /// Shift low and write bytes.
205    ///
206    /// This uses the carry-handling cache mechanism from the LZMA SDK.
207    /// The low value is a 64-bit accumulator where bits 32-39 represent overflow (carry).
208    fn shift_low(&mut self) {
209        // Check if we can output bytes:
210        // - low < 0xFF000000: no pending carry propagation needed
211        // - low > 0xFFFFFFFF: there's a carry to propagate
212        if self.low < 0xFF00_0000 || self.low > 0xFFFF_FFFF {
213            // Output pending bytes with carry propagation
214            let mut tmp = self.cache;
215            let carry = (self.low >> 32) as u8;
216
217            loop {
218                let byte = tmp.wrapping_add(carry);
219                self.buffer.push(byte);
220                tmp = 0xFF; // Subsequent bytes are 0xFF (will become 0x00 if carry)
221                self.cache_size -= 1;
222                if self.cache_size == 0 {
223                    break;
224                }
225            }
226
227            // New cache is the top byte of the 32-bit low value
228            self.cache = (self.low >> 24) as u8;
229        }
230
231        // Always increment cache_size (tracks pending bytes)
232        self.cache_size += 1;
233
234        // Shift low left by 8 bits, keeping only 32 bits
235        self.low = (self.low << 8) & 0xFFFF_FFFF;
236    }
237
238    /// Normalize the range.
239    fn normalize(&mut self) {
240        if self.range < TOP_VALUE {
241            self.range <<= 8;
242            self.shift_low();
243        }
244    }
245
246    /// Encode a single bit with the given probability.
247    pub fn encode_bit(&mut self, prob: &mut u16, bit: u32) {
248        let bound = (self.range >> PROB_BITS) * (*prob as u32);
249
250        if bit == 0 {
251            self.range = bound;
252            *prob += (PROB_MAX - *prob) >> MOVE_BITS;
253        } else {
254            self.low += bound as u64;
255            self.range -= bound;
256            *prob -= *prob >> MOVE_BITS;
257        }
258
259        self.normalize();
260    }
261
262    /// Encode a bit with fixed 50% probability.
263    pub fn encode_direct_bit(&mut self, bit: u32) {
264        self.range >>= 1;
265        if bit != 0 {
266            self.low += self.range as u64;
267        }
268        self.normalize();
269    }
270
271    /// Encode multiple bits with fixed probability.
272    pub fn encode_direct_bits(&mut self, value: u32, count: u32) {
273        for i in (0..count).rev() {
274            self.encode_direct_bit((value >> i) & 1);
275        }
276    }
277
278    /// Encode a bit tree (reverse order).
279    pub fn encode_bit_tree_reverse(&mut self, probs: &mut [u16], num_bits: u32, value: u32) {
280        let mut index = 1usize;
281
282        for i in 0..num_bits {
283            let bit = (value >> i) & 1;
284            self.encode_bit(&mut probs[index], bit);
285            index = (index << 1) | bit as usize;
286        }
287    }
288
289    /// Encode a bit tree (normal order).
290    pub fn encode_bit_tree(&mut self, probs: &mut [u16], num_bits: u32, value: u32) {
291        let mut index = 1usize;
292
293        for i in (0..num_bits).rev() {
294            let bit = (value >> i) & 1;
295            self.encode_bit(&mut probs[index], bit);
296            index = (index << 1) | bit as usize;
297        }
298    }
299
300    /// Flush the encoder.
301    pub fn flush(&mut self) {
302        for _ in 0..5 {
303            self.shift_low();
304        }
305    }
306
307    /// Get the encoded data.
308    pub fn finish(mut self) -> Vec<u8> {
309        self.flush();
310        self.buffer
311    }
312}
313
314impl Default for RangeEncoder {
315    fn default() -> Self {
316        Self::new()
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use std::io::Cursor;
324
325    #[test]
326    fn test_prob_constants() {
327        assert_eq!(PROB_INIT, 1024);
328        assert_eq!(PROB_MAX, 2048);
329    }
330
331    #[test]
332    fn test_range_encoder_basic() {
333        let encoder = RangeEncoder::new();
334        assert_eq!(encoder.range, 0xFFFF_FFFF);
335    }
336
337    #[test]
338    fn test_encode_decode_bits() {
339        // Encode some bits
340        let mut encoder = RangeEncoder::new();
341        let mut prob = PROB_INIT;
342
343        encoder.encode_bit(&mut prob, 0);
344        encoder.encode_bit(&mut prob, 1);
345        encoder.encode_bit(&mut prob, 0);
346        encoder.encode_bit(&mut prob, 1);
347
348        let encoded = encoder.finish();
349
350        // The encoder output already includes the leading 0x00 byte
351        // through its cache mechanism, so we use it directly
352        let cursor = Cursor::new(encoded);
353        let mut decoder = RangeDecoder::new(cursor).expect("valid LZMA operation");
354        let mut prob = PROB_INIT;
355
356        assert_eq!(
357            decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
358            0
359        );
360        assert_eq!(
361            decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
362            1
363        );
364        assert_eq!(
365            decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
366            0
367        );
368        assert_eq!(
369            decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
370            1
371        );
372    }
373}