use crate::PrefetchStrategy;
use super::{BenchmarkablePrefetch, PrefetchType};
#[derive(Debug, Clone)]
pub struct SequentialPrefetch<K>
where
K: Clone,
{
last_key: Option<K>,
stride: Option<i64>,
prefetch_distance: usize,
max_predictions: usize,
confidence: f64,
consecutive_hits: usize,
}
impl<K> SequentialPrefetch<K>
where
K: Clone,
{
pub fn new() -> Self {
Self::with_config(2, 3, 0.5)
}
pub fn with_config(
prefetch_distance: usize,
max_predictions: usize,
min_confidence: f64
) -> Self {
Self {
last_key: None,
stride: None,
prefetch_distance,
max_predictions,
confidence: min_confidence,
consecutive_hits: 0,
}
}
pub fn current_stride(&self) -> Option<i64> {
self.stride
}
pub fn confidence(&self) -> f64 {
self.confidence
}
}
impl<K> Default for SequentialPrefetch<K>
where
K: Clone,
{
fn default() -> Self {
Self::new()
}
}
impl PrefetchStrategy<i32> for SequentialPrefetch<i32> {
fn predict_next(&mut self, accessed_key: &i32) -> Vec<i32> {
if self.confidence < 0.5 {
return Vec::new();
}
let stride = self.stride.unwrap_or(1) as i32;
let mut predictions = Vec::with_capacity(self.max_predictions);
for i in 1..=self.max_predictions {
if predictions.len() >= self.prefetch_distance {
break;
}
let next_key = accessed_key + (stride * i as i32);
predictions.push(next_key);
}
predictions
}
fn update_access_pattern(&mut self, key: &i32) {
if let Some(last_key) = self.last_key {
let new_stride = (*key as i64) - (last_key as i64);
match self.stride {
Some(current_stride) => {
if new_stride == current_stride {
self.consecutive_hits += 1;
self.confidence = (self.confidence + 0.1).min(1.0);
} else {
self.consecutive_hits = 0;
self.confidence = (self.confidence - 0.2).max(0.0);
self.stride = Some(new_stride);
}
},
None => {
self.stride = Some(new_stride);
self.confidence = 0.3; }
}
}
self.last_key = Some(*key);
}
fn reset(&mut self) {
self.last_key = None;
self.stride = None;
self.confidence = 0.5;
self.consecutive_hits = 0;
}
}
impl PrefetchStrategy<i64> for SequentialPrefetch<i64> {
fn predict_next(&mut self, accessed_key: &i64) -> Vec<i64> {
if self.confidence < 0.5 {
return Vec::new();
}
let stride = self.stride.unwrap_or(1);
let mut predictions = Vec::with_capacity(self.max_predictions);
for i in 1..=self.max_predictions {
if predictions.len() >= self.prefetch_distance {
break;
}
let next_key = accessed_key + (stride * i as i64);
predictions.push(next_key);
}
predictions
}
fn update_access_pattern(&mut self, key: &i64) {
if let Some(last_key) = self.last_key {
let new_stride = *key - last_key;
match self.stride {
Some(current_stride) => {
if new_stride == current_stride {
self.consecutive_hits += 1;
self.confidence = (self.confidence + 0.1).min(1.0);
} else {
self.consecutive_hits = 0;
self.confidence = (self.confidence - 0.2).max(0.0);
self.stride = Some(new_stride);
}
},
None => {
self.stride = Some(new_stride);
self.confidence = 0.3;
}
}
}
self.last_key = Some(*key);
}
fn reset(&mut self) {
self.last_key = None;
self.stride = None;
self.confidence = 0.5;
self.consecutive_hits = 0;
}
}
impl PrefetchStrategy<usize> for SequentialPrefetch<usize> {
fn predict_next(&mut self, accessed_key: &usize) -> Vec<usize> {
if self.confidence < 0.5 {
return Vec::new();
}
let stride = self.stride.unwrap_or(1).max(1) as usize; let mut predictions = Vec::with_capacity(self.max_predictions);
for i in 1..=self.max_predictions {
if predictions.len() >= self.prefetch_distance {
break;
}
if let Some(next_key) = accessed_key.checked_add(stride * i) {
predictions.push(next_key);
}
}
predictions
}
fn update_access_pattern(&mut self, key: &usize) {
if let Some(last_key) = self.last_key {
let new_stride = (*key as i64) - (last_key as i64);
match self.stride {
Some(current_stride) => {
if new_stride == current_stride && new_stride > 0 {
self.consecutive_hits += 1;
self.confidence = (self.confidence + 0.1).min(1.0);
} else {
self.consecutive_hits = 0;
self.confidence = (self.confidence - 0.2).max(0.0);
if new_stride > 0 {
self.stride = Some(new_stride);
}
}
},
None => {
if new_stride > 0 {
self.stride = Some(new_stride);
self.confidence = 0.3;
}
}
}
}
self.last_key = Some(*key);
}
fn reset(&mut self) {
self.last_key = None;
self.stride = None;
self.confidence = 0.5;
self.consecutive_hits = 0;
}
}
impl BenchmarkablePrefetch<i32> for SequentialPrefetch<i32> {
fn prefetch_type(&self) -> PrefetchType {
PrefetchType::Sequential
}
}
impl BenchmarkablePrefetch<i64> for SequentialPrefetch<i64> {
fn prefetch_type(&self) -> PrefetchType {
PrefetchType::Sequential
}
}
impl BenchmarkablePrefetch<usize> for SequentialPrefetch<usize> {
fn prefetch_type(&self) -> PrefetchType {
PrefetchType::Sequential
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sequential_stride_detection() {
let mut strategy = SequentialPrefetch::<i32>::new();
strategy.update_access_pattern(&0);
strategy.update_access_pattern(&2);
strategy.update_access_pattern(&4);
strategy.update_access_pattern(&6);
assert_eq!(strategy.current_stride(), Some(2));
let predictions = strategy.predict_next(&8);
assert_eq!(predictions[0], 10); assert_eq!(predictions[1], 12); }
#[test]
fn test_sequential_pattern_break() {
let mut strategy = SequentialPrefetch::<i32>::new();
strategy.update_access_pattern(&1);
strategy.update_access_pattern(&2);
strategy.update_access_pattern(&3);
let initial_confidence = strategy.confidence();
strategy.update_access_pattern(&10);
assert!(strategy.confidence() < initial_confidence);
}
#[test]
fn test_sequential_usize_overflow_protection() {
let mut strategy = SequentialPrefetch::<usize>::new();
let large_key = usize::MAX - 1;
strategy.update_access_pattern(&(large_key - 2));
strategy.update_access_pattern(&(large_key - 1));
strategy.update_access_pattern(&large_key);
let predictions = strategy.predict_next(&large_key);
assert!(predictions.len() <= 1);
}
#[test]
fn test_sequential_reset() {
let mut strategy = SequentialPrefetch::<i32>::new();
strategy.update_access_pattern(&1);
strategy.update_access_pattern(&2);
strategy.update_access_pattern(&3);
strategy.reset();
assert_eq!(strategy.current_stride(), None);
assert_eq!(strategy.confidence(), 0.5);
}
#[test]
fn test_sequential_negative_stride() {
let mut strategy = SequentialPrefetch::<i32>::new();
strategy.update_access_pattern(&10);
strategy.update_access_pattern(&8);
strategy.update_access_pattern(&6);
assert_eq!(strategy.current_stride(), Some(-2));
let predictions = strategy.predict_next(&4);
if !predictions.is_empty() {
assert_eq!(predictions[0], 2); }
}
}