memkit 0.1.1-beta.1

Deterministic, intent-driven memory allocation for systems requiring predictable performance
Documentation
//! SIMD-accelerated memory operations.
//!
//! Uses platform intrinsics when available for fast memory initialization.

#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

#[cfg(target_arch = "x86")]
use std::arch::x86::*;

/// Fill a slice with zeros using the fastest available method.
#[inline(always)]
pub fn fill_zero(ptr: *mut u8, len: usize) {
    if len == 0 {
        return;
    }
    
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if is_x86_feature_detected!("avx2") && len >= 32 {
            unsafe { fill_zero_avx2(ptr, len) };
            return;
        }
        if is_x86_feature_detected!("sse2") && len >= 16 {
            unsafe { fill_zero_sse2(ptr, len) };
            return;
        }
    }
    
    // Fallback to standard memset
    unsafe { std::ptr::write_bytes(ptr, 0, len) };
}

/// Fill a slice with a repeated byte pattern using SIMD.
#[inline(always)]
pub fn fill_byte(ptr: *mut u8, value: u8, len: usize) {
    if len == 0 {
        return;
    }
    
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if is_x86_feature_detected!("avx2") && len >= 32 {
            unsafe { fill_byte_avx2(ptr, value, len) };
            return;
        }
        if is_x86_feature_detected!("sse2") && len >= 16 {
            unsafe { fill_byte_sse2(ptr, value, len) };
            return;
        }
    }
    
    // Fallback
    unsafe { std::ptr::write_bytes(ptr, value, len) };
}

/// Fill a slice with a repeated 32-bit value using SIMD.
#[inline(always)]
pub fn fill_u32(ptr: *mut u32, value: u32, count: usize) {
    if count == 0 {
        return;
    }
    
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if is_x86_feature_detected!("avx2") && count >= 8 {
            unsafe { fill_u32_avx2(ptr, value, count) };
            return;
        }
        if is_x86_feature_detected!("sse2") && count >= 4 {
            unsafe { fill_u32_sse2(ptr, value, count) };
            return;
        }
    }
    
    // Fallback - scalar fill
    for i in 0..count {
        unsafe { ptr.add(i).write(value) };
    }
}

/// Fill a slice with a repeated 64-bit value using SIMD.
#[inline(always)]
pub fn fill_u64(ptr: *mut u64, value: u64, count: usize) {
    if count == 0 {
        return;
    }
    
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if is_x86_feature_detected!("avx2") && count >= 4 {
            unsafe { fill_u64_avx2(ptr, value, count) };
            return;
        }
        if is_x86_feature_detected!("sse2") && count >= 2 {
            unsafe { fill_u64_sse2(ptr, value, count) };
            return;
        }
    }
    
    // Fallback
    for i in 0..count {
        unsafe { ptr.add(i).write(value) };
    }
}

/// Fill a slice with a repeated f32 value using SIMD.
#[inline(always)]
pub fn fill_f32(ptr: *mut f32, value: f32, count: usize) {
    // Reuse u32 implementation (same bit width)
    fill_u32(ptr as *mut u32, value.to_bits(), count);
}

/// Fill a slice with a repeated f64 value using SIMD.
#[inline(always)]
pub fn fill_f64(ptr: *mut f64, value: f64, count: usize) {
    // Reuse u64 implementation (same bit width)
    fill_u64(ptr as *mut u64, value.to_bits(), count);
}

