use crate::error::{Result, ZiporaError};
use crate::system::{CpuFeatures, get_cpu_features};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchLocality {
L1Temporal,
L2Temporal,
L3Temporal,
NonTemporal,
}
impl PrefetchLocality {
#[cfg(target_arch = "x86_64")]
#[inline]
fn to_x86_hint(self) -> i32 {
match self {
PrefetchLocality::L1Temporal => _MM_HINT_T0,
PrefetchLocality::L2Temporal => _MM_HINT_T1,
PrefetchLocality::L3Temporal => _MM_HINT_T2,
PrefetchLocality::NonTemporal => _MM_HINT_NTA,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AccessPattern {
Sequential { stride: isize, confidence: u8 },
Strided { stride: isize, distance: usize },
Random { entropy: f32 },
PointerChasing { indirection_level: u8 },
Unknown,
}
impl AccessPattern {
fn optimal_locality(&self) -> PrefetchLocality {
match self {
AccessPattern::Sequential { confidence, .. } if *confidence >= 7 => {
PrefetchLocality::NonTemporal }
AccessPattern::Random { .. } => PrefetchLocality::L1Temporal, AccessPattern::PointerChasing { .. } => PrefetchLocality::L1Temporal,
AccessPattern::Strided { .. } => PrefetchLocality::L2Temporal,
_ => PrefetchLocality::L2Temporal,
}
}
fn optimal_distance(&self, base_distance: usize) -> usize {
match self {
AccessPattern::Sequential { confidence, .. } if *confidence >= 7 => {
base_distance * 2 }
AccessPattern::Random { .. } => base_distance / 4, AccessPattern::PointerChasing { .. } => 1, _ => base_distance,
}
}
}
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub base_distance: usize,
pub max_degree: usize,
pub adaptive_distance: bool,
pub enable_throttle: bool,
pub target_accuracy: f32,
pub memory_latency_cycles: usize,
pub max_bandwidth_gbps: f32,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
base_distance: 8, max_degree: 4, adaptive_distance: true, enable_throttle: true, target_accuracy: 0.70, memory_latency_cycles: 250, max_bandwidth_gbps: 40.0, }
}
}
impl PrefetchConfig {
pub fn sequential_optimized() -> Self {
Self {
base_distance: 16,
max_degree: 8,
..Default::default()
}
}
pub fn random_optimized() -> Self {
Self {
base_distance: 2,
max_degree: 2,
..Default::default()
}
}
pub fn pointer_chase_optimized() -> Self {
Self {
base_distance: 1,
max_degree: 3,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
struct StrideDetector {
last_addr: usize,
last_stride: isize,
confidence: u8,
history: VecDeque<isize>,
max_history: usize,
}
impl StrideDetector {
fn new() -> Self {
Self {
last_addr: 0,
last_stride: 0,
confidence: 0,
history: VecDeque::with_capacity(8),
max_history: 8,
}
}
fn detect(&mut self, addr: usize) -> Option<AccessPattern> {
if self.last_addr == 0 {
self.last_addr = addr;
return Some(AccessPattern::Unknown);
}
let stride = addr as isize - self.last_addr as isize;
self.history.push_back(stride);
if self.history.len() > self.max_history {
self.history.pop_front();
}
if stride == self.last_stride {
self.confidence = self.confidence.saturating_add(1);
} else {
self.confidence = 0;
self.last_stride = stride;
}
self.last_addr = addr;
if self.confidence >= 3 {
Some(AccessPattern::Sequential {
stride,
confidence: self.confidence,
})
} else if self.history.len() >= 4 {
let unique_strides: std::collections::HashSet<_> =
self.history.iter().copied().collect();
if unique_strides.len() <= 2 {
Some(AccessPattern::Strided {
stride,
distance: stride.unsigned_abs(),
})
} else {
let entropy = self.calculate_entropy();
Some(AccessPattern::Random { entropy })
}
} else {
Some(AccessPattern::Unknown)
}
}
fn calculate_entropy(&self) -> f32 {
if self.history.is_empty() {
return 0.0;
}
let mut frequencies = std::collections::HashMap::new();
for &stride in &self.history {
*frequencies.entry(stride).or_insert(0) += 1;
}
let total = self.history.len() as f32;
frequencies
.values()
.map(|&count| {
let p = count as f32 / total;
-p * p.log2()
})
.sum()
}
}
#[derive(Debug)]
struct BandwidthMonitor {
bytes_prefetched: AtomicUsize,
interval_start: Instant,
max_bandwidth_mbps: f32,
}
impl BandwidthMonitor {
fn new(max_bandwidth_gbps: f32) -> Self {
Self {
bytes_prefetched: AtomicUsize::new(0),
interval_start: Instant::now(),
max_bandwidth_mbps: max_bandwidth_gbps * 1000.0,
}
}
fn record_prefetch(&self, bytes: usize) {
self.bytes_prefetched.fetch_add(bytes, Ordering::Relaxed);
}
fn should_throttle(&mut self) -> bool {
let elapsed = self.interval_start.elapsed().as_secs_f32();
if elapsed > 0.1 {
let bytes = self.bytes_prefetched.load(Ordering::Relaxed);
let current_mbps = (bytes as f32 / elapsed) / 1_000_000.0;
self.bytes_prefetched.store(0, Ordering::Relaxed);
self.interval_start = Instant::now();
current_mbps > (self.max_bandwidth_mbps * 0.8)
} else {
false
}
}
}
#[derive(Debug, Clone)]
struct AccuracyThrottler {
prefetches_issued: usize,
prefetches_used: usize,
aggressiveness: f32,
}
impl AccuracyThrottler {
fn new() -> Self {
Self {
prefetches_issued: 0,
prefetches_used: 0,
aggressiveness: 1.0,
}
}
fn record_prefetch(&mut self, was_useful: bool) {
self.prefetches_issued += 1;
if was_useful {
self.prefetches_used += 1;
}
if self.prefetches_issued % 100 == 0 {
self.adjust_aggressiveness();
}
}
fn adjust_aggressiveness(&mut self) {
if self.prefetches_issued == 0 {
return;
}
let accuracy = self.prefetches_used as f32 / self.prefetches_issued as f32;
self.aggressiveness = match accuracy {
a if a > 0.85 => 1.2, a if a > 0.60 => 1.0, a if a > 0.40 => 0.7, _ => 0.3, };
self.prefetches_issued = 0;
self.prefetches_used = 0;
}
fn should_throttle(&self) -> bool {
self.aggressiveness < 0.5
}
fn scale_distance(&self, distance: usize) -> usize {
((distance as f32 * self.aggressiveness) as usize).max(1)
}
}
#[derive(Debug, Clone, Default)]
pub struct PrefetchMetrics {
pub prefetches_issued: usize,
pub useful_prefetches: usize,
pub wasted_prefetches: usize,
pub late_prefetches: usize,
pub current_pattern: Option<AccessPattern>,
}
impl PrefetchMetrics {
pub fn pollution_ratio(&self) -> f32 {
if self.prefetches_issued == 0 {
return 0.0;
}
self.wasted_prefetches as f32 / self.prefetches_issued as f32
}
pub fn accuracy(&self) -> f32 {
if self.prefetches_issued == 0 {
return 1.0;
}
self.useful_prefetches as f32 / self.prefetches_issued as f32
}
pub fn late_ratio(&self) -> f32 {
if self.prefetches_issued == 0 {
return 0.0;
}
self.late_prefetches as f32 / self.prefetches_issued as f32
}
}
pub struct PrefetchStrategy {
config: PrefetchConfig,
cpu_features: &'static CpuFeatures,
stride_detector: StrideDetector,
bandwidth_monitor: BandwidthMonitor,
accuracy_throttler: AccuracyThrottler,
metrics: PrefetchMetrics,
current_distance: usize,
}
impl PrefetchStrategy {
pub fn new(config: PrefetchConfig) -> Self {
Self {
current_distance: config.base_distance,
bandwidth_monitor: BandwidthMonitor::new(config.max_bandwidth_gbps),
cpu_features: get_cpu_features(),
stride_detector: StrideDetector::new(),
accuracy_throttler: AccuracyThrottler::new(),
metrics: PrefetchMetrics::default(),
config,
}
}
pub fn adaptive_prefetch(&mut self, data: &[u8], access_pattern: &[usize]) {
if access_pattern.is_empty() || self.should_throttle() || data.is_empty() {
return;
}
let pattern = if let Some(&addr) = access_pattern.last() {
self.stride_detector.detect(addr)
} else {
Some(AccessPattern::Unknown)
};
if let Some(pat) = pattern {
self.metrics.current_pattern = Some(pat);
if self.config.adaptive_distance {
self.current_distance = self.accuracy_throttler.scale_distance(
pat.optimal_distance(self.config.base_distance),
);
}
let locality = pat.optimal_locality();
let base = data.as_ptr();
match pat {
AccessPattern::Sequential { stride, .. } => {
self.sequential_prefetch_internal_safe(data, stride.unsigned_abs(), self.current_distance, locality);
}
AccessPattern::Strided { stride, .. } => {
self.sequential_prefetch_internal_safe(data, stride.unsigned_abs(), self.current_distance, locality);
}
AccessPattern::Random { .. } => {
for &addr in access_pattern.iter().rev().take(2) {
if addr < data.len() {
unsafe {
self.issue_prefetch(base.add(addr), locality);
}
}
}
}
_ => {}
}
}
}
pub fn sequential_prefetch(&mut self, data: &[u8], stride: usize, count: usize) {
if self.should_throttle() || data.is_empty() {
return;
}
self.sequential_prefetch_internal_safe(
data,
stride,
count,
PrefetchLocality::NonTemporal,
);
}
pub fn random_prefetch(&mut self, addresses: &[&u8]) {
if addresses.is_empty() || self.should_throttle() {
return;
}
let limit = addresses.len().min(self.config.max_degree);
for &addr_ref in addresses.iter().take(limit) {
unsafe {
self.issue_prefetch(addr_ref as *const u8, PrefetchLocality::L1Temporal);
}
}
}
pub fn record_useful_prefetch(&mut self) {
self.accuracy_throttler.record_prefetch(true);
self.metrics.useful_prefetches += 1;
}
pub fn record_wasted_prefetch(&mut self) {
self.accuracy_throttler.record_prefetch(false);
self.metrics.wasted_prefetches += 1;
}
pub fn metrics(&self) -> &PrefetchMetrics {
&self.metrics
}
pub fn reset_metrics(&mut self) {
self.metrics = PrefetchMetrics::default();
}
#[inline]
fn should_throttle(&mut self) -> bool {
if !self.config.enable_throttle {
return false;
}
self.bandwidth_monitor.should_throttle() || self.accuracy_throttler.should_throttle()
}
#[inline]
unsafe fn sequential_prefetch_internal(
&mut self,
base: *const u8,
stride: usize,
count: usize,
locality: PrefetchLocality,
) {
let effective_count = count.min(self.config.max_degree);
for i in 1..=effective_count {
let offset = stride * i;
unsafe {
self.issue_prefetch(base.add(offset), locality);
}
}
}
#[inline]
fn sequential_prefetch_internal_safe(
&mut self,
data: &[u8],
stride: usize,
count: usize,
locality: PrefetchLocality,
) {
let effective_count = count.min(self.config.max_degree);
let base = data.as_ptr();
let len = data.len();
for i in 1..=effective_count {
let offset = stride.saturating_mul(i);
if offset < len {
unsafe {
self.issue_prefetch(base.add(offset), locality);
}
}
}
}
#[inline]
unsafe fn issue_prefetch(&mut self, addr: *const u8, locality: PrefetchLocality) {
const CACHE_LINE_SIZE: usize = 64;
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::{_MM_HINT_T0, _MM_HINT_T1, _MM_HINT_T2, _MM_HINT_NTA, _mm_prefetch};
unsafe {
match locality {
PrefetchLocality::L1Temporal => {
_mm_prefetch::<_MM_HINT_T0>(addr as *const i8);
}
PrefetchLocality::L2Temporal => {
_mm_prefetch::<_MM_HINT_T1>(addr as *const i8);
}
PrefetchLocality::L3Temporal => {
_mm_prefetch::<_MM_HINT_T2>(addr as *const i8);
}
PrefetchLocality::NonTemporal => {
_mm_prefetch::<_MM_HINT_NTA>(addr as *const i8);
}
}
}
}
#[cfg(target_arch = "aarch64")]
{
match locality {
PrefetchLocality::L1Temporal => {
unsafe { std::arch::asm!("prfm pldl1keep, [{0}]", in(reg) addr, options(nostack)); }
}
PrefetchLocality::L2Temporal => {
unsafe { std::arch::asm!("prfm pldl2keep, [{0}]", in(reg) addr, options(nostack)); }
}
PrefetchLocality::L3Temporal => {
unsafe { std::arch::asm!("prfm pldl3keep, [{0}]", in(reg) addr, options(nostack)); }
}
PrefetchLocality::NonTemporal => {
unsafe { std::arch::asm!("prfm pldl1strm, [{0}]", in(reg) addr, options(nostack)); }
}
}
}
self.bandwidth_monitor.record_prefetch(CACHE_LINE_SIZE);
self.metrics.prefetches_issued += 1;
}
}
impl std::fmt::Debug for PrefetchStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrefetchStrategy")
.field("config", &self.config)
.field("current_distance", &self.current_distance)
.field("metrics", &self.metrics)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prefetch_config_defaults() {
let config = PrefetchConfig::default();
assert_eq!(config.base_distance, 8);
assert_eq!(config.max_degree, 4);
assert!(config.adaptive_distance);
assert!(config.enable_throttle);
}
#[test]
fn test_prefetch_config_presets() {
let seq = PrefetchConfig::sequential_optimized();
assert_eq!(seq.base_distance, 16);
assert_eq!(seq.max_degree, 8);
let rand = PrefetchConfig::random_optimized();
assert_eq!(rand.base_distance, 2);
assert_eq!(rand.max_degree, 2);
let ptr = PrefetchConfig::pointer_chase_optimized();
assert_eq!(ptr.base_distance, 1);
assert_eq!(ptr.max_degree, 3);
}
#[test]
fn test_stride_detector() {
let mut detector = StrideDetector::new();
for i in 0..10 {
let pattern = detector.detect(i * 8);
if i >= 3 {
if let Some(AccessPattern::Sequential { stride, confidence }) = pattern {
assert_eq!(stride, 8);
assert!(confidence >= 3);
}
}
}
}
#[test]
fn test_access_pattern_locality() {
let seq = AccessPattern::Sequential {
stride: 64,
confidence: 8,
};
assert_eq!(seq.optimal_locality(), PrefetchLocality::NonTemporal);
let rand = AccessPattern::Random { entropy: 0.9 };
assert_eq!(rand.optimal_locality(), PrefetchLocality::L1Temporal);
let ptr = AccessPattern::PointerChasing {
indirection_level: 2,
};
assert_eq!(ptr.optimal_locality(), PrefetchLocality::L1Temporal);
}
#[test]
fn test_prefetch_metrics() {
let mut metrics = PrefetchMetrics::default();
metrics.prefetches_issued = 100;
metrics.useful_prefetches = 70;
metrics.wasted_prefetches = 20;
metrics.late_prefetches = 10;
assert_eq!(metrics.accuracy(), 0.70);
assert_eq!(metrics.pollution_ratio(), 0.20);
assert_eq!(metrics.late_ratio(), 0.10);
}
#[test]
fn test_accuracy_throttler() {
let mut throttler = AccuracyThrottler::new();
for _ in 0..90 {
throttler.record_prefetch(true);
}
for _ in 0..10 {
throttler.record_prefetch(false);
}
assert!(!throttler.should_throttle());
assert!(throttler.aggressiveness >= 0.9);
throttler = AccuracyThrottler::new();
for _ in 0..30 {
throttler.record_prefetch(true);
}
for _ in 0..70 {
throttler.record_prefetch(false);
}
assert!(throttler.should_throttle());
assert!(throttler.aggressiveness < 0.5);
}
#[test]
fn test_prefetch_strategy_creation() {
let config = PrefetchConfig::default();
let strategy = PrefetchStrategy::new(config.clone());
assert_eq!(strategy.current_distance, config.base_distance);
assert_eq!(strategy.metrics.prefetches_issued, 0);
}
#[test]
fn test_sequential_prefetch_safe() {
let mut strategy = PrefetchStrategy::new(PrefetchConfig::sequential_optimized());
let data: Vec<u64> = vec![0; 1000];
let data_bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8)
};
strategy.sequential_prefetch(data_bytes, 64, 8);
assert!(strategy.metrics.prefetches_issued > 0);
}
#[test]
fn test_random_prefetch_safe() {
let mut strategy = PrefetchStrategy::new(PrefetchConfig::random_optimized());
let data: Vec<u64> = vec![0; 1000];
{
let data_bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8)
};
let addrs = [
&data_bytes[10 * 8],
&data_bytes[20 * 8],
];
strategy.random_prefetch(&addrs);
}
assert!(strategy.metrics.prefetches_issued > 0);
assert!(strategy.metrics.prefetches_issued <= 2);
}
#[test]
fn test_adaptive_prefetch_detection() {
let mut strategy = PrefetchStrategy::new(PrefetchConfig::default());
let data: Vec<u64> = vec![0; 1000];
let pattern = vec![0, 64, 128, 192, 256];
let data_bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8)
};
strategy.adaptive_prefetch(data_bytes, &pattern);
if let Some(AccessPattern::Sequential { .. }) = strategy.metrics.current_pattern {
}
}
#[test]
#[cfg(not(debug_assertions))]
fn test_prefetch_throttling() {
use std::thread;
use std::time::Duration;
let mut config = PrefetchConfig::default();
config.max_bandwidth_gbps = 0.001; let mut strategy = PrefetchStrategy::new(config);
let data: Vec<u64> = vec![0; 10000];
let data_bytes = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, data.len() * 8)
};
for _ in 0..5000 {
strategy.sequential_prefetch(data_bytes, 64, 8);
}
thread::sleep(Duration::from_millis(110));
let should_be_throttled = strategy.bandwidth_monitor.should_throttle();
assert!(should_be_throttled, "Bandwidth throttling should trigger: ~23 MB/s >> 0.8 MB/s threshold");
}
#[test]
fn test_record_prefetch_usefulness() {
let mut strategy = PrefetchStrategy::new(PrefetchConfig::default());
strategy.record_useful_prefetch();
strategy.record_useful_prefetch();
strategy.record_wasted_prefetch();
assert_eq!(strategy.metrics.useful_prefetches, 2);
assert_eq!(strategy.metrics.wasted_prefetches, 1);
}
}