#[cfg(target_arch = "x86_64")]
pub const CHUNK_SIZE: usize = 32;
#[cfg(target_arch = "aarch64")]
pub const CHUNK_SIZE: usize = 16;
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
pub const CHUNK_SIZE: usize = 8;
#[inline(always)]
pub unsafe fn copy_chunks<const N: usize>(
mut src: *const u8,
mut dst: *mut u8,
length: usize,
) {
unsafe {
let end = dst.add(length);
core::ptr::copy_nonoverlapping(src, dst, N);
src = src.add(N);
dst = dst.add(N);
while dst < end {
core::ptr::copy_nonoverlapping(src, dst, N);
src = src.add(N);
dst = dst.add(N);
}
}
}
#[inline(always)]
pub unsafe fn copy_stride_word(
_src: *const u8,
mut dst: *mut u8,
dist: usize,
length: usize,
) {
unsafe {
let end = dst.add(length);
let end_unrolled = dst.add(length.saturating_sub(8));
while dst < end_unrolled {
let word0 = core::ptr::read_unaligned(dst.sub(dist) as *const u64);
core::ptr::write_unaligned(dst as *mut u64, word0);
dst = dst.add(8);
let word1 = core::ptr::read_unaligned(dst.sub(dist) as *const u64);
core::ptr::write_unaligned(dst as *mut u64, word1);
dst = dst.add(8);
}
while dst < end {
let word = core::ptr::read_unaligned(dst.sub(dist) as *const u64);
core::ptr::write_unaligned(dst as *mut u64, word);
dst = dst.add(8);
}
}
}
#[cfg(target_arch = "x86_64")]
static HAS_AVX2: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::arch::is_x86_feature_detected!("avx2"));
#[cfg(target_arch = "x86_64")]
static HAS_AVX512BW: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::arch::is_x86_feature_detected!("avx512bw"));
#[inline(always)]
pub unsafe fn copy_match(
out_pos: *mut u8,
dist: usize,
length: usize,
) {
let src = unsafe { out_pos.sub(dist) };
#[cfg(target_arch = "x86_64")]
{
if dist >= 64 && *HAS_AVX512BW {
unsafe { copy_chunks_avx512(src, out_pos, length) };
return;
}
}
if dist >= CHUNK_SIZE {
unsafe { copy_chunks::<CHUNK_SIZE>(src, out_pos, length) };
} else if dist == 1 {
unsafe { core::ptr::write_bytes(out_pos, *src, length) };
} else if dist >= 8 {
unsafe { copy_stride_word(src, out_pos, dist, length) };
} else {
#[cfg(target_arch = "x86_64")]
{
if *HAS_AVX2 {
unsafe { copy_short_avx2(src, out_pos, dist, length) };
return;
}
}
unsafe { copy_short_scalar(src, out_pos, dist, length) };
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512bw")]
unsafe fn copy_chunks_avx512(
mut src: *const u8,
mut dst: *mut u8,
length: usize,
) {
use core::arch::x86_64::*;
unsafe {
let end = dst.add(length);
let chunk = _mm512_loadu_si512(src as *const __m512i);
_mm512_storeu_si512(dst as *mut __m512i, chunk);
src = src.add(64);
dst = dst.add(64);
while dst < end {
let chunk = _mm512_loadu_si512(src as *const __m512i);
_mm512_storeu_si512(dst as *mut __m512i, chunk);
src = src.add(64);
dst = dst.add(64);
}
}
}
#[inline(always)]
unsafe fn copy_short_scalar(
_src: *const u8,
dst: *mut u8,
dist: usize,
length: usize,
) {
unsafe {
for i in 0..length {
*dst.add(i) = *dst.add(i).sub(dist);
}
}
}
#[cfg(target_arch = "x86_64")]
#[repr(align(32))]
struct AlignedPerm([u8; 32]);
#[cfg(target_arch = "x86_64")]
static PERM_LUT: [AlignedPerm; 6] = {
const fn make_perm(dist: usize) -> [u8; 32] {
let mut perm = [0u8; 32];
let mut i = 0;
while i < 32 {
perm[i] = (i % dist) as u8;
i += 1;
}
perm
}
[
AlignedPerm(make_perm(2)),
AlignedPerm(make_perm(3)),
AlignedPerm(make_perm(4)),
AlignedPerm(make_perm(5)),
AlignedPerm(make_perm(6)),
AlignedPerm(make_perm(7)),
]
};
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn copy_short_avx2(
src: *const u8,
mut dst: *mut u8,
dist: usize,
length: usize,
) {
use core::arch::x86_64::*;
debug_assert!((2..=7).contains(&dist));
unsafe {
let src_vec = _mm_loadu_si128(src as *const __m128i);
let wide = _mm256_inserti128_si256(_mm256_castsi128_si256(src_vec), src_vec, 1);
let perm = _mm256_load_si256(PERM_LUT[dist - 2].0.as_ptr() as *const __m256i);
let tiled = _mm256_shuffle_epi8(wide, perm);
_mm256_storeu_si256(dst as *mut __m256i, tiled);
if length <= 32 {
return;
}
let end = dst.add(length);
dst = dst.add(32);
while dst < end {
let word = core::ptr::read_unaligned(dst.sub(dist) as *const u64);
core::ptr::write_unaligned(dst as *mut u64, word);
dst = dst.add(dist);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn copy_chunks_basic() {
let src = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
0, 0, 0, 0, 0, 0, 0, 0]; let mut dst = vec![0u8; 40]; unsafe { copy_chunks::<8>(src.as_ptr(), dst.as_mut_ptr(), 24) };
assert_eq!(&dst[..24], &src[..24]);
}
#[test]
fn copy_match_rle() {
let mut buf = vec![0u8; 64];
buf[0] = 0xAB;
unsafe { copy_match(buf.as_mut_ptr().add(1), 1, 31) };
assert!(buf[1..32].iter().all(|&b| b == 0xAB));
}
#[test]
fn copy_match_non_overlapping() {
let mut buf = vec![0u8; 128];
for i in 0..CHUNK_SIZE {
buf[i] = (i + 1) as u8;
}
let length = CHUNK_SIZE;
unsafe { copy_match(buf.as_mut_ptr().add(CHUNK_SIZE), CHUNK_SIZE, length) };
assert_eq!(&buf[CHUNK_SIZE..CHUNK_SIZE*2], &buf[..CHUNK_SIZE]);
}
#[test]
fn copy_match_short_dist() {
let mut buf = vec![0u8; 128];
buf[0] = b'A';
buf[1] = b'B';
buf[2] = b'C';
let length = 30;
unsafe { copy_match(buf.as_mut_ptr().add(3), 3, length) };
for i in 0..length {
assert_eq!(buf[3 + i], buf[i % 3], "mismatch at offset {i}");
}
}
#[test]
fn copy_stride_word_basic() {
let mut buf = vec![0u8; 128];
for i in 0..8 { buf[i] = (i + 1) as u8; }
unsafe { copy_stride_word(buf.as_ptr(), buf.as_mut_ptr().add(8), 8, 24) };
for i in 0..24 {
assert_eq!(buf[8 + i], buf[i % 8], "mismatch at {i}");
}
}
}
#[cfg(test)]
mod avx2_tests {
use super::*;
#[test]
fn copy_short_dist_extensive() {
for dist in 2..=7usize {
for length in [1, 2, 3, 5, 7, 8, 15, 16, 30, 31, 32, 33, 63, 64, 100, 200] {
let mut buf = vec![0u8; 512];
for i in 0..dist {
buf[i] = (i as u8).wrapping_mul(37).wrapping_add(11); }
unsafe { copy_match(buf.as_mut_ptr().add(dist), dist, length) };
let mut expected = vec![0u8; 512];
for i in 0..dist { expected[i] = buf[i]; }
for i in 0..length {
expected[dist + i] = expected[dist + i - dist]; }
for i in 0..length {
assert_eq!(
buf[dist + i], expected[dist + i],
"dist={dist} length={length}: mismatch at offset {i}, got {} expected {}",
buf[dist + i], expected[dist + i]
);
}
}
}
}
}