use crate::Dataset;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant};
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
pub struct StreamPrefetchOptimizer<T>
where
T: Clone,
{
config: PrefetchOptimizerConfig,
pattern_analyzer: Arc<Mutex<AccessPatternAnalyzer>>,
prefetch_buffer: Arc<RwLock<PrefetchBuffer<T>>>,
metrics: Arc<Mutex<PrefetchMetrics>>,
worker_handles: Vec<thread::JoinHandle<()>>,
shutdown: Arc<AtomicBool>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct PrefetchOptimizerConfig {
pub max_buffer_size: usize,
pub worker_count: usize,
pub prediction_confidence_threshold: f64,
pub learning_rate: f64,
pub max_lookahead_distance: usize,
pub adaptive_buffer_resizing: bool,
pub buffer_resize_factor: f64,
pub min_buffer_utilization: f64,
pub pattern_window_size: usize,
pub cross_epoch_learning: bool,
}
impl Default for PrefetchOptimizerConfig {
fn default() -> Self {
Self {
max_buffer_size: 1000,
worker_count: 2,
prediction_confidence_threshold: 0.7,
learning_rate: 0.1,
max_lookahead_distance: 100,
adaptive_buffer_resizing: true,
buffer_resize_factor: 1.5,
min_buffer_utilization: 0.3,
pattern_window_size: 500,
cross_epoch_learning: true,
}
}
}
#[derive(Debug)]
pub struct AccessPatternAnalyzer {
access_history: VecDeque<AccessEvent>,
patterns: HashMap<PatternSignature, PatternPrediction>,
detection_state: PatternDetectionState,
config: PrefetchOptimizerConfig,
}
#[derive(Debug, Clone)]
pub struct AccessEvent {
pub index: usize,
pub timestamp: Instant,
pub access_type: AccessType,
pub context: AccessContext,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AccessType {
Sequential,
Random,
Strided { stride: usize },
Repetitive { cycle_length: usize },
}
#[derive(Debug, Clone)]
pub struct AccessContext {
pub epoch: Option<usize>,
pub batch_index: Option<usize>,
pub worker_id: Option<usize>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct PatternSignature {
pub pattern_type: PatternType,
pub window_hash: u64,
pub context_hash: u64,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum PatternType {
Sequential,
Strided,
Cyclic,
RandomWalk,
HotSpot,
}
#[derive(Debug, Clone)]
pub struct PatternPrediction {
pub next_indices: Vec<usize>,
pub confidence: f64,
pub last_updated: Instant,
pub usage_count: usize,
pub accuracy_history: VecDeque<bool>,
}
#[derive(Debug)]
pub struct PatternDetectionState {
pub current_sequence: VecDeque<usize>,
pub stride_detector: StrideDetector,
pub cycle_detector: CycleDetector,
pub hotspot_detector: HotspotDetector,
}
#[derive(Debug)]
pub struct StrideDetector {
pub candidate_strides: HashMap<usize, usize>, pub min_sequence_length: usize,
}
#[derive(Debug)]
pub struct CycleDetector {
pub candidate_cycles: HashMap<Vec<usize>, usize>, pub max_cycle_length: usize,
}
#[derive(Debug)]
pub struct HotspotDetector {
pub access_counts: HashMap<usize, usize>, pub temporal_windows: VecDeque<HashMap<usize, usize>>,
pub window_size: Duration,
}
#[derive(Debug)]
pub struct PrefetchBuffer<T>
where
T: Clone,
{
buffer: HashMap<usize, BufferedSample<T>>,
access_order: VecDeque<usize>,
current_size: AtomicUsize,
max_size: usize,
utilization_stats: UtilizationStats,
}
#[derive(Debug)]
pub struct BufferedSample<T>
where
T: Clone,
{
pub data: (Tensor<T>, Tensor<T>),
pub load_time: Instant,
pub access_count: usize,
pub prediction_confidence: f64,
}
#[derive(Debug, Default)]
pub struct UtilizationStats {
pub hit_count: AtomicUsize,
pub miss_count: AtomicUsize,
pub eviction_count: AtomicUsize,
pub total_requests: AtomicUsize,
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct PrefetchMetrics {
pub hit_rate: f64,
pub prediction_accuracy: f64,
pub buffer_utilization: f64,
pub average_latency_us: f64,
pub patterns_learned: usize,
pub prefetch_efficiency: f64,
pub bandwidth_utilization: f64,
pub memory_overhead: f64,
}
impl<T> StreamPrefetchOptimizer<T>
where
T: Clone + Default + Send + Sync + 'static,
{
pub fn new(config: PrefetchOptimizerConfig) -> Self {
let pattern_analyzer = Arc::new(Mutex::new(AccessPatternAnalyzer::new(config.clone())));
let prefetch_buffer = Arc::new(RwLock::new(PrefetchBuffer::new(config.max_buffer_size)));
let metrics = Arc::new(Mutex::new(PrefetchMetrics::default()));
let shutdown = Arc::new(AtomicBool::new(false));
Self {
config,
pattern_analyzer,
prefetch_buffer,
metrics,
worker_handles: Vec::new(),
shutdown,
}
}
pub fn start<D>(&mut self, dataset: Arc<D>) -> Result<()>
where
D: Dataset<T> + Send + Sync + 'static,
{
for worker_id in 0..self.config.worker_count {
let dataset_clone = Arc::clone(&dataset);
let pattern_analyzer = Arc::clone(&self.pattern_analyzer);
let prefetch_buffer = Arc::clone(&self.prefetch_buffer);
let metrics = Arc::clone(&self.metrics);
let shutdown = Arc::clone(&self.shutdown);
let config = self.config.clone();
let handle = thread::spawn(move || {
Self::prefetch_worker(
worker_id,
dataset_clone,
pattern_analyzer,
prefetch_buffer,
metrics,
shutdown,
config,
);
});
self.worker_handles.push(handle);
}
Ok(())
}
pub fn get(&self, index: usize, context: AccessContext) -> Result<(Tensor<T>, Tensor<T>)> {
let start_time = Instant::now();
self.record_access(index, context.clone());
if let Some(sample) = self.get_from_buffer(index) {
self.update_hit_metrics(start_time);
return Ok(sample.data);
}
self.update_miss_metrics(start_time);
Err(TensorError::invalid_argument(format!(
"Data not available in prefetch buffer for index {index}"
)))
}
fn record_access(&self, index: usize, context: AccessContext) {
let event = AccessEvent {
index,
timestamp: Instant::now(),
access_type: AccessType::Sequential, context,
};
if let Ok(mut analyzer) = self.pattern_analyzer.lock() {
analyzer.record_access(event);
}
}
fn get_from_buffer(&self, index: usize) -> Option<BufferedSample<T>> {
if let Ok(mut buffer) = self.prefetch_buffer.write() {
buffer.get_sample(index)
} else {
None
}
}
fn update_hit_metrics(&self, start_time: Instant) {
let latency = start_time.elapsed().as_micros() as f64;
if let Ok(mut metrics) = self.metrics.lock() {
let total_requests = metrics.hit_rate + metrics.prediction_accuracy + 1.0;
metrics.hit_rate = (metrics.hit_rate * (total_requests - 1.0) + 1.0) / total_requests;
metrics.average_latency_us =
(metrics.average_latency_us * (total_requests - 1.0) + latency) / total_requests;
}
}
fn update_miss_metrics(&self, start_time: Instant) {
let latency = start_time.elapsed().as_micros() as f64;
if let Ok(mut metrics) = self.metrics.lock() {
let total_requests = metrics.hit_rate + metrics.prediction_accuracy + 1.0;
metrics.hit_rate = (metrics.hit_rate * (total_requests - 1.0)) / total_requests;
metrics.average_latency_us =
(metrics.average_latency_us * (total_requests - 1.0) + latency) / total_requests;
}
}
fn prefetch_worker<D>(
worker_id: usize,
dataset: Arc<D>,
pattern_analyzer: Arc<Mutex<AccessPatternAnalyzer>>,
prefetch_buffer: Arc<RwLock<PrefetchBuffer<T>>>,
_metrics: Arc<Mutex<PrefetchMetrics>>,
shutdown: Arc<AtomicBool>,
_config: PrefetchOptimizerConfig,
) where
D: Dataset<T> + Send + Sync + 'static,
{
while !shutdown.load(Ordering::Relaxed) {
let predictions = if let Ok(analyzer) = pattern_analyzer.lock() {
analyzer.get_predictions()
} else {
Vec::new()
};
for prediction in predictions {
for &index in &prediction.next_indices {
if index < dataset.len() {
if let Ok(sample) = dataset.get(index) {
let buffered_sample = BufferedSample {
data: sample,
load_time: Instant::now(),
access_count: 0,
prediction_confidence: prediction.confidence,
};
if let Ok(mut buffer) = prefetch_buffer.write() {
buffer.add_sample(index, buffered_sample);
}
}
}
}
}
thread::sleep(Duration::from_millis(10));
}
println!("Prefetch worker {worker_id} shutting down");
}
pub fn get_metrics(&self) -> PrefetchMetrics {
if let Ok(metrics) = self.metrics.lock() {
PrefetchMetrics {
hit_rate: metrics.hit_rate,
prediction_accuracy: metrics.prediction_accuracy,
buffer_utilization: metrics.buffer_utilization,
average_latency_us: metrics.average_latency_us,
patterns_learned: metrics.patterns_learned,
prefetch_efficiency: metrics.prefetch_efficiency,
bandwidth_utilization: metrics.bandwidth_utilization,
memory_overhead: metrics.memory_overhead,
}
} else {
PrefetchMetrics::default()
}
}
pub fn stop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
while let Some(handle) = self.worker_handles.pop() {
let _ = handle.join();
}
}
}
impl AccessPatternAnalyzer {
fn new(config: PrefetchOptimizerConfig) -> Self {
Self {
access_history: VecDeque::with_capacity(config.pattern_window_size),
patterns: HashMap::new(),
detection_state: PatternDetectionState::new(),
config,
}
}
fn record_access(&mut self, event: AccessEvent) {
self.access_history.push_back(event.clone());
if self.access_history.len() > self.config.pattern_window_size {
self.access_history.pop_front();
}
self.detection_state.current_sequence.push_back(event.index);
if self.detection_state.current_sequence.len() > 100 {
self.detection_state.current_sequence.pop_front();
}
self.analyze_patterns();
}
fn analyze_patterns(&mut self) {
self.detect_sequential_patterns();
self.detect_strided_patterns();
self.detect_cyclic_patterns();
self.detect_hotspot_patterns();
}
fn detect_sequential_patterns(&mut self) {
if self.access_history.len() < 3 {
return;
}
let recent_accesses: Vec<usize> = self
.access_history
.iter()
.rev()
.take(10)
.map(|event| event.index)
.collect();
let mut sequential_count = 0;
for window in recent_accesses.windows(2) {
if window[1] == window[0] + 1 {
sequential_count += 1;
}
}
if sequential_count >= 5 {
let signature = PatternSignature {
pattern_type: PatternType::Sequential,
window_hash: self.hash_sequence(&recent_accesses),
context_hash: 0, };
let next_index = recent_accesses[0] + 1;
let prediction = PatternPrediction {
next_indices: vec![next_index, next_index + 1, next_index + 2],
confidence: 0.9,
last_updated: Instant::now(),
usage_count: 1,
accuracy_history: VecDeque::new(),
};
self.patterns.insert(signature, prediction);
}
}
fn detect_strided_patterns(&mut self) {
self.detection_state
.stride_detector
.analyze(&self.access_history);
}
fn detect_cyclic_patterns(&mut self) {
self.detection_state
.cycle_detector
.analyze(&self.access_history);
}
fn detect_hotspot_patterns(&mut self) {
self.detection_state
.hotspot_detector
.analyze(&self.access_history);
}
fn get_predictions(&self) -> Vec<PatternPrediction> {
self.patterns
.values()
.filter(|p| p.confidence >= self.config.prediction_confidence_threshold)
.cloned()
.collect()
}
fn hash_sequence(&self, sequence: &[usize]) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
sequence.hash(&mut hasher);
hasher.finish()
}
}
impl<T> PrefetchBuffer<T>
where
T: Clone,
{
fn new(max_size: usize) -> Self {
Self {
buffer: HashMap::new(),
access_order: VecDeque::new(),
current_size: AtomicUsize::new(0),
max_size,
utilization_stats: UtilizationStats::default(),
}
}
fn add_sample(&mut self, index: usize, sample: BufferedSample<T>) {
if self.current_size.load(Ordering::Relaxed) >= self.max_size {
self.evict_lru();
}
self.buffer.insert(index, sample);
self.access_order.push_back(index);
self.current_size.fetch_add(1, Ordering::Relaxed);
}
fn get_sample(&mut self, index: usize) -> Option<BufferedSample<T>> {
if let Some(mut sample) = self.buffer.remove(&index) {
sample.access_count += 1;
if let Some(pos) = self.access_order.iter().position(|&x| x == index) {
self.access_order.remove(pos);
self.access_order.push_back(index);
}
let updated_sample = BufferedSample {
data: sample.data.clone(),
load_time: sample.load_time,
access_count: sample.access_count,
prediction_confidence: sample.prediction_confidence,
};
self.buffer.insert(index, updated_sample);
self.utilization_stats
.hit_count
.fetch_add(1, Ordering::Relaxed);
self.utilization_stats
.total_requests
.fetch_add(1, Ordering::Relaxed);
Some(sample)
} else {
self.utilization_stats
.miss_count
.fetch_add(1, Ordering::Relaxed);
self.utilization_stats
.total_requests
.fetch_add(1, Ordering::Relaxed);
None
}
}
fn evict_lru(&mut self) {
if let Some(lru_index) = self.access_order.pop_front() {
self.buffer.remove(&lru_index);
self.current_size.fetch_sub(1, Ordering::Relaxed);
self.utilization_stats
.eviction_count
.fetch_add(1, Ordering::Relaxed);
}
}
}
impl PatternDetectionState {
fn new() -> Self {
Self {
current_sequence: VecDeque::new(),
stride_detector: StrideDetector::new(),
cycle_detector: CycleDetector::new(),
hotspot_detector: HotspotDetector::new(),
}
}
}
impl StrideDetector {
fn new() -> Self {
Self {
candidate_strides: HashMap::new(),
min_sequence_length: 5,
}
}
fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
if access_history.len() < self.min_sequence_length {
return;
}
let indices: Vec<usize> = access_history.iter().map(|e| e.index).collect();
for window_size in 3..=self.min_sequence_length {
if indices.len() >= window_size {
let window = &indices[indices.len() - window_size..];
if let Some(stride) = self.detect_stride(window) {
*self.candidate_strides.entry(stride).or_insert(0) += 1;
}
}
}
}
fn detect_stride(&self, window: &[usize]) -> Option<usize> {
if window.len() < 3 {
return None;
}
let first_diff = window[1] as i64 - window[0] as i64;
for i in 2..window.len() {
let diff = window[i] as i64 - window[i - 1] as i64;
if diff != first_diff {
return None;
}
}
if first_diff > 0 {
Some(first_diff as usize)
} else {
None
}
}
}
impl CycleDetector {
fn new() -> Self {
Self {
candidate_cycles: HashMap::new(),
max_cycle_length: 20,
}
}
fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
let indices: Vec<usize> = access_history.iter().map(|e| e.index).collect();
for cycle_len in 2..=self.max_cycle_length.min(indices.len() / 2) {
if indices.len() >= cycle_len * 2 {
let potential_cycle = &indices[indices.len() - cycle_len..];
let prev_cycle = &indices[indices.len() - cycle_len * 2..indices.len() - cycle_len];
if potential_cycle == prev_cycle {
*self
.candidate_cycles
.entry(potential_cycle.to_vec())
.or_insert(0) += 1;
}
}
}
}
}
impl HotspotDetector {
fn new() -> Self {
Self {
access_counts: HashMap::new(),
temporal_windows: VecDeque::new(),
window_size: Duration::from_secs(60),
}
}
fn analyze(&mut self, access_history: &VecDeque<AccessEvent>) {
for event in access_history {
*self.access_counts.entry(event.index).or_insert(0) += 1;
}
if let Some(latest_event) = access_history.back() {
let cutoff_time = latest_event.timestamp - self.window_size;
while let Some(front_window) = self.temporal_windows.front() {
if front_window.is_empty() {
self.temporal_windows.pop_front();
} else {
break;
}
}
let mut recent_window = HashMap::new();
for event in access_history {
if event.timestamp >= cutoff_time {
*recent_window.entry(event.index).or_insert(0) += 1;
}
}
if !recent_window.is_empty() {
self.temporal_windows.push_back(recent_window);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tenflowers_core::Tensor;
#[test]
fn test_optimizer_creation() {
let config = PrefetchOptimizerConfig::default();
let optimizer: StreamPrefetchOptimizer<f32> = StreamPrefetchOptimizer::new(config);
assert_eq!(optimizer.config.max_buffer_size, 1000);
assert_eq!(optimizer.config.worker_count, 2);
}
#[test]
fn test_access_pattern_analyzer() {
let config = PrefetchOptimizerConfig {
prediction_confidence_threshold: 0.5, ..Default::default()
};
let mut analyzer = AccessPatternAnalyzer::new(config);
for i in 0..15 {
let event = AccessEvent {
index: i,
timestamp: Instant::now(),
access_type: AccessType::Sequential,
context: AccessContext {
epoch: Some(0),
batch_index: Some(i / 4),
worker_id: Some(0),
},
};
analyzer.record_access(event);
}
let _predictions = analyzer.get_predictions();
assert!(analyzer.access_history.len() == 15);
}
#[test]
fn test_prefetch_buffer() {
let mut buffer: PrefetchBuffer<f32> = PrefetchBuffer::new(5);
let sample_data = (
Tensor::from_vec(vec![1.0, 2.0], &[2]).expect("test: tensor creation should succeed"),
Tensor::from_vec(vec![0.0], &[1]).expect("test: tensor creation should succeed"),
);
let buffered_sample = BufferedSample {
data: sample_data,
load_time: Instant::now(),
access_count: 0,
prediction_confidence: 0.8,
};
buffer.add_sample(0, buffered_sample);
assert_eq!(buffer.current_size.load(Ordering::Relaxed), 1);
let retrieved = buffer.get_sample(0);
assert!(retrieved.is_some());
assert_eq!(
retrieved
.expect("test: operation should succeed")
.access_count,
1
);
}
#[test]
fn test_stride_detector() {
let mut detector = StrideDetector::new();
let events: Vec<AccessEvent> = (0..10)
.map(|i| AccessEvent {
index: i * 3, timestamp: Instant::now(),
access_type: AccessType::Sequential,
context: AccessContext {
epoch: Some(0),
batch_index: None,
worker_id: None,
},
})
.collect();
let access_history: VecDeque<AccessEvent> = events.into();
detector.analyze(&access_history);
assert!(detector.candidate_strides.contains_key(&3));
}
#[test]
fn test_metrics_tracking() {
let config = PrefetchOptimizerConfig::default();
let optimizer: StreamPrefetchOptimizer<f32> = StreamPrefetchOptimizer::new(config);
let metrics = optimizer.get_metrics();
assert_eq!(metrics.hit_rate, 0.0);
assert_eq!(metrics.patterns_learned, 0);
}
}