Skip to main content

lean_ctx/core/
hebbian_cache.rs

1//! Hebbian Co-Access Cache with Boltzmann-Temperature Eviction.
2//!
3//! Scientific basis:
4//! - Hebb (1949): "Neurons that fire together wire together" — files accessed together
5//!   strengthen their association, making co-accessed files resistant to eviction.
6//! - Boltzmann distribution (Statistical Physics): P(evict) = exp(-E/kT) where E is the
7//!   "value" of a cache entry and T is the memory pressure. Low T = deterministic (only
8//!   lowest-value entries evicted), High T = stochastic (prevents thrashing).
9
10use std::collections::HashMap;
11use std::time::Instant;
12
13/// Maximum number of co-access pairs tracked (prevents unbounded growth).
14const MAX_ASSOCIATIONS: usize = 10_000;
15/// Decay half-life in seconds for Hebbian weights.
16const DECAY_HALF_LIFE_SECS: f64 = 300.0;
17/// Minimum weight before pruning.
18const PRUNE_THRESHOLD: f32 = 0.01;
19
20/// Tracks co-access patterns between files (Hebbian learning).
21pub struct CoAccessMatrix {
22    /// Sparse co-access weights: (path_hash_a, path_hash_b) → weight
23    weights: HashMap<(u64, u64), f32>,
24    /// When each pair was last strengthened
25    timestamps: HashMap<(u64, u64), Instant>,
26    /// Current access burst (files read in the same tool-call window)
27    current_burst: Vec<u64>,
28    burst_start: Instant,
29}
30
31impl Default for CoAccessMatrix {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl CoAccessMatrix {
38    pub fn new() -> Self {
39        Self {
40            weights: HashMap::with_capacity(256),
41            timestamps: HashMap::with_capacity(256),
42            current_burst: Vec::with_capacity(8),
43            burst_start: Instant::now(),
44        }
45    }
46
47    /// Record a file access. If within the burst window (500ms), strengthens
48    /// associations with other files in the same burst.
49    pub fn record_access(&mut self, path_hash: u64) {
50        let now = Instant::now();
51        let burst_window = std::time::Duration::from_millis(500);
52
53        if now.duration_since(self.burst_start) > burst_window {
54            self.flush_burst();
55            self.burst_start = now;
56        }
57
58        self.current_burst.push(path_hash);
59    }
60
61    /// Flush current burst: strengthen all pairwise associations.
62    fn flush_burst(&mut self) {
63        if self.current_burst.len() < 2 {
64            self.current_burst.clear();
65            return;
66        }
67
68        let now = Instant::now();
69        let burst = std::mem::take(&mut self.current_burst);
70
71        for i in 0..burst.len() {
72            for j in (i + 1)..burst.len() {
73                let key = normalized_key(burst[i], burst[j]);
74                let w = self.weights.entry(key).or_insert(0.0);
75                *w += 1.0;
76                self.timestamps.insert(key, now);
77            }
78        }
79
80        if self.weights.len() > MAX_ASSOCIATIONS {
81            self.prune();
82        }
83    }
84
85    /// Get the association strength of a file with all currently active files.
86    /// Applies exponential decay based on elapsed time.
87    pub fn association_strength(&self, path_hash: u64, active_hashes: &[u64]) -> f32 {
88        let now = Instant::now();
89        let mut total = 0.0f32;
90
91        for &active in active_hashes {
92            let key = normalized_key(path_hash, active);
93            if let Some(&weight) = self.weights.get(&key) {
94                let elapsed = self.timestamps.get(&key).map_or(DECAY_HALF_LIFE_SECS, |t| {
95                    now.duration_since(*t).as_secs_f64()
96                });
97                let decay = (-elapsed * (2.0f64.ln()) / DECAY_HALF_LIFE_SECS).exp();
98                total += weight * decay as f32;
99            }
100        }
101
102        total
103    }
104
105    /// Remove weak associations to keep memory bounded.
106    fn prune(&mut self) {
107        let now = Instant::now();
108        self.weights.retain(|key, weight| {
109            let elapsed = self
110                .timestamps
111                .get(key)
112                .map_or(DECAY_HALF_LIFE_SECS * 2.0, |t| {
113                    now.duration_since(*t).as_secs_f64()
114                });
115            let decay = (-elapsed * (2.0f64.ln()) / DECAY_HALF_LIFE_SECS).exp();
116            let effective = *weight * decay as f32;
117            if effective < PRUNE_THRESHOLD {
118                self.timestamps.remove(key);
119                false
120            } else {
121                true
122            }
123        });
124    }
125
126    /// Force flush any pending burst (call at end of tool-call processing).
127    pub fn end_burst(&mut self) {
128        self.flush_burst();
129    }
130}
131
132/// Normalize key so (a,b) == (b,a).
133fn normalized_key(a: u64, b: u64) -> (u64, u64) {
134    if a <= b {
135        (a, b)
136    } else {
137        (b, a)
138    }
139}
140
141/// Compute a fast hash for a file path.
142pub fn path_hash(path: &str) -> u64 {
143    use std::hash::{Hash, Hasher};
144    let mut h = std::collections::hash_map::DefaultHasher::new();
145    path.hash(&mut h);
146    h.finish()
147}
148
149// ─── Boltzmann-Temperature Eviction ───────────────────────────────────────────
150
151/// Compute the "energy" (value) of a cache entry for Boltzmann eviction.
152/// Higher energy = more valuable = less likely to be evicted.
153pub struct EntryEnergy {
154    pub read_count: u32,
155    pub recency_secs: f64,
156    pub association_strength: f32,
157    pub token_size: usize,
158    pub graph_centrality: f32,
159}
160
161impl EntryEnergy {
162    /// Calculate the energy value E for this entry.
163    /// Combines multiple signals into a single scalar.
164    pub fn compute(&self) -> f64 {
165        // Recency contributes with log-decay (recent = high energy)
166        let recency_score = 1.0 / (1.0 + self.recency_secs / 60.0);
167
168        // Read frequency (diminishing returns via sqrt)
169        let freq_score = (self.read_count as f64).sqrt();
170
171        // Association boost (normalized)
172        let assoc_score = (self.association_strength as f64).min(5.0);
173
174        // Size penalty (large entries cost more to keep)
175        let size_penalty = 1.0 / (1.0 + (self.token_size as f64 / 5000.0));
176
177        // Graph centrality bonus
178        let centrality_score = self.graph_centrality as f64;
179
180        // Weighted combination
181        recency_score * 3.0
182            + freq_score * 2.0
183            + assoc_score * 1.5
184            + size_penalty * 1.0
185            + centrality_score * 1.0
186    }
187}
188
189/// Boltzmann eviction decision.
190/// Returns the indices to evict from a list of energy scores, given a temperature T.
191///
192/// Temperature T = normalized memory pressure:
193/// - T ≈ 0: almost deterministic (only lowest-energy entries evicted)
194/// - T ≈ 1: stochastic (prevents pathological thrashing)
195pub fn boltzmann_select_evictions(
196    energies: &[f64],
197    num_to_evict: usize,
198    temperature: f64,
199) -> Vec<usize> {
200    if energies.is_empty() || num_to_evict == 0 {
201        return Vec::new();
202    }
203
204    let n = energies.len().min(num_to_evict);
205    let t = temperature.max(0.01); // avoid division by zero
206
207    // Compute eviction probabilities: P(evict_i) ∝ exp(-E_i / T)
208    let max_e = energies.iter().copied().fold(f64::MIN, f64::max);
209    let probs: Vec<f64> = energies
210        .iter()
211        .map(|&e| {
212            let normalized = (e - max_e) / t.max(0.01);
213            (-normalized).exp() // lower energy → higher eviction probability
214        })
215        .collect();
216
217    // Sort by eviction probability (highest first = lowest energy first)
218    let mut indexed: Vec<(usize, f64)> = probs.into_iter().enumerate().collect();
219    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
220
221    // At low temperature, this is nearly deterministic (sorted by energy).
222    // At high temperature, the probabilities flatten out.
223    indexed.into_iter().take(n).map(|(idx, _)| idx).collect()
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn co_access_strengthens_pairs() {
232        let mut matrix = CoAccessMatrix::new();
233        let a = path_hash("src/main.rs");
234        let b = path_hash("src/lib.rs");
235        let c = path_hash("src/config.rs");
236
237        // Simulate burst: A, B, C accessed together
238        matrix.record_access(a);
239        matrix.record_access(b);
240        matrix.record_access(c);
241        matrix.end_burst();
242
243        // A should have association with B
244        assert!(matrix.association_strength(a, &[b]) > 0.0);
245        // And with C
246        assert!(matrix.association_strength(a, &[c]) > 0.0);
247    }
248
249    #[test]
250    fn unrelated_files_have_zero_association() {
251        let matrix = CoAccessMatrix::new();
252        let a = path_hash("src/main.rs");
253        let b = path_hash("src/lib.rs");
254        assert_eq!(matrix.association_strength(a, &[b]), 0.0);
255    }
256
257    #[test]
258    fn boltzmann_low_temp_is_deterministic() {
259        let energies = vec![10.0, 1.0, 5.0, 0.5, 8.0];
260        let evictions = boltzmann_select_evictions(&energies, 2, 0.01);
261        // Should evict lowest-energy entries: idx 3 (0.5) and idx 1 (1.0)
262        assert!(evictions.contains(&3));
263        assert!(evictions.contains(&1));
264    }
265
266    #[test]
267    fn boltzmann_high_temp_still_picks_n() {
268        let energies = vec![10.0, 1.0, 5.0, 0.5, 8.0];
269        let evictions = boltzmann_select_evictions(&energies, 2, 100.0);
270        assert_eq!(evictions.len(), 2);
271    }
272
273    #[test]
274    fn entry_energy_compute_is_sane() {
275        let high_value = EntryEnergy {
276            read_count: 10,
277            recency_secs: 5.0,
278            association_strength: 3.0,
279            token_size: 500,
280            graph_centrality: 0.8,
281        };
282        let low_value = EntryEnergy {
283            read_count: 1,
284            recency_secs: 3600.0,
285            association_strength: 0.0,
286            token_size: 50000,
287            graph_centrality: 0.0,
288        };
289        assert!(high_value.compute() > low_value.compute());
290    }
291
292    #[test]
293    fn normalized_key_is_symmetric() {
294        assert_eq!(normalized_key(42, 99), normalized_key(99, 42));
295    }
296}