lzma_rust2/enc/
encoder.rs

1use alloc::{vec, vec::Vec};
2
3use super::{
4    encoder_fast::FastEncoderMode,
5    encoder_normal::NormalEncoderMode,
6    lz::{LzEncoder, MfType},
7    range_enc::{RangeEncoder, RangeEncoderBuffer},
8};
9use crate::{
10    get_dist_state, state::State, LengthCoder, LiteralCoder, LiteralSubCoder, LzmaCoder, Write,
11    ALIGN_BITS, ALIGN_MASK, ALIGN_SIZE, DIST_MODEL_END, DIST_MODEL_START, DIST_STATES,
12    FULL_DISTANCES, LOW_SYMBOLS, MATCH_LEN_MAX, MATCH_LEN_MIN, MID_SYMBOLS, REPS,
13};
14
15const LZMA2_UNCOMPRESSED_LIMIT: u32 = (2 << 20) - MATCH_LEN_MAX as u32;
16const LZMA2_COMPRESSED_LIMIT: u32 = (64 << 10) - 26;
17
18const DIST_PRICE_UPDATE_INTERVAL: u32 = FULL_DISTANCES as u32;
19const ALIGN_PRICE_UPDATE_INTERVAL: u32 = ALIGN_SIZE as u32;
20const PRICE_UPDATE_INTERVAL: usize = 32;
21
22/// The mode to use when encoding.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum EncodeMode {
25    /// Fast mode (lower quality but faster).
26    Fast,
27    /// Normal mode (higher quality but slower).
28    Normal,
29}
30
31pub(crate) trait LzmaEncoderTrait {
32    fn get_next_symbol(&mut self, encoder: &mut LzmaEncoder) -> u32;
33    fn reset(&mut self) {}
34}
35
36pub(crate) enum LzmaEncoderModes {
37    Fast(FastEncoderMode),
38    Normal(NormalEncoderMode),
39}
40
41impl LzmaEncoderTrait for LzmaEncoderModes {
42    fn get_next_symbol(&mut self, encoder: &mut LzmaEncoder) -> u32 {
43        match self {
44            LzmaEncoderModes::Fast(a) => a.get_next_symbol(encoder),
45            LzmaEncoderModes::Normal(a) => a.get_next_symbol(encoder),
46        }
47    }
48
49    fn reset(&mut self) {
50        match self {
51            LzmaEncoderModes::Fast(a) => a.reset(),
52            LzmaEncoderModes::Normal(a) => a.reset(),
53        }
54    }
55}
56
57pub(crate) struct LzmaEncoder {
58    pub(crate) coder: LzmaCoder,
59    pub(crate) lz: LzEncoder,
60    pub(crate) literal_encoder: LiteralEncoder,
61    pub(crate) match_len_encoder: LengthEncoder,
62    pub(crate) rep_len_encoder: LengthEncoder,
63    pub(crate) data: LzmaEncData,
64}
65
66pub(crate) struct LzmaEncData {
67    pub(crate) nice_len: usize,
68    dist_price_count: i32,
69    align_price_count: i32,
70    dist_slot_prices_size: u32,
71    dist_slot_prices: Vec<Vec<u32>>,
72    full_dist_prices: [[u32; FULL_DISTANCES]; DIST_STATES],
73    align_prices: [u32; ALIGN_SIZE],
74    pub(crate) back: i32,
75    pub(crate) read_ahead: i32,
76    pub(crate) uncompressed_size: u32,
77}
78
79impl LzmaEncoder {
80    pub(crate) fn get_dist_slot(dist: u32) -> u32 {
81        if dist <= DIST_MODEL_START as u32 {
82            return dist;
83        }
84        let mut n = dist;
85        let mut i = 31;
86
87        if (n & 0xFFFF0000) == 0 {
88            n <<= 16;
89            i = 15;
90        }
91
92        if (n & 0xFF000000) == 0 {
93            n <<= 8;
94            i -= 8;
95        }
96
97        if (n & 0xF0000000) == 0 {
98            n <<= 4;
99            i -= 4;
100        }
101
102        if (n & 0xC0000000) == 0 {
103            n <<= 2;
104            i -= 2;
105        }
106
107        if (n & 0x80000000) == 0 {
108            i -= 1;
109        }
110
111        (i << 1) + ((dist >> (i - 1)) & 1)
112    }
113
114    pub(crate) fn get_mem_usage(
115        mode: EncodeMode,
116        dict_size: u32,
117        extra_size_before: u32,
118        mf: MfType,
119    ) -> u32 {
120        let mut m = 80;
121        match mode {
122            EncodeMode::Fast => {
123                m += FastEncoderMode::get_memory_usage(dict_size, extra_size_before, mf);
124            }
125            EncodeMode::Normal => {
126                m += NormalEncoderMode::get_memory_usage(dict_size, extra_size_before, mf);
127            }
128        }
129        m
130    }
131}
132
133impl LzmaEncoder {
134    #[allow(clippy::too_many_arguments)]
135    pub(crate) fn new(
136        mode: EncodeMode,
137        lc: u32,
138        lp: u32,
139        pb: u32,
140        mf: MfType,
141        depth_limit: i32,
142        dict_size: u32,
143        nice_len: usize,
144    ) -> (Self, LzmaEncoderModes) {
145        let fast_mode = mode == EncodeMode::Fast;
146        let mut mode: LzmaEncoderModes = if fast_mode {
147            LzmaEncoderModes::Fast(FastEncoderMode::default())
148        } else {
149            LzmaEncoderModes::Normal(NormalEncoderMode::new())
150        };
151        let (extra_size_before, extra_size_after) = if fast_mode {
152            (
153                FastEncoderMode::EXTRA_SIZE_BEFORE,
154                FastEncoderMode::EXTRA_SIZE_AFTER,
155            )
156        } else {
157            (
158                NormalEncoderMode::EXTRA_SIZE_BEFORE,
159                NormalEncoderMode::EXTRA_SIZE_AFTER,
160            )
161        };
162        let lz = match mf {
163            MfType::Hc4 => LzEncoder::new_hc4(
164                dict_size,
165                extra_size_before,
166                extra_size_after,
167                nice_len as _,
168                MATCH_LEN_MAX as _,
169                depth_limit,
170            ),
171            MfType::Bt4 => LzEncoder::new_bt4(
172                dict_size,
173                extra_size_before,
174                extra_size_after,
175                nice_len as _,
176                MATCH_LEN_MAX as _,
177                depth_limit,
178            ),
179        };
180
181        let literal_encoder = LiteralEncoder::new(lc, lp);
182        let match_len_encoder = LengthEncoder::new(pb, nice_len);
183        let rep_len_encoder = LengthEncoder::new(pb, nice_len);
184        let dist_slot_price_size = LzmaEncoder::get_dist_slot(dict_size - 1) + 1;
185        let mut e = Self {
186            coder: LzmaCoder::new(pb as usize),
187            lz,
188            literal_encoder,
189            match_len_encoder,
190            rep_len_encoder,
191            data: LzmaEncData {
192                nice_len,
193                dist_price_count: 0,
194                align_price_count: 0,
195                dist_slot_prices_size: dist_slot_price_size,
196                dist_slot_prices: vec![vec![0; dist_slot_price_size as usize]; DIST_STATES],
197                full_dist_prices: [[0; FULL_DISTANCES]; DIST_STATES],
198                align_prices: Default::default(),
199                back: 0,
200                read_ahead: -1,
201                uncompressed_size: 0,
202            },
203        };
204        e.reset(&mut mode);
205
206        (e, mode)
207    }
208
209    pub(crate) fn reset(&mut self, mode: &mut dyn LzmaEncoderTrait) {
210        self.coder.reset();
211        self.literal_encoder.reset();
212        self.match_len_encoder.reset();
213        self.rep_len_encoder.reset();
214        self.data.dist_price_count = 0;
215        self.data.align_price_count = 0;
216        self.data.uncompressed_size += (self.data.read_ahead + 1) as u32;
217        self.data.read_ahead = -1;
218        mode.reset();
219    }
220
221    #[inline(always)]
222    pub(crate) fn reset_uncompressed_size(&mut self) {
223        self.data.uncompressed_size = 0;
224    }
225
226    #[allow(unused)]
227    pub(crate) fn encode_for_lzma1<W: Write>(
228        &mut self,
229        rc: &mut RangeEncoder<W>,
230        mode: &mut dyn LzmaEncoderTrait,
231    ) -> crate::Result<()> {
232        if !self.lz.is_started() && !self.encode_init(rc)? {
233            return Ok(());
234        }
235        while self.encode_symbol(rc, mode)? {}
236        Ok(())
237    }
238
239    #[allow(unused)]
240    pub(crate) fn encode_lzma1_end_marker<W: Write>(
241        &mut self,
242        rc: &mut RangeEncoder<W>,
243    ) -> crate::Result<()> {
244        let pos_state = (self.lz.get_pos() - self.data.read_ahead) as u32 & self.coder.pos_mask;
245        rc.encode_bit(
246            &mut self.coder.is_match[self.coder.state.get() as usize],
247            pos_state as usize,
248            1,
249        )?;
250        rc.encode_bit(&mut self.coder.is_rep, self.coder.state.get() as usize, 0)?;
251        self.encode_match(u32::MAX, MATCH_LEN_MIN as u32, pos_state, rc)?;
252        Ok(())
253    }
254
255    fn encode_init<W: Write>(&mut self, rc: &mut RangeEncoder<W>) -> crate::Result<bool> {
256        debug_assert_eq!(self.data.read_ahead, -1);
257        if !self.lz.has_enough_data(0) {
258            return Ok(false);
259        }
260        self.skip(1);
261        let state = self.coder.state.get() as usize;
262        rc.encode_bit(&mut self.coder.is_match[state], 0, 0)?;
263        self.literal_encoder
264            .encode_init(&self.lz, &self.data, &mut self.coder, rc)?;
265        self.data.read_ahead -= 1;
266        debug_assert_eq!(self.data.read_ahead, -1);
267        self.data.uncompressed_size += 1;
268        debug_assert_eq!(self.data.uncompressed_size, 1);
269        Ok(true)
270    }
271
272    fn encode_symbol<W: Write>(
273        &mut self,
274        rc: &mut RangeEncoder<W>,
275        mode: &mut dyn LzmaEncoderTrait,
276    ) -> crate::Result<bool> {
277        if !self.lz.has_enough_data(self.data.read_ahead + 1) {
278            return Ok(false);
279        }
280        let len = mode.get_next_symbol(self);
281
282        debug_assert!(self.data.read_ahead >= 0);
283        let pos_state = (self.lz.get_pos() - self.data.read_ahead) as u32 & self.coder.pos_mask;
284
285        if self.data.back == -1 {
286            debug_assert_eq!(len, 1);
287            let state = self.coder.state.get() as usize;
288            rc.encode_bit(&mut self.coder.is_match[state], pos_state as _, 0)?;
289            self.literal_encoder
290                .encode(&self.lz, &self.data, &mut self.coder, rc)?;
291        } else {
292            let state = self.coder.state.get() as usize;
293            rc.encode_bit(&mut self.coder.is_match[state], pos_state as usize, 1)?;
294            if self.data.back < REPS as i32 {
295                let state = self.coder.state.get() as usize;
296                rc.encode_bit(&mut self.coder.is_rep, state, 1)?;
297                self.encode_rep_match(self.data.back as u32, len, pos_state, rc)?;
298            } else {
299                let state = self.coder.state.get() as usize;
300                rc.encode_bit(&mut self.coder.is_rep, state, 0)?;
301                self.encode_match((self.data.back - REPS as i32) as u32, len, pos_state, rc)?;
302            }
303        }
304        self.data.read_ahead -= len as i32;
305        self.data.uncompressed_size += len;
306        Ok(true)
307    }
308
309    fn encode_match<W: Write>(
310        &mut self,
311        dist: u32,
312        len: u32,
313        pos_state: u32,
314        rc: &mut RangeEncoder<W>,
315    ) -> crate::Result<()> {
316        self.coder.state.update_match();
317        self.match_len_encoder.encode(len, pos_state, rc)?;
318        let dist_slot = LzmaEncoder::get_dist_slot(dist);
319        rc.encode_bit_tree(
320            &mut self.coder.dist_slots[get_dist_state(len) as usize],
321            dist_slot,
322        )?;
323
324        if dist_slot as usize >= DIST_MODEL_START {
325            let footer_bits = (dist_slot >> 1).wrapping_sub(1);
326            let base = (2 | (dist_slot & 1)) << footer_bits;
327            let dist_reduced = dist - base;
328
329            if dist_slot < DIST_MODEL_END as u32 {
330                rc.encode_reverse_bit_tree(
331                    self.coder
332                        .get_dist_special(dist_slot as usize - DIST_MODEL_START),
333                    dist_reduced,
334                )?;
335            } else {
336                rc.encode_direct_bits(dist_reduced >> ALIGN_BITS, footer_bits - ALIGN_BITS as u32)?;
337                rc.encode_reverse_bit_tree(
338                    &mut self.coder.dist_align,
339                    dist_reduced & ALIGN_MASK as u32,
340                )?;
341                self.data.align_price_count -= 1;
342            }
343        }
344
345        self.coder.reps[3] = self.coder.reps[2];
346        self.coder.reps[2] = self.coder.reps[1];
347        self.coder.reps[1] = self.coder.reps[0];
348        self.coder.reps[0] = dist as i32;
349
350        self.data.dist_price_count -= 1;
351        Ok(())
352    }
353
354    fn encode_rep_match<W: Write>(
355        &mut self,
356        rep: u32,
357        len: u32,
358        pos_state: u32,
359        rc: &mut RangeEncoder<W>,
360    ) -> crate::Result<()> {
361        if rep == 0 {
362            let state = self.coder.state.get() as usize;
363            rc.encode_bit(&mut self.coder.is_rep0, state, 0)?;
364            let state = self.coder.state.get() as usize;
365            rc.encode_bit(
366                &mut self.coder.is_rep0_long[state],
367                pos_state as usize,
368                if len == 1 { 0 } else { 1 },
369            )?;
370        } else {
371            let dist = self.coder.reps[rep as usize];
372            let state = self.coder.state.get() as usize;
373
374            rc.encode_bit(&mut self.coder.is_rep0, state, 1)?;
375
376            if rep == 1 {
377                let state = self.coder.state.get() as usize;
378                rc.encode_bit(&mut self.coder.is_rep1, state, 0)?;
379            } else {
380                let state = self.coder.state.get() as usize;
381                rc.encode_bit(&mut self.coder.is_rep1, state, 1)?;
382                let state = self.coder.state.get() as usize;
383                rc.encode_bit(&mut self.coder.is_rep2, state, rep - 2)?;
384
385                if rep == 3 {
386                    self.coder.reps[3] = self.coder.reps[2];
387                }
388                self.coder.reps[2] = self.coder.reps[1];
389            }
390
391            self.coder.reps[1] = self.coder.reps[0];
392            self.coder.reps[0] = dist;
393        }
394
395        if len == 1 {
396            self.coder.state.update_short_rep();
397        } else {
398            self.rep_len_encoder.encode(len, pos_state, rc)?;
399            self.coder.state.update_long_rep();
400        }
401        Ok(())
402    }
403
404    pub(crate) fn find_matches(&mut self) {
405        self.data.read_ahead += 1;
406        self.lz.find_matches();
407        debug_assert!(self.lz.verify_matches());
408    }
409
410    pub(crate) fn skip(&mut self, len: usize) {
411        self.data.read_ahead += len as i32;
412        self.lz.skip(len)
413    }
414
415    pub(crate) fn get_any_match_price(&self, state: &State, pos_state: u32) -> u32 {
416        RangeEncoder::get_bit_price(
417            self.coder.is_match[state.get() as usize][pos_state as usize] as _,
418            1,
419        )
420    }
421
422    pub(crate) fn get_normal_match_price(&self, any_match_price: u32, state: &State) -> u32 {
423        any_match_price
424            + RangeEncoder::get_bit_price(self.coder.is_rep[state.get() as usize] as _, 0)
425    }
426
427    pub(crate) fn get_any_rep_price(&self, any_match_price: u32, state: &State) -> u32 {
428        any_match_price
429            + RangeEncoder::get_bit_price(self.coder.is_rep[state.get() as usize] as _, 1)
430    }
431
432    pub(crate) fn get_short_rep_price(
433        &self,
434        any_rep_price: u32,
435        state: &State,
436        pos_state: u32,
437    ) -> u32 {
438        any_rep_price
439            + RangeEncoder::get_bit_price(self.coder.is_rep0[state.get() as usize] as _, 0)
440            + RangeEncoder::get_bit_price(
441                self.coder.is_rep0_long[state.get() as usize][pos_state as usize] as _,
442                0,
443            )
444    }
445
446    pub(crate) fn get_long_rep_price(
447        &self,
448        any_rep_price: u32,
449        rep: u32,
450        state: &State,
451        pos_state: u32,
452    ) -> u32 {
453        let mut price = any_rep_price;
454
455        if rep == 0 {
456            price += RangeEncoder::get_bit_price(self.coder.is_rep0[state.get() as usize] as _, 0)
457                + RangeEncoder::get_bit_price(
458                    self.coder.is_rep0_long[state.get() as usize][pos_state as usize] as _,
459                    1,
460                );
461        } else {
462            price += RangeEncoder::get_bit_price(self.coder.is_rep0[state.get() as usize] as _, 1);
463
464            if rep == 1 {
465                price +=
466                    RangeEncoder::get_bit_price(self.coder.is_rep1[state.get() as usize] as _, 0)
467            } else {
468                price +=
469                    RangeEncoder::get_bit_price(self.coder.is_rep1[state.get() as usize] as _, 1)
470                        + RangeEncoder::get_bit_price(
471                            self.coder.is_rep2[state.get() as usize] as _,
472                            rep as i32 - 2,
473                        );
474            }
475        }
476
477        price
478    }
479
480    pub(crate) fn get_long_rep_and_len_price(
481        &self,
482        rep: u32,
483        len: u32,
484        state: &State,
485        pos_state: u32,
486    ) -> u32 {
487        let any_match_price = self.get_any_match_price(state, pos_state);
488        let any_rep_price = self.get_any_rep_price(any_match_price, state);
489        let long_rep_price = self.get_long_rep_price(any_rep_price, rep, state, pos_state);
490        long_rep_price + self.rep_len_encoder.get_price(len as _, pos_state as _)
491    }
492
493    pub(crate) fn get_match_and_len_price(
494        &self,
495        normal_match_price: u32,
496        dist: u32,
497        len: u32,
498        pos_state: u32,
499    ) -> u32 {
500        let mut price =
501            normal_match_price + self.match_len_encoder.get_price(len as _, pos_state as _);
502        let dist_state = get_dist_state(len);
503
504        if dist < FULL_DISTANCES as u32 {
505            price += self.data.full_dist_prices[dist_state as usize][dist as usize];
506        } else {
507            // Note that distSlotPrices includes also
508            // the price of direct bits.
509            let dist_slot = LzmaEncoder::get_dist_slot(dist);
510            price += self.data.dist_slot_prices[dist_state as usize][dist_slot as usize]
511                + self.data.align_prices[(dist & ALIGN_MASK as u32) as usize];
512        }
513
514        price
515    }
516
517    pub(crate) fn update_dist_prices(&mut self) {
518        self.data.dist_price_count = DIST_PRICE_UPDATE_INTERVAL as _;
519
520        for dist_state in 0..DIST_STATES {
521            for dist_slot in 0..self.data.dist_slot_prices_size as usize {
522                self.data.dist_slot_prices[dist_state][dist_slot] =
523                    RangeEncoder::get_bit_tree_price(
524                        &mut self.coder.dist_slots[dist_state],
525                        dist_slot as u32,
526                    );
527            }
528
529            for dist_slot in DIST_MODEL_END as u32..self.data.dist_slot_prices_size {
530                let count = (dist_slot >> 1) - 1 - ALIGN_BITS as u32;
531                self.data.dist_slot_prices[dist_state][dist_slot as usize] +=
532                    RangeEncoder::get_direct_bits_price(count);
533            }
534
535            for dist in 0..DIST_MODEL_START {
536                self.data.full_dist_prices[dist_state][dist] =
537                    self.data.dist_slot_prices[dist_state][dist];
538            }
539        }
540
541        let mut dist = DIST_MODEL_START;
542        for dist_slot in DIST_MODEL_START..DIST_MODEL_END {
543            let footer_bits = (dist_slot >> 1) - 1;
544            let base = (2 | (dist_slot & 1)) << footer_bits;
545
546            let limit = self
547                .coder
548                .get_dist_special(dist_slot - DIST_MODEL_START)
549                .len();
550            for _i in 0..limit {
551                let dist_reduced = dist - base;
552                let price = RangeEncoder::get_reverse_bit_tree_price(
553                    self.coder.get_dist_special(dist_slot - DIST_MODEL_START),
554                    dist_reduced as u32,
555                );
556
557                for dist_state in 0..DIST_STATES {
558                    self.data.full_dist_prices[dist_state][dist] =
559                        self.data.dist_slot_prices[dist_state][dist_slot] + price;
560                }
561                dist += 1;
562            }
563        }
564
565        debug_assert_eq!(dist, FULL_DISTANCES);
566    }
567
568    fn update_align_prices(&mut self) {
569        self.data.align_price_count = ALIGN_PRICE_UPDATE_INTERVAL as i32;
570
571        for i in 0..ALIGN_SIZE {
572            self.data.align_prices[i] =
573                RangeEncoder::get_reverse_bit_tree_price(&mut self.coder.dist_align, i as u32);
574        }
575    }
576
577    pub(crate) fn update_prices(&mut self) {
578        if self.data.dist_price_count <= 0 {
579            self.update_dist_prices();
580        }
581
582        if self.data.align_price_count <= 0 {
583            self.update_align_prices();
584        }
585        self.match_len_encoder.update_prices();
586        self.rep_len_encoder.update_prices();
587    }
588}
589
590impl LzmaEncoder {
591    pub fn encode_for_lzma2(
592        &mut self,
593        rc: &mut RangeEncoder<RangeEncoderBuffer>,
594        mode: &mut dyn LzmaEncoderTrait,
595    ) -> crate::Result<bool> {
596        if !self.lz.is_started() && !self.encode_init(rc)? {
597            return Ok(false);
598        }
599        while self.data.uncompressed_size <= LZMA2_UNCOMPRESSED_LIMIT
600            && rc.get_pending_size() <= LZMA2_COMPRESSED_LIMIT
601        {
602            if !self.encode_symbol(rc, mode)? {
603                return Ok(false);
604            }
605        }
606        Ok(true)
607    }
608}
609
610pub(crate) struct LiteralEncoder {
611    coder: LiteralCoder,
612    sub_encoders: Vec<LiteralSubEncoder>,
613}
614
615#[derive(Clone)]
616struct LiteralSubEncoder {
617    coder: LiteralSubCoder,
618}
619
620impl LiteralEncoder {
621    pub(crate) fn new(lc: u32, lp: u32) -> Self {
622        Self {
623            coder: LiteralCoder::new(lc, lp),
624            sub_encoders: vec![LiteralSubEncoder::new(); 1 << (lc + lp)],
625        }
626    }
627
628    pub(crate) fn reset(&mut self) {
629        for ele in self.sub_encoders.iter_mut() {
630            ele.reset();
631        }
632    }
633
634    pub(crate) fn encode_init<W: Write>(
635        &mut self,
636        lz: &LzEncoder,
637        data: &LzmaEncData,
638        coder: &mut LzmaCoder,
639        rc: &mut RangeEncoder<W>,
640    ) -> crate::Result<()> {
641        debug_assert!(data.read_ahead >= 0);
642        self.sub_encoders[0].encode(lz, data, coder, rc)
643    }
644
645    pub(crate) fn encode<W: Write>(
646        &mut self,
647        lz: &LzEncoder,
648        data: &LzmaEncData,
649        coder: &mut LzmaCoder,
650        rc: &mut RangeEncoder<W>,
651    ) -> crate::Result<()> {
652        debug_assert!(data.read_ahead >= 0);
653        let i = self.coder.get_sub_coder_index(
654            lz.get_byte_backward(1 + data.read_ahead) as _,
655            (lz.get_pos() - data.read_ahead) as u32,
656        );
657        self.sub_encoders[i as usize].encode(lz, data, coder, rc)
658    }
659
660    pub(crate) fn get_price(
661        &self,
662        encoder: &LzmaEncoder,
663        cur_byte: u32,
664        match_byte: u32,
665        prev_byte: u32,
666        pos: u32,
667        state: &State,
668    ) -> u32 {
669        let mut price = RangeEncoder::get_bit_price(
670            encoder.coder.is_match[state.get() as usize][(pos & encoder.coder.pos_mask) as usize]
671                as _,
672            0,
673        );
674        let i = self.coder.get_sub_coder_index(prev_byte, pos) as usize;
675        price += if state.is_literal() {
676            self.sub_encoders[i].get_normal_price(cur_byte)
677        } else {
678            self.sub_encoders[i].get_matched_price(cur_byte, match_byte)
679        };
680        price
681    }
682}
683
684impl LiteralSubEncoder {
685    fn new() -> Self {
686        Self {
687            coder: LiteralSubCoder::new(),
688        }
689    }
690
691    fn reset(&mut self) {
692        self.coder.reset()
693    }
694
695    fn encode<W: Write>(
696        &mut self,
697        lz: &LzEncoder,
698        data: &LzmaEncData,
699        coder: &mut LzmaCoder,
700        rc: &mut RangeEncoder<W>,
701    ) -> crate::Result<()> {
702        let mut symbol = lz.get_byte_backward(data.read_ahead) as u32 | 0x100;
703
704        if coder.state.is_literal() {
705            let mut subencoder_index;
706            let mut bit;
707
708            loop {
709                subencoder_index = symbol >> 8;
710                bit = (symbol >> 7) & 1;
711                rc.encode_bit(&mut self.coder.probs, subencoder_index as _, bit as _)?;
712                symbol <<= 1;
713                if symbol >= 0x10000 {
714                    break;
715                }
716            }
717        } else {
718            let mut match_byte = lz.get_byte_backward(coder.reps[0] + 1 + data.read_ahead) as u32;
719            let mut offset = 0x100;
720            let mut subencoder_index;
721            let mut match_bit;
722            let mut bit;
723
724            loop {
725                match_byte <<= 1;
726                match_bit = match_byte & offset;
727                subencoder_index = offset + match_bit + (symbol >> 8);
728                bit = (symbol >> 7) & 1;
729                rc.encode_bit(&mut self.coder.probs, subencoder_index as _, bit)?;
730                symbol <<= 1;
731                offset &= !(match_byte ^ symbol);
732                if symbol >= 0x10000 {
733                    break;
734                }
735            }
736        }
737
738        coder.state.update_literal();
739        Ok(())
740    }
741
742    fn get_normal_price(&self, symbol: u32) -> u32 {
743        let mut price: u32 = 0;
744        let mut subencoder_index;
745        let mut bit;
746        let mut symbol = symbol | 0x100;
747        loop {
748            subencoder_index = symbol >> 8;
749            bit = (symbol >> 7) & 1;
750            price += RangeEncoder::get_bit_price(
751                self.coder.probs[subencoder_index as usize] as _,
752                bit as _,
753            );
754            symbol <<= 1;
755            if symbol >= (0x100 << 8) {
756                break;
757            }
758        }
759        price
760    }
761
762    fn get_matched_price(&self, symbol: u32, mut match_byte: u32) -> u32 {
763        let mut price = 0;
764        let mut offset = 0x100;
765        let mut subencoder_index;
766        let mut match_bit;
767        let mut bit;
768        let mut symbol = symbol | 0x100;
769        loop {
770            match_byte <<= 1;
771            match_bit = match_byte & offset;
772            subencoder_index = offset + match_bit + (symbol >> 8);
773            bit = (symbol >> 7) & 1;
774            price += RangeEncoder::get_bit_price(
775                self.coder.probs[subencoder_index as usize] as _,
776                bit as _,
777            );
778            symbol <<= 1;
779            offset &= !(match_byte ^ symbol);
780            if symbol >= (0x100 << 8) {
781                break;
782            }
783        }
784        price
785    }
786}
787
788pub(crate) struct LengthEncoder {
789    coder: LengthCoder,
790    counters: Vec<i32>,
791    prices: Vec<Vec<u32>>,
792}
793
794impl LengthEncoder {
795    pub(crate) fn new(pb: u32, nice_len: usize) -> Self {
796        let pos_states = 1usize << pb;
797        let counters = vec![0; pos_states];
798        let len_symbols = (nice_len - MATCH_LEN_MIN + 1).max(LOW_SYMBOLS + MID_SYMBOLS);
799        let prices = vec![vec![0; len_symbols]; pos_states];
800        Self {
801            coder: LengthCoder::new(),
802            counters,
803            prices,
804        }
805    }
806
807    fn reset(&mut self) {
808        self.coder.reset();
809        self.counters.fill(0);
810    }
811
812    fn encode<W: Write>(
813        &mut self,
814        len: u32,
815        pos_state: u32,
816        rc: &mut RangeEncoder<W>,
817    ) -> crate::Result<()> {
818        let mut len = len as usize - MATCH_LEN_MIN;
819        if len < LOW_SYMBOLS {
820            rc.encode_bit(&mut self.coder.choice, 0, 0)?;
821            rc.encode_bit_tree(&mut self.coder.low[pos_state as usize], len as _)?;
822        } else {
823            rc.encode_bit(&mut self.coder.choice, 0, 1)?;
824            len -= LOW_SYMBOLS;
825            if len < MID_SYMBOLS {
826                rc.encode_bit(&mut self.coder.choice, 1, 0)?;
827                rc.encode_bit_tree(&mut self.coder.mid[pos_state as usize], len as _)?;
828            } else {
829                rc.encode_bit(&mut self.coder.choice, 1, 1)?;
830                rc.encode_bit_tree(&mut self.coder.high, (len - MID_SYMBOLS) as _)?;
831            }
832        }
833        self.counters[pos_state as usize] = self.counters[pos_state as usize].wrapping_sub(1);
834        Ok(())
835    }
836
837    pub(crate) fn get_price(&self, len: usize, pos_state: usize) -> u32 {
838        self.prices[pos_state][len - MATCH_LEN_MIN]
839    }
840
841    fn update_prices(&mut self) {
842        for pos_state in 0..self.counters.len() {
843            if self.counters[pos_state] <= 0 {
844                self.counters[pos_state] = PRICE_UPDATE_INTERVAL as _;
845                self.update_prices_with_state(pos_state);
846            }
847        }
848    }
849
850    fn update_prices_with_state(&mut self, pos_state: usize) {
851        let mut choice0_price = RangeEncoder::get_bit_price(self.coder.choice[0] as _, 0);
852        let mut start = 0;
853        for i in start..LOW_SYMBOLS {
854            self.prices[pos_state][i] = choice0_price
855                + RangeEncoder::get_bit_tree_price(&mut self.coder.low[pos_state], i as _);
856        }
857        start = LOW_SYMBOLS;
858        choice0_price = RangeEncoder::get_bit_price(self.coder.choice[0] as _, 1);
859        let mut choice1_price = RangeEncoder::get_bit_price(self.coder.choice[1] as _, 0);
860        for i in start..(LOW_SYMBOLS + MID_SYMBOLS) {
861            self.prices[pos_state][i] = choice0_price
862                + choice1_price
863                + RangeEncoder::get_bit_tree_price(
864                    &mut self.coder.mid[pos_state],
865                    (i - start) as u32,
866                );
867        }
868        start = LOW_SYMBOLS + MID_SYMBOLS;
869        choice1_price = RangeEncoder::get_bit_price(self.coder.choice[1] as _, 1);
870        for i in start..self.prices[pos_state].len() {
871            self.prices[pos_state][i] = choice0_price
872                + choice1_price
873                + RangeEncoder::get_bit_tree_price(&mut self.coder.high, (i - start) as u32)
874        }
875    }
876}