use core::alloc::Layout;
use safe_allocator_api::RawAlloc;
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
any(feature = "estimator-avx512", feature = "estimator-avx2")
))]
use std::is_x86_feature_detected;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "estimator-avx2")]
mod avx2;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "estimator-avx512")]
mod avx512;
const GOLDEN_RATIO: u32 = 0x9E3779B1_u32;
const HASH_BITS: usize = 15; const HASH_SIZE: usize = 1 << HASH_BITS;
#[allow(dead_code)]
const HASH_MASK: u32 = (HASH_SIZE - 1) as u32;
pub fn estimate_num_lz_matches_fast(bytes: &[u8]) -> usize {
let layout = unsafe { Layout::from_size_align_unchecked(size_of::<u32>() * HASH_SIZE, 64) };
let mut alloc = RawAlloc::new_zeroed(layout).unwrap();
let hash_table = unsafe { &mut *(alloc.as_mut_ptr() as *mut [u32; HASH_SIZE]) };
let mut matches = 0;
let begin_ptr = bytes.as_ptr();
unsafe {
let end_ptr = begin_ptr.add(bytes.len().saturating_sub(7)); calculate_matches_impl(hash_table, &mut matches, begin_ptr, end_ptr);
}
matches
}
#[inline(always)]
fn calculate_matches_impl(
hash_table: &mut [u32; HASH_SIZE],
matches: &mut usize,
begin_ptr: *const u8,
end_ptr: *const u8,
) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "estimator-avx512")]
if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vl") {
unsafe {
avx512::calculate_matches_avx512(hash_table, matches, begin_ptr, end_ptr);
return;
}
}
#[cfg(feature = "estimator-avx2")]
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::calculate_matches_avx2(hash_table, matches, begin_ptr, end_ptr);
return;
}
}
}
unsafe {
calculate_matches_generic(hash_table, matches, begin_ptr, end_ptr);
}
}
#[inline(never)] pub(crate) unsafe fn calculate_matches_generic(
hash_table: &mut [u32; HASH_SIZE],
matches: &mut usize,
mut begin_ptr: *const u8,
end_ptr: *const u8,
) {
while begin_ptr < end_ptr {
let d0 = read_4_byte_le_unaligned(begin_ptr, 0);
let d1 = read_4_byte_le_unaligned(begin_ptr, 1);
let d2 = read_4_byte_le_unaligned(begin_ptr, 2);
let d3 = read_4_byte_le_unaligned(begin_ptr, 3);
begin_ptr = begin_ptr.add(4);
let d0 = reduce_to_3byte(d0);
let d1 = reduce_to_3byte(d1);
let d2 = reduce_to_3byte(d2);
let d3 = reduce_to_3byte(d3);
let h0 = hash_u32(d0);
let h1 = hash_u32(d1);
let h2 = hash_u32(d2);
let h3 = hash_u32(d3);
let index0 = (h0 >> (32 - HASH_BITS)) as usize;
let index1 = (h1 >> (32 - HASH_BITS)) as usize;
let index2 = (h2 >> (32 - HASH_BITS)) as usize;
let index3 = (h3 >> (32 - HASH_BITS)) as usize;
*matches += (hash_table[index0] == d0) as usize;
*matches += (hash_table[index1] == d1) as usize;
*matches += (hash_table[index2] == d2) as usize;
*matches += (hash_table[index3] == d3) as usize;
hash_table[index0] = d0;
hash_table[index1] = d1;
hash_table[index2] = d2;
hash_table[index3] = d3;
}
}
#[inline(always)]
#[allow(dead_code)]
pub(crate) fn hash_u32(value: u32) -> u32 {
value.wrapping_mul(GOLDEN_RATIO)
}
#[inline(always)]
#[allow(dead_code)]
pub(crate) unsafe fn read_4_byte_le_unaligned(ptr: *const u8, offset: usize) -> u32 {
(ptr.add(offset) as *const u32).read_unaligned().to_le()
}
#[inline(always)]
#[allow(dead_code)]
pub(crate) fn reduce_to_3byte(value: u32) -> u32 {
value & 0xFFFFFF
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use core::slice;
use std::borrow::ToOwned;
use std::format;
use std::vec::Vec;
use std::{println, vec};
#[test]
fn can_hash_u32() {
assert_ne!(
hash_u32(1),
hash_u32(2),
"Different inputs should produce different hashes"
);
}
#[test]
fn is_zero_on_empty_input() {
let empty: Vec<u8> = vec![];
assert_eq!(
estimate_num_lz_matches_fast(&empty),
0,
"Empty input should return 0 matches"
);
}
#[test]
fn is_0_on_small_input() {
let small = vec![1, 2, 3, 4, 5, 6];
assert_eq!(
estimate_num_lz_matches_fast(&small),
0,
"Input smaller than 7 bytes should return 0 matches"
);
}
#[rstest]
#[case(1 << 17, 0.001)] #[case(16777215, 0.001)] fn with_no_repetition_should_have_no_matches(
#[case] test_size: usize,
#[case] allowed_error: f32,
) {
let expected = (test_size as f32 * allowed_error) as usize;
let data = if test_size == 1 << 17 {
let unique: Vec<u16> = (0..u16::MAX).collect();
cast_u16_slice_to_u8_slice(&unique).to_vec()
} else {
generate_unique_3byte_sequence(test_size / 3)
};
let matches = estimate_num_lz_matches_fast(&data);
println!(
"[res:no_matches_{}] matches: {}, expected: < {}, allowed_error: {:.1}%, actual_error: {:.3}%",
if test_size == 1 << 17 {
"128k".to_owned()
} else {
format!("long_distance_{test_size}")
},
matches,
expected,
allowed_error * 100.0,
(matches as f32 / test_size as f32) * 100.0
); assert!(
matches < expected,
"Sequence with no repetitions should have very few matches, \
but got {matches} matches, expected at most {expected}"
);
}
fn generate_unique_3byte_sequence(length: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(length * 3);
for x in 0..length {
let b0 = (x & 0xFF) as u8; let b1 = ((x >> 8) & 0xFF) as u8; let b2 = ((x >> 16) & 0xFF) as u8;
result.push(b0);
result.push(b1);
result.push(b2);
}
result
}
#[rstest]
#[case(1 << 17, 1 << 12, 113000)] #[case(1 << 17, 1 << 13, 95000)] #[case(1 << 17, 1 << 14, 60000)] #[case(1 << 17, 1 << 15, 13000)] #[case(1 << 17, 1 << 16, 450)] fn estimate_num_lz_matches_at_various_offsets(
#[case] test_size: usize,
#[case] match_interval: usize,
#[case] min_matches: usize,
) {
assert!(
match_interval <= 1 << 16,
"Match interval must be <= 64K due to u16 limits"
);
assert!(match_interval > 0, "Match interval must be positive");
assert!(
test_size >= match_interval,
"Test size must be >= match interval"
);
assert!(
test_size.is_multiple_of(2) && match_interval.is_multiple_of(2),
"Test size and match interval must be even due to u16 alignment"
);
let mut unique: Vec<u16> = Vec::with_capacity(test_size / 2);
for x in 0..test_size / 2 {
let val = (x % (match_interval / 2)) as u16;
unique.push(val);
}
let matches = estimate_num_lz_matches_fast(cast_u16_slice_to_u8_slice(&unique));
let expected = test_size - match_interval;
let percentage = (matches as f32 / expected as f32) * 100.0;
assert!(
matches >= min_matches,
"Got {matches} matches, which is below minimum threshold of {min_matches}"
);
println!(
"[res:matches_{match_interval}_intervals_{test_size}] matches: {matches}, expected: < {expected}, minimum: {min_matches}, found: {percentage:.1}%"
); }
fn cast_u16_slice_to_u8_slice(u16_slice: &[u16]) -> &[u8] {
let ptr = u16_slice.as_ptr() as *const u8;
let len = u16_slice.len() * 2; unsafe { slice::from_raw_parts(ptr, len) }
}
}