Skip to main content

oxigdal_cache_advanced/predictive/
mod.rs

1//! Predictive prefetching with ML
2//!
3//! Learns access patterns and predicts future cache accesses:
4//! - Temporal pattern learning
5//! - Spatial pattern learning
6//! - Markov chain prediction
7//! - Neural network predictor
8//! - Confidence-based prefetch decisions
9//! - Advanced ML models (Transformer, LSTM, Hybrid)
10
11pub mod advanced;
12
13use crate::multi_tier::CacheKey;
14use scirs2_core::ndarray::{Array1, Array2};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19/// Generate normal distributed random number using Box-Muller transform
20fn rand_normal(mean: f64, std_dev: f64) -> f64 {
21    let u1 = fastrand::f64();
22    let u2 = fastrand::f64();
23    // Avoid log(0) by ensuring u1 > 0
24    let u1 = if u1 < 1e-10 { 1e-10 } else { u1 };
25    let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
26    mean + z0 * std_dev
27}
28
29/// Access pattern record
30#[derive(Debug, Clone)]
31pub struct AccessRecord {
32    /// Key accessed
33    pub key: CacheKey,
34    /// Timestamp
35    pub timestamp: chrono::DateTime<chrono::Utc>,
36    /// Access type (read/write)
37    pub access_type: AccessType,
38}
39
40/// Access type
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum AccessType {
43    /// Read access
44    Read,
45    /// Write access
46    Write,
47}
48
49/// Prediction with confidence
50#[derive(Debug, Clone)]
51pub struct Prediction {
52    /// Predicted key
53    pub key: CacheKey,
54    /// Confidence score (0.0 - 1.0)
55    pub confidence: f64,
56    /// Predicted access time
57    pub predicted_time: Option<chrono::DateTime<chrono::Utc>>,
58}
59
60impl Prediction {
61    /// Check if prediction is confident enough
62    pub fn is_confident(&self, threshold: f64) -> bool {
63        self.confidence >= threshold
64    }
65}
66
67/// Markov chain predictor
68/// Predicts next access based on current state
69pub struct MarkovPredictor {
70    /// Transition matrix (key -> next_key -> probability)
71    transitions: HashMap<CacheKey, HashMap<CacheKey, f64>>,
72    /// Transition counts for learning
73    transition_counts: HashMap<CacheKey, HashMap<CacheKey, u64>>,
74    /// Current state
75    current_key: Option<CacheKey>,
76    /// Order of Markov chain (n-gram size)
77    order: usize,
78    /// Recent access history
79    history: VecDeque<CacheKey>,
80}
81
82impl MarkovPredictor {
83    /// Create new Markov predictor
84    pub fn new(order: usize) -> Self {
85        Self {
86            transitions: HashMap::new(),
87            transition_counts: HashMap::new(),
88            current_key: None,
89            order,
90            history: VecDeque::with_capacity(order),
91        }
92    }
93
94    /// Record an access
95    pub fn record_access(&mut self, key: CacheKey) {
96        if let Some(prev_key) = self.current_key.clone() {
97            // Update transition counts
98            let next_counts = self.transition_counts.entry(prev_key.clone()).or_default();
99
100            *next_counts.entry(key.clone()).or_insert(0) += 1;
101
102            // Rebuild probabilities
103            self.update_probabilities(&prev_key);
104        }
105
106        // Update history
107        if self.history.len() >= self.order {
108            self.history.pop_front();
109        }
110        self.history.push_back(key.clone());
111
112        self.current_key = Some(key);
113    }
114
115    /// Update transition probabilities for a key
116    fn update_probabilities(&mut self, from_key: &CacheKey) {
117        if let Some(counts) = self.transition_counts.get(from_key) {
118            let total: u64 = counts.values().sum();
119
120            if total > 0 {
121                let probabilities: HashMap<CacheKey, f64> = counts
122                    .iter()
123                    .map(|(k, count)| (k.clone(), *count as f64 / total as f64))
124                    .collect();
125
126                self.transitions.insert(from_key.clone(), probabilities);
127            }
128        }
129    }
130
131    /// Predict next keys
132    pub fn predict(&self, top_n: usize) -> Vec<Prediction> {
133        if let Some(current) = &self.current_key {
134            if let Some(transitions) = self.transitions.get(current) {
135                let mut predictions: Vec<_> = transitions
136                    .iter()
137                    .map(|(key, prob)| Prediction {
138                        key: key.clone(),
139                        confidence: *prob,
140                        predicted_time: None,
141                    })
142                    .collect();
143
144                predictions.sort_by(|a, b| {
145                    b.confidence
146                        .partial_cmp(&a.confidence)
147                        .unwrap_or(std::cmp::Ordering::Equal)
148                });
149
150                predictions.truncate(top_n);
151                return predictions;
152            }
153        }
154
155        Vec::new()
156    }
157
158    /// Get number of states in the model
159    pub fn state_count(&self) -> usize {
160        self.transitions.len()
161    }
162
163    /// Clear the model
164    pub fn clear(&mut self) {
165        self.transitions.clear();
166        self.transition_counts.clear();
167        self.current_key = None;
168        self.history.clear();
169    }
170}
171
172/// Temporal pattern detector
173/// Detects periodic access patterns
174pub struct TemporalPatternDetector {
175    /// Access history with timestamps
176    access_history: VecDeque<(CacheKey, chrono::DateTime<chrono::Utc>)>,
177    /// Maximum history size
178    max_history: usize,
179    /// Detected patterns (key -> period in seconds)
180    patterns: HashMap<CacheKey, Vec<i64>>,
181}
182
183impl TemporalPatternDetector {
184    /// Create new temporal pattern detector
185    pub fn new(max_history: usize) -> Self {
186        Self {
187            access_history: VecDeque::with_capacity(max_history),
188            max_history,
189            patterns: HashMap::new(),
190        }
191    }
192
193    /// Record an access
194    pub fn record_access(&mut self, key: CacheKey, timestamp: chrono::DateTime<chrono::Utc>) {
195        if self.access_history.len() >= self.max_history {
196            self.access_history.pop_front();
197        }
198
199        self.access_history.push_back((key.clone(), timestamp));
200
201        // Detect patterns for this key
202        self.detect_pattern(&key);
203    }
204
205    /// Detect access pattern for a key
206    fn detect_pattern(&mut self, key: &CacheKey) {
207        let accesses: Vec<_> = self
208            .access_history
209            .iter()
210            .filter(|(k, _)| k == key)
211            .map(|(_, ts)| *ts)
212            .collect();
213
214        if accesses.len() < 3 {
215            return;
216        }
217
218        // Calculate intervals between accesses
219        let mut intervals = Vec::new();
220        for i in 1..accesses.len() {
221            let interval = (accesses[i] - accesses[i - 1]).num_seconds();
222            intervals.push(interval);
223        }
224
225        // Store intervals as pattern
226        self.patterns.insert(key.clone(), intervals);
227    }
228
229    /// Predict next access time for a key
230    pub fn predict_next_access(&self, key: &CacheKey) -> Option<chrono::DateTime<chrono::Utc>> {
231        if let Some(intervals) = self.patterns.get(key) {
232            if intervals.is_empty() {
233                return None;
234            }
235
236            // Use median interval as prediction
237            let mut sorted_intervals = intervals.clone();
238            sorted_intervals.sort();
239            let median_interval = sorted_intervals[sorted_intervals.len() / 2];
240
241            // Find last access time
242            let last_access = self
243                .access_history
244                .iter()
245                .rev()
246                .find(|(k, _)| k == key)
247                .map(|(_, ts)| *ts);
248
249            if let Some(last) = last_access {
250                return Some(last + chrono::Duration::seconds(median_interval));
251            }
252        }
253
254        None
255    }
256
257    /// Get prediction with confidence
258    pub fn predict(&self, key: &CacheKey) -> Option<Prediction> {
259        if let Some(next_time) = self.predict_next_access(key) {
260            let intervals = self.patterns.get(key)?;
261
262            // Calculate confidence based on pattern stability
263            let confidence = if intervals.len() < 2 {
264                0.5
265            } else {
266                let mean: f64 =
267                    intervals.iter().map(|&x| x as f64).sum::<f64>() / intervals.len() as f64;
268                let variance: f64 = intervals
269                    .iter()
270                    .map(|&x| {
271                        let diff = x as f64 - mean;
272                        diff * diff
273                    })
274                    .sum::<f64>()
275                    / intervals.len() as f64;
276
277                let std_dev = variance.sqrt();
278                let cv = if mean > 0.0 { std_dev / mean } else { 1.0 };
279
280                // Lower coefficient of variation = higher confidence
281                (1.0 / (1.0 + cv)).clamp(0.0, 1.0)
282            };
283
284            Some(Prediction {
285                key: key.clone(),
286                confidence,
287                predicted_time: Some(next_time),
288            })
289        } else {
290            None
291        }
292    }
293
294    /// Clear patterns
295    pub fn clear(&mut self) {
296        self.access_history.clear();
297        self.patterns.clear();
298    }
299}
300
301/// Spatial pattern detector
302/// Detects related keys that are often accessed together
303pub struct SpatialPatternDetector {
304    /// Co-occurrence matrix (key1 -> key2 -> count)
305    co_occurrences: HashMap<CacheKey, HashMap<CacheKey, u64>>,
306    /// Recent access window
307    window: VecDeque<CacheKey>,
308    /// Window size
309    window_size: usize,
310}
311
312impl SpatialPatternDetector {
313    /// Create new spatial pattern detector
314    pub fn new(window_size: usize) -> Self {
315        Self {
316            co_occurrences: HashMap::new(),
317            window: VecDeque::with_capacity(window_size),
318            window_size,
319        }
320    }
321
322    /// Record an access
323    pub fn record_access(&mut self, key: CacheKey) {
324        // Update co-occurrences with all keys in current window (bidirectional)
325        for other_key in &self.window {
326            // Record key -> other_key
327            let co_occurs = self.co_occurrences.entry(key.clone()).or_default();
328            *co_occurs.entry(other_key.clone()).or_insert(0) += 1;
329
330            // Record other_key -> key (bidirectional)
331            let co_occurs_reverse = self.co_occurrences.entry(other_key.clone()).or_default();
332            *co_occurs_reverse.entry(key.clone()).or_insert(0) += 1;
333        }
334
335        // Update window
336        if self.window.len() >= self.window_size {
337            self.window.pop_front();
338        }
339        self.window.push_back(key);
340    }
341
342    /// Get related keys
343    pub fn get_related_keys(&self, key: &CacheKey, top_n: usize) -> Vec<Prediction> {
344        if let Some(co_occurs) = self.co_occurrences.get(key) {
345            let total: u64 = co_occurs.values().sum();
346
347            if total == 0 {
348                return Vec::new();
349            }
350
351            let mut predictions: Vec<_> = co_occurs
352                .iter()
353                .map(|(k, count)| Prediction {
354                    key: k.clone(),
355                    confidence: *count as f64 / total as f64,
356                    predicted_time: None,
357                })
358                .collect();
359
360            predictions.sort_by(|a, b| {
361                b.confidence
362                    .partial_cmp(&a.confidence)
363                    .unwrap_or(std::cmp::Ordering::Equal)
364            });
365
366            predictions.truncate(top_n);
367            predictions
368        } else {
369            Vec::new()
370        }
371    }
372
373    /// Clear patterns
374    pub fn clear(&mut self) {
375        self.co_occurrences.clear();
376        self.window.clear();
377    }
378}
379
380/// Simple neural network predictor
381/// Uses a feed-forward network to predict access patterns
382pub struct NeuralPredictor {
383    /// Input size (vocabulary size)
384    vocab_size: usize,
385    /// Hidden layer size
386    hidden_size: usize,
387    /// Weights for input to hidden layer
388    w1: Option<Array2<f64>>,
389    /// Weights for hidden to output layer
390    w2: Option<Array2<f64>>,
391    /// Bias for hidden layer
392    b1: Option<Array1<f64>>,
393    /// Bias for output layer
394    b2: Option<Array1<f64>>,
395    /// Key to index mapping
396    key_to_idx: HashMap<CacheKey, usize>,
397    /// Index to key mapping
398    idx_to_key: Vec<CacheKey>,
399    /// Learning rate
400    #[allow(dead_code)]
401    learning_rate: f64,
402    /// Training enabled
403    #[allow(dead_code)]
404    training_enabled: bool,
405}
406
407impl NeuralPredictor {
408    /// Create new neural predictor
409    pub fn new(hidden_size: usize) -> Self {
410        Self {
411            vocab_size: 0,
412            hidden_size,
413            w1: None,
414            w2: None,
415            b1: None,
416            b2: None,
417            key_to_idx: HashMap::new(),
418            idx_to_key: Vec::new(),
419            learning_rate: 0.01,
420            training_enabled: true,
421        }
422    }
423
424    /// Add key to vocabulary
425    fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
426        if let Some(&idx) = self.key_to_idx.get(key) {
427            idx
428        } else {
429            let idx = self.vocab_size;
430            self.key_to_idx.insert(key.clone(), idx);
431            self.idx_to_key.push(key.clone());
432            self.vocab_size += 1;
433
434            // Reinitialize weights if needed
435            if self.vocab_size > 0 {
436                self.initialize_weights();
437            }
438
439            idx
440        }
441    }
442
443    /// Initialize weights
444    fn initialize_weights(&mut self) {
445        // Seed fastrand for reproducibility
446        fastrand::seed(42);
447
448        // Xavier initialization
449        let scale_w1 = (2.0 / (self.vocab_size + self.hidden_size) as f64).sqrt();
450        let scale_w2 = (2.0 / (self.hidden_size + self.vocab_size) as f64).sqrt();
451
452        let w1_data: Vec<f64> = (0..self.vocab_size * self.hidden_size)
453            .map(|_| rand_normal(0.0, scale_w1))
454            .collect();
455
456        let w2_data: Vec<f64> = (0..self.hidden_size * self.vocab_size)
457            .map(|_| rand_normal(0.0, scale_w2))
458            .collect();
459
460        self.w1 = Some(
461            Array2::from_shape_vec((self.vocab_size, self.hidden_size), w1_data)
462                .unwrap_or_else(|_| Array2::zeros((self.vocab_size, self.hidden_size))),
463        );
464
465        self.w2 = Some(
466            Array2::from_shape_vec((self.hidden_size, self.vocab_size), w2_data)
467                .unwrap_or_else(|_| Array2::zeros((self.hidden_size, self.vocab_size))),
468        );
469
470        self.b1 = Some(Array1::zeros(self.hidden_size));
471        self.b2 = Some(Array1::zeros(self.vocab_size));
472    }
473
474    /// Forward pass
475    fn forward(&self, input_idx: usize) -> Option<Array1<f64>> {
476        if input_idx >= self.vocab_size {
477            return None;
478        }
479
480        let w1 = self.w1.as_ref()?;
481        let w2 = self.w2.as_ref()?;
482        let b1 = self.b1.as_ref()?;
483        let b2 = self.b2.as_ref()?;
484
485        // One-hot encoding
486        let mut input = Array1::zeros(self.vocab_size);
487        input[input_idx] = 1.0;
488
489        // Hidden layer with ReLU
490        let hidden = w1.t().dot(&input) + b1;
491        let hidden_activated = hidden.mapv(|x| x.max(0.0));
492
493        // Output layer with softmax
494        let output = w2.t().dot(&hidden_activated) + b2;
495        let output_exp = output.mapv(|x| x.exp());
496        let sum_exp: f64 = output_exp.sum();
497
498        Some(output_exp / sum_exp)
499    }
500
501    /// Record access (for training)
502    pub fn record_access(&mut self, key: CacheKey) {
503        let _idx = self.add_to_vocab(&key);
504        // Training would happen here if we implement backpropagation
505    }
506
507    /// Predict next keys
508    pub fn predict(&mut self, current_key: &CacheKey, top_n: usize) -> Vec<Prediction> {
509        if let Some(&idx) = self.key_to_idx.get(current_key) {
510            if let Some(output) = self.forward(idx) {
511                let mut predictions: Vec<_> = output
512                    .iter()
513                    .enumerate()
514                    .map(|(i, &prob)| Prediction {
515                        key: self.idx_to_key.get(i).cloned().unwrap_or_default(),
516                        confidence: prob,
517                        predicted_time: None,
518                    })
519                    .collect();
520
521                predictions.sort_by(|a, b| {
522                    b.confidence
523                        .partial_cmp(&a.confidence)
524                        .unwrap_or(std::cmp::Ordering::Equal)
525                });
526
527                predictions.truncate(top_n);
528                return predictions;
529            }
530        }
531
532        Vec::new()
533    }
534
535    /// Clear model
536    pub fn clear(&mut self) {
537        self.w1 = None;
538        self.w2 = None;
539        self.b1 = None;
540        self.b2 = None;
541        self.key_to_idx.clear();
542        self.idx_to_key.clear();
543        self.vocab_size = 0;
544    }
545}
546
547/// Ensemble predictor combining multiple prediction methods
548pub struct EnsemblePredictor {
549    /// Markov predictor
550    markov: Arc<RwLock<MarkovPredictor>>,
551    /// Temporal predictor
552    temporal: Arc<RwLock<TemporalPatternDetector>>,
553    /// Spatial predictor
554    spatial: Arc<RwLock<SpatialPatternDetector>>,
555    /// Neural predictor
556    neural: Arc<RwLock<NeuralPredictor>>,
557    /// Confidence threshold for prefetching
558    confidence_threshold: f64,
559}
560
561impl EnsemblePredictor {
562    /// Create new ensemble predictor
563    pub fn new() -> Self {
564        Self {
565            markov: Arc::new(RwLock::new(MarkovPredictor::new(2))),
566            temporal: Arc::new(RwLock::new(TemporalPatternDetector::new(1000))),
567            spatial: Arc::new(RwLock::new(SpatialPatternDetector::new(10))),
568            neural: Arc::new(RwLock::new(NeuralPredictor::new(64))),
569            confidence_threshold: 0.5,
570        }
571    }
572
573    /// Set confidence threshold
574    pub fn with_threshold(mut self, threshold: f64) -> Self {
575        self.confidence_threshold = threshold;
576        self
577    }
578
579    /// Record an access
580    pub async fn record_access(&self, record: AccessRecord) {
581        let mut markov = self.markov.write().await;
582        markov.record_access(record.key.clone());
583        drop(markov);
584
585        let mut temporal = self.temporal.write().await;
586        temporal.record_access(record.key.clone(), record.timestamp);
587        drop(temporal);
588
589        let mut spatial = self.spatial.write().await;
590        spatial.record_access(record.key.clone());
591        drop(spatial);
592
593        let mut neural = self.neural.write().await;
594        neural.record_access(record.key);
595    }
596
597    /// Predict next keys to prefetch
598    pub async fn predict(&self, current_key: &CacheKey, top_n: usize) -> Vec<Prediction> {
599        let mut all_predictions = Vec::new();
600
601        // Get predictions from Markov
602        let markov = self.markov.read().await;
603        let markov_predictions = markov.predict(top_n);
604        all_predictions.extend(markov_predictions);
605        drop(markov);
606
607        // Get predictions from temporal
608        let temporal = self.temporal.read().await;
609        if let Some(temporal_pred) = temporal.predict(current_key) {
610            all_predictions.push(temporal_pred);
611        }
612        drop(temporal);
613
614        // Get predictions from spatial
615        let spatial = self.spatial.read().await;
616        let spatial_predictions = spatial.get_related_keys(current_key, top_n);
617        all_predictions.extend(spatial_predictions);
618        drop(spatial);
619
620        // Aggregate predictions by key
621        let mut aggregated: HashMap<CacheKey, Vec<f64>> = HashMap::new();
622        for pred in all_predictions {
623            aggregated
624                .entry(pred.key.clone())
625                .or_default()
626                .push(pred.confidence);
627        }
628
629        // Average confidences and filter by threshold
630        let mut final_predictions: Vec<_> = aggregated
631            .into_iter()
632            .map(|(key, confidences)| {
633                let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
634                Prediction {
635                    key,
636                    confidence: avg_confidence,
637                    predicted_time: None,
638                }
639            })
640            .filter(|p| p.confidence >= self.confidence_threshold)
641            .collect();
642
643        final_predictions.sort_by(|a, b| {
644            b.confidence
645                .partial_cmp(&a.confidence)
646                .unwrap_or(std::cmp::Ordering::Equal)
647        });
648
649        final_predictions.truncate(top_n);
650        final_predictions
651    }
652
653    /// Clear all predictors
654    pub async fn clear(&self) {
655        self.markov.write().await.clear();
656        self.temporal.write().await.clear();
657        self.spatial.write().await.clear();
658        self.neural.write().await.clear();
659    }
660}
661
662impl Default for EnsemblePredictor {
663    fn default() -> Self {
664        Self::new()
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[test]
673    fn test_markov_predictor() {
674        let mut predictor = MarkovPredictor::new(1);
675
676        predictor.record_access("A".to_string());
677        predictor.record_access("B".to_string());
678        predictor.record_access("A".to_string());
679        predictor.record_access("B".to_string());
680
681        let predictions = predictor.predict(3);
682        assert!(!predictions.is_empty());
683    }
684
685    #[test]
686    fn test_temporal_pattern_detector() {
687        let mut detector = TemporalPatternDetector::new(100);
688
689        let now = chrono::Utc::now();
690        detector.record_access("A".to_string(), now);
691        detector.record_access("A".to_string(), now + chrono::Duration::seconds(10));
692        detector.record_access("A".to_string(), now + chrono::Duration::seconds(20));
693
694        let prediction = detector.predict(&"A".to_string());
695        assert!(prediction.is_some());
696    }
697
698    #[test]
699    fn test_spatial_pattern_detector() {
700        let mut detector = SpatialPatternDetector::new(5);
701
702        detector.record_access("A".to_string());
703        detector.record_access("B".to_string());
704        detector.record_access("C".to_string());
705        detector.record_access("A".to_string());
706        detector.record_access("B".to_string());
707
708        let related = detector.get_related_keys(&"A".to_string(), 3);
709        assert!(!related.is_empty());
710    }
711
712    #[tokio::test]
713    async fn test_ensemble_predictor() {
714        let predictor = EnsemblePredictor::new();
715
716        let now = chrono::Utc::now();
717        predictor
718            .record_access(AccessRecord {
719                key: "A".to_string(),
720                timestamp: now,
721                access_type: AccessType::Read,
722            })
723            .await;
724
725        predictor
726            .record_access(AccessRecord {
727                key: "B".to_string(),
728                timestamp: now + chrono::Duration::seconds(1),
729                access_type: AccessType::Read,
730            })
731            .await;
732
733        let predictions = predictor.predict(&"A".to_string(), 5).await;
734        // May or may not have predictions depending on pattern strength
735        assert!(predictions.len() <= 5);
736    }
737}