Skip to main content

lean_ctx/core/
cooccurrence.rs

1//! Hebbian file co-access graph — "files that fire together, wire together".
2//!
3//! ## The idea (neuroscience → retrieval)
4//!
5//! Hebbian theory: synapses between co-active neurons strengthen (long-term
6//! potentiation, LTP), while unused ones weaken (long-term depression / the
7//! Ebbinghaus forgetting curve). We apply the same rule to files: whenever a
8//! task surfaces a set of files *together*, we strengthen the association
9//! between every pair; on each update all weights decay slightly, so stale
10//! associations fade. Over time the graph learns the project's real working
11//! paths — which the static import/AST graph cannot capture.
12//!
13//! The learned association is an additive retrieval signal: given a file the
14//! agent is looking at, [`related`] returns the files history says tend to be
15//! touched alongside it.
16//!
17//! The store is a small per-project JSON file; reads/writes are whole-file and
18//! cheap because the graph is pruned aggressively (decay + min-weight + caps).
19
20use std::collections::HashMap;
21use std::path::PathBuf;
22
23use serde::{Deserialize, Serialize};
24
25/// Multiplicative decay applied to every edge on each `record` — the forgetting
26/// curve. 0.98 ⇒ an association roughly halves after ~34 un-reinforced updates.
27const DECAY: f64 = 0.98;
28/// Edges weaker than this are pruned (kept the graph small + relevant).
29const MIN_WEIGHT: f64 = 0.08;
30/// Potentiation increment for a co-access (LTP step).
31const LTP_INCREMENT: f64 = 1.0;
32/// Cap on neighbours kept per file (strongest retained) to bound memory.
33const MAX_NEIGHBORS: usize = 32;
34/// Cap on total tracked files; beyond it new files are still recorded but the
35/// weakest-degree files are evicted to stay bounded.
36const MAX_FILES: usize = 5_000;
37/// A single record never associates more than this many files (avoids O(n²)
38/// blow-ups when a tool surfaces a huge file set).
39const MAX_RECORD_FILES: usize = 24;
40
41/// Persistent, decaying co-access graph for one project.
42#[derive(Debug, Default, Clone, Serialize, Deserialize)]
43pub struct CoAccessGraph {
44    /// `file → (neighbour → weight)`. Symmetric by construction.
45    edges: HashMap<String, HashMap<String, f64>>,
46}
47
48impl CoAccessGraph {
49    /// Reinforce the mutual association of every pair in `files` (LTP) after
50    /// decaying the whole graph one step (global forgetting). Self-pairs and
51    /// duplicates are ignored. Bounded work: at most `MAX_RECORD_FILES²` pairs.
52    pub fn record(&mut self, files: &[String]) {
53        // Distinct, capped input.
54        let mut uniq: Vec<&String> = Vec::new();
55        for f in files {
56            if !f.is_empty() && !uniq.contains(&f) {
57                uniq.push(f);
58                if uniq.len() >= MAX_RECORD_FILES {
59                    break;
60                }
61            }
62        }
63        if uniq.len() < 2 {
64            return; // nothing to associate
65        }
66
67        self.decay_all();
68
69        for i in 0..uniq.len() {
70            for j in (i + 1)..uniq.len() {
71                self.bump(uniq[i], uniq[j]);
72                self.bump(uniq[j], uniq[i]);
73            }
74        }
75
76        self.prune();
77    }
78
79    /// Files most strongly associated with `file`, strongest first.
80    pub fn related(&self, file: &str, top_k: usize) -> Vec<(String, f64)> {
81        let Some(neighbours) = self.edges.get(file) else {
82            return Vec::new();
83        };
84        let mut v: Vec<(String, f64)> = neighbours.iter().map(|(k, &w)| (k.clone(), w)).collect();
85        v.sort_by(|a, b| b.1.total_cmp(&a.1));
86        v.truncate(top_k);
87        v
88    }
89
90    fn bump(&mut self, from: &str, to: &str) {
91        let entry = self.edges.entry(from.to_string()).or_default();
92        *entry.entry(to.to_string()).or_insert(0.0) += LTP_INCREMENT;
93    }
94
95    fn decay_all(&mut self) {
96        for neighbours in self.edges.values_mut() {
97            for w in neighbours.values_mut() {
98                *w *= DECAY;
99            }
100        }
101    }
102
103    fn prune(&mut self) {
104        for neighbours in self.edges.values_mut() {
105            neighbours.retain(|_, &mut w| w >= MIN_WEIGHT);
106            if neighbours.len() > MAX_NEIGHBORS {
107                let mut kept: Vec<(String, f64)> =
108                    neighbours.iter().map(|(k, &w)| (k.clone(), w)).collect();
109                kept.sort_by(|a, b| b.1.total_cmp(&a.1));
110                kept.truncate(MAX_NEIGHBORS);
111                *neighbours = kept.into_iter().collect();
112            }
113        }
114        self.edges.retain(|_, neighbours| !neighbours.is_empty());
115
116        if self.edges.len() > MAX_FILES {
117            // Evict the lowest-degree files (least-connected memories).
118            let mut by_degree: Vec<(String, usize)> = self
119                .edges
120                .iter()
121                .map(|(k, n)| (k.clone(), n.len()))
122                .collect();
123            by_degree.sort_by_key(|(_, d)| *d);
124            let evict = self.edges.len() - MAX_FILES;
125            for (file, _) in by_degree.into_iter().take(evict) {
126                self.edges.remove(&file);
127            }
128        }
129    }
130}
131
132// ── Persistence (one small JSON file per project) ──────────────────────────
133
134fn store_path(project_root: &str) -> Option<PathBuf> {
135    let normalized = crate::core::graph_index::normalize_project_root(project_root);
136    let hash = crate::core::project_hash::hash_project_root(&normalized);
137    crate::core::data_dir::lean_ctx_data_dir()
138        .ok()
139        .map(|d| d.join("cooccurrence").join(format!("{hash}.json")))
140}
141
142/// Load the co-access graph for `project_root` (empty if none / unreadable).
143pub fn load(project_root: &str) -> CoAccessGraph {
144    let Some(path) = store_path(project_root) else {
145        return CoAccessGraph::default();
146    };
147    std::fs::read_to_string(&path)
148        .ok()
149        .and_then(|s| serde_json::from_str(&s).ok())
150        .unwrap_or_default()
151}
152
153fn save(project_root: &str, graph: &CoAccessGraph) {
154    let Some(path) = store_path(project_root) else {
155        return;
156    };
157    if let Some(parent) = path.parent() {
158        let _ = std::fs::create_dir_all(parent);
159    }
160    if let Ok(json) = serde_json::to_string(graph) {
161        let _ = std::fs::write(&path, json);
162    }
163}
164
165/// Record that `files` were accessed together for one task, persisting the
166/// reinforced graph. No-op for fewer than two distinct files.
167pub fn record_access(project_root: &str, files: &[String]) {
168    if files.len() < 2 {
169        return;
170    }
171    let mut graph = load(project_root);
172    graph.record(files);
173    save(project_root, &graph);
174}
175
176/// Files historically co-accessed with `file`, strongest association first.
177pub fn related(project_root: &str, file: &str, top_k: usize) -> Vec<(String, f64)> {
178    load(project_root).related(file, top_k)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn co_access_strengthens_association() {
187        let mut g = CoAccessGraph::default();
188        g.record(&["a.rs".into(), "b.rs".into()]);
189        let rel = g.related("a.rs", 5);
190        assert_eq!(rel.len(), 1);
191        assert_eq!(rel[0].0, "b.rs");
192        assert!(rel[0].1 > 0.0);
193    }
194
195    #[test]
196    fn repeated_co_access_outweighs_single() {
197        let mut g = CoAccessGraph::default();
198        for _ in 0..5 {
199            g.record(&["x.rs".into(), "y.rs".into()]);
200        }
201        g.record(&["x.rs".into(), "z.rs".into()]);
202        let rel = g.related("x.rs", 5);
203        // y was reinforced 5×, z once → y must rank first.
204        assert_eq!(rel[0].0, "y.rs");
205        assert!(rel.iter().any(|(f, _)| f == "z.rs"));
206        assert!(rel[0].1 > rel.iter().find(|(f, _)| f == "z.rs").unwrap().1);
207    }
208
209    #[test]
210    fn weak_associations_are_pruned_by_decay() {
211        let mut g = CoAccessGraph::default();
212        g.record(&["a.rs".into(), "b.rs".into()]);
213        // Hammer an unrelated pair so the a–b edge decays below MIN_WEIGHT.
214        for _ in 0..400 {
215            g.record(&["c.rs".into(), "d.rs".into()]);
216        }
217        assert!(
218            g.related("a.rs", 5).is_empty(),
219            "decayed association should be pruned"
220        );
221        assert!(!g.related("c.rs", 5).is_empty());
222    }
223
224    #[test]
225    fn single_file_record_is_noop() {
226        let mut g = CoAccessGraph::default();
227        g.record(&["lonely.rs".into()]);
228        assert!(g.related("lonely.rs", 5).is_empty());
229    }
230
231    #[test]
232    fn association_is_symmetric() {
233        let mut g = CoAccessGraph::default();
234        g.record(&["one.rs".into(), "two.rs".into()]);
235        assert_eq!(g.related("one.rs", 5)[0].0, "two.rs");
236        assert_eq!(g.related("two.rs", 5)[0].0, "one.rs");
237    }
238
239    #[test]
240    fn serializes_round_trip() {
241        // Deterministic: exercises the persistence *format* (the on-disk path
242        // uses this same serde round-trip) without touching the process-global
243        // data-dir env var, which other tests mutate concurrently.
244        let mut g = CoAccessGraph::default();
245        g.record(&["alpha.rs".into(), "beta.rs".into()]);
246        let json = serde_json::to_string(&g).unwrap();
247        let restored: CoAccessGraph = serde_json::from_str(&json).unwrap();
248        let rel = restored.related("alpha.rs", 5);
249        assert_eq!(rel.len(), 1);
250        assert_eq!(rel[0].0, "beta.rs");
251    }
252
253    #[test]
254    fn neighbours_are_capped() {
255        let mut g = CoAccessGraph::default();
256        // Pair one hub file with many distinct others across separate records
257        // so its neighbour set exceeds the cap before pruning.
258        for i in 0..(MAX_NEIGHBORS + 20) {
259            g.record(&["hub.rs".into(), format!("f{i}.rs")]);
260        }
261        assert!(g.related("hub.rs", 1000).len() <= MAX_NEIGHBORS);
262    }
263}