#[cfg(target_arch = "x86")]
use core::arch::x86::{
__m128i, __m256i, __m512i, _mm_loadu_si128, _mm_storeu_si128, _mm256_loadu_si256,
_mm256_storeu_si256, _mm512_loadu_si512, _mm512_storeu_si512,
};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m128i, __m256i, __m512i, _mm_loadu_si128, _mm_storeu_si128, _mm256_loadu_si256,
_mm256_storeu_si256, _mm512_loadu_si512, _mm512_storeu_si512,
};
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
use std::arch::is_x86_feature_detected;
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
use std::sync::OnceLock;
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
use core::arch::aarch64::{uint8x16_t, vld1q_u8, vst1q_u8};
#[inline(always)]
pub(crate) unsafe fn copy_bytes_overshooting(
src: (*const u8, usize),
dst: (*mut u8, usize),
copy_at_least: usize,
) {
if copy_at_least == 0 {
return;
}
let min_buffer_size = core::cmp::min(src.1, dst.1);
if copy_at_least <= 16 && min_buffer_size >= 16 {
unsafe { single_op_copy_16(src.0, dst.0, copy_at_least) };
debug_assert_eq_copy(src, dst, copy_at_least);
return;
}
if copy_at_least <= 32 {
unsafe {
if copy_at_least <= 8 {
let mut i = 0;
while i < copy_at_least {
dst.0.add(i).write(src.0.add(i).read());
i += 1;
}
} else if copy_at_least <= 16 {
let lo: u64 = src.0.cast::<u64>().read_unaligned();
let hi_offset = copy_at_least - 8;
let hi: u64 = src.0.add(hi_offset).cast::<u64>().read_unaligned();
dst.0.cast::<u64>().write_unaligned(lo);
dst.0.add(hi_offset).cast::<u64>().write_unaligned(hi);
} else {
let lo: u64 = src.0.cast::<u64>().read_unaligned();
let hi: u64 = src.0.add(8).cast::<u64>().read_unaligned();
dst.0.cast::<u64>().write_unaligned(lo);
dst.0.add(8).cast::<u64>().write_unaligned(hi);
let tail_off = copy_at_least - 16;
let tail_lo: u64 = src.0.add(tail_off).cast::<u64>().read_unaligned();
let tail_hi: u64 = src.0.add(copy_at_least - 8).cast::<u64>().read_unaligned();
dst.0.add(tail_off).cast::<u64>().write_unaligned(tail_lo);
dst.0
.add(copy_at_least - 8)
.cast::<u64>()
.write_unaligned(tail_hi);
}
}
debug_assert_eq_copy(src, dst, copy_at_least);
return;
}
macro_rules! try_chunk_kernel {
($chunk:expr, $kernel:ident) => {{
if copy_at_least >= $chunk {
let rounded = copy_at_least.next_multiple_of($chunk);
if min_buffer_size >= rounded {
unsafe { $kernel(src.0, dst.0, rounded) };
debug_assert_eq_copy(src, dst, copy_at_least);
return;
}
}
}};
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
{
let caps = detect_x86_caps();
if caps.avx512f {
try_chunk_kernel!(64, copy_avx512);
}
if caps.avx2 {
try_chunk_kernel!(32, copy_avx2);
}
if caps.sse2 {
try_chunk_kernel!(16, copy_sse2);
}
}
#[cfg(all(not(feature = "std"), any(target_arch = "x86", target_arch = "x86_64")))]
{
#[cfg(target_feature = "avx512f")]
try_chunk_kernel!(64, copy_avx512);
#[cfg(target_feature = "avx2")]
try_chunk_kernel!(32, copy_avx2);
#[cfg(target_feature = "sse2")]
try_chunk_kernel!(16, copy_sse2);
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
try_chunk_kernel!(16, copy_neon);
let scalar_chunk = core::mem::size_of::<usize>();
let rounded = copy_at_least.next_multiple_of(scalar_chunk);
if min_buffer_size >= rounded {
unsafe { copy_scalar(src.0, dst.0, rounded) };
} else {
unsafe { dst.0.copy_from_nonoverlapping(src.0, copy_at_least) };
}
debug_assert_eq_copy(src, dst, copy_at_least);
}
#[inline(always)]
unsafe fn single_op_copy_16(src: *const u8, dst: *mut u8, len: usize) {
debug_assert!(len <= 16);
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
unsafe {
let v: uint8x16_t = vld1q_u8(src);
vst1q_u8(dst, v);
return;
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
unsafe {
if detect_x86_caps().sse2 {
copy_sse2(src, dst, 16);
return;
}
}
#[cfg(all(
not(feature = "std"),
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "sse2"
))]
unsafe {
copy_sse2(src, dst, 16);
return;
}
#[allow(unreachable_code)]
unsafe {
let lo: u64 = src.cast::<u64>().read_unaligned();
let hi_offset = len.saturating_sub(8);
let hi: u64 = src.add(hi_offset).cast::<u64>().read_unaligned();
dst.cast::<u64>().write_unaligned(lo);
dst.add(hi_offset).cast::<u64>().write_unaligned(hi);
}
}
#[inline(always)]
fn debug_assert_eq_copy(_src: (*const u8, usize), _dst: (*mut u8, usize), _len: usize) {
#[cfg(debug_assertions)]
unsafe {
let s = core::slice::from_raw_parts(_src.0, _len);
let d = core::slice::from_raw_parts(_dst.0, _len);
debug_assert_eq!(s, d);
}
}
#[cfg(feature = "bench_internals")]
#[inline(always)]
pub(crate) unsafe fn copy_bytes_overshooting_for_bench(
src: (*const u8, usize),
dst: (*mut u8, usize),
copy_at_least: usize,
) {
unsafe { copy_bytes_overshooting(src, dst, copy_at_least) };
}
#[cfg(test)]
#[inline]
pub(crate) fn active_chunk_size_for_tests() -> usize {
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
{
let caps = detect_x86_caps();
if caps.avx512f {
return 64;
}
if caps.avx2 {
return 32;
}
if caps.sse2 {
return 16;
}
}
#[cfg(all(
not(feature = "std"),
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "avx512f"
))]
{
return 64;
}
#[cfg(all(
not(feature = "std"),
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "avx2"
))]
{
return 32;
}
#[cfg(all(
not(feature = "std"),
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "sse2"
))]
{
return 16;
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
{
return 16;
}
#[allow(unreachable_code)]
{
core::mem::size_of::<usize>()
}
}
#[inline(always)]
unsafe fn copy_scalar(mut src: *const u8, mut dst: *mut u8, len: usize) {
let end = unsafe { src.add(len) };
while src < end {
unsafe {
dst.cast::<usize>()
.write_unaligned(src.cast::<usize>().read_unaligned());
src = src.add(core::mem::size_of::<usize>());
dst = dst.add(core::mem::size_of::<usize>());
}
}
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[derive(Clone, Copy)]
struct X86Caps {
avx512f: bool,
avx2: bool,
sse2: bool,
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[inline(always)]
fn detect_x86_caps() -> X86Caps {
static CAPS: OnceLock<X86Caps> = OnceLock::new();
*CAPS.get_or_init(|| X86Caps {
avx512f: is_x86_feature_detected!("avx512f"),
avx2: is_x86_feature_detected!("avx2"),
sse2: is_x86_feature_detected!("sse2"),
})
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn copy_sse2(mut src: *const u8, mut dst: *mut u8, len: usize) {
let end = unsafe { src.add(len) };
while src < end {
unsafe {
let v: __m128i = _mm_loadu_si128(src.cast::<__m128i>());
_mm_storeu_si128(dst.cast::<__m128i>(), v);
src = src.add(16);
dst = dst.add(16);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
#[allow(dead_code)]
unsafe fn copy_avx2(mut src: *const u8, mut dst: *mut u8, len: usize) {
debug_assert!(
len.is_multiple_of(32),
"copy_avx2 expects len to be a multiple of 32 (dispatcher rounds up)",
);
let end_unrolled = len & !63;
let mut copied = 0usize;
while copied < end_unrolled {
unsafe {
let v0: __m256i = _mm256_loadu_si256(src.cast::<__m256i>());
let v1: __m256i = _mm256_loadu_si256(src.add(32).cast::<__m256i>());
_mm256_storeu_si256(dst.cast::<__m256i>(), v0);
_mm256_storeu_si256(dst.add(32).cast::<__m256i>(), v1);
src = src.add(64);
dst = dst.add(64);
}
copied += 64;
}
if copied < len {
unsafe {
let v: __m256i = _mm256_loadu_si256(src.cast::<__m256i>());
_mm256_storeu_si256(dst.cast::<__m256i>(), v);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx512f")]
#[allow(dead_code)]
unsafe fn copy_avx512(mut src: *const u8, mut dst: *mut u8, len: usize) {
let end = unsafe { src.add(len) };
while src < end {
unsafe {
let v: __m512i = _mm512_loadu_si512(src.cast::<__m512i>());
_mm512_storeu_si512(dst.cast::<__m512i>(), v);
src = src.add(64);
dst = dst.add(64);
}
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
#[inline(always)]
unsafe fn copy_neon(mut src: *const u8, mut dst: *mut u8, len: usize) {
let end = unsafe { src.add(len) };
while src < end {
unsafe {
let v: uint8x16_t = vld1q_u8(src);
vst1q_u8(dst, v);
src = src.add(16);
dst = dst.add(16);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn copy_bytes_overshooting_zero_len_is_noop() {
let src = [1_u8, 2, 3, 4];
let mut dst = [9_u8, 9, 9, 9];
unsafe {
copy_bytes_overshooting((src.as_ptr(), src.len()), (dst.as_mut_ptr(), dst.len()), 0);
}
assert_eq!(dst, [9_u8, 9, 9, 9]);
}
#[test]
fn copy_bytes_overshooting_fallback_exact_copy_when_caps_are_tight() {
let len = 65; let src = vec![5_u8; len];
let mut dst = vec![0_u8; len];
unsafe {
copy_bytes_overshooting((src.as_ptr(), len), (dst.as_mut_ptr(), len), len);
}
assert_eq!(dst, src);
}
#[test]
fn copy_bytes_overshooting_single_op_small() {
for len in 1..=16 {
let mut src = [0u8; 32];
for (i, b) in src.iter_mut().enumerate() {
*b = i as u8;
}
let mut dst = [0u8; 32];
unsafe {
copy_bytes_overshooting((src.as_ptr(), 32), (dst.as_mut_ptr(), 32), len);
}
assert_eq!(&dst[..len], &src[..len], "len={len}");
}
}
#[test]
fn copy_scalar_copies_requested_bytes() {
let src = [11_u8, 12, 13, 14, 15, 16, 17, 18];
let mut dst = [0_u8; 8];
unsafe { copy_scalar(src.as_ptr(), dst.as_mut_ptr(), src.len()) };
assert_eq!(dst, src);
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[test]
fn copy_sse2_copies_full_chunk_when_available() {
if !std::arch::is_x86_feature_detected!("sse2") {
return;
}
let src = [7_u8; 16];
let mut dst = [0_u8; 16];
unsafe { copy_sse2(src.as_ptr(), dst.as_mut_ptr(), 16) };
assert_eq!(dst, src);
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[test]
fn copy_avx2_copies_full_chunk_when_available() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let src = [8_u8; 32];
let mut dst = [0_u8; 32];
unsafe { copy_avx2(src.as_ptr(), dst.as_mut_ptr(), 32) };
assert_eq!(dst, src);
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[test]
fn copy_avx2_copies_full_unroll2_iteration() {
use alloc::vec::Vec;
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let src: Vec<u8> = (0..64u8).collect();
let mut dst = [0_u8; 64];
unsafe { copy_avx2(src.as_ptr(), dst.as_mut_ptr(), 64) };
assert_eq!(&dst[..], &src[..]);
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[test]
fn copy_avx2_copies_unroll2_loop_plus_residual_tail() {
use alloc::vec::Vec;
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let src: Vec<u8> = (0..96u8).collect();
let mut dst = [0_u8; 96];
unsafe { copy_avx2(src.as_ptr(), dst.as_mut_ptr(), 96) };
assert_eq!(&dst[..], &src[..]);
assert_eq!(&dst[60..68], &[60, 61, 62, 63, 64, 65, 66, 67]);
}
#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
#[test]
fn copy_avx512_copies_full_chunk_when_available() {
if !std::arch::is_x86_feature_detected!("avx512f") {
return;
}
let src = [9_u8; 64];
let mut dst = [0_u8; 64];
unsafe { copy_avx512(src.as_ptr(), dst.as_mut_ptr(), 64) };
assert_eq!(dst, src);
}
}