#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn prefetch_read_l1(ptr: *const u8) {
unsafe {
core::arch::asm!(
"prfm pldl1keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn prefetch_read_l2(ptr: *const u8) {
unsafe {
core::arch::asm!(
"prfm pldl2keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn prefetch_read_l3(ptr: *const u8) {
unsafe {
core::arch::asm!(
"prfm pldl3keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn prefetch_write_l1(ptr: *const u8) {
unsafe {
core::arch::asm!(
"prfm pstl1keep, [{ptr}]",
ptr = in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub fn prefetch_vector_neon(vector: &[f32]) {
if vector.is_empty() {
return;
}
const ARM_CACHE_LINE: usize = 128;
let base = vector.as_ptr().cast::<u8>();
let vector_bytes = vector.len() * std::mem::size_of::<f32>();
prefetch_read_l1(base);
if vector_bytes > ARM_CACHE_LINE {
let ptr = unsafe { base.add(ARM_CACHE_LINE) };
prefetch_read_l2(ptr);
}
if vector_bytes > ARM_CACHE_LINE * 2 {
let ptr = unsafe { base.add(ARM_CACHE_LINE * 2) };
prefetch_read_l3(ptr);
}
if vector_bytes > ARM_CACHE_LINE * 3 {
let ptr = unsafe { base.add(ARM_CACHE_LINE * 3) };
prefetch_read_l3(ptr);
}
}
#[must_use]
#[inline]
pub fn calculate_prefetch_distance_neon(dimension: usize) -> usize {
match dimension {
0..=128 => 4,
129..=384 => 6,
385..=768 => 10,
769..=1536 => 14,
_ => 16,
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline(always)]
pub fn prefetch_read_l1(_ptr: *const u8) {}
#[cfg(not(target_arch = "aarch64"))]
#[inline(always)]
pub fn prefetch_read_l2(_ptr: *const u8) {}
#[cfg(not(target_arch = "aarch64"))]
#[inline(always)]
pub fn prefetch_read_l3(_ptr: *const u8) {}
#[cfg(not(target_arch = "aarch64"))]
#[inline(always)]
pub fn prefetch_write_l1(_ptr: *const u8) {}
#[cfg(not(target_arch = "aarch64"))]
#[inline(always)]
pub fn prefetch_vector_neon(_vector: &[f32]) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefetch_read_l1_safe() {
let data = vec![0u8; 4096];
prefetch_read_l1(data.as_ptr());
}
#[test]
fn test_prefetch_read_l1_null_safe() {
prefetch_read_l1(std::ptr::null());
}
#[test]
fn test_prefetch_vector_neon() {
let vector: Vec<f32> = (0..768).map(|i| i as f32).collect();
prefetch_vector_neon(&vector);
}
#[test]
fn test_prefetch_vector_neon_empty() {
let vector: Vec<f32> = vec![];
prefetch_vector_neon(&vector);
}
#[test]
fn test_calculate_prefetch_distance() {
assert_eq!(calculate_prefetch_distance_neon(128), 4);
assert_eq!(calculate_prefetch_distance_neon(384), 6);
assert_eq!(calculate_prefetch_distance_neon(768), 10);
assert_eq!(calculate_prefetch_distance_neon(1536), 14);
assert_eq!(calculate_prefetch_distance_neon(3072), 16);
}
#[test]
fn test_all_prefetch_variants() {
let data = vec![0u8; 256];
let ptr = data.as_ptr();
prefetch_read_l1(ptr);
prefetch_read_l2(ptr);
prefetch_read_l3(ptr);
prefetch_write_l1(ptr);
}
}