Skip to main content

lean_ctx/core/
predictive_prefetch.rs

1//! Predictive Prefetch via Free Energy Minimization.
2//!
3//! Scientific basis: Karl Friston's Free Energy Principle (2010) — the system minimizes
4//! "surprise" (unexpected information requests) by maintaining a generative model of what
5//! files will be needed next and proactively loading them when resources permit.
6//!
7//! The model combines:
8//! 1. Co-access history (Hebbian associations)
9//! 2. Graph neighborhood (import/call relationships)
10//! 3. Recency patterns (temporal locality)
11
12use std::collections::HashMap;
13
14/// Maximum files to prefetch per prediction cycle.
15const MAX_PREFETCH: usize = 5;
16/// Minimum prediction confidence to trigger prefetch.
17const MIN_CONFIDENCE: f64 = 0.3;
18
19/// Tracks prediction accuracy for model self-evaluation.
20pub struct PrefetchModel {
21    /// Transition probabilities: after accessing file A, probability of accessing file B.
22    transitions: HashMap<u64, Vec<(u64, f64)>>,
23    /// Rolling accuracy metric.
24    predictions_made: u64,
25    predictions_hit: u64,
26    /// Recent access sequence for learning.
27    recent_accesses: Vec<u64>,
28}
29
30impl Default for PrefetchModel {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl PrefetchModel {
37    pub fn new() -> Self {
38        Self {
39            transitions: HashMap::with_capacity(128),
40            predictions_made: 0,
41            predictions_hit: 0,
42            recent_accesses: Vec::with_capacity(64),
43        }
44    }
45
46    /// Record a file access and learn transition patterns.
47    pub fn observe(&mut self, path_hash: u64) {
48        // Learn: strengthen transition from last N accesses → this file
49        let window = self.recent_accesses.len().min(3);
50        if window > 0 {
51            for &prev in &self.recent_accesses[self.recent_accesses.len() - window..] {
52                let entry = self.transitions.entry(prev).or_default();
53                if let Some(pair) = entry.iter_mut().find(|(h, _)| *h == path_hash) {
54                    pair.1 += 1.0;
55                } else {
56                    entry.push((path_hash, 1.0));
57                }
58            }
59        }
60
61        self.recent_accesses.push(path_hash);
62        if self.recent_accesses.len() > 100 {
63            self.recent_accesses.drain(..50);
64        }
65
66        // Prune transition table if too large
67        if self.transitions.len() > 2000 {
68            self.prune_transitions();
69        }
70    }
71
72    /// Predict which files will be accessed next, based on current state.
73    /// Returns (path_hash, confidence) pairs sorted by confidence descending.
74    pub fn predict(&self, current_hash: u64, active_hashes: &[u64]) -> Vec<(u64, f64)> {
75        let mut candidates: HashMap<u64, f64> = HashMap::new();
76
77        // Signal 1: Direct transitions from current file
78        if let Some(transitions) = self.transitions.get(&current_hash) {
79            let total: f64 = transitions.iter().map(|(_, w)| w).sum();
80            if total > 0.0 {
81                for &(target, weight) in transitions {
82                    let prob = weight / total;
83                    *candidates.entry(target).or_insert(0.0) += prob * 0.6;
84                }
85            }
86        }
87
88        // Signal 2: Transitions from recently active files (temporal context)
89        for &active in active_hashes.iter().take(5) {
90            if let Some(transitions) = self.transitions.get(&active) {
91                let total: f64 = transitions.iter().map(|(_, w)| w).sum();
92                if total > 0.0 {
93                    for &(target, weight) in transitions {
94                        let prob = weight / total;
95                        *candidates.entry(target).or_insert(0.0) += prob * 0.3;
96                    }
97                }
98            }
99        }
100
101        // Signal 3: Global frequency (fallback for cold-start)
102        if candidates.is_empty() {
103            let last_n: Vec<u64> = self
104                .recent_accesses
105                .iter()
106                .rev()
107                .take(10)
108                .copied()
109                .collect();
110            for &h in &last_n {
111                *candidates.entry(h).or_insert(0.0) += 0.1;
112            }
113        }
114
115        // Remove already-active files from predictions
116        let active_set: std::collections::HashSet<u64> = active_hashes.iter().copied().collect();
117        candidates.retain(|h, _| !active_set.contains(h) && *h != current_hash);
118
119        // Sort by confidence and take top-k
120        let mut sorted: Vec<(u64, f64)> = candidates.into_iter().collect();
121        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
122        sorted.truncate(MAX_PREFETCH);
123
124        // Filter by minimum confidence
125        sorted.retain(|(_, conf)| *conf >= MIN_CONFIDENCE);
126        sorted
127    }
128
129    /// Report whether a predicted file was actually accessed (feedback loop).
130    pub fn report_hit(&mut self, predicted_hash: u64, was_accessed: bool) {
131        self.predictions_made += 1;
132        if was_accessed {
133            self.predictions_hit += 1;
134
135            // Strengthen the transition that led to this prediction
136            if let Some(&last) = self.recent_accesses.last() {
137                if let Some(transitions) = self.transitions.get_mut(&last) {
138                    if let Some(pair) = transitions.iter_mut().find(|(h, _)| *h == predicted_hash) {
139                        pair.1 *= 1.2; // Reward correct prediction
140                    }
141                }
142            }
143        }
144    }
145
146    /// Current prediction accuracy (0.0 - 1.0).
147    pub fn accuracy(&self) -> f64 {
148        if self.predictions_made == 0 {
149            return 0.0;
150        }
151        self.predictions_hit as f64 / self.predictions_made as f64
152    }
153
154    /// Free Energy = surprise metric. High value means predictions are poor.
155    pub fn free_energy(&self) -> f64 {
156        1.0 - self.accuracy()
157    }
158
159    /// Should we actively prefetch? Only when model has learned enough and
160    /// prediction accuracy is reasonable.
161    pub fn should_prefetch(&self) -> bool {
162        self.predictions_made >= 10 && self.accuracy() > 0.2
163    }
164
165    fn prune_transitions(&mut self) {
166        // Keep only top-10 transitions per source
167        for transitions in self.transitions.values_mut() {
168            if transitions.len() > 10 {
169                transitions
170                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
171                transitions.truncate(10);
172            }
173        }
174        // Remove sources with all-zero transitions
175        self.transitions.retain(|_, v| !v.is_empty());
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn model_learns_transitions() {
185        let mut model = PrefetchModel::new();
186        let a = 1u64;
187        let b = 2u64;
188
189        // Repeated strong pattern: A → B (30 times builds high weight)
190        for _ in 0..30 {
191            model.observe(a);
192            model.observe(b);
193        }
194
195        // After observing A, should predict B with high confidence
196        let predictions = model.predict(a, &[]);
197        assert!(
198            !predictions.is_empty(),
199            "Expected predictions after 30 A→B transitions"
200        );
201        assert!(
202            predictions.iter().any(|(h, _)| *h == b),
203            "Expected B in predictions, got: {predictions:?}"
204        );
205    }
206
207    #[test]
208    fn empty_model_returns_no_predictions_above_threshold() {
209        let model = PrefetchModel::new();
210        let predictions = model.predict(42, &[]);
211        // Fresh model may return recent accesses but below threshold
212        assert!(predictions.iter().all(|(_, conf)| *conf >= MIN_CONFIDENCE));
213    }
214
215    #[test]
216    fn accuracy_tracking() {
217        let mut model = PrefetchModel::new();
218        model.report_hit(1, true);
219        model.report_hit(2, true);
220        model.report_hit(3, false);
221        assert!((model.accuracy() - 0.666).abs() < 0.01);
222    }
223
224    #[test]
225    fn free_energy_decreases_with_accuracy() {
226        let mut model = PrefetchModel::new();
227        for i in 0..20 {
228            model.report_hit(i, true);
229        }
230        assert!(model.free_energy() < 0.1);
231    }
232}