oxigdal_cache_advanced/predictive/
mod.rs1pub mod advanced;
12
13use crate::multi_tier::CacheKey;
14use scirs2_core::ndarray::{Array1, Array2};
15use std::collections::{HashMap, VecDeque};
16use std::sync::Arc;
17use tokio::sync::RwLock;
18
19fn rand_normal(mean: f64, std_dev: f64) -> f64 {
21 let u1 = fastrand::f64();
22 let u2 = fastrand::f64();
23 let u1 = if u1 < 1e-10 { 1e-10 } else { u1 };
25 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
26 mean + z0 * std_dev
27}
28
29#[derive(Debug, Clone)]
31pub struct AccessRecord {
32 pub key: CacheKey,
34 pub timestamp: chrono::DateTime<chrono::Utc>,
36 pub access_type: AccessType,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum AccessType {
43 Read,
45 Write,
47}
48
49#[derive(Debug, Clone)]
51pub struct Prediction {
52 pub key: CacheKey,
54 pub confidence: f64,
56 pub predicted_time: Option<chrono::DateTime<chrono::Utc>>,
58}
59
60impl Prediction {
61 pub fn is_confident(&self, threshold: f64) -> bool {
63 self.confidence >= threshold
64 }
65}
66
67pub struct MarkovPredictor {
70 transitions: HashMap<CacheKey, HashMap<CacheKey, f64>>,
72 transition_counts: HashMap<CacheKey, HashMap<CacheKey, u64>>,
74 current_key: Option<CacheKey>,
76 order: usize,
78 history: VecDeque<CacheKey>,
80}
81
82impl MarkovPredictor {
83 pub fn new(order: usize) -> Self {
85 Self {
86 transitions: HashMap::new(),
87 transition_counts: HashMap::new(),
88 current_key: None,
89 order,
90 history: VecDeque::with_capacity(order),
91 }
92 }
93
94 pub fn record_access(&mut self, key: CacheKey) {
96 if let Some(prev_key) = self.current_key.clone() {
97 let next_counts = self.transition_counts.entry(prev_key.clone()).or_default();
99
100 *next_counts.entry(key.clone()).or_insert(0) += 1;
101
102 self.update_probabilities(&prev_key);
104 }
105
106 if self.history.len() >= self.order {
108 self.history.pop_front();
109 }
110 self.history.push_back(key.clone());
111
112 self.current_key = Some(key);
113 }
114
115 fn update_probabilities(&mut self, from_key: &CacheKey) {
117 if let Some(counts) = self.transition_counts.get(from_key) {
118 let total: u64 = counts.values().sum();
119
120 if total > 0 {
121 let probabilities: HashMap<CacheKey, f64> = counts
122 .iter()
123 .map(|(k, count)| (k.clone(), *count as f64 / total as f64))
124 .collect();
125
126 self.transitions.insert(from_key.clone(), probabilities);
127 }
128 }
129 }
130
131 pub fn predict(&self, top_n: usize) -> Vec<Prediction> {
133 if let Some(current) = &self.current_key {
134 if let Some(transitions) = self.transitions.get(current) {
135 let mut predictions: Vec<_> = transitions
136 .iter()
137 .map(|(key, prob)| Prediction {
138 key: key.clone(),
139 confidence: *prob,
140 predicted_time: None,
141 })
142 .collect();
143
144 predictions.sort_by(|a, b| {
145 b.confidence
146 .partial_cmp(&a.confidence)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149
150 predictions.truncate(top_n);
151 return predictions;
152 }
153 }
154
155 Vec::new()
156 }
157
158 pub fn state_count(&self) -> usize {
160 self.transitions.len()
161 }
162
163 pub fn clear(&mut self) {
165 self.transitions.clear();
166 self.transition_counts.clear();
167 self.current_key = None;
168 self.history.clear();
169 }
170}
171
172pub struct TemporalPatternDetector {
175 access_history: VecDeque<(CacheKey, chrono::DateTime<chrono::Utc>)>,
177 max_history: usize,
179 patterns: HashMap<CacheKey, Vec<i64>>,
181}
182
183impl TemporalPatternDetector {
184 pub fn new(max_history: usize) -> Self {
186 Self {
187 access_history: VecDeque::with_capacity(max_history),
188 max_history,
189 patterns: HashMap::new(),
190 }
191 }
192
193 pub fn record_access(&mut self, key: CacheKey, timestamp: chrono::DateTime<chrono::Utc>) {
195 if self.access_history.len() >= self.max_history {
196 self.access_history.pop_front();
197 }
198
199 self.access_history.push_back((key.clone(), timestamp));
200
201 self.detect_pattern(&key);
203 }
204
205 fn detect_pattern(&mut self, key: &CacheKey) {
207 let accesses: Vec<_> = self
208 .access_history
209 .iter()
210 .filter(|(k, _)| k == key)
211 .map(|(_, ts)| *ts)
212 .collect();
213
214 if accesses.len() < 3 {
215 return;
216 }
217
218 let mut intervals = Vec::new();
220 for i in 1..accesses.len() {
221 let interval = (accesses[i] - accesses[i - 1]).num_seconds();
222 intervals.push(interval);
223 }
224
225 self.patterns.insert(key.clone(), intervals);
227 }
228
229 pub fn predict_next_access(&self, key: &CacheKey) -> Option<chrono::DateTime<chrono::Utc>> {
231 if let Some(intervals) = self.patterns.get(key) {
232 if intervals.is_empty() {
233 return None;
234 }
235
236 let mut sorted_intervals = intervals.clone();
238 sorted_intervals.sort();
239 let median_interval = sorted_intervals[sorted_intervals.len() / 2];
240
241 let last_access = self
243 .access_history
244 .iter()
245 .rev()
246 .find(|(k, _)| k == key)
247 .map(|(_, ts)| *ts);
248
249 if let Some(last) = last_access {
250 return Some(last + chrono::Duration::seconds(median_interval));
251 }
252 }
253
254 None
255 }
256
257 pub fn predict(&self, key: &CacheKey) -> Option<Prediction> {
259 if let Some(next_time) = self.predict_next_access(key) {
260 let intervals = self.patterns.get(key)?;
261
262 let confidence = if intervals.len() < 2 {
264 0.5
265 } else {
266 let mean: f64 =
267 intervals.iter().map(|&x| x as f64).sum::<f64>() / intervals.len() as f64;
268 let variance: f64 = intervals
269 .iter()
270 .map(|&x| {
271 let diff = x as f64 - mean;
272 diff * diff
273 })
274 .sum::<f64>()
275 / intervals.len() as f64;
276
277 let std_dev = variance.sqrt();
278 let cv = if mean > 0.0 { std_dev / mean } else { 1.0 };
279
280 (1.0 / (1.0 + cv)).clamp(0.0, 1.0)
282 };
283
284 Some(Prediction {
285 key: key.clone(),
286 confidence,
287 predicted_time: Some(next_time),
288 })
289 } else {
290 None
291 }
292 }
293
294 pub fn clear(&mut self) {
296 self.access_history.clear();
297 self.patterns.clear();
298 }
299}
300
301pub struct SpatialPatternDetector {
304 co_occurrences: HashMap<CacheKey, HashMap<CacheKey, u64>>,
306 window: VecDeque<CacheKey>,
308 window_size: usize,
310}
311
312impl SpatialPatternDetector {
313 pub fn new(window_size: usize) -> Self {
315 Self {
316 co_occurrences: HashMap::new(),
317 window: VecDeque::with_capacity(window_size),
318 window_size,
319 }
320 }
321
322 pub fn record_access(&mut self, key: CacheKey) {
324 for other_key in &self.window {
326 let co_occurs = self.co_occurrences.entry(key.clone()).or_default();
328 *co_occurs.entry(other_key.clone()).or_insert(0) += 1;
329
330 let co_occurs_reverse = self.co_occurrences.entry(other_key.clone()).or_default();
332 *co_occurs_reverse.entry(key.clone()).or_insert(0) += 1;
333 }
334
335 if self.window.len() >= self.window_size {
337 self.window.pop_front();
338 }
339 self.window.push_back(key);
340 }
341
342 pub fn get_related_keys(&self, key: &CacheKey, top_n: usize) -> Vec<Prediction> {
344 if let Some(co_occurs) = self.co_occurrences.get(key) {
345 let total: u64 = co_occurs.values().sum();
346
347 if total == 0 {
348 return Vec::new();
349 }
350
351 let mut predictions: Vec<_> = co_occurs
352 .iter()
353 .map(|(k, count)| Prediction {
354 key: k.clone(),
355 confidence: *count as f64 / total as f64,
356 predicted_time: None,
357 })
358 .collect();
359
360 predictions.sort_by(|a, b| {
361 b.confidence
362 .partial_cmp(&a.confidence)
363 .unwrap_or(std::cmp::Ordering::Equal)
364 });
365
366 predictions.truncate(top_n);
367 predictions
368 } else {
369 Vec::new()
370 }
371 }
372
373 pub fn clear(&mut self) {
375 self.co_occurrences.clear();
376 self.window.clear();
377 }
378}
379
380pub struct NeuralPredictor {
383 vocab_size: usize,
385 hidden_size: usize,
387 w1: Option<Array2<f64>>,
389 w2: Option<Array2<f64>>,
391 b1: Option<Array1<f64>>,
393 b2: Option<Array1<f64>>,
395 key_to_idx: HashMap<CacheKey, usize>,
397 idx_to_key: Vec<CacheKey>,
399 #[allow(dead_code)]
401 learning_rate: f64,
402 #[allow(dead_code)]
404 training_enabled: bool,
405}
406
407impl NeuralPredictor {
408 pub fn new(hidden_size: usize) -> Self {
410 Self {
411 vocab_size: 0,
412 hidden_size,
413 w1: None,
414 w2: None,
415 b1: None,
416 b2: None,
417 key_to_idx: HashMap::new(),
418 idx_to_key: Vec::new(),
419 learning_rate: 0.01,
420 training_enabled: true,
421 }
422 }
423
424 fn add_to_vocab(&mut self, key: &CacheKey) -> usize {
426 if let Some(&idx) = self.key_to_idx.get(key) {
427 idx
428 } else {
429 let idx = self.vocab_size;
430 self.key_to_idx.insert(key.clone(), idx);
431 self.idx_to_key.push(key.clone());
432 self.vocab_size += 1;
433
434 if self.vocab_size > 0 {
436 self.initialize_weights();
437 }
438
439 idx
440 }
441 }
442
443 fn initialize_weights(&mut self) {
445 fastrand::seed(42);
447
448 let scale_w1 = (2.0 / (self.vocab_size + self.hidden_size) as f64).sqrt();
450 let scale_w2 = (2.0 / (self.hidden_size + self.vocab_size) as f64).sqrt();
451
452 let w1_data: Vec<f64> = (0..self.vocab_size * self.hidden_size)
453 .map(|_| rand_normal(0.0, scale_w1))
454 .collect();
455
456 let w2_data: Vec<f64> = (0..self.hidden_size * self.vocab_size)
457 .map(|_| rand_normal(0.0, scale_w2))
458 .collect();
459
460 self.w1 = Some(
461 Array2::from_shape_vec((self.vocab_size, self.hidden_size), w1_data)
462 .unwrap_or_else(|_| Array2::zeros((self.vocab_size, self.hidden_size))),
463 );
464
465 self.w2 = Some(
466 Array2::from_shape_vec((self.hidden_size, self.vocab_size), w2_data)
467 .unwrap_or_else(|_| Array2::zeros((self.hidden_size, self.vocab_size))),
468 );
469
470 self.b1 = Some(Array1::zeros(self.hidden_size));
471 self.b2 = Some(Array1::zeros(self.vocab_size));
472 }
473
474 fn forward(&self, input_idx: usize) -> Option<Array1<f64>> {
476 if input_idx >= self.vocab_size {
477 return None;
478 }
479
480 let w1 = self.w1.as_ref()?;
481 let w2 = self.w2.as_ref()?;
482 let b1 = self.b1.as_ref()?;
483 let b2 = self.b2.as_ref()?;
484
485 let mut input = Array1::zeros(self.vocab_size);
487 input[input_idx] = 1.0;
488
489 let hidden = w1.t().dot(&input) + b1;
491 let hidden_activated = hidden.mapv(|x| x.max(0.0));
492
493 let output = w2.t().dot(&hidden_activated) + b2;
495 let output_exp = output.mapv(|x| x.exp());
496 let sum_exp: f64 = output_exp.sum();
497
498 Some(output_exp / sum_exp)
499 }
500
501 pub fn record_access(&mut self, key: CacheKey) {
503 let _idx = self.add_to_vocab(&key);
504 }
506
507 pub fn predict(&mut self, current_key: &CacheKey, top_n: usize) -> Vec<Prediction> {
509 if let Some(&idx) = self.key_to_idx.get(current_key) {
510 if let Some(output) = self.forward(idx) {
511 let mut predictions: Vec<_> = output
512 .iter()
513 .enumerate()
514 .map(|(i, &prob)| Prediction {
515 key: self.idx_to_key.get(i).cloned().unwrap_or_default(),
516 confidence: prob,
517 predicted_time: None,
518 })
519 .collect();
520
521 predictions.sort_by(|a, b| {
522 b.confidence
523 .partial_cmp(&a.confidence)
524 .unwrap_or(std::cmp::Ordering::Equal)
525 });
526
527 predictions.truncate(top_n);
528 return predictions;
529 }
530 }
531
532 Vec::new()
533 }
534
535 pub fn clear(&mut self) {
537 self.w1 = None;
538 self.w2 = None;
539 self.b1 = None;
540 self.b2 = None;
541 self.key_to_idx.clear();
542 self.idx_to_key.clear();
543 self.vocab_size = 0;
544 }
545}
546
547pub struct EnsemblePredictor {
549 markov: Arc<RwLock<MarkovPredictor>>,
551 temporal: Arc<RwLock<TemporalPatternDetector>>,
553 spatial: Arc<RwLock<SpatialPatternDetector>>,
555 neural: Arc<RwLock<NeuralPredictor>>,
557 confidence_threshold: f64,
559}
560
561impl EnsemblePredictor {
562 pub fn new() -> Self {
564 Self {
565 markov: Arc::new(RwLock::new(MarkovPredictor::new(2))),
566 temporal: Arc::new(RwLock::new(TemporalPatternDetector::new(1000))),
567 spatial: Arc::new(RwLock::new(SpatialPatternDetector::new(10))),
568 neural: Arc::new(RwLock::new(NeuralPredictor::new(64))),
569 confidence_threshold: 0.5,
570 }
571 }
572
573 pub fn with_threshold(mut self, threshold: f64) -> Self {
575 self.confidence_threshold = threshold;
576 self
577 }
578
579 pub async fn record_access(&self, record: AccessRecord) {
581 let mut markov = self.markov.write().await;
582 markov.record_access(record.key.clone());
583 drop(markov);
584
585 let mut temporal = self.temporal.write().await;
586 temporal.record_access(record.key.clone(), record.timestamp);
587 drop(temporal);
588
589 let mut spatial = self.spatial.write().await;
590 spatial.record_access(record.key.clone());
591 drop(spatial);
592
593 let mut neural = self.neural.write().await;
594 neural.record_access(record.key);
595 }
596
597 pub async fn predict(&self, current_key: &CacheKey, top_n: usize) -> Vec<Prediction> {
599 let mut all_predictions = Vec::new();
600
601 let markov = self.markov.read().await;
603 let markov_predictions = markov.predict(top_n);
604 all_predictions.extend(markov_predictions);
605 drop(markov);
606
607 let temporal = self.temporal.read().await;
609 if let Some(temporal_pred) = temporal.predict(current_key) {
610 all_predictions.push(temporal_pred);
611 }
612 drop(temporal);
613
614 let spatial = self.spatial.read().await;
616 let spatial_predictions = spatial.get_related_keys(current_key, top_n);
617 all_predictions.extend(spatial_predictions);
618 drop(spatial);
619
620 let mut aggregated: HashMap<CacheKey, Vec<f64>> = HashMap::new();
622 for pred in all_predictions {
623 aggregated
624 .entry(pred.key.clone())
625 .or_default()
626 .push(pred.confidence);
627 }
628
629 let mut final_predictions: Vec<_> = aggregated
631 .into_iter()
632 .map(|(key, confidences)| {
633 let avg_confidence = confidences.iter().sum::<f64>() / confidences.len() as f64;
634 Prediction {
635 key,
636 confidence: avg_confidence,
637 predicted_time: None,
638 }
639 })
640 .filter(|p| p.confidence >= self.confidence_threshold)
641 .collect();
642
643 final_predictions.sort_by(|a, b| {
644 b.confidence
645 .partial_cmp(&a.confidence)
646 .unwrap_or(std::cmp::Ordering::Equal)
647 });
648
649 final_predictions.truncate(top_n);
650 final_predictions
651 }
652
653 pub async fn clear(&self) {
655 self.markov.write().await.clear();
656 self.temporal.write().await.clear();
657 self.spatial.write().await.clear();
658 self.neural.write().await.clear();
659 }
660}
661
662impl Default for EnsemblePredictor {
663 fn default() -> Self {
664 Self::new()
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn test_markov_predictor() {
674 let mut predictor = MarkovPredictor::new(1);
675
676 predictor.record_access("A".to_string());
677 predictor.record_access("B".to_string());
678 predictor.record_access("A".to_string());
679 predictor.record_access("B".to_string());
680
681 let predictions = predictor.predict(3);
682 assert!(!predictions.is_empty());
683 }
684
685 #[test]
686 fn test_temporal_pattern_detector() {
687 let mut detector = TemporalPatternDetector::new(100);
688
689 let now = chrono::Utc::now();
690 detector.record_access("A".to_string(), now);
691 detector.record_access("A".to_string(), now + chrono::Duration::seconds(10));
692 detector.record_access("A".to_string(), now + chrono::Duration::seconds(20));
693
694 let prediction = detector.predict(&"A".to_string());
695 assert!(prediction.is_some());
696 }
697
698 #[test]
699 fn test_spatial_pattern_detector() {
700 let mut detector = SpatialPatternDetector::new(5);
701
702 detector.record_access("A".to_string());
703 detector.record_access("B".to_string());
704 detector.record_access("C".to_string());
705 detector.record_access("A".to_string());
706 detector.record_access("B".to_string());
707
708 let related = detector.get_related_keys(&"A".to_string(), 3);
709 assert!(!related.is_empty());
710 }
711
712 #[tokio::test]
713 async fn test_ensemble_predictor() {
714 let predictor = EnsemblePredictor::new();
715
716 let now = chrono::Utc::now();
717 predictor
718 .record_access(AccessRecord {
719 key: "A".to_string(),
720 timestamp: now,
721 access_type: AccessType::Read,
722 })
723 .await;
724
725 predictor
726 .record_access(AccessRecord {
727 key: "B".to_string(),
728 timestamp: now + chrono::Duration::seconds(1),
729 access_type: AccessType::Read,
730 })
731 .await;
732
733 let predictions = predictor.predict(&"A".to_string(), 5).await;
734 assert!(predictions.len() <= 5);
736 }
737}