1use std::path::{Path, PathBuf};
2
3use serde::{Deserialize, Serialize};
4
5use anyhow::Result;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LearnedPattern {
12 pub source_file: String,
14 pub call_name: String,
16 pub resolved_to_file: String,
18 pub resolved_to_symbol: String,
20 pub confidence: f32,
22 pub source: String,
24 pub last_updated: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, Default)]
33pub struct LearnedStore {
34 pub patterns: Vec<LearnedPattern>,
35}
36
37impl LearnedStore {
38 pub fn load(root: &Path) -> Self {
41 let path = Self::path(root);
42 match std::fs::read_to_string(&path) {
43 Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
44 Err(_) => Self::default(),
45 }
46 }
47
48 pub fn save(&self, root: &Path) -> Result<()> {
50 let path = Self::path(root);
51 std::fs::create_dir_all(path.parent().unwrap())?;
52 let json = serde_json::to_string_pretty(self)?;
53 std::fs::write(&path, json)?;
54 Ok(())
55 }
56
57 fn path(root: &Path) -> PathBuf {
58 root.join(".infigraph")
59 .join("learned")
60 .join("patterns.json")
61 }
62
63 pub fn record_correction(
69 &mut self,
70 source_file: &str,
71 call_name: &str,
72 resolved_to_file: &str,
73 resolved_to_symbol: &str,
74 ) {
75 if let Some(existing) = self
76 .patterns
77 .iter_mut()
78 .find(|p| p.source_file == source_file && p.call_name == call_name)
79 {
80 existing.resolved_to_file = resolved_to_file.to_string();
81 existing.resolved_to_symbol = resolved_to_symbol.to_string();
82 existing.confidence = (existing.confidence + 0.1).min(1.0);
83 existing.last_updated = epoch_now();
84 } else {
85 self.patterns.push(LearnedPattern {
86 source_file: source_file.to_string(),
87 call_name: call_name.to_string(),
88 resolved_to_file: resolved_to_file.to_string(),
89 resolved_to_symbol: resolved_to_symbol.to_string(),
90 confidence: 0.5,
91 source: "scip".to_string(),
92 last_updated: epoch_now(),
93 });
94 }
95 }
96
97 pub fn lookup(&self, source_file: &str, call_name: &str) -> Option<&LearnedPattern> {
100 self.patterns.iter().find(|p| {
101 p.source_file == source_file && p.call_name == call_name && p.confidence >= 0.3
102 })
103 }
104
105 pub fn prune_stale(&mut self, existing_files: &std::collections::HashSet<String>) {
107 self.patterns
108 .retain(|p| existing_files.contains(&p.resolved_to_file));
109 }
110
111 pub fn clear(&mut self) {
113 self.patterns.clear();
114 }
115
116 pub fn len(&self) -> usize {
118 self.patterns.len()
119 }
120
121 pub fn is_empty(&self) -> bool {
123 self.patterns.is_empty()
124 }
125}
126
127fn epoch_now() -> String {
129 use std::time::SystemTime;
130 let duration = SystemTime::now()
131 .duration_since(SystemTime::UNIX_EPOCH)
132 .unwrap_or_default();
133 format!("{}", duration.as_secs())
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use tempfile::TempDir;
140
141 #[test]
142 fn test_record_and_lookup() {
143 let mut store = LearnedStore::default();
144 store.record_correction(
145 "main.py",
146 "authenticate",
147 "auth.py",
148 "auth.py::authenticate",
149 );
150
151 let found = store.lookup("main.py", "authenticate");
152 assert!(found.is_some());
153 let p = found.unwrap();
154 assert_eq!(p.resolved_to_symbol, "auth.py::authenticate");
155 assert!((p.confidence - 0.5).abs() < f32::EPSILON);
156 }
157
158 #[test]
159 fn test_repeated_correction_increases_confidence() {
160 let mut store = LearnedStore::default();
161 store.record_correction(
162 "main.py",
163 "authenticate",
164 "auth.py",
165 "auth.py::authenticate",
166 );
167 store.record_correction(
168 "main.py",
169 "authenticate",
170 "auth.py",
171 "auth.py::authenticate",
172 );
173 store.record_correction(
174 "main.py",
175 "authenticate",
176 "auth.py",
177 "auth.py::authenticate",
178 );
179
180 let p = store.lookup("main.py", "authenticate").unwrap();
181 assert!((p.confidence - 0.7).abs() < f32::EPSILON);
182 }
183
184 #[test]
185 fn test_confidence_caps_at_one() {
186 let mut store = LearnedStore::default();
187 for _ in 0..20 {
188 store.record_correction("a.py", "foo", "b.py", "b.py::foo");
189 }
190 let p = store.lookup("a.py", "foo").unwrap();
191 assert!((p.confidence - 1.0).abs() < f32::EPSILON);
192 }
193
194 #[test]
195 fn test_low_confidence_not_returned() {
196 let mut store = LearnedStore::default();
197 store.patterns.push(LearnedPattern {
198 source_file: "a.py".into(),
199 call_name: "foo".into(),
200 resolved_to_file: "b.py".into(),
201 resolved_to_symbol: "b.py::foo".into(),
202 confidence: 0.1,
203 source: "scip".into(),
204 last_updated: "0".into(),
205 });
206 assert!(store.lookup("a.py", "foo").is_none());
207 }
208
209 #[test]
210 fn test_save_and_load() {
211 let tmp = TempDir::new().unwrap();
212 let root = tmp.path();
213 let mut store = LearnedStore::default();
214 store.record_correction("main.py", "auth", "auth.py", "auth.py::auth");
215 store.save(root).unwrap();
216
217 let loaded = LearnedStore::load(root);
218 assert_eq!(loaded.len(), 1);
219 assert_eq!(loaded.patterns[0].call_name, "auth");
220 }
221
222 #[test]
223 fn test_prune_stale() {
224 let mut store = LearnedStore::default();
225 store.record_correction("a.py", "foo", "b.py", "b.py::foo");
226 store.record_correction("a.py", "bar", "c.py", "c.py::bar");
227
228 let mut existing = std::collections::HashSet::new();
229 existing.insert("b.py".to_string());
230 store.prune_stale(&existing);
231
232 assert_eq!(store.len(), 1);
233 assert_eq!(store.patterns[0].call_name, "foo");
234 }
235
236 #[test]
237 fn test_clear() {
238 let mut store = LearnedStore::default();
239 store.record_correction("a.py", "foo", "b.py", "b.py::foo");
240 assert_eq!(store.len(), 1);
241 store.clear();
242 assert_eq!(store.len(), 0);
243 }
244}