lean_ctx/core/
cooccurrence.rs1use std::collections::HashMap;
21use std::path::PathBuf;
22
23use serde::{Deserialize, Serialize};
24
25const DECAY: f64 = 0.98;
28const MIN_WEIGHT: f64 = 0.08;
30const LTP_INCREMENT: f64 = 1.0;
32const MAX_NEIGHBORS: usize = 32;
34const MAX_FILES: usize = 5_000;
37const MAX_RECORD_FILES: usize = 24;
40
41#[derive(Debug, Default, Clone, Serialize, Deserialize)]
43pub struct CoAccessGraph {
44 edges: HashMap<String, HashMap<String, f64>>,
46}
47
48impl CoAccessGraph {
49 pub fn record(&mut self, files: &[String]) {
53 let mut uniq: Vec<&String> = Vec::new();
55 for f in files {
56 if !f.is_empty() && !uniq.contains(&f) {
57 uniq.push(f);
58 if uniq.len() >= MAX_RECORD_FILES {
59 break;
60 }
61 }
62 }
63 if uniq.len() < 2 {
64 return; }
66
67 self.decay_all();
68
69 for i in 0..uniq.len() {
70 for j in (i + 1)..uniq.len() {
71 self.bump(uniq[i], uniq[j]);
72 self.bump(uniq[j], uniq[i]);
73 }
74 }
75
76 self.prune();
77 }
78
79 pub fn related(&self, file: &str, top_k: usize) -> Vec<(String, f64)> {
81 let Some(neighbours) = self.edges.get(file) else {
82 return Vec::new();
83 };
84 let mut v: Vec<(String, f64)> = neighbours.iter().map(|(k, &w)| (k.clone(), w)).collect();
85 v.sort_by(|a, b| b.1.total_cmp(&a.1));
86 v.truncate(top_k);
87 v
88 }
89
90 fn bump(&mut self, from: &str, to: &str) {
91 let entry = self.edges.entry(from.to_string()).or_default();
92 *entry.entry(to.to_string()).or_insert(0.0) += LTP_INCREMENT;
93 }
94
95 fn decay_all(&mut self) {
96 for neighbours in self.edges.values_mut() {
97 for w in neighbours.values_mut() {
98 *w *= DECAY;
99 }
100 }
101 }
102
103 fn prune(&mut self) {
104 for neighbours in self.edges.values_mut() {
105 neighbours.retain(|_, &mut w| w >= MIN_WEIGHT);
106 if neighbours.len() > MAX_NEIGHBORS {
107 let mut kept: Vec<(String, f64)> =
108 neighbours.iter().map(|(k, &w)| (k.clone(), w)).collect();
109 kept.sort_by(|a, b| b.1.total_cmp(&a.1));
110 kept.truncate(MAX_NEIGHBORS);
111 *neighbours = kept.into_iter().collect();
112 }
113 }
114 self.edges.retain(|_, neighbours| !neighbours.is_empty());
115
116 if self.edges.len() > MAX_FILES {
117 let mut by_degree: Vec<(String, usize)> = self
119 .edges
120 .iter()
121 .map(|(k, n)| (k.clone(), n.len()))
122 .collect();
123 by_degree.sort_by_key(|(_, d)| *d);
124 let evict = self.edges.len() - MAX_FILES;
125 for (file, _) in by_degree.into_iter().take(evict) {
126 self.edges.remove(&file);
127 }
128 }
129 }
130}
131
132fn store_path(project_root: &str) -> Option<PathBuf> {
135 let normalized = crate::core::graph_index::normalize_project_root(project_root);
136 let hash = crate::core::project_hash::hash_project_root(&normalized);
137 crate::core::data_dir::lean_ctx_data_dir()
138 .ok()
139 .map(|d| d.join("cooccurrence").join(format!("{hash}.json")))
140}
141
142pub fn load(project_root: &str) -> CoAccessGraph {
144 let Some(path) = store_path(project_root) else {
145 return CoAccessGraph::default();
146 };
147 std::fs::read_to_string(&path)
148 .ok()
149 .and_then(|s| serde_json::from_str(&s).ok())
150 .unwrap_or_default()
151}
152
153fn save(project_root: &str, graph: &CoAccessGraph) {
154 let Some(path) = store_path(project_root) else {
155 return;
156 };
157 if let Some(parent) = path.parent() {
158 let _ = std::fs::create_dir_all(parent);
159 }
160 if let Ok(json) = serde_json::to_string(graph) {
161 let _ = std::fs::write(&path, json);
162 }
163}
164
165pub fn record_access(project_root: &str, files: &[String]) {
168 if files.len() < 2 {
169 return;
170 }
171 let mut graph = load(project_root);
172 graph.record(files);
173 save(project_root, &graph);
174}
175
176pub fn related(project_root: &str, file: &str, top_k: usize) -> Vec<(String, f64)> {
178 load(project_root).related(file, top_k)
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn co_access_strengthens_association() {
187 let mut g = CoAccessGraph::default();
188 g.record(&["a.rs".into(), "b.rs".into()]);
189 let rel = g.related("a.rs", 5);
190 assert_eq!(rel.len(), 1);
191 assert_eq!(rel[0].0, "b.rs");
192 assert!(rel[0].1 > 0.0);
193 }
194
195 #[test]
196 fn repeated_co_access_outweighs_single() {
197 let mut g = CoAccessGraph::default();
198 for _ in 0..5 {
199 g.record(&["x.rs".into(), "y.rs".into()]);
200 }
201 g.record(&["x.rs".into(), "z.rs".into()]);
202 let rel = g.related("x.rs", 5);
203 assert_eq!(rel[0].0, "y.rs");
205 assert!(rel.iter().any(|(f, _)| f == "z.rs"));
206 assert!(rel[0].1 > rel.iter().find(|(f, _)| f == "z.rs").unwrap().1);
207 }
208
209 #[test]
210 fn weak_associations_are_pruned_by_decay() {
211 let mut g = CoAccessGraph::default();
212 g.record(&["a.rs".into(), "b.rs".into()]);
213 for _ in 0..400 {
215 g.record(&["c.rs".into(), "d.rs".into()]);
216 }
217 assert!(
218 g.related("a.rs", 5).is_empty(),
219 "decayed association should be pruned"
220 );
221 assert!(!g.related("c.rs", 5).is_empty());
222 }
223
224 #[test]
225 fn single_file_record_is_noop() {
226 let mut g = CoAccessGraph::default();
227 g.record(&["lonely.rs".into()]);
228 assert!(g.related("lonely.rs", 5).is_empty());
229 }
230
231 #[test]
232 fn association_is_symmetric() {
233 let mut g = CoAccessGraph::default();
234 g.record(&["one.rs".into(), "two.rs".into()]);
235 assert_eq!(g.related("one.rs", 5)[0].0, "two.rs");
236 assert_eq!(g.related("two.rs", 5)[0].0, "one.rs");
237 }
238
239 #[test]
240 fn serializes_round_trip() {
241 let mut g = CoAccessGraph::default();
245 g.record(&["alpha.rs".into(), "beta.rs".into()]);
246 let json = serde_json::to_string(&g).unwrap();
247 let restored: CoAccessGraph = serde_json::from_str(&json).unwrap();
248 let rel = restored.related("alpha.rs", 5);
249 assert_eq!(rel.len(), 1);
250 assert_eq!(rel[0].0, "beta.rs");
251 }
252
253 #[test]
254 fn neighbours_are_capped() {
255 let mut g = CoAccessGraph::default();
256 for i in 0..(MAX_NEIGHBORS + 20) {
259 g.record(&["hub.rs".into(), format!("f{i}.rs")]);
260 }
261 assert!(g.related("hub.rs", 1000).len() <= MAX_NEIGHBORS);
262 }
263}