lean_ctx/core/
predictive_prefetch.rs1use std::collections::HashMap;
13
14const MAX_PREFETCH: usize = 5;
16const MIN_CONFIDENCE: f64 = 0.3;
18
19pub struct PrefetchModel {
21 transitions: HashMap<u64, Vec<(u64, f64)>>,
23 predictions_made: u64,
25 predictions_hit: u64,
26 recent_accesses: Vec<u64>,
28}
29
30impl Default for PrefetchModel {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl PrefetchModel {
37 pub fn new() -> Self {
38 Self {
39 transitions: HashMap::with_capacity(128),
40 predictions_made: 0,
41 predictions_hit: 0,
42 recent_accesses: Vec::with_capacity(64),
43 }
44 }
45
46 pub fn observe(&mut self, path_hash: u64) {
48 let window = self.recent_accesses.len().min(3);
50 if window > 0 {
51 for &prev in &self.recent_accesses[self.recent_accesses.len() - window..] {
52 let entry = self.transitions.entry(prev).or_default();
53 if let Some(pair) = entry.iter_mut().find(|(h, _)| *h == path_hash) {
54 pair.1 += 1.0;
55 } else {
56 entry.push((path_hash, 1.0));
57 }
58 }
59 }
60
61 self.recent_accesses.push(path_hash);
62 if self.recent_accesses.len() > 100 {
63 self.recent_accesses.drain(..50);
64 }
65
66 if self.transitions.len() > 2000 {
68 self.prune_transitions();
69 }
70 }
71
72 pub fn predict(&self, current_hash: u64, active_hashes: &[u64]) -> Vec<(u64, f64)> {
75 let mut candidates: HashMap<u64, f64> = HashMap::new();
76
77 if let Some(transitions) = self.transitions.get(¤t_hash) {
79 let total: f64 = transitions.iter().map(|(_, w)| w).sum();
80 if total > 0.0 {
81 for &(target, weight) in transitions {
82 let prob = weight / total;
83 *candidates.entry(target).or_insert(0.0) += prob * 0.6;
84 }
85 }
86 }
87
88 for &active in active_hashes.iter().take(5) {
90 if let Some(transitions) = self.transitions.get(&active) {
91 let total: f64 = transitions.iter().map(|(_, w)| w).sum();
92 if total > 0.0 {
93 for &(target, weight) in transitions {
94 let prob = weight / total;
95 *candidates.entry(target).or_insert(0.0) += prob * 0.3;
96 }
97 }
98 }
99 }
100
101 if candidates.is_empty() {
103 let last_n: Vec<u64> = self
104 .recent_accesses
105 .iter()
106 .rev()
107 .take(10)
108 .copied()
109 .collect();
110 for &h in &last_n {
111 *candidates.entry(h).or_insert(0.0) += 0.1;
112 }
113 }
114
115 let active_set: std::collections::HashSet<u64> = active_hashes.iter().copied().collect();
117 candidates.retain(|h, _| !active_set.contains(h) && *h != current_hash);
118
119 let mut sorted: Vec<(u64, f64)> = candidates.into_iter().collect();
121 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
122 sorted.truncate(MAX_PREFETCH);
123
124 sorted.retain(|(_, conf)| *conf >= MIN_CONFIDENCE);
126 sorted
127 }
128
129 pub fn report_hit(&mut self, predicted_hash: u64, was_accessed: bool) {
131 self.predictions_made += 1;
132 if was_accessed {
133 self.predictions_hit += 1;
134
135 if let Some(&last) = self.recent_accesses.last() {
137 if let Some(transitions) = self.transitions.get_mut(&last) {
138 if let Some(pair) = transitions.iter_mut().find(|(h, _)| *h == predicted_hash) {
139 pair.1 *= 1.2; }
141 }
142 }
143 }
144 }
145
146 pub fn accuracy(&self) -> f64 {
148 if self.predictions_made == 0 {
149 return 0.0;
150 }
151 self.predictions_hit as f64 / self.predictions_made as f64
152 }
153
154 pub fn free_energy(&self) -> f64 {
156 1.0 - self.accuracy()
157 }
158
159 pub fn should_prefetch(&self) -> bool {
162 self.predictions_made >= 10 && self.accuracy() > 0.2
163 }
164
165 fn prune_transitions(&mut self) {
166 for transitions in self.transitions.values_mut() {
168 if transitions.len() > 10 {
169 transitions
170 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
171 transitions.truncate(10);
172 }
173 }
174 self.transitions.retain(|_, v| !v.is_empty());
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn model_learns_transitions() {
185 let mut model = PrefetchModel::new();
186 let a = 1u64;
187 let b = 2u64;
188
189 for _ in 0..30 {
191 model.observe(a);
192 model.observe(b);
193 }
194
195 let predictions = model.predict(a, &[]);
197 assert!(
198 !predictions.is_empty(),
199 "Expected predictions after 30 A→B transitions"
200 );
201 assert!(
202 predictions.iter().any(|(h, _)| *h == b),
203 "Expected B in predictions, got: {predictions:?}"
204 );
205 }
206
207 #[test]
208 fn empty_model_returns_no_predictions_above_threshold() {
209 let model = PrefetchModel::new();
210 let predictions = model.predict(42, &[]);
211 assert!(predictions.iter().all(|(_, conf)| *conf >= MIN_CONFIDENCE));
213 }
214
215 #[test]
216 fn accuracy_tracking() {
217 let mut model = PrefetchModel::new();
218 model.report_hit(1, true);
219 model.report_hit(2, true);
220 model.report_hit(3, false);
221 assert!((model.accuracy() - 0.666).abs() < 0.01);
222 }
223
224 #[test]
225 fn free_energy_decreases_with_accuracy() {
226 let mut model = PrefetchModel::new();
227 for i in 0..20 {
228 model.report_hit(i, true);
229 }
230 assert!(model.free_energy() < 0.1);
231 }
232}