1use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10use std::thread::{self, JoinHandle};
11use std::time::{Duration, Instant};
12use tenflowers_core::{Result, Tensor};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AccessPattern {
17 Sequential { stride: usize },
19 Random,
21 Cyclic { pattern: Vec<usize> },
23 Strided { start: usize, stride: usize },
25}
26
27#[derive(Debug, Clone, Default)]
29pub struct AccessStats {
30 pub total_accesses: u64,
31 pub sequential_accesses: u64,
32 pub random_accesses: u64,
33 pub pattern_hits: u64,
34 pub pattern_misses: u64,
35 pub prefetch_hits: u64,
36 pub prefetch_misses: u64,
37 pub bandwidth_saved: u64, }
39
40impl AccessStats {
41 pub fn pattern_accuracy(&self) -> f64 {
43 let total_predictions = self.pattern_hits + self.pattern_misses;
44 if total_predictions == 0 {
45 0.0
46 } else {
47 self.pattern_hits as f64 / total_predictions as f64
48 }
49 }
50
51 pub fn prefetch_efficiency(&self) -> f64 {
53 let total_prefetches = self.prefetch_hits + self.prefetch_misses;
54 if total_prefetches == 0 {
55 0.0
56 } else {
57 self.prefetch_hits as f64 / total_prefetches as f64
58 }
59 }
60
61 pub fn sequential_ratio(&self) -> f64 {
63 if self.total_accesses == 0 {
64 0.0
65 } else {
66 self.sequential_accesses as f64 / self.total_accesses as f64
67 }
68 }
69}
70
71#[derive(Debug)]
73struct PrefetchEntry<T> {
74 data: (Tensor<T>, Tensor<T>),
75 timestamp: Instant,
76 access_count: u32,
77}
78
79#[derive(Debug)]
81struct PatternDetector {
82 access_history: VecDeque<usize>,
84 detected_patterns: HashMap<AccessPattern, f64>,
86 max_history: usize,
88 min_pattern_length: usize,
90}
91
92impl PatternDetector {
93 fn new(max_history: usize) -> Self {
94 Self {
95 access_history: VecDeque::new(),
96 detected_patterns: HashMap::new(),
97 max_history,
98 min_pattern_length: 3,
99 }
100 }
101
102 fn record_access(&mut self, index: usize) {
104 self.access_history.push_back(index);
105
106 while self.access_history.len() > self.max_history {
108 self.access_history.pop_front();
109 }
110
111 self.analyze_patterns();
113 }
114
115 fn analyze_patterns(&mut self) {
117 if self.access_history.len() < self.min_pattern_length {
118 return;
119 }
120
121 self.detected_patterns
123 .retain(|_, confidence| *confidence > 0.1);
124
125 self.detect_sequential_pattern();
127
128 self.detect_strided_pattern();
130
131 self.detect_cyclic_pattern();
133 }
134
135 fn detect_sequential_pattern(&mut self) {
137 let history: Vec<_> = self.access_history.iter().cloned().collect();
138 let mut sequential_count = 0;
139
140 for window in history.windows(2) {
141 if window[1] == window[0] + 1 {
142 sequential_count += 1;
143 }
144 }
145
146 let confidence = sequential_count as f64 / (history.len() - 1) as f64;
147 if confidence > 0.7 {
148 self.detected_patterns
149 .insert(AccessPattern::Sequential { stride: 1 }, confidence);
150 }
151 }
152
153 fn detect_strided_pattern(&mut self) {
155 let history: Vec<_> = self.access_history.iter().cloned().collect();
156 if history.len() < 3 {
157 return;
158 }
159
160 for stride in 2..=10 {
162 let mut matches = 0;
163 let start = history[0];
164
165 for (i, &index) in history.iter().enumerate() {
166 if index == start + i * stride {
167 matches += 1;
168 }
169 }
170
171 let confidence = matches as f64 / history.len() as f64;
172 if confidence > 0.8 {
173 self.detected_patterns
174 .insert(AccessPattern::Strided { start, stride }, confidence);
175 }
176 }
177 }
178
179 fn detect_cyclic_pattern(&mut self) {
181 let history: Vec<_> = self.access_history.iter().cloned().collect();
182
183 for pattern_length in self.min_pattern_length..=(history.len() / 2) {
185 if history.len() < pattern_length * 2 {
186 continue;
187 }
188
189 let pattern: Vec<_> = history[history.len() - pattern_length..].to_vec();
190 let mut repeats = 0;
191 let mut total_checks = 0;
192
193 let mut pos = history.len() - pattern_length * 2;
194 while pos < history.len() - pattern_length {
195 total_checks += 1;
196 let segment = &history[pos..pos + pattern_length];
197 if segment == pattern {
198 repeats += 1;
199 }
200 pos += pattern_length;
201 }
202
203 if total_checks > 0 {
204 let confidence = repeats as f64 / total_checks as f64;
205 if confidence > 0.8 {
206 self.detected_patterns
207 .insert(AccessPattern::Cyclic { pattern }, confidence);
208 }
209 }
210 }
211 }
212
213 fn predict_next(&self, current_index: usize, count: usize) -> Vec<usize> {
215 let mut predictions = Vec::new();
216
217 if let Some((pattern, _)) = self.detected_patterns.iter().max_by(|a, b| {
219 a.1.partial_cmp(b.1)
220 .expect("partial_cmp should not return None for valid values")
221 }) {
222 match pattern {
223 AccessPattern::Sequential { stride } => {
224 for i in 1..=count {
225 predictions.push(current_index + i * stride);
226 }
227 }
228 AccessPattern::Strided { start: _, stride } => {
229 let next_in_sequence = current_index + stride;
230 predictions.push(next_in_sequence);
231 for i in 1..count {
232 predictions.push(next_in_sequence + i * stride);
233 }
234 }
235 AccessPattern::Cyclic { pattern } => {
236 if let Some(current_pos) = pattern.iter().position(|&x| x == current_index) {
237 for i in 1..=count {
238 let next_pos = (current_pos + i) % pattern.len();
239 predictions.push(pattern[next_pos]);
240 }
241 }
242 }
243 AccessPattern::Random => {
244 }
246 }
247 }
248
249 predictions
250 }
251
252 pub fn dominant_pattern(&self) -> Option<AccessPattern> {
254 self.detected_patterns
255 .iter()
256 .max_by(|a, b| {
257 a.1.partial_cmp(b.1)
258 .expect("partial_cmp should not return None for valid values")
259 })
260 .map(|(pattern, _)| pattern.clone())
261 }
262}
263
264pub struct PredictivePrefetcher<T, D: Dataset<T>>
266where
267 T: Clone + Send + Sync + 'static,
268 D: Send + Sync + 'static,
269{
270 dataset: Arc<D>,
272 pattern_detector: Arc<RwLock<PatternDetector>>,
274 prefetch_cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
276 config: PrefetchConfig,
278 worker_handle: Option<JoinHandle<()>>,
280 shutdown_signal: Arc<AtomicBool>,
282 stats: Arc<RwLock<AccessStats>>,
284 prefetch_queue: Arc<Mutex<VecDeque<usize>>>,
286}
287
288#[derive(Debug, Clone)]
290pub struct PrefetchConfig {
291 pub max_prefetch_count: usize,
293 pub max_cache_size: usize,
295 pub pattern_history_size: usize,
297 pub cache_ttl: Duration,
299 pub worker_sleep_duration: Duration,
301 pub bandwidth_optimization: bool,
303}
304
305impl Default for PrefetchConfig {
306 fn default() -> Self {
307 Self {
308 max_prefetch_count: 8,
309 max_cache_size: 128,
310 pattern_history_size: 50,
311 cache_ttl: Duration::from_secs(300), worker_sleep_duration: Duration::from_millis(10),
313 bandwidth_optimization: true,
314 }
315 }
316}
317
318impl<T, D> PredictivePrefetcher<T, D>
319where
320 T: Clone + Send + Sync + 'static,
321 D: Dataset<T> + Send + Sync + 'static,
322{
323 pub fn new(dataset: Arc<D>) -> Self {
325 Self::with_config(dataset, PrefetchConfig::default())
326 }
327
328 pub fn with_config(dataset: Arc<D>, config: PrefetchConfig) -> Self {
330 let pattern_detector = Arc::new(RwLock::new(PatternDetector::new(
331 config.pattern_history_size,
332 )));
333 let prefetch_cache = Arc::new(RwLock::new(HashMap::new()));
334 let shutdown_signal = Arc::new(AtomicBool::new(false));
335 let stats = Arc::new(RwLock::new(AccessStats::default()));
336 let prefetch_queue = Arc::new(Mutex::new(VecDeque::new()));
337
338 let worker_handle = Self::start_prefetch_worker(
340 dataset.clone(),
341 prefetch_cache.clone(),
342 prefetch_queue.clone(),
343 shutdown_signal.clone(),
344 config.clone(),
345 stats.clone(),
346 );
347
348 Self {
349 dataset,
350 pattern_detector,
351 prefetch_cache,
352 config,
353 worker_handle: Some(worker_handle),
354 shutdown_signal,
355 stats,
356 prefetch_queue,
357 }
358 }
359
360 fn start_prefetch_worker(
362 dataset: Arc<D>,
363 cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
364 queue: Arc<Mutex<VecDeque<usize>>>,
365 shutdown: Arc<AtomicBool>,
366 config: PrefetchConfig,
367 stats: Arc<RwLock<AccessStats>>,
368 ) -> JoinHandle<()> {
369 thread::spawn(move || {
370 while !shutdown.load(Ordering::Relaxed) {
371 let indices_to_prefetch: Vec<usize> = {
373 let mut queue_guard = queue.lock().expect("lock should not be poisoned");
374 let mut indices = Vec::new();
375
376 for _ in 0..config.max_prefetch_count {
378 if let Some(index) = queue_guard.pop_front() {
379 indices.push(index);
380 } else {
381 break;
382 }
383 }
384 indices
385 };
386
387 for index in indices_to_prefetch {
389 if let Ok(data) = dataset.get(index) {
390 let mut cache_guard =
391 cache.write().expect("write lock should not be poisoned");
392
393 if cache_guard.len() >= config.max_cache_size {
395 let oldest_key = cache_guard
397 .iter()
398 .min_by_key(|(_, entry)| entry.timestamp)
399 .map(|(k, _)| *k);
400
401 if let Some(key) = oldest_key {
402 cache_guard.remove(&key);
403 }
404 }
405
406 cache_guard.insert(
407 index,
408 PrefetchEntry {
409 data,
410 timestamp: Instant::now(),
411 access_count: 0,
412 },
413 );
414
415 let mut stats_guard =
417 stats.write().expect("write lock should not be poisoned");
418 stats_guard.bandwidth_saved +=
419 std::mem::size_of::<(Tensor<T>, Tensor<T>)>() as u64;
420 }
421 }
422
423 {
425 let mut cache_guard = cache.write().expect("write lock should not be poisoned");
426 let now = Instant::now();
427 cache_guard
428 .retain(|_, entry| now.duration_since(entry.timestamp) < config.cache_ttl);
429 }
430
431 thread::sleep(config.worker_sleep_duration);
432 }
433 })
434 }
435
436 pub fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
438 {
440 let mut stats = self
441 .stats
442 .write()
443 .expect("write lock should not be poisoned");
444 stats.total_accesses += 1;
445 }
446
447 {
449 let mut detector = self
450 .pattern_detector
451 .write()
452 .expect("write lock should not be poisoned");
453 detector.record_access(index);
454 }
455
456 {
458 let mut cache = self
459 .prefetch_cache
460 .write()
461 .expect("write lock should not be poisoned");
462 if let Some(entry) = cache.get_mut(&index) {
463 entry.access_count += 1;
464 entry.timestamp = Instant::now(); let mut stats = self
467 .stats
468 .write()
469 .expect("write lock should not be poisoned");
470 stats.prefetch_hits += 1;
471
472 return Ok(entry.data.clone());
473 } else {
474 let mut stats = self
475 .stats
476 .write()
477 .expect("write lock should not be poisoned");
478 stats.prefetch_misses += 1;
479 }
480 }
481
482 self.predict_and_queue_prefetch(index);
484
485 self.dataset.get(index)
487 }
488
489 fn predict_and_queue_prefetch(&self, current_index: usize) {
491 let predictions = {
492 let detector = self
493 .pattern_detector
494 .read()
495 .expect("read lock should not be poisoned");
496 detector.predict_next(current_index, self.config.max_prefetch_count)
497 };
498
499 if !predictions.is_empty() {
500 let mut queue = self
501 .prefetch_queue
502 .lock()
503 .expect("lock should not be poisoned");
504 for predicted_index in predictions {
505 let cache = self
507 .prefetch_cache
508 .read()
509 .expect("read lock should not be poisoned");
510 if !cache.contains_key(&predicted_index) {
511 queue.push_back(predicted_index);
512 }
513 }
514
515 let mut stats = self
517 .stats
518 .write()
519 .expect("write lock should not be poisoned");
520 stats.pattern_hits += 1;
521 } else {
522 let mut stats = self
523 .stats
524 .write()
525 .expect("write lock should not be poisoned");
526 stats.pattern_misses += 1;
527 }
528 }
529
530 pub fn stats(&self) -> AccessStats {
532 self.stats
533 .read()
534 .expect("read lock should not be poisoned")
535 .clone()
536 }
537
538 pub fn dominant_pattern(&self) -> Option<AccessPattern> {
540 self.pattern_detector
541 .read()
542 .expect("read lock should not be poisoned")
543 .dominant_pattern()
544 }
545
546 pub fn clear_cache(&self) {
548 let mut cache = self
549 .prefetch_cache
550 .write()
551 .expect("write lock should not be poisoned");
552 cache.clear();
553 }
554
555 pub fn cache_info(&self) -> (usize, usize) {
557 let cache = self
558 .prefetch_cache
559 .read()
560 .expect("read lock should not be poisoned");
561 (cache.len(), self.config.max_cache_size)
562 }
563}
564
565impl<T, D> Drop for PredictivePrefetcher<T, D>
566where
567 T: Clone + Send + Sync + 'static,
568 D: Dataset<T> + Send + Sync + 'static,
569{
570 fn drop(&mut self) {
571 self.shutdown_signal.store(true, Ordering::Relaxed);
573
574 if let Some(handle) = self.worker_handle.take() {
575 let _ = handle.join();
576 }
577 }
578}
579
580pub struct PredictivePrefetchDataset<T, D: Dataset<T>>
582where
583 T: Clone + Send + Sync + 'static,
584 D: Send + Sync + 'static,
585{
586 prefetcher: PredictivePrefetcher<T, D>,
587}
588
589impl<T, D> PredictivePrefetchDataset<T, D>
590where
591 T: Clone + Send + Sync + 'static,
592 D: Dataset<T> + Send + Sync + 'static,
593{
594 pub fn new(dataset: D) -> Self {
596 Self {
597 prefetcher: PredictivePrefetcher::new(Arc::new(dataset)),
598 }
599 }
600
601 pub fn with_config(dataset: D, config: PrefetchConfig) -> Self {
603 Self {
604 prefetcher: PredictivePrefetcher::with_config(Arc::new(dataset), config),
605 }
606 }
607
608 pub fn stats(&self) -> AccessStats {
610 self.prefetcher.stats()
611 }
612
613 pub fn dominant_pattern(&self) -> Option<AccessPattern> {
615 self.prefetcher.dominant_pattern()
616 }
617}
618
619impl<T, D> Dataset<T> for PredictivePrefetchDataset<T, D>
620where
621 T: Clone + Send + Sync + 'static,
622 D: Dataset<T> + Send + Sync + 'static,
623{
624 fn len(&self) -> usize {
625 self.prefetcher.dataset.len()
626 }
627
628 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
629 self.prefetcher.get(index)
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use crate::TensorDataset;
637 use tenflowers_core::Tensor;
638
639 #[test]
640 fn test_pattern_detector_sequential() {
641 let mut detector = PatternDetector::new(10);
642
643 for i in 0..5 {
645 detector.record_access(i);
646 }
647
648 let dominant = detector.dominant_pattern();
649 assert!(matches!(
650 dominant,
651 Some(AccessPattern::Sequential { stride: 1 })
652 ));
653 }
654
655 #[test]
656 fn test_pattern_detector_strided() {
657 let mut detector = PatternDetector::new(10);
658
659 for i in 0..5 {
661 detector.record_access(i * 2);
662 }
663
664 let dominant = detector.dominant_pattern();
665 assert!(matches!(
666 dominant,
667 Some(AccessPattern::Strided {
668 start: 0,
669 stride: 2
670 })
671 ));
672 }
673
674 #[test]
675 fn test_predictive_prefetcher() {
676 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2])
678 .expect("test: tensor creation should succeed");
679 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
680 .expect("test: tensor creation should succeed");
681 let dataset = Arc::new(TensorDataset::new(features, labels));
682
683 let config = PrefetchConfig {
684 max_prefetch_count: 2,
685 max_cache_size: 10,
686 pattern_history_size: 10,
687 cache_ttl: Duration::from_secs(60),
688 worker_sleep_duration: Duration::from_millis(1),
689 bandwidth_optimization: true,
690 };
691
692 let prefetcher = PredictivePrefetcher::with_config(dataset, config);
693
694 let _ = prefetcher.get(0).expect("index should be in bounds");
696 let _ = prefetcher.get(1).expect("index should be in bounds");
697 let _ = prefetcher.get(2).expect("index should be in bounds");
698
699 thread::sleep(Duration::from_millis(50));
701
702 let stats = prefetcher.stats();
703 assert!(stats.total_accesses >= 3);
704 }
705
706 #[test]
707 fn test_predictive_prefetch_dataset() {
708 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
709 .expect("test: tensor creation should succeed");
710 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
711 .expect("test: tensor creation should succeed");
712 let base_dataset = TensorDataset::new(features, labels);
713
714 let dataset = PredictivePrefetchDataset::new(base_dataset);
715
716 assert_eq!(dataset.len(), 2);
717
718 let (feat, label) = dataset.get(0).expect("index should be in bounds");
719 assert_eq!(feat.shape().dims(), &[2]);
720 assert_eq!(label.shape().dims(), &[] as &[usize]);
721
722 let stats = dataset.stats();
723 assert_eq!(stats.total_accesses, 1);
724 }
725}