use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use libm;
use spin::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessPattern {
Sequential,
ReverseSequential,
Strided,
Random,
Looping,
}
impl AccessPattern {
pub fn prefetch_distance(&self) -> usize {
match self {
AccessPattern::Sequential => 16, AccessPattern::ReverseSequential => 8, AccessPattern::Strided => 4, AccessPattern::Random => 0, AccessPattern::Looping => 2, }
}
pub fn confidence_threshold(&self) -> f32 {
match self {
AccessPattern::Sequential => 0.7, AccessPattern::ReverseSequential => 0.75,
AccessPattern::Strided => 0.8,
AccessPattern::Random => 0.95, AccessPattern::Looping => 0.85,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct AccessRecord {
pub offset: u64,
pub timestamp: u64,
pub size: u64,
pub is_read: bool,
}
#[derive(Debug, Clone)]
pub struct PatternWindow {
accesses: Vec<AccessRecord>,
write_pos: usize,
window_size: usize,
}
impl PatternWindow {
pub fn new(window_size: usize) -> Self {
Self {
accesses: Vec::with_capacity(window_size),
write_pos: 0,
window_size,
}
}
pub fn add_access(&mut self, record: AccessRecord) {
if self.accesses.len() < self.window_size {
self.accesses.push(record);
} else {
self.accesses[self.write_pos] = record;
self.write_pos = (self.write_pos + 1) % self.window_size;
}
}
pub fn detect_pattern(&self) -> (AccessPattern, f32) {
if self.accesses.len() < 3 {
return (AccessPattern::Random, 0.0);
}
let mut deltas = Vec::new();
for i in 1..self.accesses.len() {
let prev_idx = if i > 0 {
i - 1
} else {
self.accesses.len() - 1
};
let delta = self.accesses[i].offset as i64 - self.accesses[prev_idx].offset as i64;
deltas.push(delta);
}
let sequential_count = deltas.iter().filter(|&&d| d > 0 && d < 1024 * 1024).count();
let sequential_ratio = sequential_count as f32 / deltas.len() as f32;
if sequential_ratio > 0.8 {
return (AccessPattern::Sequential, sequential_ratio);
}
let reverse_count = deltas
.iter()
.filter(|&&d| d < 0 && d > -1024 * 1024)
.count();
let reverse_ratio = reverse_count as f32 / deltas.len() as f32;
if reverse_ratio > 0.8 {
return (AccessPattern::ReverseSequential, reverse_ratio);
}
if let Some(&first_delta) = deltas.first() {
if first_delta != 0 {
let stride_matches = deltas
.iter()
.filter(|&&d| (d - first_delta).abs() < 4096)
.count();
let stride_ratio = stride_matches as f32 / deltas.len() as f32;
if stride_ratio > 0.7 {
return (AccessPattern::Strided, stride_ratio);
}
}
}
let mut unique_offsets = BTreeMap::new();
for access in &self.accesses {
*unique_offsets.entry(access.offset).or_insert(0) += 1;
}
let revisit_count = unique_offsets.values().filter(|&&count| count > 1).count();
let loop_ratio = revisit_count as f32 / unique_offsets.len().max(1) as f32;
if loop_ratio > 0.5 {
return (AccessPattern::Looping, loop_ratio);
}
(
AccessPattern::Random,
1.0 - sequential_ratio.max(reverse_ratio),
)
}
}
#[derive(Debug, Clone)]
pub struct PatternPredictor {
weights_ih: Vec<Vec<f32>>,
weights_ho: Vec<Vec<f32>>,
bias_h: Vec<f32>,
bias_o: Vec<f32>,
}
impl Default for PatternPredictor {
fn default() -> Self {
Self::new()
}
}
impl PatternPredictor {
pub fn new() -> Self {
Self {
weights_ih: vec![
vec![0.5, -0.3, 0.2, 0.1, 0.4, -0.2, 0.3, 0.1],
vec![-0.2, 0.4, -0.1, 0.3, 0.2, 0.5, -0.3, 0.2],
vec![0.3, 0.1, 0.5, -0.2, -0.1, 0.3, 0.4, -0.1],
vec![0.1, -0.2, 0.3, 0.4, 0.5, -0.3, 0.1, 0.2],
],
weights_ho: vec![
vec![0.6, -0.2, 0.1, -0.3, 0.2],
vec![-0.1, 0.5, 0.3, 0.2, -0.2],
vec![0.3, 0.2, 0.4, -0.1, 0.3],
vec![-0.2, 0.4, -0.3, 0.5, 0.1],
vec![0.5, -0.3, 0.2, 0.1, 0.4],
vec![0.2, 0.3, -0.2, 0.4, -0.1],
vec![-0.3, 0.1, 0.5, 0.2, 0.3],
vec![0.4, -0.1, 0.3, -0.2, 0.5],
],
bias_h: vec![0.1, -0.1, 0.2, -0.2, 0.1, 0.2, -0.1, 0.1],
bias_o: vec![0.0, 0.1, -0.1, 0.0, 0.1],
}
}
fn relu(x: f32) -> f32 {
if x > 0.0 { x } else { 0.0 }
}
fn softmax(inputs: &[f32]) -> Vec<f32> {
let max = inputs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = inputs.iter().map(|&x| libm::expf(x - max)).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&x| x / sum).collect()
}
pub fn predict(&self, features: &[f32; 4]) -> (AccessPattern, f32) {
let mut hidden = [0.0; 8];
for (i, h) in hidden.iter_mut().enumerate() {
let mut sum = self.bias_h[i];
for (j, &feat) in features.iter().enumerate() {
sum += feat * self.weights_ih[j][i];
}
*h = Self::relu(sum);
}
let mut output = vec![0.0; 5];
for (i, out) in output.iter_mut().enumerate() {
let mut sum = self.bias_o[i];
for (j, &h) in hidden.iter().enumerate() {
sum += h * self.weights_ho[j][i];
}
*out = sum;
}
let probabilities = Self::softmax(&output);
debug_assert!(!probabilities.is_empty(), "probabilities must be non-empty");
let max_idx = probabilities
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Less)
})
.map(|(idx, _)| idx)
.unwrap_or(4);
let pattern = match max_idx {
0 => AccessPattern::Sequential,
1 => AccessPattern::ReverseSequential,
2 => AccessPattern::Strided,
3 => AccessPattern::Looping,
_ => AccessPattern::Random,
};
(pattern, probabilities[max_idx])
}
fn extract_features(window: &PatternWindow) -> [f32; 4] {
if window.accesses.is_empty() {
return [0.0, 0.0, 0.0, 0.0];
}
let mut deltas = Vec::new();
for i in 1..window.accesses.len() {
let delta = window.accesses[i].offset as i64 - window.accesses[i - 1].offset as i64;
deltas.push(delta);
}
let avg_delta = if !deltas.is_empty() {
deltas.iter().sum::<i64>() as f32 / deltas.len() as f32
} else {
0.0
};
let variance = if !deltas.is_empty() && deltas.len() > 1 {
let mean = avg_delta;
deltas
.iter()
.map(|&d| {
let diff = d as f32 - mean;
diff * diff
})
.sum::<f32>()
/ (deltas.len() - 1) as f32
} else {
0.0
};
let mut unique_offsets = BTreeMap::new();
for access in &window.accesses {
*unique_offsets.entry(access.offset).or_insert(0) += 1;
}
let unique_ratio = unique_offsets.len() as f32 / window.accesses.len() as f32;
let avg_size = window.accesses.iter().map(|a| a.size).sum::<u64>() as f32
/ window.accesses.len() as f32;
[
avg_delta / 1_000_000.0, libm::sqrtf(variance) / 100_000.0,
unique_ratio,
avg_size / 1_000_000.0,
]
}
}
#[derive(Debug, Clone, Default)]
pub struct PrefetchStats {
pub total_prefetches: u64,
pub prefetch_hits: u64,
pub prefetch_misses: u64,
pub bytes_prefetched: u64,
}
lazy_static! {
static ref ML_PREFETCH: Mutex<MlPrefetchEngine> = Mutex::new(MlPrefetchEngine::new());
}
pub struct MlPrefetchEngine {
window: PatternWindow,
predictor: PatternPredictor,
stats: PrefetchStats,
prefetch_queue: BTreeMap<u64, u64>,
}
impl Default for MlPrefetchEngine {
fn default() -> Self {
Self::new()
}
}
impl MlPrefetchEngine {
pub fn new() -> Self {
Self {
window: PatternWindow::new(16),
predictor: PatternPredictor::new(),
stats: PrefetchStats::default(),
prefetch_queue: BTreeMap::new(),
}
}
pub fn record_and_predict(
&mut self,
offset: u64,
size: u64,
timestamp: u64,
is_read: bool,
) -> Vec<(u64, u64)> {
let record = AccessRecord {
offset,
timestamp,
size,
is_read,
};
self.window.add_access(record);
if self.prefetch_queue.remove(&offset).is_some() {
self.stats.prefetch_hits += 1;
}
if !is_read {
return Vec::new();
}
let (rule_pattern, rule_confidence) = self.window.detect_pattern();
let features = PatternPredictor::extract_features(&self.window);
let (ml_pattern, ml_confidence) = self.predictor.predict(&features);
let (pattern, confidence) = if ml_confidence > rule_confidence {
(ml_pattern, ml_confidence)
} else {
(rule_pattern, rule_confidence)
};
let threshold = pattern.confidence_threshold();
if confidence < threshold {
return Vec::new();
}
let distance = pattern.prefetch_distance();
if distance == 0 {
return Vec::new();
}
let mut prefetches = Vec::new();
match pattern {
AccessPattern::Sequential => {
for i in 1..=distance {
let prefetch_offset = offset + (i as u64 * size);
prefetches.push((prefetch_offset, size));
self.prefetch_queue.insert(prefetch_offset, size);
}
}
AccessPattern::ReverseSequential => {
for i in 1..=distance {
if let Some(prefetch_offset) = offset.checked_sub(i as u64 * size) {
prefetches.push((prefetch_offset, size));
self.prefetch_queue.insert(prefetch_offset, size);
}
}
}
AccessPattern::Strided => {
if self.window.accesses.len() >= 2 {
let stride = self.window.accesses[self.window.accesses.len() - 1].offset as i64
- self.window.accesses[self.window.accesses.len() - 2].offset as i64;
if stride > 0 {
for i in 1..=distance {
let prefetch_offset = (offset as i64 + (i as i64 * stride)) as u64;
prefetches.push((prefetch_offset, size));
self.prefetch_queue.insert(prefetch_offset, size);
}
}
}
}
AccessPattern::Looping => {
let recent_offsets: Vec<u64> = self
.window
.accesses
.iter()
.rev()
.take(distance)
.map(|a| a.offset)
.collect();
for prefetch_offset in recent_offsets {
if prefetch_offset != offset {
prefetches.push((prefetch_offset, size));
self.prefetch_queue.insert(prefetch_offset, size);
}
}
}
AccessPattern::Random => {}
}
self.stats.total_prefetches += prefetches.len() as u64;
self.stats.bytes_prefetched += prefetches.iter().map(|(_, s)| s).sum::<u64>();
prefetches
}
pub fn expire_prefetches(&mut self) {
let expired = self.prefetch_queue.len();
self.stats.prefetch_misses += expired as u64;
self.prefetch_queue.clear();
}
pub fn hit_rate(&self) -> f32 {
let total = self.stats.prefetch_hits + self.stats.prefetch_misses;
if total == 0 {
return 0.0;
}
self.stats.prefetch_hits as f32 / total as f32
}
pub fn get_stats(&self) -> PrefetchStats {
self.stats.clone()
}
}
pub struct MlPrefetch;
impl MlPrefetch {
pub fn predict(offset: u64, size: u64, timestamp: u64, is_read: bool) -> Vec<(u64, u64)> {
let mut engine = ML_PREFETCH.lock();
engine.record_and_predict(offset, size, timestamp, is_read)
}
pub fn hit_rate() -> f32 {
let engine = ML_PREFETCH.lock();
engine.hit_rate()
}
pub fn stats() -> PrefetchStats {
let engine = ML_PREFETCH.lock();
engine.get_stats()
}
pub fn expire() {
let mut engine = ML_PREFETCH.lock();
engine.expire_prefetches();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_properties() {
assert_eq!(AccessPattern::Sequential.prefetch_distance(), 16);
assert!(AccessPattern::Random.prefetch_distance() == 0);
assert!(AccessPattern::Sequential.confidence_threshold() < 0.8);
}
#[test]
fn test_sequential_detection() {
let mut window = PatternWindow::new(16);
for i in 0..10 {
window.add_access(AccessRecord {
offset: i * 4096,
timestamp: i,
size: 4096,
is_read: true,
});
}
let (pattern, confidence) = window.detect_pattern();
assert_eq!(pattern, AccessPattern::Sequential);
assert!(confidence > 0.8);
}
#[test]
fn test_random_detection() {
let mut window = PatternWindow::new(16);
let offsets = [0, 100000, 50000, 200000, 10000, 150000];
for (i, &offset) in offsets.iter().enumerate() {
window.add_access(AccessRecord {
offset,
timestamp: i as u64,
size: 4096,
is_read: true,
});
}
let (pattern, _) = window.detect_pattern();
assert_eq!(pattern, AccessPattern::Random);
}
#[test]
fn test_strided_detection() {
let mut window = PatternWindow::new(16);
for i in 0..10 {
window.add_access(AccessRecord {
offset: i * 8192,
timestamp: i,
size: 4096,
is_read: true,
});
}
let (pattern, confidence) = window.detect_pattern();
assert!(pattern == AccessPattern::Sequential || pattern == AccessPattern::Strided);
assert!(confidence > 0.7);
}
#[test]
fn test_neural_network_prediction() {
let predictor = PatternPredictor::new();
let features = [1.0, 0.1, 0.9, 0.5]; let (pattern, confidence) = predictor.predict(&features);
assert!(confidence > 0.0 && confidence <= 1.0);
assert!(matches!(
pattern,
AccessPattern::Sequential
| AccessPattern::ReverseSequential
| AccessPattern::Strided
| AccessPattern::Random
| AccessPattern::Looping
));
}
#[test]
fn test_prefetch_generation() {
let mut engine = MlPrefetchEngine::new();
for i in 0..10 {
engine.record_and_predict(i * 4096, 4096, i, true);
}
let prefetches = engine.record_and_predict(10 * 4096, 4096, 10, true);
assert!(!prefetches.is_empty());
assert!(prefetches.len() <= 16); }
#[test]
fn test_prefetch_hit_tracking() {
let mut engine = MlPrefetchEngine::new();
for i in 0..5 {
engine.record_and_predict(i * 4096, 4096, i, true);
}
let prefetches = engine.record_and_predict(5 * 4096, 4096, 5, true);
if !prefetches.is_empty() {
let (prefetch_offset, _) = prefetches[0];
engine.record_and_predict(prefetch_offset, 4096, 6, true);
assert!(engine.stats.prefetch_hits > 0);
}
}
#[test]
fn test_no_prefetch_for_writes() {
let mut engine = MlPrefetchEngine::new();
for i in 0..5 {
let prefetches = engine.record_and_predict(i * 4096, 4096, i, false);
assert!(prefetches.is_empty()); }
}
#[test]
fn test_statistics() {
let mut engine = MlPrefetchEngine::new();
for i in 0..10 {
engine.record_and_predict(i * 4096, 4096, i, true);
}
let stats = engine.get_stats();
let _ = stats.total_prefetches;
let _ = stats.bytes_prefetched;
}
}