#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[cfg(target_arch = "x86")]
use std::arch::x86::*;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
#[cfg(target_arch = "arm")]
use std::arch::arm::*;
#[inline(always)]
pub fn fill_zero(ptr: *mut u8, len: usize) {
if len == 0 {
return;
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx512f") && len >= 64 {
unsafe { fill_zero_avx512(ptr, len) };
return;
}
if is_x86_feature_detected!("avx2") && len >= 32 {
unsafe { fill_zero_avx2(ptr, len) };
return;
}
if is_x86_feature_detected!("sse2") && len >= 16 {
unsafe { fill_zero_sse2(ptr, len) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if is_aarch64_feature_detected!("neon") && len >= 16 {
unsafe { fill_zero_neon(ptr, len) };
return;
}
}
#[cfg(target_arch = "arm")]
{
if is_arm_feature_detected!("neon") && len >= 16 {
unsafe { fill_zero_neon(ptr, len) };
return;
}
}
unsafe { std::ptr::write_bytes(ptr, 0, len) };
}
#[inline(always)]
pub fn fill_byte(ptr: *mut u8, value: u8, len: usize) {
if len == 0 {
return;
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx512f") && len >= 64 {
unsafe { fill_byte_avx512(ptr, value, len) };
return;
}
if is_x86_feature_detected!("avx2") && len >= 32 {
unsafe { fill_byte_avx2(ptr, value, len) };
return;
}
if is_x86_feature_detected!("sse2") && len >= 16 {
unsafe { fill_byte_sse2(ptr, value, len) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if is_aarch64_feature_detected!("neon") && len >= 16 {
unsafe { fill_byte_neon(ptr, value, len) };
return;
}
}
#[cfg(target_arch = "arm")]
{
if is_arm_feature_detected!("neon") && len >= 16 {
unsafe { fill_byte_neon(ptr, value, len) };
return;
}
}
unsafe { std::ptr::write_bytes(ptr, value, len) };
}
#[inline(always)]
pub fn fill_u32(ptr: *mut u32, value: u32, count: usize) {
if count == 0 {
return;
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx512f") && count >= 16 {
unsafe { fill_u32_avx512(ptr, value, count) };
return;
}
if is_x86_feature_detected!("avx2") && count >= 8 {
unsafe { fill_u32_avx2(ptr, value, count) };
return;
}
if is_x86_feature_detected!("sse2") && count >= 4 {
unsafe { fill_u32_sse2(ptr, value, count) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if is_aarch64_feature_detected!("neon") && count >= 4 {
unsafe { fill_u32_neon(ptr, value, count) };
return;
}
}
for i in 0..count {
unsafe { ptr.add(i).write(value) };
}
}
#[inline(always)]
pub fn fill_u64(ptr: *mut u64, value: u64, count: usize) {
if count == 0 {
return;
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx512f") && count >= 8 {
unsafe { fill_u64_avx512(ptr, value, count) };
return;
}
if is_x86_feature_detected!("avx2") && count >= 4 {
unsafe { fill_u64_avx2(ptr, value, count) };
return;
}
if is_x86_feature_detected!("sse2") && count >= 2 {
unsafe { fill_u64_sse2(ptr, value, count) };
return;
}
}
#[cfg(target_arch = "aarch64")]
{
if is_aarch64_feature_detected!("neon") && count >= 2 {
unsafe { fill_u64_neon(ptr, value, count) };
return;
}
}
for i in 0..count {
unsafe { ptr.add(i).write(value) };
}
}
#[inline(always)]
pub fn fill_f32(ptr: *mut f32, value: f32, count: usize) {
fill_u32(ptr as *mut u32, value.to_bits(), count);
}
#[inline(always)]
pub fn fill_f64(ptr: *mut f64, value: f64, count: usize) {
fill_u64(ptr as *mut u64, value.to_bits(), count);
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_zero_avx2(ptr: *mut u8, len: usize) {
let zero = _mm256_setzero_si256();
let mut p = ptr;
let mut remaining = len;
while remaining >= 32 {
_mm256_storeu_si256(p as *mut __m256i, zero);
p = p.add(32);
remaining -= 32;
}
if remaining > 0 {
std::ptr::write_bytes(p, 0, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_byte_avx2(ptr: *mut u8, value: u8, len: usize) {
let broadcast = _mm256_set1_epi8(value as i8);
let mut p = ptr;
let mut remaining = len;
while remaining >= 32 {
_mm256_storeu_si256(p as *mut __m256i, broadcast);
p = p.add(32);
remaining -= 32;
}
if remaining > 0 {
std::ptr::write_bytes(p, value, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_u32_avx2(ptr: *mut u32, value: u32, count: usize) {
let broadcast = _mm256_set1_epi32(value as i32);
let mut p = ptr;
let mut remaining = count;
while remaining >= 8 {
_mm256_storeu_si256(p as *mut __m256i, broadcast);
p = p.add(8);
remaining -= 8;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_u64_avx2(ptr: *mut u64, value: u64, count: usize) {
let broadcast = _mm256_set1_epi64x(value as i64);
let mut p = ptr;
let mut remaining = count;
while remaining >= 4 {
_mm256_storeu_si256(p as *mut __m256i, broadcast);
p = p.add(4);
remaining -= 4;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_zero_sse2(ptr: *mut u8, len: usize) {
let zero = _mm_setzero_si128();
let mut p = ptr;
let mut remaining = len;
while remaining >= 16 {
_mm_storeu_si128(p as *mut __m128i, zero);
p = p.add(16);
remaining -= 16;
}
if remaining > 0 {
std::ptr::write_bytes(p, 0, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_byte_sse2(ptr: *mut u8, value: u8, len: usize) {
let broadcast = _mm_set1_epi8(value as i8);
let mut p = ptr;
let mut remaining = len;
while remaining >= 16 {
_mm_storeu_si128(p as *mut __m128i, broadcast);
p = p.add(16);
remaining -= 16;
}
if remaining > 0 {
std::ptr::write_bytes(p, value, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_u32_sse2(ptr: *mut u32, value: u32, count: usize) {
let broadcast = _mm_set1_epi32(value as i32);
let mut p = ptr;
let mut remaining = count;
while remaining >= 4 {
_mm_storeu_si128(p as *mut __m128i, broadcast);
p = p.add(4);
remaining -= 4;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_u64_sse2(ptr: *mut u64, value: u64, count: usize) {
let broadcast = _mm_set1_epi64x(value as i64);
let mut p = ptr;
let mut remaining = count;
while remaining >= 2 {
_mm_storeu_si128(p as *mut __m128i, broadcast);
p = p.add(2);
remaining -= 2;
}
if remaining == 1 {
p.write(value);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx512f")]
unsafe fn fill_zero_avx512(ptr: *mut u8, len: usize) {
let zero = _mm512_setzero_si512();
let mut p = ptr;
let mut remaining = len;
while remaining >= 64 {
_mm512_storeu_si512(p as *mut __m512i, zero);
p = p.add(64);
remaining -= 64;
}
if remaining > 0 {
std::ptr::write_bytes(p, 0, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx512f")]
unsafe fn fill_byte_avx512(ptr: *mut u8, value: u8, len: usize) {
let broadcast = _mm512_set1_epi8(value as i8);
let mut p = ptr;
let mut remaining = len;
while remaining >= 64 {
_mm512_storeu_si512(p as *mut __m512i, broadcast);
p = p.add(64);
remaining -= 64;
}
if remaining > 0 {
std::ptr::write_bytes(p, value, remaining);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx512f")]
unsafe fn fill_u32_avx512(ptr: *mut u32, value: u32, count: usize) {
let broadcast = _mm512_set1_epi32(value as i32);
let mut p = ptr;
let mut remaining = count;
while remaining >= 16 {
_mm512_storeu_si512(p as *mut __m512i, broadcast);
p = p.add(16);
remaining -= 16;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx512f")]
unsafe fn fill_u64_avx512(ptr: *mut u64, value: u64, count: usize) {
let broadcast = _mm512_set1_epi64(value as i64);
let mut p = ptr;
let mut remaining = count;
while remaining >= 8 {
_mm512_storeu_si512(p as *mut __m512i, broadcast);
p = p.add(8);
remaining -= 8;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
#[target_feature(enable = "neon")]
unsafe fn fill_zero_neon(ptr: *mut u8, len: usize) {
let zero = vdupq_n_u8(0);
let mut p = ptr;
let mut remaining = len;
while remaining >= 16 {
vst1q_u8(p, zero);
p = p.add(16);
remaining -= 16;
}
if remaining > 0 {
std::ptr::write_bytes(p, 0, remaining);
}
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
#[target_feature(enable = "neon")]
unsafe fn fill_byte_neon(ptr: *mut u8, value: u8, len: usize) {
let broadcast = vdupq_n_u8(value);
let mut p = ptr;
let mut remaining = len;
while remaining >= 16 {
vst1q_u8(p, broadcast);
p = p.add(16);
remaining -= 16;
}
if remaining > 0 {
std::ptr::write_bytes(p, value, remaining);
}
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
#[target_feature(enable = "neon")]
unsafe fn fill_u32_neon(ptr: *mut u32, value: u32, count: usize) {
let broadcast = vdupq_n_u32(value);
let mut p = ptr;
let mut remaining = count;
while remaining >= 4 {
vst1q_u32(p, broadcast);
p = p.add(4);
remaining -= 4;
}
for i in 0..remaining {
p.add(i).write(value);
}
}
#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
#[target_feature(enable = "neon")]
unsafe fn fill_u64_neon(ptr: *mut u64, value: u64, count: usize) {
let broadcast = vdupq_n_u64(value);
let mut p = ptr;
let mut remaining = count;
while remaining >= 2 {
vst1q_u64(p, broadcast);
p = p.add(2);
remaining -= 2;
}
if remaining == 1 {
p.write(value);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fill_zero() {
let mut buf = vec![0xFFu8; 100];
fill_zero(buf.as_mut_ptr(), buf.len());
assert!(buf.iter().all(|&b| b == 0));
}
#[test]
fn test_fill_byte() {
let mut buf = vec![0u8; 100];
fill_byte(buf.as_mut_ptr(), 0xAB, buf.len());
assert!(buf.iter().all(|&b| b == 0xAB));
}
#[test]
fn test_fill_u32() {
let mut buf = vec![0u32; 100];
fill_u32(buf.as_mut_ptr(), 0xDEADBEEF, buf.len());
assert!(buf.iter().all(|&v| v == 0xDEADBEEF));
}
#[test]
fn test_fill_u64() {
let mut buf = vec![0u64; 100];
fill_u64(buf.as_mut_ptr(), 0xCAFEBABE_DEADBEEF, buf.len());
assert!(buf.iter().all(|&v| v == 0xCAFEBABE_DEADBEEF));
}
#[test]
fn test_fill_f32() {
let mut buf = vec![0.0f32; 100];
fill_f32(buf.as_mut_ptr(), 3.14, buf.len());
assert!(buf.iter().all(|&v| (v - 3.14).abs() < f32::EPSILON));
}
}