Skip to main content

tenflowers_dataset/
predictive_prefetch.rs

1//! Predictive prefetching with access pattern learning
2//!
3//! This module provides intelligent prefetching capabilities that learn from
4//! access patterns to predict and preload data before it's requested.
5
6use crate::Dataset;
7use std::collections::{HashMap, VecDeque};
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10use std::thread::{self, JoinHandle};
11use std::time::{Duration, Instant};
12use tenflowers_core::{Result, Tensor};
13
14/// Access pattern types detected by the system
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum AccessPattern {
17    /// Sequential access (i, i+1, i+2, ...)
18    Sequential { stride: usize },
19    /// Random access with no detectable pattern
20    Random,
21    /// Repeating pattern (cycles through a set of indices)
22    Cyclic { pattern: Vec<usize> },
23    /// Strided access (i, i+k, i+2k, ...)
24    Strided { start: usize, stride: usize },
25}
26
27/// Statistics about access patterns
28#[derive(Debug, Clone, Default)]
29pub struct AccessStats {
30    pub total_accesses: u64,
31    pub sequential_accesses: u64,
32    pub random_accesses: u64,
33    pub pattern_hits: u64,
34    pub pattern_misses: u64,
35    pub prefetch_hits: u64,
36    pub prefetch_misses: u64,
37    pub bandwidth_saved: u64, // Bytes saved by avoiding disk I/O
38}
39
40impl AccessStats {
41    /// Calculate pattern prediction accuracy
42    pub fn pattern_accuracy(&self) -> f64 {
43        let total_predictions = self.pattern_hits + self.pattern_misses;
44        if total_predictions == 0 {
45            0.0
46        } else {
47            self.pattern_hits as f64 / total_predictions as f64
48        }
49    }
50
51    /// Calculate prefetch efficiency
52    pub fn prefetch_efficiency(&self) -> f64 {
53        let total_prefetches = self.prefetch_hits + self.prefetch_misses;
54        if total_prefetches == 0 {
55            0.0
56        } else {
57            self.prefetch_hits as f64 / total_prefetches as f64
58        }
59    }
60
61    /// Calculate sequential access ratio
62    pub fn sequential_ratio(&self) -> f64 {
63        if self.total_accesses == 0 {
64            0.0
65        } else {
66            self.sequential_accesses as f64 / self.total_accesses as f64
67        }
68    }
69}
70
71/// Prefetch entry stored in cache
72#[derive(Debug)]
73struct PrefetchEntry<T> {
74    data: (Tensor<T>, Tensor<T>),
75    timestamp: Instant,
76    access_count: u32,
77}
78
79/// Pattern detector that learns access patterns
80#[derive(Debug)]
81struct PatternDetector {
82    /// Recent access history
83    access_history: VecDeque<usize>,
84    /// Detected patterns and their confidence scores
85    detected_patterns: HashMap<AccessPattern, f64>,
86    /// Maximum history size to keep
87    max_history: usize,
88    /// Minimum pattern length to detect
89    min_pattern_length: usize,
90}
91
92impl PatternDetector {
93    fn new(max_history: usize) -> Self {
94        Self {
95            access_history: VecDeque::new(),
96            detected_patterns: HashMap::new(),
97            max_history,
98            min_pattern_length: 3,
99        }
100    }
101
102    /// Record a new access
103    fn record_access(&mut self, index: usize) {
104        self.access_history.push_back(index);
105
106        // Trim history if too long
107        while self.access_history.len() > self.max_history {
108            self.access_history.pop_front();
109        }
110
111        // Analyze patterns
112        self.analyze_patterns();
113    }
114
115    /// Analyze recent accesses to detect patterns
116    fn analyze_patterns(&mut self) {
117        if self.access_history.len() < self.min_pattern_length {
118            return;
119        }
120
121        // Clear old patterns with low confidence
122        self.detected_patterns
123            .retain(|_, confidence| *confidence > 0.1);
124
125        // Check for sequential pattern
126        self.detect_sequential_pattern();
127
128        // Check for strided pattern
129        self.detect_strided_pattern();
130
131        // Check for cyclic pattern
132        self.detect_cyclic_pattern();
133    }
134
135    /// Detect sequential access pattern
136    fn detect_sequential_pattern(&mut self) {
137        let history: Vec<_> = self.access_history.iter().cloned().collect();
138        let mut sequential_count = 0;
139
140        for window in history.windows(2) {
141            if window[1] == window[0] + 1 {
142                sequential_count += 1;
143            }
144        }
145
146        let confidence = sequential_count as f64 / (history.len() - 1) as f64;
147        if confidence > 0.7 {
148            self.detected_patterns
149                .insert(AccessPattern::Sequential { stride: 1 }, confidence);
150        }
151    }
152
153    /// Detect strided access pattern
154    fn detect_strided_pattern(&mut self) {
155        let history: Vec<_> = self.access_history.iter().cloned().collect();
156        if history.len() < 3 {
157            return;
158        }
159
160        // Try different stride values
161        for stride in 2..=10 {
162            let mut matches = 0;
163            let start = history[0];
164
165            for (i, &index) in history.iter().enumerate() {
166                if index == start + i * stride {
167                    matches += 1;
168                }
169            }
170
171            let confidence = matches as f64 / history.len() as f64;
172            if confidence > 0.8 {
173                self.detected_patterns
174                    .insert(AccessPattern::Strided { start, stride }, confidence);
175            }
176        }
177    }
178
179    /// Detect cyclic access pattern
180    fn detect_cyclic_pattern(&mut self) {
181        let history: Vec<_> = self.access_history.iter().cloned().collect();
182
183        // Look for repeating subsequences
184        for pattern_length in self.min_pattern_length..=(history.len() / 2) {
185            if history.len() < pattern_length * 2 {
186                continue;
187            }
188
189            let pattern: Vec<_> = history[history.len() - pattern_length..].to_vec();
190            let mut repeats = 0;
191            let mut total_checks = 0;
192
193            let mut pos = history.len() - pattern_length * 2;
194            while pos < history.len() - pattern_length {
195                total_checks += 1;
196                let segment = &history[pos..pos + pattern_length];
197                if segment == pattern {
198                    repeats += 1;
199                }
200                pos += pattern_length;
201            }
202
203            if total_checks > 0 {
204                let confidence = repeats as f64 / total_checks as f64;
205                if confidence > 0.8 {
206                    self.detected_patterns
207                        .insert(AccessPattern::Cyclic { pattern }, confidence);
208                }
209            }
210        }
211    }
212
213    /// Predict next indices based on detected patterns
214    fn predict_next(&self, current_index: usize, count: usize) -> Vec<usize> {
215        let mut predictions = Vec::new();
216
217        // Use the pattern with highest confidence
218        if let Some((pattern, _)) = self.detected_patterns.iter().max_by(|a, b| {
219            a.1.partial_cmp(b.1)
220                .expect("partial_cmp should not return None for valid values")
221        }) {
222            match pattern {
223                AccessPattern::Sequential { stride } => {
224                    for i in 1..=count {
225                        predictions.push(current_index + i * stride);
226                    }
227                }
228                AccessPattern::Strided { start: _, stride } => {
229                    let next_in_sequence = current_index + stride;
230                    predictions.push(next_in_sequence);
231                    for i in 1..count {
232                        predictions.push(next_in_sequence + i * stride);
233                    }
234                }
235                AccessPattern::Cyclic { pattern } => {
236                    if let Some(current_pos) = pattern.iter().position(|&x| x == current_index) {
237                        for i in 1..=count {
238                            let next_pos = (current_pos + i) % pattern.len();
239                            predictions.push(pattern[next_pos]);
240                        }
241                    }
242                }
243                AccessPattern::Random => {
244                    // No predictions for random access
245                }
246            }
247        }
248
249        predictions
250    }
251
252    /// Get the most confident pattern
253    pub fn dominant_pattern(&self) -> Option<AccessPattern> {
254        self.detected_patterns
255            .iter()
256            .max_by(|a, b| {
257                a.1.partial_cmp(b.1)
258                    .expect("partial_cmp should not return None for valid values")
259            })
260            .map(|(pattern, _)| pattern.clone())
261    }
262}
263
264/// Predictive prefetcher that learns access patterns and preloads data
265pub struct PredictivePrefetcher<T, D: Dataset<T>>
266where
267    T: Clone + Send + Sync + 'static,
268    D: Send + Sync + 'static,
269{
270    /// Reference to the dataset
271    dataset: Arc<D>,
272    /// Pattern detector for learning access patterns
273    pattern_detector: Arc<RwLock<PatternDetector>>,
274    /// Prefetch cache
275    prefetch_cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
276    /// Configuration
277    config: PrefetchConfig,
278    /// Background prefetch worker
279    worker_handle: Option<JoinHandle<()>>,
280    /// Shutdown signal
281    shutdown_signal: Arc<AtomicBool>,
282    /// Statistics
283    stats: Arc<RwLock<AccessStats>>,
284    /// Pending prefetch requests
285    prefetch_queue: Arc<Mutex<VecDeque<usize>>>,
286}
287
288/// Configuration for predictive prefetching
289#[derive(Debug, Clone)]
290pub struct PrefetchConfig {
291    /// Maximum number of items to prefetch ahead
292    pub max_prefetch_count: usize,
293    /// Maximum cache size (number of items)
294    pub max_cache_size: usize,
295    /// Pattern detection history size
296    pub pattern_history_size: usize,
297    /// Cache entry TTL (time to live)
298    pub cache_ttl: Duration,
299    /// Prefetch worker sleep duration when idle
300    pub worker_sleep_duration: Duration,
301    /// Enable bandwidth optimization
302    pub bandwidth_optimization: bool,
303}
304
305impl Default for PrefetchConfig {
306    fn default() -> Self {
307        Self {
308            max_prefetch_count: 8,
309            max_cache_size: 128,
310            pattern_history_size: 50,
311            cache_ttl: Duration::from_secs(300), // 5 minutes
312            worker_sleep_duration: Duration::from_millis(10),
313            bandwidth_optimization: true,
314        }
315    }
316}
317
318impl<T, D> PredictivePrefetcher<T, D>
319where
320    T: Clone + Send + Sync + 'static,
321    D: Dataset<T> + Send + Sync + 'static,
322{
323    /// Create a new predictive prefetcher
324    pub fn new(dataset: Arc<D>) -> Self {
325        Self::with_config(dataset, PrefetchConfig::default())
326    }
327
328    /// Create a new predictive prefetcher with custom configuration
329    pub fn with_config(dataset: Arc<D>, config: PrefetchConfig) -> Self {
330        let pattern_detector = Arc::new(RwLock::new(PatternDetector::new(
331            config.pattern_history_size,
332        )));
333        let prefetch_cache = Arc::new(RwLock::new(HashMap::new()));
334        let shutdown_signal = Arc::new(AtomicBool::new(false));
335        let stats = Arc::new(RwLock::new(AccessStats::default()));
336        let prefetch_queue = Arc::new(Mutex::new(VecDeque::new()));
337
338        // Start background prefetch worker
339        let worker_handle = Self::start_prefetch_worker(
340            dataset.clone(),
341            prefetch_cache.clone(),
342            prefetch_queue.clone(),
343            shutdown_signal.clone(),
344            config.clone(),
345            stats.clone(),
346        );
347
348        Self {
349            dataset,
350            pattern_detector,
351            prefetch_cache,
352            config,
353            worker_handle: Some(worker_handle),
354            shutdown_signal,
355            stats,
356            prefetch_queue,
357        }
358    }
359
360    /// Start the background prefetch worker
361    fn start_prefetch_worker(
362        dataset: Arc<D>,
363        cache: Arc<RwLock<HashMap<usize, PrefetchEntry<T>>>>,
364        queue: Arc<Mutex<VecDeque<usize>>>,
365        shutdown: Arc<AtomicBool>,
366        config: PrefetchConfig,
367        stats: Arc<RwLock<AccessStats>>,
368    ) -> JoinHandle<()> {
369        thread::spawn(move || {
370            while !shutdown.load(Ordering::Relaxed) {
371                // Process prefetch requests
372                let indices_to_prefetch: Vec<usize> = {
373                    let mut queue_guard = queue.lock().expect("lock should not be poisoned");
374                    let mut indices = Vec::new();
375
376                    // Take up to max_prefetch_count items
377                    for _ in 0..config.max_prefetch_count {
378                        if let Some(index) = queue_guard.pop_front() {
379                            indices.push(index);
380                        } else {
381                            break;
382                        }
383                    }
384                    indices
385                };
386
387                // Prefetch the data
388                for index in indices_to_prefetch {
389                    if let Ok(data) = dataset.get(index) {
390                        let mut cache_guard =
391                            cache.write().expect("write lock should not be poisoned");
392
393                        // Check cache size limit
394                        if cache_guard.len() >= config.max_cache_size {
395                            // Remove oldest entries
396                            let oldest_key = cache_guard
397                                .iter()
398                                .min_by_key(|(_, entry)| entry.timestamp)
399                                .map(|(k, _)| *k);
400
401                            if let Some(key) = oldest_key {
402                                cache_guard.remove(&key);
403                            }
404                        }
405
406                        cache_guard.insert(
407                            index,
408                            PrefetchEntry {
409                                data,
410                                timestamp: Instant::now(),
411                                access_count: 0,
412                            },
413                        );
414
415                        // Update stats
416                        let mut stats_guard =
417                            stats.write().expect("write lock should not be poisoned");
418                        stats_guard.bandwidth_saved +=
419                            std::mem::size_of::<(Tensor<T>, Tensor<T>)>() as u64;
420                    }
421                }
422
423                // Clean up expired entries
424                {
425                    let mut cache_guard = cache.write().expect("write lock should not be poisoned");
426                    let now = Instant::now();
427                    cache_guard
428                        .retain(|_, entry| now.duration_since(entry.timestamp) < config.cache_ttl);
429                }
430
431                thread::sleep(config.worker_sleep_duration);
432            }
433        })
434    }
435
436    /// Get data with predictive prefetching
437    pub fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
438        // Update statistics
439        {
440            let mut stats = self
441                .stats
442                .write()
443                .expect("write lock should not be poisoned");
444            stats.total_accesses += 1;
445        }
446
447        // Record access pattern
448        {
449            let mut detector = self
450                .pattern_detector
451                .write()
452                .expect("write lock should not be poisoned");
453            detector.record_access(index);
454        }
455
456        // Check cache first
457        {
458            let mut cache = self
459                .prefetch_cache
460                .write()
461                .expect("write lock should not be poisoned");
462            if let Some(entry) = cache.get_mut(&index) {
463                entry.access_count += 1;
464                entry.timestamp = Instant::now(); // Update LRU
465
466                let mut stats = self
467                    .stats
468                    .write()
469                    .expect("write lock should not be poisoned");
470                stats.prefetch_hits += 1;
471
472                return Ok(entry.data.clone());
473            } else {
474                let mut stats = self
475                    .stats
476                    .write()
477                    .expect("write lock should not be poisoned");
478                stats.prefetch_misses += 1;
479            }
480        }
481
482        // Predict and queue future accesses
483        self.predict_and_queue_prefetch(index);
484
485        // Get data from dataset
486        self.dataset.get(index)
487    }
488
489    /// Predict future accesses and queue them for prefetching
490    fn predict_and_queue_prefetch(&self, current_index: usize) {
491        let predictions = {
492            let detector = self
493                .pattern_detector
494                .read()
495                .expect("read lock should not be poisoned");
496            detector.predict_next(current_index, self.config.max_prefetch_count)
497        };
498
499        if !predictions.is_empty() {
500            let mut queue = self
501                .prefetch_queue
502                .lock()
503                .expect("lock should not be poisoned");
504            for predicted_index in predictions {
505                // Only queue if not already cached
506                let cache = self
507                    .prefetch_cache
508                    .read()
509                    .expect("read lock should not be poisoned");
510                if !cache.contains_key(&predicted_index) {
511                    queue.push_back(predicted_index);
512                }
513            }
514
515            // Update pattern prediction stats
516            let mut stats = self
517                .stats
518                .write()
519                .expect("write lock should not be poisoned");
520            stats.pattern_hits += 1;
521        } else {
522            let mut stats = self
523                .stats
524                .write()
525                .expect("write lock should not be poisoned");
526            stats.pattern_misses += 1;
527        }
528    }
529
530    /// Get current statistics
531    pub fn stats(&self) -> AccessStats {
532        self.stats
533            .read()
534            .expect("read lock should not be poisoned")
535            .clone()
536    }
537
538    /// Get the dominant access pattern
539    pub fn dominant_pattern(&self) -> Option<AccessPattern> {
540        self.pattern_detector
541            .read()
542            .expect("read lock should not be poisoned")
543            .dominant_pattern()
544    }
545
546    /// Clear the prefetch cache
547    pub fn clear_cache(&self) {
548        let mut cache = self
549            .prefetch_cache
550            .write()
551            .expect("write lock should not be poisoned");
552        cache.clear();
553    }
554
555    /// Get cache statistics
556    pub fn cache_info(&self) -> (usize, usize) {
557        let cache = self
558            .prefetch_cache
559            .read()
560            .expect("read lock should not be poisoned");
561        (cache.len(), self.config.max_cache_size)
562    }
563}
564
565impl<T, D> Drop for PredictivePrefetcher<T, D>
566where
567    T: Clone + Send + Sync + 'static,
568    D: Dataset<T> + Send + Sync + 'static,
569{
570    fn drop(&mut self) {
571        // Signal shutdown and wait for worker to finish
572        self.shutdown_signal.store(true, Ordering::Relaxed);
573
574        if let Some(handle) = self.worker_handle.take() {
575            let _ = handle.join();
576        }
577    }
578}
579
580/// Dataset wrapper that provides predictive prefetching
581pub struct PredictivePrefetchDataset<T, D: Dataset<T>>
582where
583    T: Clone + Send + Sync + 'static,
584    D: Send + Sync + 'static,
585{
586    prefetcher: PredictivePrefetcher<T, D>,
587}
588
589impl<T, D> PredictivePrefetchDataset<T, D>
590where
591    T: Clone + Send + Sync + 'static,
592    D: Dataset<T> + Send + Sync + 'static,
593{
594    /// Create a new predictive prefetch dataset
595    pub fn new(dataset: D) -> Self {
596        Self {
597            prefetcher: PredictivePrefetcher::new(Arc::new(dataset)),
598        }
599    }
600
601    /// Create with custom configuration
602    pub fn with_config(dataset: D, config: PrefetchConfig) -> Self {
603        Self {
604            prefetcher: PredictivePrefetcher::with_config(Arc::new(dataset), config),
605        }
606    }
607
608    /// Get access statistics
609    pub fn stats(&self) -> AccessStats {
610        self.prefetcher.stats()
611    }
612
613    /// Get the dominant access pattern
614    pub fn dominant_pattern(&self) -> Option<AccessPattern> {
615        self.prefetcher.dominant_pattern()
616    }
617}
618
619impl<T, D> Dataset<T> for PredictivePrefetchDataset<T, D>
620where
621    T: Clone + Send + Sync + 'static,
622    D: Dataset<T> + Send + Sync + 'static,
623{
624    fn len(&self) -> usize {
625        self.prefetcher.dataset.len()
626    }
627
628    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
629        self.prefetcher.get(index)
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use crate::TensorDataset;
637    use tenflowers_core::Tensor;
638
639    #[test]
640    fn test_pattern_detector_sequential() {
641        let mut detector = PatternDetector::new(10);
642
643        // Create sequential access pattern
644        for i in 0..5 {
645            detector.record_access(i);
646        }
647
648        let dominant = detector.dominant_pattern();
649        assert!(matches!(
650            dominant,
651            Some(AccessPattern::Sequential { stride: 1 })
652        ));
653    }
654
655    #[test]
656    fn test_pattern_detector_strided() {
657        let mut detector = PatternDetector::new(10);
658
659        // Create strided access pattern (0, 2, 4, 6, 8)
660        for i in 0..5 {
661            detector.record_access(i * 2);
662        }
663
664        let dominant = detector.dominant_pattern();
665        assert!(matches!(
666            dominant,
667            Some(AccessPattern::Strided {
668                start: 0,
669                stride: 2
670            })
671        ));
672    }
673
674    #[test]
675    fn test_predictive_prefetcher() {
676        // Create test dataset
677        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2])
678            .expect("test: tensor creation should succeed");
679        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
680            .expect("test: tensor creation should succeed");
681        let dataset = Arc::new(TensorDataset::new(features, labels));
682
683        let config = PrefetchConfig {
684            max_prefetch_count: 2,
685            max_cache_size: 10,
686            pattern_history_size: 10,
687            cache_ttl: Duration::from_secs(60),
688            worker_sleep_duration: Duration::from_millis(1),
689            bandwidth_optimization: true,
690        };
691
692        let prefetcher = PredictivePrefetcher::with_config(dataset, config);
693
694        // Access in sequential pattern
695        let _ = prefetcher.get(0).expect("index should be in bounds");
696        let _ = prefetcher.get(1).expect("index should be in bounds");
697        let _ = prefetcher.get(2).expect("index should be in bounds");
698
699        // Give prefetcher time to work
700        thread::sleep(Duration::from_millis(50));
701
702        let stats = prefetcher.stats();
703        assert!(stats.total_accesses >= 3);
704    }
705
706    #[test]
707    fn test_predictive_prefetch_dataset() {
708        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
709            .expect("test: tensor creation should succeed");
710        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
711            .expect("test: tensor creation should succeed");
712        let base_dataset = TensorDataset::new(features, labels);
713
714        let dataset = PredictivePrefetchDataset::new(base_dataset);
715
716        assert_eq!(dataset.len(), 2);
717
718        let (feat, label) = dataset.get(0).expect("index should be in bounds");
719        assert_eq!(feat.shape().dims(), &[2]);
720        assert_eq!(label.shape().dims(), &[] as &[usize]);
721
722        let stats = dataset.stats();
723        assert_eq!(stats.total_accesses, 1);
724    }
725}