#![allow(dead_code)]
#![allow(clippy::cast_precision_loss)]
pub const CACHE_LINE_BYTES: usize = 64;
pub const PAGE_BYTES: usize = 4096;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchDistance {
Near,
Medium,
Far,
Custom(usize),
}
impl PrefetchDistance {
#[must_use]
pub fn lines(self) -> usize {
match self {
Self::Near => 1,
Self::Medium => 4,
Self::Far => 8,
Self::Custom(n) => n,
}
}
#[must_use]
pub fn bytes(self) -> usize {
self.lines() * CACHE_LINE_BYTES
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StridePattern {
Sequential,
Constant(usize),
Irregular,
}
#[must_use]
pub fn detect_stride(offsets: &[usize], element_bytes: usize) -> StridePattern {
if offsets.len() < 2 {
return StridePattern::Sequential;
}
let first_delta = offsets[1].wrapping_sub(offsets[0]);
let all_equal = offsets
.windows(2)
.all(|w| w[1].wrapping_sub(w[0]) == first_delta);
if !all_equal {
return StridePattern::Irregular;
}
if first_delta == element_bytes {
StridePattern::Sequential
} else {
StridePattern::Constant(first_delta)
}
}
#[allow(unused_variables)]
pub fn prefetch_read<T>(ptr: *const T) {
#[cfg(target_arch = "x86_64")]
{
unsafe {
core::arch::x86_64::_mm_prefetch(ptr.cast::<i8>(), core::arch::x86_64::_MM_HINT_T0);
}
}
}
pub fn prefetch_slice<T>(data: &[T], distance: PrefetchDistance) {
let step = distance.bytes() / core::mem::size_of::<T>().max(1);
let step = step.max(1);
let mut i = 0;
while i < data.len() {
prefetch_read(data[i..].as_ptr());
i += step;
}
}
#[derive(Debug, Clone)]
pub struct PrefetchPlanner {
element_bytes: usize,
cache_level: CacheLevel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheLevel {
L1,
L2,
}
impl PrefetchPlanner {
#[must_use]
pub fn new(element_bytes: usize, cache_level: CacheLevel) -> Self {
Self {
element_bytes,
cache_level,
}
}
#[must_use]
pub fn recommended_distance(&self) -> PrefetchDistance {
match self.cache_level {
CacheLevel::L1 => PrefetchDistance::Near,
CacheLevel::L2 => PrefetchDistance::Medium,
}
}
#[must_use]
pub fn loop_step(&self) -> usize {
let dist = self.recommended_distance().bytes();
(dist / self.element_bytes.max(1)).max(1)
}
#[must_use]
pub fn elements_per_cache_line(&self) -> usize {
CACHE_LINE_BYTES / self.element_bytes.max(1)
}
}
#[must_use]
pub fn prefetch_sum(data: &[f32], distance: PrefetchDistance) -> f64 {
let step = (distance.bytes() / core::mem::size_of::<f32>()).max(1);
let mut sum = 0.0f64;
for (i, &v) in data.iter().enumerate() {
if i + step < data.len() {
prefetch_read(data[i + step..].as_ptr());
}
sum += f64::from(v);
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefetch_distance_lines_near() {
assert_eq!(PrefetchDistance::Near.lines(), 1);
}
#[test]
fn test_prefetch_distance_lines_medium() {
assert_eq!(PrefetchDistance::Medium.lines(), 4);
}
#[test]
fn test_prefetch_distance_lines_far() {
assert_eq!(PrefetchDistance::Far.lines(), 8);
}
#[test]
fn test_prefetch_distance_custom() {
assert_eq!(PrefetchDistance::Custom(16).lines(), 16);
}
#[test]
fn test_prefetch_distance_bytes() {
assert_eq!(PrefetchDistance::Medium.bytes(), 256);
}
#[test]
fn test_detect_stride_sequential() {
let offsets = vec![0, 4, 8, 12];
assert_eq!(detect_stride(&offsets, 4), StridePattern::Sequential);
}
#[test]
fn test_detect_stride_constant() {
let offsets = vec![0, 16, 32, 48];
assert_eq!(detect_stride(&offsets, 4), StridePattern::Constant(16));
}
#[test]
fn test_detect_stride_irregular() {
let offsets = vec![0, 4, 20, 24];
assert_eq!(detect_stride(&offsets, 4), StridePattern::Irregular);
}
#[test]
fn test_detect_stride_single_element() {
let offsets = vec![42];
assert_eq!(detect_stride(&offsets, 4), StridePattern::Sequential);
}
#[test]
fn test_detect_stride_empty() {
let offsets: Vec<usize> = vec![];
assert_eq!(detect_stride(&offsets, 4), StridePattern::Sequential);
}
#[test]
fn test_prefetch_slice_does_not_panic() {
let data: Vec<f32> = (0..1024).map(|i| i as f32).collect();
prefetch_slice(&data, PrefetchDistance::Medium); }
#[test]
fn test_prefetch_planner_loop_step_l2() {
let planner = PrefetchPlanner::new(4, CacheLevel::L2);
assert_eq!(planner.loop_step(), 64);
}
#[test]
fn test_prefetch_planner_loop_step_l1() {
let planner = PrefetchPlanner::new(4, CacheLevel::L1);
assert_eq!(planner.loop_step(), 16);
}
#[test]
fn test_prefetch_planner_elements_per_cache_line() {
let planner = PrefetchPlanner::new(4, CacheLevel::L1);
assert_eq!(planner.elements_per_cache_line(), 16);
}
#[test]
fn test_prefetch_planner_recommended_distance_l1() {
let p = PrefetchPlanner::new(4, CacheLevel::L1);
assert_eq!(p.recommended_distance(), PrefetchDistance::Near);
}
#[test]
fn test_prefetch_sum_correctness() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let sum = prefetch_sum(&data, PrefetchDistance::Near);
assert!((sum - 10.0).abs() < 1e-9);
}
#[test]
fn test_prefetch_sum_empty() {
let data: Vec<f32> = vec![];
assert_eq!(prefetch_sum(&data, PrefetchDistance::Medium), 0.0);
}
}