oxiarc_lzma/
range_coder.rs1use oxiarc_core::error::{OxiArcError, Result};
10use std::io::Read;
11
12pub const PROB_BITS: u32 = 11;
14
15pub const PROB_INIT: u16 = 1 << (PROB_BITS - 1);
17
18pub const PROB_MAX: u16 = 1 << PROB_BITS;
20
21pub const MOVE_BITS: u32 = 5;
23
24const TOP_VALUE: u32 = 1 << 24;
26
27#[derive(Debug)]
29pub struct RangeDecoder<R: Read> {
30 reader: R,
31 range: u32,
32 code: u32,
33 corrupted: bool,
34}
35
36impl<R: Read> RangeDecoder<R> {
37 pub fn new(mut reader: R) -> Result<Self> {
39 let mut buf = [0u8; 1];
41 reader.read_exact(&mut buf)?;
42
43 if buf[0] != 0x00 {
44 return Err(OxiArcError::invalid_header(
45 "Invalid LZMA stream start byte",
46 ));
47 }
48
49 let mut code_buf = [0u8; 4];
51 reader.read_exact(&mut code_buf)?;
52 let code = u32::from_be_bytes(code_buf);
53
54 Ok(Self {
55 reader,
56 range: 0xFFFF_FFFF,
57 code,
58 corrupted: false,
59 })
60 }
61
62 pub fn new_lzma2(mut reader: R) -> Result<Self> {
64 let mut code_buf = [0u8; 5];
66 reader.read_exact(&mut code_buf)?;
67
68 if code_buf[0] != 0 {
70 return Err(OxiArcError::invalid_header("Invalid LZMA2 stream"));
71 }
72
73 let code = u32::from_be_bytes([code_buf[1], code_buf[2], code_buf[3], code_buf[4]]);
74
75 Ok(Self {
76 reader,
77 range: 0xFFFF_FFFF,
78 code,
79 corrupted: false,
80 })
81 }
82
83 fn normalize(&mut self) -> Result<()> {
85 if self.range < TOP_VALUE {
86 let mut buf = [0u8; 1];
87 self.reader.read_exact(&mut buf)?;
88 self.range <<= 8;
89 self.code = (self.code << 8) | buf[0] as u32;
90 }
91 Ok(())
92 }
93
94 pub fn decode_bit(&mut self, prob: &mut u16) -> Result<u32> {
96 self.normalize()?;
97
98 let bound = (self.range >> PROB_BITS) * (*prob as u32);
99
100 if self.code < bound {
101 self.range = bound;
103 *prob += (PROB_MAX - *prob) >> MOVE_BITS;
104 Ok(0)
105 } else {
106 self.range -= bound;
108 self.code -= bound;
109 *prob -= *prob >> MOVE_BITS;
110 Ok(1)
111 }
112 }
113
114 pub fn decode_direct_bit(&mut self) -> Result<u32> {
116 self.normalize()?;
117
118 self.range >>= 1;
119 self.code = self.code.wrapping_sub(self.range);
120
121 let bit = if (self.code as i32) < 0 {
122 self.code = self.code.wrapping_add(self.range);
123 0
124 } else {
125 1
126 };
127
128 Ok(bit)
129 }
130
131 pub fn decode_direct_bits(&mut self, count: u32) -> Result<u32> {
133 let mut result = 0u32;
134 for _ in 0..count {
135 result = (result << 1) | self.decode_direct_bit()?;
136 }
137 Ok(result)
138 }
139
140 pub fn decode_bit_tree_reverse(&mut self, probs: &mut [u16], num_bits: u32) -> Result<u32> {
142 let mut result = 0u32;
143 let mut index = 1usize;
144
145 for i in 0..num_bits {
146 let bit = self.decode_bit(&mut probs[index])?;
147 index = (index << 1) | bit as usize;
148 result |= bit << i;
149 }
150
151 Ok(result)
152 }
153
154 pub fn decode_bit_tree(&mut self, probs: &mut [u16], num_bits: u32) -> Result<u32> {
156 let mut index = 1usize;
157
158 for _ in 0..num_bits {
159 let bit = self.decode_bit(&mut probs[index])?;
160 index = (index << 1) | bit as usize;
161 }
162
163 Ok((index as u32) - (1 << num_bits))
164 }
165
166 pub fn is_corrupted(&self) -> bool {
168 self.corrupted
169 }
170
171 pub fn is_finished_ok(&self) -> bool {
173 self.code == 0
174 }
175}
176
177#[derive(Debug)]
179pub struct RangeEncoder {
180 buffer: Vec<u8>,
182 range: u32,
184 low: u64,
186 cache: u8,
188 cache_size: u64,
190}
191
192impl RangeEncoder {
193 pub fn new() -> Self {
195 Self {
196 buffer: Vec::new(),
197 range: 0xFFFF_FFFF,
198 low: 0,
199 cache: 0,
200 cache_size: 1,
201 }
202 }
203
204 fn shift_low(&mut self) {
209 if self.low < 0xFF00_0000 || self.low > 0xFFFF_FFFF {
213 let mut tmp = self.cache;
215 let carry = (self.low >> 32) as u8;
216
217 loop {
218 let byte = tmp.wrapping_add(carry);
219 self.buffer.push(byte);
220 tmp = 0xFF; self.cache_size -= 1;
222 if self.cache_size == 0 {
223 break;
224 }
225 }
226
227 self.cache = (self.low >> 24) as u8;
229 }
230
231 self.cache_size += 1;
233
234 self.low = (self.low << 8) & 0xFFFF_FFFF;
236 }
237
238 fn normalize(&mut self) {
240 if self.range < TOP_VALUE {
241 self.range <<= 8;
242 self.shift_low();
243 }
244 }
245
246 pub fn encode_bit(&mut self, prob: &mut u16, bit: u32) {
248 let bound = (self.range >> PROB_BITS) * (*prob as u32);
249
250 if bit == 0 {
251 self.range = bound;
252 *prob += (PROB_MAX - *prob) >> MOVE_BITS;
253 } else {
254 self.low += bound as u64;
255 self.range -= bound;
256 *prob -= *prob >> MOVE_BITS;
257 }
258
259 self.normalize();
260 }
261
262 pub fn encode_direct_bit(&mut self, bit: u32) {
264 self.range >>= 1;
265 if bit != 0 {
266 self.low += self.range as u64;
267 }
268 self.normalize();
269 }
270
271 pub fn encode_direct_bits(&mut self, value: u32, count: u32) {
273 for i in (0..count).rev() {
274 self.encode_direct_bit((value >> i) & 1);
275 }
276 }
277
278 pub fn encode_bit_tree_reverse(&mut self, probs: &mut [u16], num_bits: u32, value: u32) {
280 let mut index = 1usize;
281
282 for i in 0..num_bits {
283 let bit = (value >> i) & 1;
284 self.encode_bit(&mut probs[index], bit);
285 index = (index << 1) | bit as usize;
286 }
287 }
288
289 pub fn encode_bit_tree(&mut self, probs: &mut [u16], num_bits: u32, value: u32) {
291 let mut index = 1usize;
292
293 for i in (0..num_bits).rev() {
294 let bit = (value >> i) & 1;
295 self.encode_bit(&mut probs[index], bit);
296 index = (index << 1) | bit as usize;
297 }
298 }
299
300 pub fn flush(&mut self) {
302 for _ in 0..5 {
303 self.shift_low();
304 }
305 }
306
307 pub fn finish(mut self) -> Vec<u8> {
309 self.flush();
310 self.buffer
311 }
312}
313
314impl Default for RangeEncoder {
315 fn default() -> Self {
316 Self::new()
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use std::io::Cursor;
324
325 #[test]
326 fn test_prob_constants() {
327 assert_eq!(PROB_INIT, 1024);
328 assert_eq!(PROB_MAX, 2048);
329 }
330
331 #[test]
332 fn test_range_encoder_basic() {
333 let encoder = RangeEncoder::new();
334 assert_eq!(encoder.range, 0xFFFF_FFFF);
335 }
336
337 #[test]
338 fn test_encode_decode_bits() {
339 let mut encoder = RangeEncoder::new();
341 let mut prob = PROB_INIT;
342
343 encoder.encode_bit(&mut prob, 0);
344 encoder.encode_bit(&mut prob, 1);
345 encoder.encode_bit(&mut prob, 0);
346 encoder.encode_bit(&mut prob, 1);
347
348 let encoded = encoder.finish();
349
350 let cursor = Cursor::new(encoded);
353 let mut decoder = RangeDecoder::new(cursor).expect("valid LZMA operation");
354 let mut prob = PROB_INIT;
355
356 assert_eq!(
357 decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
358 0
359 );
360 assert_eq!(
361 decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
362 1
363 );
364 assert_eq!(
365 decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
366 0
367 );
368 assert_eq!(
369 decoder.decode_bit(&mut prob).expect("valid LZMA operation"),
370 1
371 );
372 }
373}