Skip to main content

oxigdal_cache_advanced/predictive/
advanced.rs

1//! Advanced ML-based prediction models
2//!
3//! Enhanced prediction models with:
4//! - Transformer-based attention mechanism
5//! - LSTM for temporal sequences
6//! - Hybrid predictor combining multiple models
7//! - Online learning and adaptation
8//! - Model selection based on data characteristics
9//! - Prediction confidence calibration
10
11use crate::error::{CacheError, Result};
12use crate::multi_tier::CacheKey;
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use std::collections::{HashMap, VecDeque};
15
16/// Generate normal distributed random number using Box-Muller transform
17fn rand_normal(mean: f64, std_dev: f64) -> f64 {
18    let u1 = fastrand::f64();
19    let u2 = fastrand::f64();
20    // Avoid log(0) by ensuring u1 > 0
21    let u1 = if u1 < 1e-10 { 1e-10 } else { u1 };
22    let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
23    mean + z0 * std_dev
24}
25
26/// Transformer-based predictor using attention mechanism
27pub struct TransformerPredictor {
28    /// Embedding dimension
29    embedding_dim: usize,
30    /// Number of attention heads
31    #[allow(dead_code)]
32    num_heads: usize,
33    /// Sequence length
34    seq_length: usize,
35    /// Query weights
36    w_query: Option<Array2<f64>>,
37    /// Key weights
38    w_key: Option<Array2<f64>>,
39    /// Value weights
40    w_value: Option<Array2<f64>>,
41    /// Output projection weights
42    w_output: Option<Array2<f64>>,
43    /// Key to index mapping
44    key_to_idx: HashMap<CacheKey, usize>,
45    /// Index to key mapping
46    idx_to_key: Vec<CacheKey>,
47    /// Recent access sequence
48    sequence: VecDeque<usize>,
49    /// Vocabulary size
50    vocab_size: usize,
51}
52
53impl TransformerPredictor {
54    /// Create new transformer predictor
55    pub fn new(embedding_dim: usize, num_heads: usize, seq_length: usize) -> Self {
56        Self {
57            embedding_dim,
58            num_heads,
59            seq_length,
60            w_query: None,
61            w_key: None,
62            w_value: None,
63            w_output: None,
64            key_to_idx: HashMap::new(),
65            idx_to_key: Vec::new(),
66            sequence: VecDeque::with_capacity(seq_length),
67            vocab_size: 0,
68        }
69    }
70
71    /// Initialize weights
72    fn initialize_weights(&mut self) {
73        // Seed fastrand for reproducibility
74        fastrand::seed(42);
75        let scale = (2.0 / self.embedding_dim as f64).sqrt();
76
77        let q_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
78            .map(|_| rand_normal(0.0, scale))
79            .collect();
80
81        let k_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
82            .map(|_| rand_normal(0.0, scale))
83            .collect();
84
85        let v_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
86            .map(|_| rand_normal(0.0, scale))
87            .collect();
88
89        let o_data: Vec<f64> = (0..self.embedding_dim * self.embedding_dim)
90            .map(|_| rand_normal(0.0, scale))
91            .collect();
92
93        self.w_query = Some(
94            Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), q_data)
95                .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
96        );
97
98        self.w_key = Some(
99            Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), k_data)
100                .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
101        );
102
103        self.w_value = Some(
104            Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), v_data)
105                .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
106        );
107
108        self.w_output = Some(
109            Array2::from_shape_vec((self.embedding_dim, self.embedding_dim), o_data)
110                .unwrap_or_else(|_| Array2::zeros((self.embedding_dim, self.embedding_dim))),
111        );
112    }
113
114    /// Add key to vocabulary
115    fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
116        if let Some(&idx) = self.key_to_idx.get(key) {
117            idx
118        } else {
119            let idx = self.vocab_size;
120            self.key_to_idx.insert(key.clone(), idx);
121            self.idx_to_key.push(key.clone());
122            self.vocab_size += 1;
123
124            if self.w_query.is_none() {
125                self.initialize_weights();
126            }
127
128            idx
129        }
130    }
131
132    /// Compute multi-head attention
133    fn attention(
134        &self,
135        query: &Array2<f64>,
136        key: &Array2<f64>,
137        value: &Array2<f64>,
138    ) -> Result<Array2<f64>> {
139        let w_q = self
140            .w_query
141            .as_ref()
142            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
143        let w_k = self
144            .w_key
145            .as_ref()
146            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
147        let w_v = self
148            .w_value
149            .as_ref()
150            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
151        let w_o = self
152            .w_output
153            .as_ref()
154            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
155
156        // Project to Q, K, V
157        let q_proj = query.dot(w_q);
158        let k_proj = key.dot(w_k);
159        let v_proj = value.dot(w_v);
160
161        // Compute attention scores
162        let scores = q_proj.dot(&k_proj.t()) / (self.embedding_dim as f64).sqrt();
163
164        // Apply softmax
165        let scores_exp = scores.mapv(|x| x.exp());
166        let scores_sum = scores_exp.sum_axis(Axis(1));
167        let attention_weights = &scores_exp / &scores_sum.insert_axis(Axis(1));
168
169        // Apply attention to values
170        let attended = attention_weights.dot(&v_proj);
171
172        // Output projection
173        Ok(attended.dot(w_o))
174    }
175
176    /// Record access
177    pub fn record_access(&mut self, key: CacheKey) {
178        let idx = self.add_to_vocab(&key);
179
180        if self.sequence.len() >= self.seq_length {
181            self.sequence.pop_front();
182        }
183        self.sequence.push_back(idx);
184    }
185
186    /// Predict next keys
187    pub fn predict(&self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
188        if self.sequence.is_empty() {
189            return Ok(Vec::new());
190        }
191
192        // Create embedding matrix for sequence
193        let mut embeddings = Array2::zeros((self.sequence.len(), self.embedding_dim));
194        for (i, &idx) in self.sequence.iter().enumerate() {
195            // Simple one-hot-like embedding
196            if idx < self.embedding_dim {
197                embeddings[[i, idx]] = 1.0;
198            }
199        }
200
201        // Compute self-attention
202        let output = self.attention(&embeddings, &embeddings, &embeddings)?;
203
204        // Use last output for prediction
205        let last_output = output.row(output.nrows() - 1);
206
207        // Compute scores for all vocabulary items
208        let mut scores: Vec<(CacheKey, f64)> = self
209            .idx_to_key
210            .iter()
211            .enumerate()
212            .map(|(idx, key)| {
213                let score = if idx < last_output.len() {
214                    last_output[idx]
215                } else {
216                    0.0
217                };
218                (key.clone(), score)
219            })
220            .collect();
221
222        // Sort by score
223        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
224        scores.truncate(top_n);
225
226        // Normalize to probabilities
227        let sum: f64 = scores.iter().map(|(_, s)| s.exp()).sum();
228        if sum > 0.0 {
229            for (_, score) in &mut scores {
230                *score = score.exp() / sum;
231            }
232        }
233
234        Ok(scores)
235    }
236
237    /// Clear predictor
238    pub fn clear(&mut self) {
239        self.sequence.clear();
240        self.key_to_idx.clear();
241        self.idx_to_key.clear();
242        self.vocab_size = 0;
243        self.w_query = None;
244        self.w_key = None;
245        self.w_value = None;
246        self.w_output = None;
247    }
248}
249
250/// LSTM-based predictor for temporal sequences
251pub struct LSTMPredictor {
252    /// Hidden size
253    hidden_size: usize,
254    /// Input size (vocabulary size)
255    vocab_size: usize,
256    /// Forget gate weights
257    w_forget: Option<Array2<f64>>,
258    /// Input gate weights
259    w_input: Option<Array2<f64>>,
260    /// Output gate weights
261    w_output: Option<Array2<f64>>,
262    /// Cell state weights
263    w_cell: Option<Array2<f64>>,
264    /// Hidden state
265    hidden_state: Option<Array1<f64>>,
266    /// Cell state
267    cell_state: Option<Array1<f64>>,
268    /// Key to index mapping
269    key_to_idx: HashMap<CacheKey, usize>,
270    /// Index to key mapping
271    idx_to_key: Vec<CacheKey>,
272}
273
274impl LSTMPredictor {
275    /// Create new LSTM predictor
276    pub fn new(hidden_size: usize) -> Self {
277        Self {
278            hidden_size,
279            vocab_size: 0,
280            w_forget: None,
281            w_input: None,
282            w_output: None,
283            w_cell: None,
284            hidden_state: None,
285            cell_state: None,
286            key_to_idx: HashMap::new(),
287            idx_to_key: Vec::new(),
288        }
289    }
290
291    /// Initialize weights
292    fn initialize_weights(&mut self) {
293        // Seed fastrand for reproducibility
294        fastrand::seed(42);
295        let input_size = self.vocab_size + self.hidden_size;
296        let scale = (2.0 / input_size as f64).sqrt();
297
298        let wf_data: Vec<f64> = (0..input_size * self.hidden_size)
299            .map(|_| rand_normal(0.0, scale))
300            .collect();
301
302        let wi_data: Vec<f64> = (0..input_size * self.hidden_size)
303            .map(|_| rand_normal(0.0, scale))
304            .collect();
305
306        let wo_data: Vec<f64> = (0..input_size * self.hidden_size)
307            .map(|_| rand_normal(0.0, scale))
308            .collect();
309
310        let wc_data: Vec<f64> = (0..input_size * self.hidden_size)
311            .map(|_| rand_normal(0.0, scale))
312            .collect();
313
314        self.w_forget = Some(
315            Array2::from_shape_vec((input_size, self.hidden_size), wf_data)
316                .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
317        );
318
319        self.w_input = Some(
320            Array2::from_shape_vec((input_size, self.hidden_size), wi_data)
321                .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
322        );
323
324        self.w_output = Some(
325            Array2::from_shape_vec((input_size, self.hidden_size), wo_data)
326                .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
327        );
328
329        self.w_cell = Some(
330            Array2::from_shape_vec((input_size, self.hidden_size), wc_data)
331                .unwrap_or_else(|_| Array2::zeros((input_size, self.hidden_size))),
332        );
333
334        self.hidden_state = Some(Array1::zeros(self.hidden_size));
335        self.cell_state = Some(Array1::zeros(self.hidden_size));
336    }
337
338    /// Add key to vocabulary
339    fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
340        if let Some(&idx) = self.key_to_idx.get(key) {
341            idx
342        } else {
343            let idx = self.vocab_size;
344            self.key_to_idx.insert(key.clone(), idx);
345            self.idx_to_key.push(key.clone());
346            self.vocab_size += 1;
347
348            // Reinitialize weights when vocabulary changes
349            self.initialize_weights();
350
351            idx
352        }
353    }
354
355    /// Sigmoid activation
356    fn sigmoid(x: f64) -> f64 {
357        1.0 / (1.0 + (-x).exp())
358    }
359
360    /// Forward pass through LSTM cell
361    fn forward(&mut self, input_idx: usize) -> Result<Array1<f64>> {
362        let w_f = self
363            .w_forget
364            .as_ref()
365            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
366        let w_i = self
367            .w_input
368            .as_ref()
369            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
370        let w_o = self
371            .w_output
372            .as_ref()
373            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
374        let w_c = self
375            .w_cell
376            .as_ref()
377            .ok_or_else(|| CacheError::Prediction("Weights not initialized".to_string()))?;
378
379        let h_prev = self
380            .hidden_state
381            .as_ref()
382            .ok_or_else(|| CacheError::Prediction("Hidden state not initialized".to_string()))?;
383        let c_prev = self
384            .cell_state
385            .as_ref()
386            .ok_or_else(|| CacheError::Prediction("Cell state not initialized".to_string()))?;
387
388        // One-hot encode input
389        let mut input = Array1::zeros(self.vocab_size);
390        if input_idx < self.vocab_size {
391            input[input_idx] = 1.0;
392        }
393
394        // Concatenate input and hidden state
395        let mut combined = Array1::zeros(self.vocab_size + self.hidden_size);
396        for i in 0..self.vocab_size {
397            combined[i] = input[i];
398        }
399        for i in 0..self.hidden_size {
400            combined[self.vocab_size + i] = h_prev[i];
401        }
402
403        // Compute gates
404        let forget_gate = w_f.t().dot(&combined).mapv(Self::sigmoid);
405        let input_gate = w_i.t().dot(&combined).mapv(Self::sigmoid);
406        let output_gate = w_o.t().dot(&combined).mapv(Self::sigmoid);
407        let cell_candidate = w_c.t().dot(&combined).mapv(|x| x.tanh());
408
409        // Update cell state
410        let new_cell = &forget_gate * c_prev + &input_gate * &cell_candidate;
411
412        // Compute new hidden state
413        let new_hidden = &output_gate * &new_cell.mapv(|x| x.tanh());
414
415        // Update states
416        self.cell_state = Some(new_cell);
417        self.hidden_state = Some(new_hidden.clone());
418
419        Ok(new_hidden)
420    }
421
422    /// Record access
423    pub fn record_access(&mut self, key: CacheKey) -> Result<()> {
424        let idx = self.add_to_vocab(&key);
425        self.forward(idx)?;
426        Ok(())
427    }
428
429    /// Predict next keys
430    pub fn predict(&mut self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
431        let hidden = self
432            .hidden_state
433            .as_ref()
434            .ok_or_else(|| CacheError::Prediction("Not trained".to_string()))?;
435
436        // Use hidden state to score vocabulary items
437        let mut scores: Vec<(CacheKey, f64)> = self
438            .idx_to_key
439            .iter()
440            .enumerate()
441            .map(|(idx, key)| {
442                let score = if idx < hidden.len() { hidden[idx] } else { 0.0 };
443                (key.clone(), score)
444            })
445            .collect();
446
447        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
448        scores.truncate(top_n);
449
450        // Normalize to probabilities
451        let sum: f64 = scores.iter().map(|(_, s)| s.exp()).sum();
452        if sum > 0.0 {
453            for (_, score) in &mut scores {
454                *score = score.exp() / sum;
455            }
456        }
457
458        Ok(scores)
459    }
460
461    /// Reset states
462    pub fn reset(&mut self) {
463        self.hidden_state = Some(Array1::zeros(self.hidden_size));
464        self.cell_state = Some(Array1::zeros(self.hidden_size));
465    }
466
467    /// Clear predictor
468    pub fn clear(&mut self) {
469        self.key_to_idx.clear();
470        self.idx_to_key.clear();
471        self.vocab_size = 0;
472        self.w_forget = None;
473        self.w_input = None;
474        self.w_output = None;
475        self.w_cell = None;
476        self.hidden_state = None;
477        self.cell_state = None;
478    }
479}
480
481/// Hybrid predictor combining multiple models
482pub struct HybridPredictor {
483    /// Transformer predictor
484    transformer: TransformerPredictor,
485    /// LSTM predictor
486    lstm: LSTMPredictor,
487    /// Model weights (learned based on performance)
488    model_weights: HashMap<String, f64>,
489    /// Performance tracking
490    performance_history: VecDeque<(String, f64)>,
491    /// History size for performance tracking
492    history_size: usize,
493}
494
495impl HybridPredictor {
496    /// Create new hybrid predictor
497    pub fn new(embedding_dim: usize, hidden_size: usize, seq_length: usize) -> Self {
498        let mut model_weights = HashMap::new();
499        model_weights.insert("transformer".to_string(), 0.5);
500        model_weights.insert("lstm".to_string(), 0.5);
501
502        Self {
503            transformer: TransformerPredictor::new(embedding_dim, 4, seq_length),
504            lstm: LSTMPredictor::new(hidden_size),
505            model_weights,
506            performance_history: VecDeque::with_capacity(100),
507            history_size: 100,
508        }
509    }
510
511    /// Record access
512    pub fn record_access(&mut self, key: CacheKey) -> Result<()> {
513        self.transformer.record_access(key.clone());
514        self.lstm.record_access(key)?;
515        Ok(())
516    }
517
518    /// Update model weights based on performance
519    fn update_weights(&mut self) {
520        if self.performance_history.len() < 10 {
521            return;
522        }
523
524        let mut model_scores: HashMap<String, f64> = HashMap::new();
525        let mut model_counts: HashMap<String, usize> = HashMap::new();
526
527        for (model, score) in &self.performance_history {
528            *model_scores.entry(model.clone()).or_insert(0.0) += score;
529            *model_counts.entry(model.clone()).or_insert(0) += 1;
530        }
531
532        // Calculate average scores
533        let avg_scores: Vec<(String, f64)> = model_scores
534            .into_iter()
535            .map(|(model, total)| {
536                let count = model_counts.get(&model).copied().unwrap_or(1);
537                (model, total / count as f64)
538            })
539            .collect();
540
541        // Normalize to weights (softmax)
542        let sum: f64 = avg_scores.iter().map(|(_, s)| s.exp()).sum();
543        if sum > 0.0 {
544            for (model, score) in avg_scores {
545                self.model_weights.insert(model, score.exp() / sum);
546            }
547        }
548    }
549
550    /// Predict with model ensemble
551    pub fn predict(&mut self, top_n: usize) -> Result<Vec<(CacheKey, f64)>> {
552        // Get predictions from both models
553        let transformer_preds = self.transformer.predict(top_n)?;
554        let lstm_preds = self.lstm.predict(top_n)?;
555
556        // Combine predictions with weights
557        let mut combined_scores: HashMap<CacheKey, f64> = HashMap::new();
558
559        let transformer_weight = self
560            .model_weights
561            .get("transformer")
562            .copied()
563            .unwrap_or(0.5);
564        let lstm_weight = self.model_weights.get("lstm").copied().unwrap_or(0.5);
565
566        for (key, score) in transformer_preds {
567            *combined_scores.entry(key).or_insert(0.0) += score * transformer_weight;
568        }
569
570        for (key, score) in lstm_preds {
571            *combined_scores.entry(key).or_insert(0.0) += score * lstm_weight;
572        }
573
574        // Sort and return top predictions
575        let mut results: Vec<(CacheKey, f64)> = combined_scores.into_iter().collect();
576        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
577        results.truncate(top_n);
578
579        Ok(results)
580    }
581
582    /// Report prediction accuracy for online learning
583    pub fn report_accuracy(&mut self, model_name: &str, accuracy: f64) {
584        if self.performance_history.len() >= self.history_size {
585            self.performance_history.pop_front();
586        }
587        self.performance_history
588            .push_back((model_name.to_string(), accuracy));
589        self.update_weights();
590    }
591
592    /// Get current model weights
593    pub fn get_weights(&self) -> &HashMap<String, f64> {
594        &self.model_weights
595    }
596
597    /// Clear all predictors
598    pub fn clear(&mut self) {
599        self.transformer.clear();
600        self.lstm.clear();
601        self.performance_history.clear();
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_transformer_predictor() {
611        let mut predictor = TransformerPredictor::new(16, 2, 5);
612
613        predictor.record_access("key1".to_string());
614        predictor.record_access("key2".to_string());
615        predictor.record_access("key3".to_string());
616
617        let result = predictor.predict(3);
618        assert!(result.is_ok());
619    }
620
621    #[test]
622    fn test_lstm_predictor() {
623        let mut predictor = LSTMPredictor::new(32);
624
625        let result = predictor.record_access("key1".to_string());
626        assert!(result.is_ok());
627
628        let result = predictor.record_access("key2".to_string());
629        assert!(result.is_ok());
630
631        let predictions = predictor.predict(3);
632        assert!(predictions.is_ok());
633    }
634
635    #[test]
636    fn test_hybrid_predictor() {
637        let mut predictor = HybridPredictor::new(16, 32, 5);
638
639        let result = predictor.record_access("key1".to_string());
640        assert!(result.is_ok());
641
642        let result = predictor.record_access("key2".to_string());
643        assert!(result.is_ok());
644
645        let predictions = predictor.predict(3);
646        assert!(predictions.is_ok());
647    }
648
649    #[test]
650    fn test_hybrid_online_learning() {
651        let mut predictor = HybridPredictor::new(16, 32, 5);
652
653        // update_weights requires at least 10 entries in performance_history,
654        // so provide enough data points for the weight update to trigger
655        for _ in 0..10 {
656            predictor.report_accuracy("transformer", 0.8);
657            predictor.report_accuracy("lstm", 0.6);
658        }
659
660        let weights = predictor.get_weights();
661        let transformer_weight = weights.get("transformer").copied().unwrap_or(0.0);
662        let lstm_weight = weights.get("lstm").copied().unwrap_or(0.0);
663
664        // Transformer should have higher weight due to better accuracy
665        assert!(transformer_weight > lstm_weight);
666    }
667}