// ============================================================================
// AVX2 Implementations (256-bit / 32-byte)
// ============================================================================

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_zero_avx2(ptr: *mut u8, len: usize) {
    let zero = _mm256_setzero_si256();
    let mut p = ptr;
    let mut remaining = len;
    
    // Main loop: 32 bytes at a time
    while remaining >= 32 {
        _mm256_storeu_si256(p as *mut __m256i, zero);
        p = p.add(32);
        remaining -= 32;
    }
    
    // Handle remainder
    if remaining > 0 {
        std::ptr::write_bytes(p, 0, remaining);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_byte_avx2(ptr: *mut u8, value: u8, len: usize) {
    let broadcast = _mm256_set1_epi8(value as i8);
    let mut p = ptr;
    let mut remaining = len;
    
    while remaining >= 32 {
        _mm256_storeu_si256(p as *mut __m256i, broadcast);
        p = p.add(32);
        remaining -= 32;
    }
    
    if remaining > 0 {
        std::ptr::write_bytes(p, value, remaining);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_u32_avx2(ptr: *mut u32, value: u32, count: usize) {
    let broadcast = _mm256_set1_epi32(value as i32);
    let mut p = ptr;
    let mut remaining = count;
    
    // 8 u32s per AVX2 register
    while remaining >= 8 {
        _mm256_storeu_si256(p as *mut __m256i, broadcast);
        p = p.add(8);
        remaining -= 8;
    }
    
    // Scalar remainder
    for i in 0..remaining {
        p.add(i).write(value);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "avx2")]
unsafe fn fill_u64_avx2(ptr: *mut u64, value: u64, count: usize) {
    let broadcast = _mm256_set1_epi64x(value as i64);
    let mut p = ptr;
    let mut remaining = count;
    
    // 4 u64s per AVX2 register
    while remaining >= 4 {
        _mm256_storeu_si256(p as *mut __m256i, broadcast);
        p = p.add(4);
        remaining -= 4;
    }
    
    // Scalar remainder
    for i in 0..remaining {
        p.add(i).write(value);
    }
}

// ============================================================================
// SSE2 Implementations (128-bit / 16-byte)
// ============================================================================

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_zero_sse2(ptr: *mut u8, len: usize) {
    let zero = _mm_setzero_si128();
    let mut p = ptr;
    let mut remaining = len;
    
    while remaining >= 16 {
        _mm_storeu_si128(p as *mut __m128i, zero);
        p = p.add(16);
        remaining -= 16;
    }
    
    if remaining > 0 {
        std::ptr::write_bytes(p, 0, remaining);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_byte_sse2(ptr: *mut u8, value: u8, len: usize) {
    let broadcast = _mm_set1_epi8(value as i8);
    let mut p = ptr;
    let mut remaining = len;
    
    while remaining >= 16 {
        _mm_storeu_si128(p as *mut __m128i, broadcast);
        p = p.add(16);
        remaining -= 16;
    }
    
    if remaining > 0 {
        std::ptr::write_bytes(p, value, remaining);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_u32_sse2(ptr: *mut u32, value: u32, count: usize) {
    let broadcast = _mm_set1_epi32(value as i32);
    let mut p = ptr;
    let mut remaining = count;
    
    while remaining >= 4 {
        _mm_storeu_si128(p as *mut __m128i, broadcast);
        p = p.add(4);
        remaining -= 4;
    }
    
    for i in 0..remaining {
        p.add(i).write(value);
    }
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
#[target_feature(enable = "sse2")]
unsafe fn fill_u64_sse2(ptr: *mut u64, value: u64, count: usize) {
    let broadcast = _mm_set1_epi64x(value as i64);
    let mut p = ptr;
    let mut remaining = count;
    
    while remaining >= 2 {
        _mm_storeu_si128(p as *mut __m128i, broadcast);
        p = p.add(2);
        remaining -= 2;
    }
    
    if remaining == 1 {
        p.write(value);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_fill_zero() {
        let mut buf = vec![0xFFu8; 100];
        fill_zero(buf.as_mut_ptr(), buf.len());
        assert!(buf.iter().all(|&b| b == 0));
    }

    #[test]
    fn test_fill_byte() {
        let mut buf = vec![0u8; 100];
        fill_byte(buf.as_mut_ptr(), 0xAB, buf.len());
        assert!(buf.iter().all(|&b| b == 0xAB));
    }

    #[test]
    fn test_fill_u32() {
        let mut buf = vec![0u32; 100];
        fill_u32(buf.as_mut_ptr(), 0xDEADBEEF, buf.len());
        assert!(buf.iter().all(|&v| v == 0xDEADBEEF));
    }

    #[test]
    fn test_fill_u64() {
        let mut buf = vec![0u64; 100];
        fill_u64(buf.as_mut_ptr(), 0xCAFEBABE_DEADBEEF, buf.len());
        assert!(buf.iter().all(|&v| v == 0xCAFEBABE_DEADBEEF));
    }

    #[test]
    fn test_fill_f32() {
        let mut buf = vec![0.0f32; 100];
        fill_f32(buf.as_mut_ptr(), 3.14, buf.len());
        assert!(buf.iter().all(|&v| (v - 3.14).abs() < f32::EPSILON));
    }
}