use crate::Dataset;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use tenflowers_core::{Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AccessPattern {
Sequential { stride: usize },
Random,
Cyclic { pattern: Vec<usize> },
Strided { start: usize, stride: usize },
}
#[derive(Debug, Clone, Default)]
pub struct AccessStats {
pub total_accesses: u64,
pub sequential_accesses: u64,
pub random_accesses: u64,
pub pattern_hits: u64,
pub pattern_misses: u64,
pub prefetch_hits: u64,
pub prefetch_misses: u64,
pub bandwidth_saved: u64, }
impl AccessStats {
pub fn pattern_accuracy(&self) -> f64 {
let total_predictions = self.pattern_hits + self.pattern_misses;
if total_predictions == 0 {
0.0
} else {
self.pattern_hits as f64 / total_predictions as f64
}
}
pub fn prefetch_efficiency(&self) -> f64 {
let total_prefetches = self.prefetch_hits + self.prefetch_misses;
if total_prefetches == 0 {
0.0
} else {
self.prefetch_hits as f64 / total_prefetches as f64
}
}
pub fn sequential_ratio(&self) -> f64 {
if self.total_accesses == 0 {
0.0
} else {
self.sequential_accesses as f64 / self.total_accesses as f64
}
}
}
#[derive(Debug)]
struct PrefetchEntry<T> {
data: (Tensor<T>, Tensor<T>),
timestamp: Instant,
access_count: u32,
}
#[derive(Debug)]
struct PatternDetector {
access_history: VecDeque<usize>,
detected_patterns: HashMap<AccessPattern, f64>,
max_history: usize,
min_pattern_length: usize,
}
impl PatternDetector {
fn new(max_history: usize) -> Self {
Self {
access_history: VecDeque::new(),
detected_patterns: HashMap::new(),
max_history,
min_pattern_length: 3,
}
}
fn record_access(&mut self, index: usize) {
self.access_history.push_back(index);
while self.access_history.len() > self.max_history {
self.access_history.pop_front();
}
self.analyze_patterns();
}
fn analyze_patterns(&mut self) {
if self.access_history.len() < self.min_pattern_length {
return;
}
self.detected_patterns
.retain(|_, confidence| *confidence > 0.1);
self.detect_sequential_pattern();
self.detect_strided_pattern();
self.detect_cyclic_pattern();
}
fn detect_sequential_pattern(&mut self) {
let history: Vec<_> = self.access_history.iter().cloned().collect();
let mut sequential_count = 0;
for window in history.windows(2) {
if window[1] == window[0] + 1 {
sequential_count += 1;
}
}
let confidence = sequential_count as f64 / (history.len() - 1) as f64;
if confidence > 0.7 {
self.detected_patterns
.insert(AccessPattern::Sequential { stride: 1 }, confidence);
}
}
fn detect_strided_pattern(&mut self) {
let history: Vec<_> = self.access_history.iter().cloned().collect();
if history.len() < 3 {
return;
}
for stride in 2..=10 {
let mut matches = 0;
let start = history[0];
for (i, &index) in history.iter().enumerate() {
if index == start + i * stride {
matches += 1;
}
}
let confidence = matches as f64 / history.len() as f64;
if confidence > 0.8 {
self.detected_patterns
.insert(AccessPattern::Strided { start, stride }, confidence);
}
}
}
fn detect_cyclic_pattern(&mut self) {
let history: Vec<_> = self.access_history.iter().cloned().collect();
for pattern_length in self.min_pattern_length..=(history.len() / 2) {
if history.len() < pattern_length * 2 {
continue;
}
let pattern: Vec<_> = history[history.len() - pattern_length..].to_vec();
let mut repeats = 0;
let mut total_checks = 0;
let mut pos = history.len() - pattern_length * 2;
while pos < history.len() - pattern_length {
total_checks += 1;
let segment = &history[pos..pos + pattern_length];
if segment == pattern {
repeats += 1;
}
pos += pattern_length;
}
if total_checks > 0 {
let confidence = repeats as f64 / total_checks as f64;
if confidence > 0.8 {
self.detected_patterns
.insert(AccessPattern::Cyclic { pattern }, confidence);
}
}
}
}
fn predict_next(&self, current_index: usize, count: usize) -> Vec<usize> {
let mut predictions = Vec::new();
if let Some((pattern, _)) = self.detected_patterns.iter().max_by(|a, b| {
a.1.partial_cmp(b.1)
.expect("partial_cmp should not return None for valid values")
}) {
match pattern {
AccessPattern::Sequential { stride } => {
for i in 1..=count {
predictions.push(current_index + i * stride);
}
}
AccessPattern::Strided { start: _, stride } => {
let next_in_sequence = current_index + stride;
predictions.push(next_in_sequence);
for i in 1..count {
predictions.push(next_in_sequence + i * stride);
}
}
AccessPattern::Cyclic { pattern } => {
if let Some(current_pos) = pattern.iter().position(|&x| x == current_index) {
for i in 1..=count {
let next_pos = (current_pos + i) % pattern.len();
predictions.push(pattern[next_pos]);
}
}
}
AccessPattern::Random => {
}
}
}
predictions
}
pub fn dominant_pattern(&self) -> Option<AccessPattern> {
self.detected_patterns
.iter()
.max_by(|a, b| {
a.1.partial_cmp(b.1)
.expect("partial_cmp should not return None for valid values")
})
.map(|(pattern, _)| pattern.clone())
}
}
pub struct PredictivePrefetcher<T, D: Dataset<T>>
where
T: Clone + Send + Sync + 'static,
D: Send + Sync + 'static,
{
dataset: Arc<D>,
pattern_detector: Arc<RwLock<PatternDetector>>,
prefetch_cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
config: PrefetchConfig,
worker_handle: Option<JoinHandle<()>>,
shutdown_signal: Arc<AtomicBool>,
stats: Arc<RwLock<AccessStats>>,
prefetch_queue: Arc<Mutex<VecDeque<usize>>>,
}
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub max_prefetch_count: usize,
pub max_cache_size: usize,
pub pattern_history_size: usize,
pub cache_ttl: Duration,
pub worker_sleep_duration: Duration,
pub bandwidth_optimization: bool,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
max_prefetch_count: 8,
max_cache_size: 128,
pattern_history_size: 50,
cache_ttl: Duration::from_secs(300), worker_sleep_duration: Duration::from_millis(10),
bandwidth_optimization: true,
}
}
}
impl<T, D> PredictivePrefetcher<T, D>
where
T: Clone + Send + Sync + 'static,
D: Dataset<T> + Send + Sync + 'static,
{
pub fn new(dataset: Arc<D>) -> Self {
Self::with_config(dataset, PrefetchConfig::default())
}
pub fn with_config(dataset: Arc<D>, config: PrefetchConfig) -> Self {
let pattern_detector = Arc::new(RwLock::new(PatternDetector::new(
config.pattern_history_size,
)));
let prefetch_cache = Arc::new(RwLock::new(HashMap::new()));
let shutdown_signal = Arc::new(AtomicBool::new(false));
let stats = Arc::new(RwLock::new(AccessStats::default()));
let prefetch_queue = Arc::new(Mutex::new(VecDeque::new()));
let worker_handle = Self::start_prefetch_worker(
dataset.clone(),
prefetch_cache.clone(),
prefetch_queue.clone(),
shutdown_signal.clone(),
config.clone(),
stats.clone(),
);
Self {
dataset,
pattern_detector,
prefetch_cache,
config,
worker_handle: Some(worker_handle),
shutdown_signal,
stats,
prefetch_queue,
}
}
fn start_prefetch_worker(
dataset: Arc<D>,
cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
queue: Arc<Mutex<VecDeque<usize>>>,
shutdown: Arc<AtomicBool>,
config: PrefetchConfig,
stats: Arc<RwLock<AccessStats>>,
) -> JoinHandle<()> {
thread::spawn(move || {
while !shutdown.load(Ordering::Relaxed) {
let indices_to_prefetch: Vec<usize> = {
let mut queue_guard = queue.lock().expect("lock should not be poisoned");
let mut indices = Vec::new();
for _ in 0..config.max_prefetch_count {
if let Some(index) = queue_guard.pop_front() {
indices.push(index);
} else {
break;
}
}
indices
};
for index in indices_to_prefetch {
if let Ok(data) = dataset.get(index) {
let mut cache_guard =
cache.write().expect("write lock should not be poisoned");
if cache_guard.len() >= config.max_cache_size {
let oldest_key = cache_guard
.iter()
.min_by_key(|(_, entry)| entry.timestamp)
.map(|(k, _)| *k);
if let Some(key) = oldest_key {
cache_guard.remove(&key);
}
}
cache_guard.insert(
index,
PrefetchEntry {
data,
timestamp: Instant::now(),
access_count: 0,
},
);
let mut stats_guard =
stats.write().expect("write lock should not be poisoned");
stats_guard.bandwidth_saved +=
std::mem::size_of::<(Tensor<T>, Tensor<T>)>() as u64;
}
}
{
let mut cache_guard = cache.write().expect("write lock should not be poisoned");
let now = Instant::now();
cache_guard
.retain(|_, entry| now.duration_since(entry.timestamp) < config.cache_ttl);
}
thread::sleep(config.worker_sleep_duration);
}
})
}
pub fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
{
let mut stats = self
.stats
.write()
.expect("write lock should not be poisoned");
stats.total_accesses += 1;
}
{
let mut detector = self
.pattern_detector
.write()
.expect("write lock should not be poisoned");
detector.record_access(index);
}
{
let mut cache = self
.prefetch_cache
.write()
.expect("write lock should not be poisoned");
if let Some(entry) = cache.get_mut(&index) {
entry.access_count += 1;
entry.timestamp = Instant::now();
let mut stats = self
.stats
.write()
.expect("write lock should not be poisoned");
stats.prefetch_hits += 1;
return Ok(entry.data.clone());
} else {
let mut stats = self
.stats
.write()
.expect("write lock should not be poisoned");
stats.prefetch_misses += 1;
}
}
self.predict_and_queue_prefetch(index);
self.dataset.get(index)
}
fn predict_and_queue_prefetch(&self, current_index: usize) {
let predictions = {
let detector = self
.pattern_detector
.read()
.expect("read lock should not be poisoned");
detector.predict_next(current_index, self.config.max_prefetch_count)
};
if !predictions.is_empty() {
let mut queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
for predicted_index in predictions {
let cache = self
.prefetch_cache
.read()
.expect("read lock should not be poisoned");
if !cache.contains_key(&predicted_index) {
queue.push_back(predicted_index);
}
}
let mut stats = self
.stats
.write()
.expect("write lock should not be poisoned");
stats.pattern_hits += 1;
} else {
let mut stats = self
.stats
.write()
.expect("write lock should not be poisoned");
stats.pattern_misses += 1;
}
}
pub fn stats(&self) -> AccessStats {
self.stats
.read()
.expect("read lock should not be poisoned")
.clone()
}
pub fn dominant_pattern(&self) -> Option<AccessPattern> {
self.pattern_detector
.read()
.expect("read lock should not be poisoned")
.dominant_pattern()
}
pub fn clear_cache(&self) {
let mut cache = self
.prefetch_cache
.write()
.expect("write lock should not be poisoned");
cache.clear();
}
pub fn cache_info(&self) -> (usize, usize) {
let cache = self
.prefetch_cache
.read()
.expect("read lock should not be poisoned");
(cache.len(), self.config.max_cache_size)
}
}
impl<T, D> Drop for PredictivePrefetcher<T, D>
where
T: Clone + Send + Sync + 'static,
D: Dataset<T> + Send + Sync + 'static,
{
fn drop(&mut self) {
self.shutdown_signal.store(true, Ordering::Relaxed);
if let Some(handle) = self.worker_handle.take() {
let _ = handle.join();
}
}
}
pub struct PredictivePrefetchDataset<T, D: Dataset<T>>
where
T: Clone + Send + Sync + 'static,
D: Send + Sync + 'static,
{
prefetcher: PredictivePrefetcher<T, D>,
}
impl<T, D> PredictivePrefetchDataset<T, D>
where
T: Clone + Send + Sync + 'static,
D: Dataset<T> + Send + Sync + 'static,
{
pub fn new(dataset: D) -> Self {
Self {
prefetcher: PredictivePrefetcher::new(Arc::new(dataset)),
}
}
pub fn with_config(dataset: D, config: PrefetchConfig) -> Self {
Self {
prefetcher: PredictivePrefetcher::with_config(Arc::new(dataset), config),
}
}
pub fn stats(&self) -> AccessStats {
self.prefetcher.stats()
}
pub fn dominant_pattern(&self) -> Option<AccessPattern> {
self.prefetcher.dominant_pattern()
}
}
impl<T, D> Dataset<T> for PredictivePrefetchDataset<T, D>
where
T: Clone + Send + Sync + 'static,
D: Dataset<T> + Send + Sync + 'static,
{
fn len(&self) -> usize {
self.prefetcher.dataset.len()
}
fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
self.prefetcher.get(index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TensorDataset;
use tenflowers_core::Tensor;
#[test]
fn test_pattern_detector_sequential() {
let mut detector = PatternDetector::new(10);
for i in 0..5 {
detector.record_access(i);
}
let dominant = detector.dominant_pattern();
assert!(matches!(
dominant,
Some(AccessPattern::Sequential { stride: 1 })
));
}
#[test]
fn test_pattern_detector_strided() {
let mut detector = PatternDetector::new(10);
for i in 0..5 {
detector.record_access(i * 2);
}
let dominant = detector.dominant_pattern();
assert!(matches!(
dominant,
Some(AccessPattern::Strided {
start: 0,
stride: 2
})
));
}
#[test]
fn test_predictive_prefetcher() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
.expect("test: tensor creation should succeed");
let dataset = Arc::new(TensorDataset::new(features, labels));
let config = PrefetchConfig {
max_prefetch_count: 2,
max_cache_size: 10,
pattern_history_size: 10,
cache_ttl: Duration::from_secs(60),
worker_sleep_duration: Duration::from_millis(1),
bandwidth_optimization: true,
};
let prefetcher = PredictivePrefetcher::with_config(dataset, config);
let _ = prefetcher.get(0).expect("index should be in bounds");
let _ = prefetcher.get(1).expect("index should be in bounds");
let _ = prefetcher.get(2).expect("index should be in bounds");
thread::sleep(Duration::from_millis(50));
let stats = prefetcher.stats();
assert!(stats.total_accesses >= 3);
}
#[test]
fn test_predictive_prefetch_dataset() {
let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
.expect("test: tensor creation should succeed");
let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
.expect("test: tensor creation should succeed");
let base_dataset = TensorDataset::new(features, labels);
let dataset = PredictivePrefetchDataset::new(base_dataset);
assert_eq!(dataset.len(), 2);
let (feat, label) = dataset.get(0).expect("index should be in bounds");
assert_eq!(feat.shape().dims(), &[2]);
assert_eq!(label.shape().dims(), &[] as &[usize]);
let stats = dataset.stats();
assert_eq!(stats.total_accesses, 1);
}
}