gistools/readers/las/laz/
arithmetic_decoder.rs

1use crate::{
2    parsers::Reader,
3    readers::util::{U32I32F32, U64I64F64, ValueType32, ValueType64},
4};
5use alloc::{rc::Rc, vec, vec::Vec};
6use core::cell::RefCell;
7
8// self header byte needs to change in case incompatible change happen
9/// The header byte for arithmetic coding
10pub const AC_HEADER_BYTE: u8 = 2;
11/// The buffer size
12pub const AC_BUFFER_SIZE: u32 = 4096;
13
14/// threshold for renormalization
15pub const AC_MIN_LENGTH: u32 = 0x01000000;
16/// maximum AC interval length
17pub const AC_MAX_LENGTH: u32 = 0xffffffff;
18
19// Maximum values for binary models
20/// length bits discarded before mult
21pub const BM_LENGTH_SHIFT: u32 = 13;
22/// for adaptive models
23pub const BM_MAX_COUNT: u32 = 1 << BM_LENGTH_SHIFT;
24
25// Maximum values for general models
26/// length bits discarded before mult
27pub const DM_LENGTH_SHIFT: u32 = 15;
28/// for adaptive models
29pub const DM_MAX_COUNT: u32 = 1 << DM_LENGTH_SHIFT;
30
31/// A "Corrector" wrapper that handles two type of arithmetic models
32#[derive(Debug, Clone)]
33pub enum Corrector {
34    /// ArithmeticModel
35    ArithmeticModel(ArithmeticModel),
36    /// ArithmeticBitModel
37    ArithmeticBitModel(ArithmeticBitModel),
38}
39impl Corrector {
40    /// Initialize the model
41    pub fn init(&mut self, table: Option<Vec<u32>>) {
42        match self {
43            Corrector::ArithmeticModel(m) => m.init(table),
44            Corrector::ArithmeticBitModel(m) => m.init(table),
45        }
46    }
47    /// Update the model
48    pub fn update(&mut self) {
49        match self {
50            Corrector::ArithmeticModel(m) => m.update(),
51            Corrector::ArithmeticBitModel(m) => m.update(),
52        }
53    }
54    /// Get the model. Most models are ArithmeticModel, so when we ask for a model, we normally
55    /// want the arithmetic model. The Bit Model is only the first model.
56    pub fn get_model(&mut self) -> Option<&mut ArithmeticModel> {
57        match self {
58            Corrector::ArithmeticModel(m) => Some(m),
59            Corrector::ArithmeticBitModel(_) => None,
60        }
61    }
62    /// Get the bit model. The FIRST model is the bit model.
63    pub fn get_bit_model(&mut self) -> Option<&mut ArithmeticBitModel> {
64        match self {
65            Corrector::ArithmeticModel(_) => None,
66            Corrector::ArithmeticBitModel(m) => Some(m),
67        }
68    }
69}
70impl From<ArithmeticModel> for Corrector {
71    fn from(m: ArithmeticModel) -> Self {
72        Corrector::ArithmeticModel(m)
73    }
74}
75impl From<ArithmeticBitModel> for Corrector {
76    fn from(m: ArithmeticBitModel) -> Self {
77        Corrector::ArithmeticBitModel(m)
78    }
79}
80
81/// <https://github.com/LASzip/LASzip/blob/master/src/arithmeticdecoder.cpp>
82#[derive(Debug, Default, Clone, PartialEq)]
83pub struct ArithmeticDecoder<T: Reader> {
84    /// The data to read from
85    pub reader: Rc<RefCell<T>>,
86    /// The current value
87    pub value: u32,
88    /// The current length
89    pub length: u32,
90}
91impl<T: Reader> ArithmeticDecoder<T> {
92    /// Create a new arithmetic decoder
93    pub fn new(reader: Rc<RefCell<T>>) -> Self {
94        Self { reader, value: 0, length: AC_MAX_LENGTH }
95    }
96
97    /// Initialize the decoder
98    ///
99    /// ## Parameters
100    /// - `reallyInit`: If set to true, initializes the value
101    pub fn init(&mut self, really_init: bool) {
102        self.length = AC_MAX_LENGTH;
103        if really_init {
104            self.value = self.reader.borrow().uint32_be(None);
105        }
106    }
107
108    /// Read bits
109    ///
110    /// ## Parameters
111    /// - `bits`: The number of bits
112    ///
113    /// ## Returns
114    /// The decoded bits
115    pub fn read_bits(&mut self, mut bits: u32) -> u32 {
116        assert!(bits != 0 && (bits <= 32));
117
118        if bits > 19 {
119            let tmp = self.read_short() as u32;
120            bits -= 16;
121            let tmp1 = self.read_bits(bits) * 65_536;
122            return tmp1 | tmp;
123        }
124
125        self.length >>= bits;
126        let sym = self.value / self.length;
127        self.value -= self.length * sym; // update interval
128
129        if self.length < AC_MIN_LENGTH {
130            self.renorm_dec_interval(); // renormalization
131        }
132        if sym >= 1 << bits {
133            panic!("4711");
134        }
135
136        sym
137    }
138
139    /// Decode a bit
140    ///
141    /// ## Parameters
142    /// - `m`: The arithmetic bit model
143    ///
144    /// ## Returns
145    /// The decoded bit
146    pub fn decode_bit(&mut self, m: &mut ArithmeticBitModel) -> u32 {
147        let x = m.bit0_prob * (self.length >> BM_LENGTH_SHIFT); // product l x p0
148        let sym = if self.value >= x { 1 } else { 0 }; // decision
149        // update & shift interval
150        if sym == 0 {
151            self.length = x;
152            m.bit0_count += 1;
153        } else {
154            self.value -= x; // shifted interval base = 0
155            self.length -= x;
156        }
157
158        if self.length < AC_MIN_LENGTH {
159            self.renorm_dec_interval(); // renormalization
160        }
161        m.bits_until_update -= 1;
162        if m.bits_until_update == 0 {
163            m.update(); // periodic model update
164        }
165
166        sym // return data bit value
167    }
168
169    /// Decode the symbol
170    ///
171    /// ## Parameters
172    /// - `m`: The arithmetic model
173    ///
174    /// ## Returns
175    /// The decoded symbol
176    pub fn decode_symbol(&mut self, m: &mut ArithmeticModel) -> u32 {
177        let mut sym;
178        let mut x;
179        let mut y = self.length;
180
181        if m.decoder_table_index != NULL_POINTER {
182            self.length >>= DM_LENGTH_SHIFT;
183            let dv = self.value / self.length;
184            let t = dv >> m.table_shift;
185
186            sym = m.distribution[(m.decoder_table_index + t) as usize]; // initial decision based on table look-up
187            let mut n = m.distribution[(m.decoder_table_index + t + 1) as usize] + 1;
188
189            while n > sym + 1 {
190                // finish with bisection search
191                let k = (sym + n) >> 1;
192                if m.distribution[k as usize] > dv {
193                    n = k;
194                } else {
195                    sym = k;
196                }
197            }
198            // compute products
199            x = m.distribution[sym as usize] * self.length;
200            if sym != m.last_symbol {
201                y = m.distribution[(sym + 1) as usize] * self.length;
202            }
203        } else {
204            // decode using only multiplications
205            sym = 0;
206            x = sym;
207            self.length >>= DM_LENGTH_SHIFT;
208            let mut n = m.symbols;
209            let mut k = n >> 1;
210            // decode via bisection search
211            loop {
212                let z = self.length * m.distribution[k as usize];
213                if z > self.value {
214                    n = k;
215                    y = z; // value is smaller
216                } else {
217                    sym = k;
218                    x = z; // value is larger or equal
219                }
220                k = (sym + n) >> 1;
221                if k == sym {
222                    break;
223                }
224            }
225        }
226
227        self.value -= x; // update interval
228        self.length = y - x;
229
230        if self.length < AC_MIN_LENGTH {
231            self.renorm_dec_interval(); // renormalization
232        }
233
234        m.distribution[(m.symbol_count_index + sym) as usize] += 1;
235        m.symbols_until_update -= 1;
236        if m.symbols_until_update == 0 {
237            m.update(); // periodic model update
238        }
239        assert!(sym < m.symbols);
240
241        sym
242    }
243
244    /// ## Returns
245    /// The decoded bit
246    pub fn read_bit(&mut self) -> u32 {
247        self.length >>= 1;
248        let sym = self.value / self.length; // decode symbol, change length
249        self.value -= self.length * sym; // update interval
250
251        if self.length < AC_MIN_LENGTH {
252            self.renorm_dec_interval(); // renormalization
253        }
254        if sym >= 2 {
255            panic!("4711");
256        }
257
258        sym
259    }
260
261    /// ## Returns
262    /// The decoded byte
263    pub fn read_byte(&mut self) -> u8 {
264        self.length >>= 8;
265        let sym = self.value / self.length; // decode symbol, change length
266        self.value -= self.length * sym; // update interval
267
268        if self.length < AC_MIN_LENGTH {
269            self.renorm_dec_interval(); // renormalization
270        }
271        if sym >= 1 << 8 {
272            panic!("4711");
273        }
274
275        sym as u8
276    }
277
278    /// ## Returns
279    /// The decoded short
280    pub fn read_short(&mut self) -> u16 {
281        self.length >>= 16;
282        let sym = self.value / self.length; // decode symbol, change length
283        self.value -= self.length * sym; // update interval
284
285        if self.length < AC_MIN_LENGTH {
286            self.renorm_dec_interval(); // renormalization
287        }
288        if sym >= 65_536 {
289            panic!("4711");
290        }
291
292        sym as u16
293    }
294
295    /// ## Returns
296    /// The decoded int
297    pub fn read_int(&mut self) -> u32 {
298        let lower_int = self.read_short() as u32;
299        let upper_int = self.read_short() as u32;
300        (upper_int << 16) | lower_int
301    }
302
303    /// ## Returns
304    /// The decoded float
305    pub fn read_float(&mut self) -> f32 {
306        // danger in float reinterpretation
307        let u32i32f32 = U32I32F32::new(self.read_int(), ValueType32::U32);
308        u32i32f32.f32()
309    }
310
311    /// ## Returns
312    /// The decoded int64
313    pub fn read_int64(&mut self) -> u64 {
314        let lower_int = self.read_int() as u64;
315        let upper_int = self.read_int() as u64;
316        (upper_int << 32) | lower_int
317    }
318
319    /// ## Returns
320    /// The decoded double
321    pub fn read_double(&mut self) -> f64 {
322        let u64i64f64 = U64I64F64::new(self.read_int64(), ValueType64::U64);
323        u64i64f64.f64()
324    }
325
326    /// Renormalize the decoder interval
327    pub fn renorm_dec_interval(&mut self) {
328        loop {
329            let byte = self.reader.borrow().uint8(None) as u32;
330            self.value = (self.value << 8) | byte;
331            self.length <<= 8;
332            if self.length >= AC_MIN_LENGTH {
333                break;
334            }
335        }
336    }
337}
338
339const NULL_POINTER: u32 = u32::MAX;
340
341/// Arithmetic Model
342#[derive(Debug, Default, Clone, PartialEq)]
343pub struct ArithmeticModel {
344    /// The distribution
345    pub distribution: Vec<u32>,
346    /// The symbol count index
347    pub symbol_count_index: u32,
348    /// The decoder table index
349    pub decoder_table_index: u32,
350    /// The total count
351    pub total_count: u32,
352    /// The update cycle
353    pub update_cycle: u32,
354    /// The symbols until update
355    pub symbols_until_update: u32,
356    /// The last symbol
357    pub last_symbol: u32,
358    /// The table size
359    pub table_size: u32,
360    /// The table shift
361    pub table_shift: u32,
362    /// The symbols
363    pub symbols: u32,
364    /// The compress
365    pub compress: bool,
366}
367impl ArithmeticModel {
368    /// Create a new ArithmeticModel
369    pub fn new(symbols: u32, compress: bool) -> Self {
370        Self {
371            distribution: Vec::new(),
372            symbol_count_index: 0,
373            decoder_table_index: NULL_POINTER,
374            total_count: 0,
375            update_cycle: 0,
376            symbols_until_update: 0,
377            last_symbol: 0,
378            table_size: 0,
379            table_shift: 0,
380            symbols,
381            compress,
382        }
383    }
384
385    /// Initialize the model
386    pub fn init(&mut self, table: Option<Vec<u32>>) {
387        if self.distribution.is_empty() {
388            if self.symbols < 2 || self.symbols > (1 << 11) {
389                panic!("invalid number of symbols");
390            }
391            self.last_symbol = self.symbols - 1;
392            if !self.compress && self.symbols > 16 {
393                let mut table_bits = 3;
394                while self.symbols > 1 << (table_bits + 2) {
395                    table_bits += 1;
396                }
397                self.table_size = 1 << table_bits;
398                self.table_shift = DM_LENGTH_SHIFT - table_bits;
399                self.distribution = vec![0; (2 * self.symbols + self.table_size + 2) as usize];
400                self.decoder_table_index = 2 * self.symbols;
401            } else {
402                // small alphabet: no table needed
403                self.decoder_table_index = NULL_POINTER;
404                self.table_shift = 0;
405                self.table_size = 0;
406                self.distribution = vec![0; 2 * self.symbols as usize];
407            }
408            self.symbol_count_index = self.symbols;
409        }
410
411        self.total_count = 0;
412        self.update_cycle = self.symbols;
413        if let Some(table) = table {
414            for k in 0..self.symbols {
415                self.distribution[(self.symbol_count_index + k) as usize] = table[k as usize];
416            }
417        } else {
418            for k in 0..self.symbols {
419                self.distribution[(self.symbol_count_index + k) as usize] = 1;
420            }
421        }
422
423        self.update();
424        self.update_cycle = (self.symbols + 6) >> 1;
425        self.symbols_until_update = self.update_cycle;
426    }
427
428    /// Update the model
429    pub fn update(&mut self) {
430        // halve counts when a threshold is reached
431        self.total_count += self.update_cycle;
432        if self.total_count > DM_MAX_COUNT {
433            self.total_count = 0;
434            for n in 0..self.symbols {
435                self.distribution[(self.symbol_count_index + n) as usize] =
436                    (self.distribution[(self.symbol_count_index + n) as usize] + 1) >> 1;
437                self.total_count += self.distribution[(self.symbol_count_index + n) as usize];
438            }
439        }
440
441        // compute cumulative distribution, decoder table
442        let mut sum = 0;
443        let mut s = 0;
444        let scale = 0x80000000 / self.total_count;
445
446        if self.compress || self.table_size == 0 {
447            for k in 0..self.symbols {
448                self.distribution[k as usize] =
449                    ((scale as u64 * sum as u64) >> (31 - DM_LENGTH_SHIFT)) as u32;
450                sum += self.distribution[(self.symbol_count_index + k) as usize];
451            }
452        } else {
453            for k in 0..self.symbols {
454                self.distribution[k as usize] =
455                    ((scale as u64 * sum as u64) >> (31 - DM_LENGTH_SHIFT)) as u32;
456                sum += self.distribution[(self.symbol_count_index + k) as usize];
457                let w = self.distribution[k as usize] >> self.table_shift;
458                while s < w {
459                    s += 1;
460                    self.distribution[(self.decoder_table_index + s) as usize] = k - 1;
461                }
462            }
463            self.distribution[self.decoder_table_index as usize] = 0;
464            while s <= self.table_size {
465                s += 1;
466                self.distribution[(self.decoder_table_index + s) as usize] = self.symbols - 1;
467            }
468        }
469
470        // set frequency of model updates
471        self.update_cycle = (5 * self.update_cycle) >> 2;
472        let max_cycle = (self.symbols + 6) << 3;
473        if self.update_cycle > max_cycle {
474            self.update_cycle = max_cycle;
475        }
476        self.symbols_until_update = self.update_cycle;
477    }
478}
479
480/// Arithmetic Bit Model
481#[derive(Debug, Default, Clone, PartialEq)]
482pub struct ArithmeticBitModel {
483    // start with frequent updates
484    /// update cycle
485    pub update_cycle: u32,
486    /// bits until update
487    pub bits_until_update: u32,
488    // initialization to equiprobable model
489    /// bit 0 probability
490    pub bit0_prob: u32,
491    /// bit 0 count
492    pub bit0_count: u32,
493    /// bit count
494    pub bit_count: u32,
495}
496
497impl ArithmeticBitModel {
498    /// Create a new ArithmeticBitModel
499    pub fn new() -> Self {
500        Self {
501            update_cycle: 4,
502            bits_until_update: 4,
503            bit0_prob: 1 << (BM_LENGTH_SHIFT - 1),
504            bit0_count: 1,
505            bit_count: 2,
506        }
507    }
508
509    /// Initialize the model
510    pub fn init(&mut self, _table: Option<Vec<u32>>) {
511        self.update_cycle = 4;
512        self.bits_until_update = 4;
513        self.bit0_prob = 1 << (BM_LENGTH_SHIFT - 1);
514        self.bit0_count = 1;
515        self.bit_count = 2;
516    }
517
518    /// Update the model
519    pub fn update(&mut self) {
520        // halve counts when a threshold is reached
521        self.bit_count += self.update_cycle;
522        if self.bit_count > BM_MAX_COUNT {
523            self.bit_count = (self.bit_count + 1) >> 1;
524            self.bit0_count = (self.bit0_count + 1) >> 1;
525            if self.bit0_count == self.bit_count {
526                self.bit_count += 1;
527            }
528        }
529        // compute scaled bit 0 probability
530        let scale = 0x80000000 / self.bit_count;
531        self.bit0_prob = (self.bit0_count * scale) >> (31 - BM_LENGTH_SHIFT);
532        // set frequency of model updates
533        self.update_cycle = (5 * self.update_cycle) >> 2;
534        if self.update_cycle > 64 {
535            self.update_cycle = 64;
536        }
537        self.bits_until_update = self.update_cycle;
538    }
539}