use crate::wide_utils::{FromBitmask, WideUtilsExt};
#[cfg(target_arch = "aarch64")]
use std::arch::asm;
pub trait CacheLevel {
const HINT: i32;
}
pub struct NTA;
impl CacheLevel for NTA {
#[cfg(target_arch = "x86_64")]
const HINT: i32 = 0;
#[cfg(target_arch = "aarch64")]
const HINT: i32 = 0; }
pub struct L1;
impl CacheLevel for L1 {
#[cfg(target_arch = "x86_64")]
const HINT: i32 = 3;
#[cfg(target_arch = "aarch64")]
const HINT: i32 = 1; }
pub struct L2;
impl CacheLevel for L2 {
#[cfg(target_arch = "x86_64")]
const HINT: i32 = 2;
#[cfg(target_arch = "aarch64")]
const HINT: i32 = 2; }
pub struct L3;
impl CacheLevel for L3 {
#[cfg(target_arch = "x86_64")]
const HINT: i32 = 1;
#[cfg(target_arch = "aarch64")]
const HINT: i32 = 3; }
#[inline(always)]
pub fn prefetch_address<T, L: CacheLevel>(base: &T, offset: u32) {
let ptr = unsafe { (base as *const T).add(offset as usize) as *const i8 };
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
match L::HINT {
0 => _mm_prefetch(ptr, _MM_HINT_NTA),
1 => _mm_prefetch(ptr, _MM_HINT_T2),
2 => _mm_prefetch(ptr, _MM_HINT_T1),
3 => _mm_prefetch(ptr, _MM_HINT_T0),
_ => _mm_prefetch(ptr, _MM_HINT_T0),
}
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
match L::HINT {
0 => asm!("prfm pldl1strm, [{0}]", in(reg) ptr), 1 => asm!("prfm pldl1keep, [{0}]", in(reg) ptr), 2 => asm!("prfm pldl2keep, [{0}]", in(reg) ptr), 3 => asm!("prfm pldl3keep, [{0}]", in(reg) ptr), _ => asm!("prfm pldl1keep, [{0}]", in(reg) ptr), }
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = ptr;
}
}
#[inline(always)]
pub fn prefetch_eight_offsets<T, L: CacheLevel>(base: &T, offsets: &[u32; 8]) {
let base_ptr = base as *const T;
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
let ptrs = [
(base_ptr.add(offsets[0] as usize) as *const i8),
(base_ptr.add(offsets[1] as usize) as *const i8),
(base_ptr.add(offsets[2] as usize) as *const i8),
(base_ptr.add(offsets[3] as usize) as *const i8),
(base_ptr.add(offsets[4] as usize) as *const i8),
(base_ptr.add(offsets[5] as usize) as *const i8),
(base_ptr.add(offsets[6] as usize) as *const i8),
(base_ptr.add(offsets[7] as usize) as *const i8),
];
match L::HINT {
0 => {
_mm_prefetch(ptrs[0], _MM_HINT_NTA);
_mm_prefetch(ptrs[1], _MM_HINT_NTA);
_mm_prefetch(ptrs[2], _MM_HINT_NTA);
_mm_prefetch(ptrs[3], _MM_HINT_NTA);
_mm_prefetch(ptrs[4], _MM_HINT_NTA);
_mm_prefetch(ptrs[5], _MM_HINT_NTA);
_mm_prefetch(ptrs[6], _MM_HINT_NTA);
_mm_prefetch(ptrs[7], _MM_HINT_NTA);
}
1 => {
_mm_prefetch(ptrs[0], _MM_HINT_T2);
_mm_prefetch(ptrs[1], _MM_HINT_T2);
_mm_prefetch(ptrs[2], _MM_HINT_T2);
_mm_prefetch(ptrs[3], _MM_HINT_T2);
_mm_prefetch(ptrs[4], _MM_HINT_T2);
_mm_prefetch(ptrs[5], _MM_HINT_T2);
_mm_prefetch(ptrs[6], _MM_HINT_T2);
_mm_prefetch(ptrs[7], _MM_HINT_T2);
}
2 => {
_mm_prefetch(ptrs[0], _MM_HINT_T1);
_mm_prefetch(ptrs[1], _MM_HINT_T1);
_mm_prefetch(ptrs[2], _MM_HINT_T1);
_mm_prefetch(ptrs[3], _MM_HINT_T1);
_mm_prefetch(ptrs[4], _MM_HINT_T1);
_mm_prefetch(ptrs[5], _MM_HINT_T1);
_mm_prefetch(ptrs[6], _MM_HINT_T1);
_mm_prefetch(ptrs[7], _MM_HINT_T1);
}
3 => {
_mm_prefetch(ptrs[0], _MM_HINT_T0);
_mm_prefetch(ptrs[1], _MM_HINT_T0);
_mm_prefetch(ptrs[2], _MM_HINT_T0);
_mm_prefetch(ptrs[3], _MM_HINT_T0);
_mm_prefetch(ptrs[4], _MM_HINT_T0);
_mm_prefetch(ptrs[5], _MM_HINT_T0);
_mm_prefetch(ptrs[6], _MM_HINT_T0);
_mm_prefetch(ptrs[7], _MM_HINT_T0);
}
_ => {
_mm_prefetch(ptrs[0], _MM_HINT_T0);
_mm_prefetch(ptrs[1], _MM_HINT_T0);
_mm_prefetch(ptrs[2], _MM_HINT_T0);
_mm_prefetch(ptrs[3], _MM_HINT_T0);
_mm_prefetch(ptrs[4], _MM_HINT_T0);
_mm_prefetch(ptrs[5], _MM_HINT_T0);
_mm_prefetch(ptrs[6], _MM_HINT_T0);
_mm_prefetch(ptrs[7], _MM_HINT_T0);
}
}
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
let addrs = [
base_ptr.add(offsets[0] as usize) as *const u8,
base_ptr.add(offsets[1] as usize) as *const u8,
base_ptr.add(offsets[2] as usize) as *const u8,
base_ptr.add(offsets[3] as usize) as *const u8,
base_ptr.add(offsets[4] as usize) as *const u8,
base_ptr.add(offsets[5] as usize) as *const u8,
base_ptr.add(offsets[6] as usize) as *const u8,
base_ptr.add(offsets[7] as usize) as *const u8,
];
match L::HINT {
0 => {
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[0]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[1]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[2]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[3]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[4]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[5]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[6]);
asm!("prfm pldl1strm, [{0}]", in(reg) addrs[7]);
}
1 => {
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[0]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[1]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[2]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[3]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[4]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[5]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[6]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[7]);
}
2 => {
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[0]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[1]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[2]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[3]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[4]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[5]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[6]);
asm!("prfm pldl2keep, [{0}]", in(reg) addrs[7]);
}
3 => {
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[0]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[1]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[2]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[3]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[4]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[5]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[6]);
asm!("prfm pldl3keep, [{0}]", in(reg) addrs[7]);
}
_ => {
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[0]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[1]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[2]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[3]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[4]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[5]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[6]);
asm!("prfm pldl1keep, [{0}]", in(reg) addrs[7]);
}
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = (base_ptr, offsets);
}
}
#[inline(always)]
pub fn prefetch_eight_masked<T, L: CacheLevel>(base: &T, offsets: [u32; 8], mask: u8) {
let base_ptr = base as *const T;
let base_addr = base_ptr as u64;
let base_simd = wide::u64x8::splat(base_addr);
let offsets_u32_simd = wide::u32x8::from(offsets) * (std::mem::size_of::<T>() as u32);
let zero_offsets_simd = wide::u32x8::splat(0);
let mask_simd = wide::u32x8::from_bitmask(mask);
let blended_offsets_simd = mask_simd.blend(offsets_u32_simd, zero_offsets_simd);
let offsets_u64_simd = blended_offsets_simd.widen_to_u64x8();
let selected_addrs_simd = base_simd + offsets_u64_simd;
let selected_addrs = selected_addrs_simd.to_array();
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
unsafe {
let ptrs = [
selected_addrs[0] as *const i8,
selected_addrs[1] as *const i8,
selected_addrs[2] as *const i8,
selected_addrs[3] as *const i8,
selected_addrs[4] as *const i8,
selected_addrs[5] as *const i8,
selected_addrs[6] as *const i8,
selected_addrs[7] as *const i8,
];
match L::HINT {
0 => {
_mm_prefetch(ptrs[0], _MM_HINT_NTA);
_mm_prefetch(ptrs[1], _MM_HINT_NTA);
_mm_prefetch(ptrs[2], _MM_HINT_NTA);
_mm_prefetch(ptrs[3], _MM_HINT_NTA);
_mm_prefetch(ptrs[4], _MM_HINT_NTA);
_mm_prefetch(ptrs[5], _MM_HINT_NTA);
_mm_prefetch(ptrs[6], _MM_HINT_NTA);
_mm_prefetch(ptrs[7], _MM_HINT_NTA);
}
1 => {
_mm_prefetch(ptrs[0], _MM_HINT_T2);
_mm_prefetch(ptrs[1], _MM_HINT_T2);
_mm_prefetch(ptrs[2], _MM_HINT_T2);
_mm_prefetch(ptrs[3], _MM_HINT_T2);
_mm_prefetch(ptrs[4], _MM_HINT_T2);
_mm_prefetch(ptrs[5], _MM_HINT_T2);
_mm_prefetch(ptrs[6], _MM_HINT_T2);
_mm_prefetch(ptrs[7], _MM_HINT_T2);
}
2 => {
_mm_prefetch(ptrs[0], _MM_HINT_T1);
_mm_prefetch(ptrs[1], _MM_HINT_T1);
_mm_prefetch(ptrs[2], _MM_HINT_T1);
_mm_prefetch(ptrs[3], _MM_HINT_T1);
_mm_prefetch(ptrs[4], _MM_HINT_T1);
_mm_prefetch(ptrs[5], _MM_HINT_T1);
_mm_prefetch(ptrs[6], _MM_HINT_T1);
_mm_prefetch(ptrs[7], _MM_HINT_T1);
}
3 => {
_mm_prefetch(ptrs[0], _MM_HINT_T0);
_mm_prefetch(ptrs[1], _MM_HINT_T0);
_mm_prefetch(ptrs[2], _MM_HINT_T0);
_mm_prefetch(ptrs[3], _MM_HINT_T0);
_mm_prefetch(ptrs[4], _MM_HINT_T0);
_mm_prefetch(ptrs[5], _MM_HINT_T0);
_mm_prefetch(ptrs[6], _MM_HINT_T0);
_mm_prefetch(ptrs[7], _MM_HINT_T0);
}
_ => {
_mm_prefetch(ptrs[0], _MM_HINT_T0);
_mm_prefetch(ptrs[1], _MM_HINT_T0);
_mm_prefetch(ptrs[2], _MM_HINT_T0);
_mm_prefetch(ptrs[3], _MM_HINT_T0);
_mm_prefetch(ptrs[4], _MM_HINT_T0);
_mm_prefetch(ptrs[5], _MM_HINT_T0);
_mm_prefetch(ptrs[6], _MM_HINT_T0);
_mm_prefetch(ptrs[7], _MM_HINT_T0);
}
}
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
let ptrs = [
selected_addrs[0] as *const u8,
selected_addrs[1] as *const u8,
selected_addrs[2] as *const u8,
selected_addrs[3] as *const u8,
selected_addrs[4] as *const u8,
selected_addrs[5] as *const u8,
selected_addrs[6] as *const u8,
selected_addrs[7] as *const u8,
];
match L::HINT {
0 => {
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[0]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[1]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[2]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[3]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[4]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[5]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[6]);
asm!("prfm pldl1strm, [{0}]", in(reg) ptrs[7]);
}
1 => {
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[0]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[1]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[2]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[3]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[4]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[5]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[6]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[7]);
}
2 => {
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[0]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[1]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[2]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[3]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[4]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[5]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[6]);
asm!("prfm pldl2keep, [{0}]", in(reg) ptrs[7]);
}
3 => {
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[0]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[1]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[2]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[3]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[4]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[5]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[6]);
asm!("prfm pldl3keep, [{0}]", in(reg) ptrs[7]);
}
_ => {
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[0]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[1]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[2]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[3]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[4]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[5]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[6]);
asm!("prfm pldl1keep, [{0}]", in(reg) ptrs[7]);
}
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = selected_addrs;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_level_constants() {
assert_eq!(NTA::HINT >= 0, true);
assert_eq!(L1::HINT >= 0, true);
assert_eq!(L2::HINT >= 0, true);
assert_eq!(L3::HINT >= 0, true);
}
#[test]
fn test_prefetch_single_address() {
let data = vec![0u32; 100];
prefetch_address::<_, NTA>(&data, 10);
prefetch_address::<_, L1>(&data, 20);
prefetch_address::<_, L2>(&data, 30);
prefetch_address::<_, L3>(&data, 40);
}
#[test]
fn test_prefetch_eight_addresses() {
let data = vec![0u32; 100];
let offsets = [10, 20, 30, 40, 50, 60, 70, 80];
prefetch_eight_offsets::<_, NTA>(&data, &offsets);
prefetch_eight_offsets::<_, L1>(&data, &offsets);
prefetch_eight_offsets::<_, L2>(&data, &offsets);
prefetch_eight_offsets::<_, L3>(&data, &offsets);
}
#[test]
fn test_prefetch_eight_masked() {
let data = vec![0u32; 100];
let offsets = [10, 20, 30, 40, 50, 60, 70, 80];
prefetch_eight_masked::<_, L1>(&data, offsets, 0xFF); prefetch_eight_masked::<_, L1>(&data, offsets, 0x00); prefetch_eight_masked::<_, L1>(&data, offsets, 0xAA); prefetch_eight_masked::<_, L1>(&data, offsets, 0x55); prefetch_eight_masked::<_, L1>(&data, offsets, 0x0F); prefetch_eight_masked::<_, L1>(&data, offsets, 0xF0); }
#[test]
fn test_different_data_types() {
let u32_data = vec![0u32; 100];
let u64_data = vec![0u64; 100];
let f32_data = vec![0.0f32; 100];
let offsets = [1, 2, 3, 4, 5, 6, 7, 8];
prefetch_eight_offsets::<_, L1>(&u32_data, &offsets);
prefetch_eight_offsets::<_, L1>(&u64_data, &offsets);
prefetch_eight_offsets::<_, L1>(&f32_data, &offsets);
prefetch_eight_masked::<_, L1>(&u32_data, offsets, 0xFF);
prefetch_eight_masked::<_, L1>(&u64_data, offsets, 0xAA);
prefetch_eight_masked::<_, L1>(&f32_data, offsets, 0x55);
}
}