lzma_rust2/lz/
lz_encoder.rs

1use alloc::{vec, vec::Vec};
2use core::ops::Deref;
3
4use super::{bt4::Bt4, extend_match, hc4::Hc4};
5use crate::Write;
6
7/// Align to a 64-byte cache line
8const MOVE_BLOCK_ALIGN: i32 = 64;
9const MOVE_BLOCK_ALIGN_MASK: i32 = !(MOVE_BLOCK_ALIGN - 1);
10
11pub(crate) trait MatchFind {
12    fn find_matches(&mut self, encoder: &mut LzEncoderData, matches: &mut Matches);
13    fn skip(&mut self, encoder: &mut LzEncoderData, len: usize);
14}
15
16pub(crate) enum MatchFinders {
17    Hc4(Hc4),
18    Bt4(Bt4),
19}
20
21impl MatchFind for MatchFinders {
22    fn find_matches(&mut self, encoder: &mut LzEncoderData, matches: &mut Matches) {
23        match self {
24            MatchFinders::Hc4(m) => m.find_matches(encoder, matches),
25            MatchFinders::Bt4(m) => m.find_matches(encoder, matches),
26        }
27    }
28
29    fn skip(&mut self, encoder: &mut LzEncoderData, len: usize) {
30        match self {
31            MatchFinders::Hc4(m) => m.skip(encoder, len),
32            MatchFinders::Bt4(m) => m.skip(encoder, len),
33        }
34    }
35}
36
37/// Match finders to use when encoding.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum MfType {
40    /// Hash chain for 4 bytes entries (lower quality but faster).
41    Hc4,
42    /// Binary tree for 4 byte entries (higher quality but slower).
43    Bt4,
44}
45
46impl Default for MfType {
47    fn default() -> Self {
48        Self::Hc4
49    }
50}
51
52impl MfType {
53    #[inline]
54    fn get_memory_usage(self, dict_size: u32) -> u32 {
55        match self {
56            MfType::Hc4 => Hc4::get_mem_usage(dict_size),
57            MfType::Bt4 => Bt4::get_mem_usage(dict_size),
58        }
59    }
60}
61
62pub(crate) struct LzEncoder {
63    pub(crate) data: LzEncoderData,
64    pub(crate) matches: Matches,
65    pub(crate) match_finder: MatchFinders,
66}
67
68pub(crate) struct LzEncoderData {
69    pub(crate) keep_size_before: u32,
70    pub(crate) keep_size_after: u32,
71    pub(crate) match_len_max: u32,
72    pub(crate) nice_len: u32,
73    pub(crate) buf: Vec<u8>,
74    pub(crate) buf_size: usize,
75    pub(crate) buf_limit_u16: usize,
76    pub(crate) read_pos: i32,
77    pub(crate) read_limit: i32,
78    pub(crate) finishing: bool,
79    pub(crate) write_pos: i32,
80    pub(crate) pending_size: u32,
81}
82
83pub(crate) struct Matches {
84    pub(crate) len: Vec<u32>,
85    pub(crate) dist: Vec<i32>,
86    pub(crate) count: u32,
87}
88
89impl Matches {
90    pub(crate) fn new(count_max: usize) -> Self {
91        Self {
92            len: vec![0; count_max],
93            dist: vec![0; count_max],
94            count: 0,
95        }
96    }
97}
98
99impl LzEncoder {
100    pub(crate) fn get_memory_usage(
101        dict_size: u32,
102        extra_size_before: u32,
103        extra_size_after: u32,
104        match_len_max: u32,
105        mf: MfType,
106    ) -> u32 {
107        get_buf_size(
108            dict_size,
109            extra_size_before,
110            extra_size_after,
111            match_len_max,
112        ) + mf.get_memory_usage(dict_size)
113    }
114
115    pub(crate) fn new_hc4(
116        dict_size: u32,
117        extra_size_before: u32,
118        extra_size_after: u32,
119        nice_len: u32,
120        match_len_max: u32,
121        depth_limit: i32,
122    ) -> Self {
123        Self::new(
124            dict_size,
125            extra_size_before,
126            extra_size_after,
127            nice_len,
128            match_len_max,
129            MatchFinders::Hc4(Hc4::new(dict_size, nice_len, depth_limit)),
130        )
131    }
132
133    pub(crate) fn new_bt4(
134        dict_size: u32,
135        extra_size_before: u32,
136        extra_size_after: u32,
137        nice_len: u32,
138        match_len_max: u32,
139        depth_limit: i32,
140    ) -> Self {
141        Self::new(
142            dict_size,
143            extra_size_before,
144            extra_size_after,
145            nice_len,
146            match_len_max,
147            MatchFinders::Bt4(Bt4::new(dict_size, nice_len, depth_limit)),
148        )
149    }
150
151    fn new(
152        dict_size: u32,
153        extra_size_before: u32,
154        extra_size_after: u32,
155        nice_len: u32,
156        match_len_max: u32,
157        match_finder: MatchFinders,
158    ) -> Self {
159        let buf_size = get_buf_size(
160            dict_size,
161            extra_size_before,
162            extra_size_after,
163            match_len_max,
164        );
165        let buf_size = buf_size as usize;
166        let buf = vec![0; buf_size];
167        let buf_limit_u16 = buf_size.checked_sub(size_of::<u16>()).unwrap();
168
169        let keep_size_before = extra_size_before + dict_size;
170        let keep_size_after = extra_size_after + match_len_max;
171
172        Self {
173            data: LzEncoderData {
174                keep_size_before,
175                keep_size_after,
176                match_len_max,
177                nice_len,
178                buf,
179                buf_size,
180                buf_limit_u16,
181                read_pos: -1,
182                read_limit: -1,
183                finishing: false,
184                write_pos: 0,
185                pending_size: 0,
186            },
187            matches: Matches::new(nice_len as usize - 1),
188            match_finder,
189        }
190    }
191
192    pub(crate) fn normalize(positions: &mut [i32], norm_offset: i32) {
193        #[cfg(all(feature = "std", feature = "optimization", target_arch = "x86_64"))]
194        {
195            if std::arch::is_x86_feature_detected!("avx2") {
196                // SAFETY: We've checked that the CPU supports AVX2.
197                return unsafe { normalize_avx2(positions, norm_offset) };
198            }
199            if std::arch::is_x86_feature_detected!("sse4.1") {
200                // SAFETY: We've checked that the CPU supports SSE4.1.
201                return unsafe { normalize_sse41(positions, norm_offset) };
202            }
203        }
204
205        #[cfg(all(feature = "std", feature = "optimization", target_arch = "aarch64"))]
206        {
207            if std::arch::is_aarch64_feature_detected!("neon") {
208                // SAFETY: We've checked that the CPU supports NEON.
209                return unsafe { normalize_neon(positions, norm_offset) };
210            }
211        }
212
213        normalize_scalar(positions, norm_offset);
214    }
215
216    pub(crate) fn find_matches(&mut self) {
217        self.match_finder
218            .find_matches(&mut self.data, &mut self.matches)
219    }
220
221    pub(crate) fn matches(&mut self) -> &mut Matches {
222        &mut self.matches
223    }
224
225    pub(crate) fn skip(&mut self, len: usize) {
226        self.match_finder.skip(&mut self.data, len)
227    }
228
229    pub(crate) fn set_preset_dict(&mut self, dict_size: u32, preset_dict: &[u8]) {
230        self.data
231            .set_preset_dict(dict_size, preset_dict, &mut self.match_finder)
232    }
233
234    pub(crate) fn set_finishing(&mut self) {
235        self.data.set_finishing(&mut self.match_finder)
236    }
237
238    pub(crate) fn fill_window(&mut self, input: &[u8]) -> usize {
239        self.data.fill_window(input, &mut self.match_finder)
240    }
241
242    pub(crate) fn set_flushing(&mut self) {
243        self.data.set_flushing(&mut self.match_finder)
244    }
245
246    pub(crate) fn verify_matches(&self) -> bool {
247        self.data.verify_matches(&self.matches)
248    }
249}
250
251impl LzEncoderData {
252    pub(crate) fn is_started(&self) -> bool {
253        self.read_pos != -1
254    }
255
256    pub(crate) fn read_buffer(&self) -> &[u8] {
257        &self.buf[self.read_pos as usize..]
258    }
259
260    fn set_preset_dict(
261        &mut self,
262        dict_size: u32,
263        preset_dict: &[u8],
264        match_finder: &mut dyn MatchFind,
265    ) {
266        debug_assert!(!self.is_started());
267        debug_assert_eq!(self.write_pos, 0);
268        let copy_size = preset_dict.len().min(dict_size as usize);
269        let offset = preset_dict.len() - copy_size;
270        self.buf[0..copy_size].copy_from_slice(&preset_dict[offset..(offset + copy_size)]);
271        self.write_pos += copy_size as i32;
272        match_finder.skip(self, copy_size);
273    }
274
275    fn move_window(&mut self) {
276        let move_offset =
277            (self.read_pos + 1 - self.keep_size_before as i32) & MOVE_BLOCK_ALIGN_MASK;
278        let move_size = self.write_pos - move_offset;
279
280        debug_assert!(move_size >= 0);
281        debug_assert!(move_offset >= 0);
282
283        let move_size = move_size as usize;
284        let offset = move_offset as usize;
285
286        self.buf.copy_within(offset..offset + move_size, 0);
287
288        self.read_pos -= move_offset;
289        self.read_limit -= move_offset;
290        self.write_pos -= move_offset;
291    }
292
293    fn fill_window(&mut self, input: &[u8], match_finder: &mut dyn MatchFind) -> usize {
294        debug_assert!(!self.finishing);
295        if self.read_pos >= (self.buf_size as i32 - self.keep_size_after as i32) {
296            self.move_window();
297        }
298        let len = if input.len() as i32 > self.buf_size as i32 - self.write_pos {
299            (self.buf_size as i32 - self.write_pos) as usize
300        } else {
301            input.len()
302        };
303        let d_start = self.write_pos as usize;
304        let d_end = d_start + len;
305        self.buf[d_start..d_end].copy_from_slice(&input[..len]);
306        self.write_pos += len as i32;
307        if self.write_pos >= self.keep_size_after as i32 {
308            self.read_limit = self.write_pos - self.keep_size_after as i32;
309        }
310        self.process_pending_bytes(match_finder);
311        len
312    }
313
314    fn process_pending_bytes(&mut self, match_finder: &mut dyn MatchFind) {
315        if self.pending_size > 0 && self.read_pos < self.read_limit {
316            self.read_pos -= self.pending_size as i32;
317            let old_pending = self.pending_size;
318            self.pending_size = 0;
319            match_finder.skip(self, old_pending as _);
320            debug_assert!(self.pending_size < old_pending)
321        }
322    }
323
324    fn set_flushing(&mut self, match_finder: &mut dyn MatchFind) {
325        self.read_limit = self.write_pos - 1;
326        self.process_pending_bytes(match_finder);
327    }
328
329    fn set_finishing(&mut self, match_finder: &mut dyn MatchFind) {
330        self.read_limit = self.write_pos - 1;
331        self.finishing = true;
332        self.process_pending_bytes(match_finder);
333    }
334
335    pub fn has_enough_data(&self, already_read_len: i32) -> bool {
336        self.read_pos - already_read_len < self.read_limit
337    }
338
339    pub(crate) fn copy_uncompressed<W: Write>(
340        &self,
341        out: &mut W,
342        backward: i32,
343        len: usize,
344    ) -> crate::Result<()> {
345        let start = (self.read_pos + 1 - backward) as usize;
346        out.write_all(&self.buf[start..(start + len)])
347    }
348
349    #[inline(always)]
350    pub(crate) fn get_avail(&self) -> i32 {
351        debug_assert_ne!(self.read_pos, -1);
352        self.write_pos - self.read_pos
353    }
354
355    #[inline(always)]
356    pub(crate) fn get_pos(&self) -> i32 {
357        self.read_pos
358    }
359
360    #[inline(always)]
361    pub(crate) fn get_byte(&self, forward: i32, backward: i32) -> u8 {
362        self.buf[(self.read_pos + forward - backward) as usize]
363    }
364
365    #[inline(always)]
366    pub(crate) fn get_byte_by_pos(&self, pos: i32) -> u8 {
367        self.buf[pos as usize]
368    }
369
370    #[inline(always)]
371    pub(crate) fn get_byte_backward(&self, backward: i32) -> u8 {
372        self.buf[(self.read_pos - backward) as usize]
373    }
374
375    #[inline(always)]
376    pub(crate) fn get_current_byte(&self) -> u8 {
377        self.buf[self.read_pos as usize]
378    }
379
380    #[inline(always)]
381    pub(crate) fn get_match_len(&self, dist: i32, len_limit: i32) -> usize {
382        extend_match(&self.buf, self.read_pos, 0, dist + 1, len_limit) as usize
383    }
384
385    #[inline(always)]
386    pub(crate) fn get_match_len2(&self, forward: i32, dist: i32, len_limit: i32) -> u32 {
387        if len_limit <= 0 {
388            return 0;
389        }
390        extend_match(&self.buf, self.read_pos + forward, 0, dist + 1, len_limit) as u32
391    }
392
393    #[inline(always)]
394    pub(crate) fn get_match_len_fast_reject<const MATCH_LEN_MIN: usize>(
395        &self,
396        dist: i32,
397        len_limit: i32,
398    ) -> usize {
399        let match_dist = dist + 1;
400        let read_pos = self.read_pos as usize;
401
402        // Fast rejection
403        #[cfg(feature = "optimization")]
404        unsafe {
405            // SAFETY: We clamp the read positions in range of the buffer.
406            let clamped0 = read_pos.min(self.buf_limit_u16);
407            let clamped1 = (read_pos - match_dist as usize).min(self.buf_limit_u16);
408
409            if core::ptr::read_unaligned(self.buf.as_ptr().add(clamped0) as *const u16)
410                != core::ptr::read_unaligned(self.buf.as_ptr().add(clamped1) as *const u16)
411            {
412                return 0;
413            }
414        }
415        #[cfg(not(feature = "optimization"))]
416        if self.buf[read_pos] != self.buf[read_pos - match_dist as usize]
417            || self.buf[read_pos + 1] != self.buf[read_pos + 1 - match_dist as usize]
418        {
419            return 0;
420        }
421
422        extend_match(&self.buf, self.read_pos, 2, match_dist, len_limit) as usize
423    }
424
425    fn verify_matches(&self, matches: &Matches) -> bool {
426        let len_limit = self.get_avail().min(self.match_len_max as i32);
427
428        for i in 0..matches.count as usize {
429            let match_distance = matches.dist[i] + 1;
430            let actual_len = extend_match(&self.buf, self.read_pos, 0, match_distance, len_limit);
431
432            if actual_len as u32 != matches.len[i] {
433                return false;
434            }
435        }
436
437        true
438    }
439
440    pub(crate) fn move_pos(
441        &mut self,
442        required_for_flushing: i32,
443        required_for_finishing: i32,
444    ) -> i32 {
445        debug_assert!(required_for_flushing >= required_for_finishing);
446        self.read_pos += 1;
447        let mut avail = self.write_pos - self.read_pos;
448        if avail < required_for_flushing && (avail < required_for_finishing || !self.finishing) {
449            self.pending_size += 1;
450            avail = 0;
451        }
452        avail
453    }
454}
455
456impl Deref for LzEncoder {
457    type Target = LzEncoderData;
458
459    fn deref(&self) -> &Self::Target {
460        &self.data
461    }
462}
463
464fn get_buf_size(
465    dict_size: u32,
466    extra_size_before: u32,
467    extra_size_after: u32,
468    match_len_max: u32,
469) -> u32 {
470    let keep_size_before = extra_size_before + dict_size;
471    let keep_size_after = extra_size_after + match_len_max;
472    let reserve_size = (dict_size / 2 + (256 << 10)).min(512 << 20);
473    keep_size_before + keep_size_after + reserve_size
474}
475
476#[inline(always)]
477fn normalize_scalar(positions: &mut [i32], norm_offset: i32) {
478    positions
479        .iter_mut()
480        .for_each(|p| *p = p.saturating_sub(norm_offset));
481}
482
483/// Normalization implementation using ARM NEON for 128-bit SIMD processing.
484#[cfg(all(feature = "std", feature = "optimization", target_arch = "aarch64"))]
485#[target_feature(enable = "neon")]
486unsafe fn normalize_neon(positions: &mut [i32], norm_offset: i32) {
487    use core::arch::aarch64::*;
488
489    // Create a 128-bit vector with the offset broadcast to all 4 lanes.
490    let norm_v = vdupq_n_s32(norm_offset);
491
492    // Split the slice into a 16-byte aligned middle part and unaligned ends.
493    // `int32x4_t` is the NEON vector type for 4 x i32, which is 16 bytes.
494    let (prefix, chunks, suffix) = positions.align_to_mut::<int32x4_t>();
495
496    normalize_scalar(prefix, norm_offset);
497
498    for chunk in chunks {
499        let ptr = chunk as *mut int32x4_t as *mut i32;
500
501        let data = vld1q_s32(ptr);
502
503        // Perform saturated subtraction on 8 integers simultaneously.
504        let max_val = vmaxq_s32(data, norm_v);
505        let result = vsubq_s32(max_val, norm_v);
506
507        vst1q_s32(ptr, result);
508    }
509
510    normalize_scalar(suffix, norm_offset);
511}
512
513/// Normalization implementation using AVX2 for 256-bit SIMD processing.
514#[cfg(all(feature = "std", feature = "optimization", target_arch = "x86_64"))]
515#[target_feature(enable = "avx2")]
516unsafe fn normalize_avx2(positions: &mut [i32], norm_offset: i32) {
517    use core::arch::x86_64::*;
518
519    // Create a 256-bit vector with the normalization offset broadcast to all 8 lanes.
520    let norm_v = _mm256_set1_epi32(norm_offset);
521
522    // Split the slice into a 32-byte aligned middle part and unaligned ends.
523    let (prefix, chunks, suffix) = positions.align_to_mut::<__m256i>();
524
525    normalize_scalar(prefix, norm_offset);
526
527    for chunk in chunks {
528        // Use ALIGNED load. This is safe because `align_to_mut`
529        // guarantees that `chunk` is aligned to 32 bytes.
530        let data = _mm256_load_si256(chunk as *mut _);
531
532        // Perform saturated subtraction on 8 integers simultaneously.
533        let max_val = _mm256_max_epi32(data, norm_v);
534        let result = _mm256_sub_epi32(max_val, norm_v);
535
536        // Use ALIGNED store to write the results back.
537        _mm256_store_si256(chunk as *mut _, result);
538    }
539
540    normalize_scalar(suffix, norm_offset);
541}
542
543/// Normalization implementation using SSE4.1 for 128-bit SIMD processing.
544#[cfg(all(feature = "std", feature = "optimization", target_arch = "x86_64"))]
545#[target_feature(enable = "sse4.1")]
546unsafe fn normalize_sse41(positions: &mut [i32], norm_offset: i32) {
547    use core::arch::x86_64::*;
548
549    // Create a 128-bit vector with the offset broadcast to all 4 lanes.
550    let norm_v = _mm_set1_epi32(norm_offset);
551
552    // Split the slice into a 16-byte aligned middle part and unaligned ends.
553    let (prefix, chunks, suffix) = positions.align_to_mut::<__m128i>();
554
555    normalize_scalar(prefix, norm_offset);
556
557    // Process the aligned middle part in 128-bit (4 x i32) chunks.
558    for chunk in chunks {
559        // Use ALIGNED 128-bit load.
560        let data = _mm_load_si128(chunk as *mut _);
561
562        let max_val = _mm_max_epi32(data, norm_v);
563        let result = _mm_sub_epi32(max_val, norm_v);
564
565        // Use ALIGNED 128-bit store.
566        _mm_store_si128(chunk as *mut _, result);
567    }
568
569    normalize_scalar(suffix, norm_offset);
570}