Skip to main content

microscope_memory/
predictive_cache.rs

1//! Predictive Cache — pre-fetches blocks based on ThoughtGraph patterns.
2//!
3//! After each recall, the cache predicts what the *next* query will be
4//! based on recognized thought patterns. It pre-loads the expected result
5//! blocks. If the prediction hits, the pattern gets a positive reward;
6//! misses decay the prediction confidence.
7//!
8//! This is a feedback loop: good patterns → accurate predictions → rewards → stronger patterns.
9//!
10//! Binary format: predictive_cache.bin (PRC1)
11
12use std::fs;
13use std::io::Write;
14use std::path::Path;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17use crate::thought_graph::{ThoughtGraphState, PATTERN_BOOST_WEIGHT};
18
19// ─── Constants ──────────────────────────────────────
20
21const MAX_PREDICTIONS: usize = 50; // max cached predictions
22const MAX_BLOCKS_PER_PREDICTION: usize = 30;
23const HIT_REWARD: f32 = 0.3; // strength reward on cache hit
24const MISS_PENALTY: f32 = 0.05; // strength penalty on miss
25const PREDICTION_DECAY: f32 = 0.98; // per-recall confidence decay
26const MIN_CONFIDENCE: f32 = 0.1; // below this, evict prediction
27const MIN_PATTERN_FREQ: u32 = 3; // only predict from crystallized patterns
28
29// ─── Prediction ─────────────────────────────────────
30
31/// A single prediction: "if next query hash is X, these blocks are likely results."
32#[derive(Clone, Debug)]
33pub struct Prediction {
34    pub predicted_query_hash: u64,
35    pub blocks: Vec<u32>,
36    pub confidence: f32,
37    pub pattern_id: u32,
38    pub created_ms: u64,
39}
40
41// ─── CacheStats ─────────────────────────────────────
42
43#[derive(Clone, Debug, Default)]
44pub struct CacheStats {
45    pub total_predictions: u32,
46    pub total_hits: u32,
47    pub total_misses: u32,
48    pub total_partial_hits: u32,
49    pub current_predictions: usize,
50    pub avg_confidence: f32,
51}
52
53impl CacheStats {
54    pub fn hit_rate(&self) -> f32 {
55        let total = self.total_hits + self.total_misses + self.total_partial_hits;
56        if total == 0 {
57            return 0.0;
58        }
59        (self.total_hits as f32 + self.total_partial_hits as f32 * 0.5) / total as f32
60    }
61}
62
63// ─── PredictiveCache ────────────────────────────────
64
65pub struct PredictiveCache {
66    pub predictions: Vec<Prediction>,
67    pub stats: CacheStats,
68}
69
70impl PredictiveCache {
71    pub fn load_or_init(output_dir: &Path) -> Self {
72        let path = output_dir.join("predictive_cache.bin");
73        if path.exists() {
74            load_cache(&path)
75        } else {
76            Self {
77                predictions: Vec::new(),
78                stats: CacheStats::default(),
79            }
80        }
81    }
82
83    /// Check if we have a prediction for this query. Returns cached blocks + confidence.
84    /// This is called BEFORE the actual search, so the results can be used immediately.
85    pub fn check(&self, query_hash: u64) -> Option<(Vec<u32>, f32)> {
86        self.predictions
87            .iter()
88            .find(|p| p.predicted_query_hash == query_hash && p.confidence >= MIN_CONFIDENCE)
89            .map(|p| (p.blocks.clone(), p.confidence))
90    }
91
92    /// Evaluate prediction accuracy after a recall completes.
93    /// Compares predicted blocks against actual results.
94    /// Returns: (hit_type, overlap_count) where hit_type is "hit", "partial", or "miss".
95    pub fn evaluate(
96        &mut self,
97        query_hash: u64,
98        actual_results: &[u32],
99        thought_graph: &mut ThoughtGraphState,
100    ) -> (&'static str, usize) {
101        let prediction = self
102            .predictions
103            .iter()
104            .find(|p| p.predicted_query_hash == query_hash);
105
106        let prediction = match prediction {
107            Some(p) => p.clone(),
108            None => return ("none", 0),
109        };
110
111        // Count overlap
112        let overlap = prediction
113            .blocks
114            .iter()
115            .filter(|b| actual_results.contains(b))
116            .count();
117
118        let hit_type = if overlap == 0 {
119            // Miss
120            self.stats.total_misses += 1;
121            // Penalize the pattern
122            if let Some(pattern) = thought_graph
123                .patterns
124                .iter_mut()
125                .find(|p| p.id == prediction.pattern_id)
126            {
127                pattern.strength = (pattern.strength - MISS_PENALTY).max(0.0);
128            }
129            // Decay prediction confidence
130            if let Some(pred) = self
131                .predictions
132                .iter_mut()
133                .find(|p| p.predicted_query_hash == query_hash)
134            {
135                pred.confidence *= 0.5; // harsh decay on miss
136            }
137            "miss"
138        } else if overlap >= prediction.blocks.len() / 2 || overlap >= 3 {
139            // Hit — majority of predicted blocks were correct
140            self.stats.total_hits += 1;
141            // Reward the pattern
142            if let Some(pattern) = thought_graph
143                .patterns
144                .iter_mut()
145                .find(|p| p.id == prediction.pattern_id)
146            {
147                pattern.strength = (pattern.strength + HIT_REWARD).min(5.0);
148            }
149            // Boost prediction confidence
150            if let Some(pred) = self
151                .predictions
152                .iter_mut()
153                .find(|p| p.predicted_query_hash == query_hash)
154            {
155                pred.confidence = (pred.confidence + 0.2).min(1.0);
156            }
157            "hit"
158        } else {
159            // Partial hit
160            self.stats.total_partial_hits += 1;
161            let reward = HIT_REWARD * (overlap as f32 / prediction.blocks.len() as f32);
162            if let Some(pattern) = thought_graph
163                .patterns
164                .iter_mut()
165                .find(|p| p.id == prediction.pattern_id)
166            {
167                pattern.strength = (pattern.strength + reward).min(5.0);
168            }
169            "partial"
170        };
171
172        (hit_type, overlap)
173    }
174
175    /// Generate predictions for the next likely query based on current session state.
176    /// Called after each recall to pre-load the cache.
177    pub fn predict_next(&mut self, thought_graph: &ThoughtGraphState) {
178        // Decay all existing predictions
179        for pred in &mut self.predictions {
180            pred.confidence *= PREDICTION_DECAY;
181        }
182        self.predictions.retain(|p| p.confidence >= MIN_CONFIDENCE);
183
184        let session_hashes: Vec<u64> = thought_graph
185            .nodes
186            .iter()
187            .filter(|n| n.session_id == thought_graph.current_session_id)
188            .map(|n| n.query_hash)
189            .collect();
190
191        if session_hashes.is_empty() {
192            return;
193        }
194
195        let now_ms = now_epoch_ms();
196
197        for pattern in &thought_graph.patterns {
198            if pattern.frequency < MIN_PATTERN_FREQ {
199                continue;
200            }
201            if pattern.result_blocks.is_empty() {
202                continue;
203            }
204
205            let seq = &pattern.sequence;
206
207            // Check if session trail matches any prefix of this pattern
208            // If trail ends with seq[0..n], predict seq[n] with its blocks
209            for prefix_len in 1..seq.len() {
210                if session_hashes.len() < prefix_len {
211                    continue;
212                }
213
214                let trail_start = session_hashes.len() - prefix_len;
215                let trail = &session_hashes[trail_start..];
216
217                if trail == &seq[..prefix_len] {
218                    let predicted_hash = seq[prefix_len];
219
220                    // Don't duplicate predictions for same hash
221                    if self
222                        .predictions
223                        .iter()
224                        .any(|p| p.predicted_query_hash == predicted_hash)
225                    {
226                        continue;
227                    }
228
229                    let confidence = pattern.strength
230                        * PATTERN_BOOST_WEIGHT
231                        * (prefix_len as f32 / seq.len() as f32);
232
233                    let blocks: Vec<u32> = pattern
234                        .result_blocks
235                        .iter()
236                        .take(MAX_BLOCKS_PER_PREDICTION)
237                        .copied()
238                        .collect();
239
240                    self.predictions.push(Prediction {
241                        predicted_query_hash: predicted_hash,
242                        blocks,
243                        confidence: confidence.min(1.0),
244                        pattern_id: pattern.id,
245                        created_ms: now_ms,
246                    });
247
248                    self.stats.total_predictions += 1;
249                }
250            }
251        }
252
253        // Cap predictions
254        if self.predictions.len() > MAX_PREDICTIONS {
255            self.predictions
256                .sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
257            self.predictions.truncate(MAX_PREDICTIONS);
258        }
259
260        // Update avg confidence
261        if !self.predictions.is_empty() {
262            self.stats.avg_confidence = self.predictions.iter().map(|p| p.confidence).sum::<f32>()
263                / self.predictions.len() as f32;
264        }
265        self.stats.current_predictions = self.predictions.len();
266    }
267
268    /// Export stats for cross-instance exchange.
269    pub fn export_stats(&self) -> (u32, u32, u32, f32) {
270        (
271            self.stats.total_predictions,
272            self.stats.total_hits,
273            self.stats.total_misses,
274            self.stats.hit_rate(),
275        )
276    }
277
278    /// Merge remote stats (additive, for reporting).
279    pub fn merge_stats(&mut self, remote_predictions: u32, remote_hits: u32, remote_misses: u32) {
280        self.stats.total_predictions += remote_predictions;
281        self.stats.total_hits += remote_hits;
282        self.stats.total_misses += remote_misses;
283    }
284
285    /// Dream cleanup: remove predictions with very low confidence.
286    pub fn dream_cleanup(&mut self) {
287        self.predictions.retain(|p| p.confidence > 0.1);
288        self.stats.current_predictions = self.predictions.len();
289    }
290
291    /// Save to binary.
292    pub fn save(&self, output_dir: &Path) -> Result<(), String> {
293        save_cache(&output_dir.join("predictive_cache.bin"), self)
294    }
295}
296
297// ─── Binary I/O ─────────────────────────────────────
298
299fn now_epoch_ms() -> u64 {
300    SystemTime::now()
301        .duration_since(UNIX_EPOCH)
302        .unwrap_or_default()
303        .as_millis() as u64
304}
305
306fn save_cache(path: &Path, cache: &PredictiveCache) -> Result<(), String> {
307    let mut buf = Vec::with_capacity(256);
308
309    // Header
310    buf.write_all(b"PRC1").map_err(|e| e.to_string())?;
311    buf.write_all(&(cache.predictions.len() as u32).to_le_bytes())
312        .map_err(|e| e.to_string())?;
313
314    // Stats
315    buf.write_all(&cache.stats.total_predictions.to_le_bytes())
316        .map_err(|e| e.to_string())?;
317    buf.write_all(&cache.stats.total_hits.to_le_bytes())
318        .map_err(|e| e.to_string())?;
319    buf.write_all(&cache.stats.total_misses.to_le_bytes())
320        .map_err(|e| e.to_string())?;
321    buf.write_all(&cache.stats.total_partial_hits.to_le_bytes())
322        .map_err(|e| e.to_string())?;
323
324    // Predictions (variable length)
325    for p in &cache.predictions {
326        buf.write_all(&p.predicted_query_hash.to_le_bytes())
327            .map_err(|e| e.to_string())?;
328        buf.write_all(&p.confidence.to_le_bytes())
329            .map_err(|e| e.to_string())?;
330        buf.write_all(&p.pattern_id.to_le_bytes())
331            .map_err(|e| e.to_string())?;
332        buf.write_all(&p.created_ms.to_le_bytes())
333            .map_err(|e| e.to_string())?;
334        buf.write_all(&(p.blocks.len() as u16).to_le_bytes())
335            .map_err(|e| e.to_string())?;
336        for &b in &p.blocks {
337            buf.write_all(&b.to_le_bytes()).map_err(|e| e.to_string())?;
338        }
339    }
340
341    fs::write(path, &buf).map_err(|e| e.to_string())
342}
343
344fn load_cache(path: &Path) -> PredictiveCache {
345    let data = match fs::read(path) {
346        Ok(d) => d,
347        Err(_) => {
348            return PredictiveCache {
349                predictions: Vec::new(),
350                stats: CacheStats::default(),
351            }
352        }
353    };
354
355    if data.len() < 24 || &data[0..4] != b"PRC1" {
356        return PredictiveCache {
357            predictions: Vec::new(),
358            stats: CacheStats::default(),
359        };
360    }
361
362    let pred_count = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
363
364    let total_predictions = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
365    let total_hits = u32::from_le_bytes([data[12], data[13], data[14], data[15]]);
366    let total_misses = u32::from_le_bytes([data[16], data[17], data[18], data[19]]);
367    let total_partial_hits = u32::from_le_bytes([data[20], data[21], data[22], data[23]]);
368
369    let mut offset = 24;
370    let mut predictions = Vec::with_capacity(pred_count);
371
372    for _ in 0..pred_count {
373        if offset + 22 > data.len() {
374            break;
375        }
376
377        let predicted_query_hash = u64::from_le_bytes([
378            data[offset],
379            data[offset + 1],
380            data[offset + 2],
381            data[offset + 3],
382            data[offset + 4],
383            data[offset + 5],
384            data[offset + 6],
385            data[offset + 7],
386        ]);
387        offset += 8;
388
389        let confidence = f32::from_le_bytes([
390            data[offset],
391            data[offset + 1],
392            data[offset + 2],
393            data[offset + 3],
394        ]);
395        offset += 4;
396
397        let pattern_id = u32::from_le_bytes([
398            data[offset],
399            data[offset + 1],
400            data[offset + 2],
401            data[offset + 3],
402        ]);
403        offset += 4;
404
405        let created_ms = u64::from_le_bytes([
406            data[offset],
407            data[offset + 1],
408            data[offset + 2],
409            data[offset + 3],
410            data[offset + 4],
411            data[offset + 5],
412            data[offset + 6],
413            data[offset + 7],
414        ]);
415        offset += 8;
416
417        if offset + 2 > data.len() {
418            break;
419        }
420        let block_count = u16::from_le_bytes([data[offset], data[offset + 1]]) as usize;
421        offset += 2;
422
423        if offset + block_count * 4 > data.len() {
424            break;
425        }
426        let mut blocks = Vec::with_capacity(block_count);
427        for _ in 0..block_count {
428            let b = u32::from_le_bytes([
429                data[offset],
430                data[offset + 1],
431                data[offset + 2],
432                data[offset + 3],
433            ]);
434            blocks.push(b);
435            offset += 4;
436        }
437
438        predictions.push(Prediction {
439            predicted_query_hash,
440            blocks,
441            confidence,
442            pattern_id,
443            created_ms,
444        });
445    }
446
447    let current_predictions = predictions.len();
448    let avg_confidence = if predictions.is_empty() {
449        0.0
450    } else {
451        predictions.iter().map(|p| p.confidence).sum::<f32>() / predictions.len() as f32
452    };
453
454    PredictiveCache {
455        predictions,
456        stats: CacheStats {
457            total_predictions,
458            total_hits,
459            total_misses,
460            total_partial_hits,
461            current_predictions,
462            avg_confidence,
463        },
464    }
465}
466
467// ─── Tests ──────────────────────────────────────────
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::thought_graph::{ThoughtGraphState, ThoughtPattern};
473
474    fn make_tg() -> ThoughtGraphState {
475        ThoughtGraphState::load_or_init(Path::new("/nonexistent"))
476    }
477
478    #[test]
479    fn test_check_empty() {
480        let cache = PredictiveCache {
481            predictions: Vec::new(),
482            stats: CacheStats::default(),
483        };
484        assert!(cache.check(0xAA).is_none());
485    }
486
487    #[test]
488    fn test_check_hit() {
489        let cache = PredictiveCache {
490            predictions: vec![Prediction {
491                predicted_query_hash: 0xAA,
492                blocks: vec![10, 20, 30],
493                confidence: 0.8,
494                pattern_id: 0,
495                created_ms: 0,
496            }],
497            stats: CacheStats::default(),
498        };
499        let result = cache.check(0xAA);
500        assert!(result.is_some());
501        let (blocks, conf) = result.unwrap();
502        assert_eq!(blocks, vec![10, 20, 30]);
503        assert!((conf - 0.8).abs() < 0.001);
504    }
505
506    #[test]
507    fn test_evaluate_hit() {
508        let mut cache = PredictiveCache {
509            predictions: vec![Prediction {
510                predicted_query_hash: 0xAA,
511                blocks: vec![10, 20, 30],
512                confidence: 0.5,
513                pattern_id: 0,
514                created_ms: 0,
515            }],
516            stats: CacheStats::default(),
517        };
518        let mut tg = make_tg();
519        tg.patterns.push(ThoughtPattern {
520            id: 0,
521            sequence: vec![0xBB, 0xAA],
522            frequency: 5,
523            strength: 1.0,
524            last_seen_ms: 0,
525            result_blocks: vec![10, 20, 30],
526        });
527
528        let actual = vec![10u32, 20, 30, 40];
529        let (hit_type, overlap) = cache.evaluate(0xAA, &actual, &mut tg);
530        assert_eq!(hit_type, "hit");
531        assert_eq!(overlap, 3);
532        assert_eq!(cache.stats.total_hits, 1);
533        // Pattern should be rewarded
534        assert!(tg.patterns[0].strength > 1.0);
535    }
536
537    #[test]
538    fn test_evaluate_miss() {
539        let mut cache = PredictiveCache {
540            predictions: vec![Prediction {
541                predicted_query_hash: 0xAA,
542                blocks: vec![10, 20, 30],
543                confidence: 0.5,
544                pattern_id: 0,
545                created_ms: 0,
546            }],
547            stats: CacheStats::default(),
548        };
549        let mut tg = make_tg();
550        tg.patterns.push(ThoughtPattern {
551            id: 0,
552            sequence: vec![0xBB, 0xAA],
553            frequency: 5,
554            strength: 1.0,
555            last_seen_ms: 0,
556            result_blocks: vec![10, 20, 30],
557        });
558
559        let actual = vec![100u32, 200, 300]; // no overlap
560        let (hit_type, overlap) = cache.evaluate(0xAA, &actual, &mut tg);
561        assert_eq!(hit_type, "miss");
562        assert_eq!(overlap, 0);
563        assert_eq!(cache.stats.total_misses, 1);
564        // Pattern should be penalized
565        assert!(tg.patterns[0].strength < 1.0);
566    }
567
568    #[test]
569    fn test_evaluate_no_prediction() {
570        let mut cache = PredictiveCache {
571            predictions: Vec::new(),
572            stats: CacheStats::default(),
573        };
574        let mut tg = make_tg();
575        let (hit_type, _) = cache.evaluate(0xAA, &[10, 20], &mut tg);
576        assert_eq!(hit_type, "none");
577    }
578
579    #[test]
580    fn test_predict_next() {
581        let mut cache = PredictiveCache {
582            predictions: Vec::new(),
583            stats: CacheStats::default(),
584        };
585        let mut tg = make_tg();
586
587        // Set up: session with one recall (hash=0xAA), pattern AA→BB with blocks
588        tg.current_session_id = 1;
589        tg.nodes.push(crate::thought_graph::ThoughtNode {
590            timestamp_ms: 1000,
591            query_hash: 0xAA,
592            session_id: 1,
593            result_count: 3,
594            dominant_layer: 1,
595            centroid_hash: 0,
596        });
597        tg.patterns.push(ThoughtPattern {
598            id: 0,
599            sequence: vec![0xAA, 0xBB],
600            frequency: 5,
601            strength: 2.0,
602            last_seen_ms: 1000,
603            result_blocks: vec![10, 20, 30],
604        });
605
606        cache.predict_next(&tg);
607
608        assert_eq!(cache.predictions.len(), 1);
609        assert_eq!(cache.predictions[0].predicted_query_hash, 0xBB);
610        assert_eq!(cache.predictions[0].blocks, vec![10, 20, 30]);
611        assert!(cache.predictions[0].confidence > 0.0);
612    }
613
614    #[test]
615    fn test_predict_decay() {
616        let mut cache = PredictiveCache {
617            predictions: vec![Prediction {
618                predicted_query_hash: 0xAA,
619                blocks: vec![10],
620                confidence: MIN_CONFIDENCE + 0.01,
621                pattern_id: 0,
622                created_ms: 0,
623            }],
624            stats: CacheStats::default(),
625        };
626        let tg = make_tg();
627
628        // Multiple predict_next calls should decay confidence below threshold
629        for _ in 0..20 {
630            cache.predict_next(&tg);
631        }
632        assert!(cache.predictions.is_empty());
633    }
634
635    #[test]
636    fn test_hit_rate() {
637        let mut stats = CacheStats::default();
638        assert_eq!(stats.hit_rate(), 0.0);
639
640        stats.total_hits = 7;
641        stats.total_misses = 3;
642        assert!((stats.hit_rate() - 0.7).abs() < 0.001);
643
644        stats.total_partial_hits = 2;
645        // (7 + 2*0.5) / (7+3+2) = 8/12 = 0.667
646        assert!((stats.hit_rate() - 0.6667).abs() < 0.01);
647    }
648
649    #[test]
650    fn test_save_load_roundtrip() {
651        let dir = tempfile::tempdir().unwrap();
652
653        let cache = PredictiveCache {
654            predictions: vec![
655                Prediction {
656                    predicted_query_hash: 0xAA,
657                    blocks: vec![10, 20],
658                    confidence: 0.75,
659                    pattern_id: 1,
660                    created_ms: 12345,
661                },
662                Prediction {
663                    predicted_query_hash: 0xBB,
664                    blocks: vec![30, 40, 50],
665                    confidence: 0.5,
666                    pattern_id: 2,
667                    created_ms: 67890,
668                },
669            ],
670            stats: CacheStats {
671                total_predictions: 10,
672                total_hits: 5,
673                total_misses: 3,
674                total_partial_hits: 2,
675                current_predictions: 2,
676                avg_confidence: 0.625,
677            },
678        };
679
680        cache.save(dir.path()).unwrap();
681        let loaded = PredictiveCache::load_or_init(dir.path());
682
683        assert_eq!(loaded.predictions.len(), 2);
684        assert_eq!(loaded.predictions[0].predicted_query_hash, 0xAA);
685        assert_eq!(loaded.predictions[0].blocks, vec![10, 20]);
686        assert!((loaded.predictions[0].confidence - 0.75).abs() < 0.001);
687        assert_eq!(loaded.predictions[1].blocks, vec![30, 40, 50]);
688        assert_eq!(loaded.stats.total_hits, 5);
689        assert_eq!(loaded.stats.total_misses, 3);
690        assert_eq!(loaded.stats.total_partial_hits, 2);
691    }
692}