Skip to main content

lz4rip_encode/
hashtable.rs

1#[cfg(feature = "alloc")]
2use alloc::boxed::Box;
3
4/// Count matching bytes between `input[cur..]` and `source[candidate..]`,
5/// stopping before `input[input_len - end_offset]`. Uses raw pointer
6/// comparison (usize, then u32/u16/u8 stepdown) without bounds checks.
7///
8/// Caller must ensure both ranges are valid and `end_offset` bytes of
9/// input are reserved after the match region.
10#[cfg(not(feature = "paranoid"))]
11#[inline]
12pub(crate) fn count_same_bytes_inbounds(
13    input: &[u8],
14    cur: &mut usize,
15    source: &[u8],
16    candidate: usize,
17    end_offset: usize,
18) -> usize {
19    let max_input = input.len().saturating_sub(*cur + end_offset);
20    debug_assert!(candidate <= source.len());
21    let max_cand = source.len().saturating_sub(candidate);
22    let input_end = *cur + max_input.min(max_cand);
23    let start = *cur;
24
25    // SAFETY: `input_end` is clamped to both `input.len() - end_offset` and
26    // `source.len() - candidate`, so all pointer offsets up to `input_end` are
27    // within the respective slices.
28    unsafe {
29        let mut src_ptr = source.as_ptr().add(candidate);
30        let inp_base = input.as_ptr();
31
32        const STEP: usize = core::mem::size_of::<usize>();
33        while *cur + STEP <= input_end {
34            let diff = (inp_base.add(*cur) as *const usize).read_unaligned()
35                ^ (src_ptr as *const usize).read_unaligned();
36            if diff == 0 {
37                *cur += STEP;
38                src_ptr = src_ptr.add(STEP);
39            } else {
40                *cur += (diff.to_le().trailing_zeros() / 8) as usize;
41                return *cur - start;
42            }
43        }
44
45        #[cfg(target_pointer_width = "64")]
46        if input_end - *cur >= 4 {
47            let diff = (inp_base.add(*cur) as *const u32).read_unaligned()
48                ^ (src_ptr as *const u32).read_unaligned();
49            if diff == 0 {
50                *cur += 4;
51                src_ptr = src_ptr.add(4);
52            } else {
53                *cur += (diff.to_le().trailing_zeros() / 8) as usize;
54                return *cur - start;
55            }
56        }
57
58        if input_end - *cur >= 2
59            && (inp_base.add(*cur) as *const u16).read_unaligned()
60                == (src_ptr as *const u16).read_unaligned()
61        {
62            *cur += 2;
63            src_ptr = src_ptr.add(2);
64        }
65
66        if *cur < input_end && *inp_base.add(*cur) == *src_ptr {
67            *cur += 1;
68        }
69    }
70
71    *cur - start
72}
73
74/// Count matching bytes (paranoid: safe `chunks_exact` 8-byte compare).
75///
76/// Uses the same idiom as lz4_flex's safe encoder and the test-only
77/// `count_same_bytes` helper: `chunks_exact(8).zip(..)` avoids a per-iteration
78/// bounds check and autovectorizes, then a byte tail. `to_le` makes
79/// `trailing_zeros` count from the lowest-address mismatching byte.
80#[cfg(feature = "paranoid")]
81#[inline]
82pub(crate) fn count_same_bytes_inbounds(
83    input: &[u8],
84    cur: &mut usize,
85    source: &[u8],
86    candidate: usize,
87    end_offset: usize,
88) -> usize {
89    const STEP: usize = 8;
90    debug_assert!(*cur + end_offset <= input.len());
91    debug_assert!(candidate <= source.len());
92    let max_input = input.len() - *cur - end_offset;
93    let max_cand = source.len() - candidate;
94    let limit = max_input.min(max_cand);
95    let cur_slice = &input[*cur..*cur + limit];
96    let cand_slice = &source[candidate..candidate + limit];
97
98    let mut num = 0;
99    for (a, b) in cur_slice
100        .chunks_exact(STEP)
101        .zip(cand_slice.chunks_exact(STEP))
102    {
103        let av = u64::from_ne_bytes(a.try_into().unwrap());
104        let bv = u64::from_ne_bytes(b.try_into().unwrap());
105        if av == bv {
106            num += STEP;
107        } else {
108            num += ((av ^ bv).to_le().trailing_zeros() / 8) as usize;
109            *cur += num;
110            return num;
111        }
112    }
113    num += cur_slice[num..]
114        .iter()
115        .zip(&cand_slice[num..])
116        .take_while(|(a, b)| a == b)
117        .count();
118
119    *cur += num;
120    num
121}
122
123/// Read 4 bytes from `input` at position `n` without bounds checking.
124///
125/// # Safety
126/// Caller must ensure `n + 4 <= input.len()`.
127#[cfg(not(feature = "paranoid"))]
128#[inline]
129pub(crate) fn get_batch_inbounds(input: &[u8], n: usize) -> u32 {
130    debug_assert!(n + 4 <= input.len());
131    // SAFETY: caller ensures `n + 4 <= input.len()`.
132    unsafe { (input.as_ptr().add(n) as *const u32).read_unaligned() }
133}
134
135/// Read 4 bytes at position `n` (paranoid: bounds-checked, native-endian).
136#[cfg(feature = "paranoid")]
137#[inline]
138pub(crate) fn get_batch_inbounds(input: &[u8], n: usize) -> u32 {
139    u32::from_ne_bytes(input[n..n + 4].try_into().unwrap())
140}
141
142/// Read an usize sized "batch" from some position (native-endian).
143#[inline]
144#[cfg(target_pointer_width = "64")]
145pub(crate) fn get_batch_arch(input: &[u8], n: usize) -> usize {
146    const USIZE_SIZE: usize = core::mem::size_of::<usize>();
147    let arr: &[u8; USIZE_SIZE] = input[n..n + USIZE_SIZE].try_into().unwrap();
148    usize::from_ne_bytes(*arr)
149}
150
151#[inline]
152#[cfg(all(target_pointer_width = "64", not(feature = "paranoid")))]
153unsafe fn get_batch_arch_unchecked(input: &[u8], n: usize) -> usize {
154    debug_assert!(n + core::mem::size_of::<usize>() <= input.len());
155    unsafe { (input.as_ptr().add(n) as *const usize).read_unaligned() }
156}
157
158// Knuth's multiplicative hash constant (golden ratio * 2^32).
159const KNUTH: u32 = 2_654_435_761;
160
161#[cfg(target_pointer_width = "64")]
162const PRIME5: usize = if cfg!(target_endian = "little") {
163    889_523_592_379
164} else {
165    11_400_714_785_074_694_791
166};
167
168/// Hash table trait for LZ4 match finding.
169pub(crate) trait HashTable {
170    /// Look up a table entry by hash index.
171    fn get_at(&self, idx: usize) -> usize;
172    /// Store a position at the given hash index.
173    fn put_at(&mut self, idx: usize, val: usize);
174    /// Zero all entries.
175    fn clear(&mut self);
176    /// Hash `input[pos..]` with bounds checking.
177    fn get_hash_at(input: &[u8], pos: usize) -> usize;
178    /// Hash `input[pos..]` without bounds checking.
179    ///
180    /// Default delegates to the checked [`get_hash_at`](Self::get_hash_at).
181    #[inline]
182    fn get_hash_at_inbounds(input: &[u8], pos: usize) -> usize {
183        Self::get_hash_at(input, pos)
184    }
185}
186
187/// Default entry count for the no-dict (`u32`-valued) table: 2048 x 4 B = 8 KB.
188pub const DEFAULT_NODICT_ENTRIES: usize = 2 * 1024;
189/// Default entry count for the dict (`u16`-valued) tables: 4096 x 2 B = 8 KB.
190pub const DEFAULT_DICT_ENTRIES: usize = 4 * 1024;
191/// Smallest permitted hash-table entry count: 256 (an 8-bit index). Below this
192/// the hash collapses 5 input bytes onto too few buckets to find matches, so the
193/// compressor degrades to emitting literals. Matches C lz4's floor
194/// (`LZ4_MEMORY_USAGE_MIN = 10` -> `1 << (10 - 2)` = 256-entry table).
195pub const MIN_ENTRIES: usize = 256;
196
197/// Compile-time validation of a hash-table entry count `N`.
198///
199/// `N` must be a power of two so the index shift `64 - N.ilog2()` maps the hash
200/// onto exactly `[0, N)`, and at least [`MIN_ENTRIES`] so the shift is in range
201/// and the table carries enough index bits to match.
202const fn assert_valid_entries(n: usize) {
203    assert!(
204        n.is_power_of_two(),
205        "hash table entry count must be a power of two"
206    );
207    assert!(
208        n >= MIN_ENTRIES,
209        "hash table entry count must be at least MIN_ENTRIES (256)"
210    );
211}
212
213#[cfg(target_pointer_width = "64")]
214const U32_HASH_BYTES: usize = 5;
215
216/// A hash table with `N` entries using 16-bit values (`2 * N` bytes).
217///
218/// `N` must be a power of two (checked at compile time in [`new`](Self::new)).
219/// Stored positions must fit in `u16`, so this is used only when dict + input
220/// stays below 64 KB.
221#[derive(Debug)]
222#[repr(align(64))]
223pub(crate) struct HashTableU32U16<const N: usize = DEFAULT_DICT_ENTRIES> {
224    #[cfg(feature = "alloc")]
225    dict: Box<[u16; N]>,
226    #[cfg(not(feature = "alloc"))]
227    dict: [u16; N],
228}
229impl<const N: usize> HashTableU32U16<N> {
230    #[cfg(feature = "alloc")]
231    #[inline]
232    pub(crate) fn new() -> Self {
233        const { assert_valid_entries(N) };
234        let dict = alloc::vec![0; N].into_boxed_slice().try_into().unwrap();
235        Self { dict }
236    }
237    #[cfg(not(feature = "alloc"))]
238    #[inline]
239    pub(crate) fn new() -> Self {
240        const { assert_valid_entries(N) };
241        Self { dict: [0u16; N] }
242    }
243}
244impl<const N: usize> HashTable for HashTableU32U16<N> {
245    #[inline]
246    fn get_at(&self, idx: usize) -> usize {
247        self.dict[idx] as usize
248    }
249    #[inline]
250    fn put_at(&mut self, idx: usize, val: usize) {
251        self.dict[idx] = val as u16;
252    }
253    #[inline]
254    fn clear(&mut self) {
255        self.dict.fill(0);
256    }
257    #[inline]
258    #[cfg(target_pointer_width = "64")]
259    fn get_hash_at(input: &[u8], pos: usize) -> usize {
260        let batch = get_batch_arch(input, pos);
261        (batch << 24).wrapping_mul(PRIME5) >> (64 - N.ilog2() as usize)
262    }
263    #[inline]
264    #[cfg(all(target_pointer_width = "64", not(feature = "paranoid")))]
265    fn get_hash_at_inbounds(input: &[u8], pos: usize) -> usize {
266        // SAFETY: callers guarantee pos + 8 <= input.len() via end_pos_check.
267        let batch = unsafe { get_batch_arch_unchecked(input, pos) };
268        (batch << 24).wrapping_mul(PRIME5) >> (64 - N.ilog2() as usize)
269    }
270    #[inline]
271    #[cfg(target_pointer_width = "32")]
272    fn get_hash_at(input: &[u8], pos: usize) -> usize {
273        let batch = u32::from_ne_bytes(input[pos..pos + 4].try_into().unwrap());
274        (batch.wrapping_mul(KNUTH) >> (32 - N.ilog2())) as usize
275    }
276}
277
278/// A hash table with `N` entries using 32-bit values (`4 * N` bytes).
279///
280/// `N` must be a power of two (checked at compile time in [`new`](Self::new)).
281#[derive(Debug)]
282pub struct HashTableU32<const N: usize = DEFAULT_NODICT_ENTRIES> {
283    #[cfg(feature = "alloc")]
284    dict: Box<[u32; N]>,
285    #[cfg(not(feature = "alloc"))]
286    dict: [u32; N],
287}
288impl<const N: usize> Default for HashTableU32<N> {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293impl<const N: usize> HashTableU32<N> {
294    #[cfg(feature = "alloc")]
295    #[inline]
296    /// Create a new zeroed hash table.
297    pub fn new() -> Self {
298        const { assert_valid_entries(N) };
299        let dict = alloc::vec![0; N].into_boxed_slice().try_into().unwrap();
300        Self { dict }
301    }
302    #[cfg(not(feature = "alloc"))]
303    #[inline]
304    /// Create a new zeroed hash table.
305    pub fn new() -> Self {
306        const { assert_valid_entries(N) };
307        Self { dict: [0u32; N] }
308    }
309
310    /// Zero all entries.
311    #[inline]
312    pub fn clear(&mut self) {
313        self.dict.fill(0);
314    }
315
316    /// Subtract `offset` from all entries (saturating).
317    #[cold]
318    pub fn reposition(&mut self, offset: u32) {
319        for i in self.dict.iter_mut() {
320            *i = i.saturating_sub(offset);
321        }
322    }
323}
324impl<const N: usize> HashTable for HashTableU32<N> {
325    #[inline]
326    fn get_at(&self, idx: usize) -> usize {
327        self.dict[idx] as usize
328    }
329    #[inline]
330    fn put_at(&mut self, idx: usize, val: usize) {
331        self.dict[idx] = val as u32;
332    }
333    #[inline]
334    fn clear(&mut self) {
335        self.dict.fill(0);
336    }
337    #[inline]
338    #[cfg(target_pointer_width = "64")]
339    fn get_hash_at(input: &[u8], pos: usize) -> usize {
340        if U32_HASH_BYTES == 5 {
341            let batch = get_batch_arch(input, pos);
342            (batch << 24).wrapping_mul(PRIME5) >> (64 - N.ilog2() as usize)
343        } else {
344            let batch = u32::from_ne_bytes(input[pos..pos + 4].try_into().unwrap());
345            (batch.wrapping_mul(KNUTH) >> (32 - N.ilog2())) as usize
346        }
347    }
348    #[inline]
349    #[cfg(all(target_pointer_width = "64", not(feature = "paranoid")))]
350    fn get_hash_at_inbounds(input: &[u8], pos: usize) -> usize {
351        // SAFETY: callers guarantee pos + 8 <= input.len() via end_pos_check.
352        let batch = unsafe { get_batch_arch_unchecked(input, pos) };
353        (batch << 24).wrapping_mul(PRIME5) >> (64 - N.ilog2() as usize)
354    }
355    #[inline]
356    #[cfg(target_pointer_width = "32")]
357    fn get_hash_at(input: &[u8], pos: usize) -> usize {
358        let batch = u32::from_ne_bytes(input[pos..pos + 4].try_into().unwrap());
359        (batch.wrapping_mul(KNUTH) >> (32 - N.ilog2())) as usize
360    }
361}