Skip to main content

infigraph_core/learned/
mod.rs

1use std::path::{Path, PathBuf};
2
3use serde::{Deserialize, Serialize};
4
5use anyhow::Result;
6
7/// A single learned resolution pattern — records how SCIP corrected a
8/// tree-sitter heuristic so the correction can be replayed in future
9/// indexes even without SCIP data.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LearnedPattern {
12    /// File where the call occurs (relative path).
13    pub source_file: String,
14    /// Name of the called function/method.
15    pub call_name: String,
16    /// Correct target file (relative path).
17    pub resolved_to_file: String,
18    /// Correct target symbol id.
19    pub resolved_to_symbol: String,
20    /// Confidence score in 0.0..=1.0 — increases with repeated corrections.
21    pub confidence: f32,
22    /// Origin of the pattern: "scip" or "user".
23    pub source: String,
24    /// Unix-epoch seconds when this pattern was last updated.
25    pub last_updated: String,
26}
27
28/// Persistent store for learned resolution patterns.
29///
30/// Stored as `.infigraph/learned/patterns.json`, separate from the graph DB
31/// so it survives `infigraph index --full` and graph rebuilds.
32#[derive(Debug, Clone, Serialize, Deserialize, Default)]
33pub struct LearnedStore {
34    pub patterns: Vec<LearnedPattern>,
35}
36
37impl LearnedStore {
38    /// Load from `.infigraph/learned/patterns.json`.
39    /// Returns an empty store if the file doesn't exist or is malformed.
40    pub fn load(root: &Path) -> Self {
41        let path = Self::path(root);
42        match std::fs::read_to_string(&path) {
43            Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
44            Err(_) => Self::default(),
45        }
46    }
47
48    /// Persist to `.infigraph/learned/patterns.json`.
49    pub fn save(&self, root: &Path) -> Result<()> {
50        let path = Self::path(root);
51        std::fs::create_dir_all(path.parent().unwrap())?;
52        let json = serde_json::to_string_pretty(self)?;
53        std::fs::write(&path, json)?;
54        Ok(())
55    }
56
57    fn path(root: &Path) -> PathBuf {
58        root.join(".infigraph")
59            .join("learned")
60            .join("patterns.json")
61    }
62
63    /// Record a correction: SCIP resolved a call differently than tree-sitter.
64    ///
65    /// If a pattern for the same (source_file, call_name) already exists its
66    /// confidence is bumped by 0.1 (capped at 1.0). Otherwise a new pattern
67    /// is created with confidence 0.5.
68    pub fn record_correction(
69        &mut self,
70        source_file: &str,
71        call_name: &str,
72        resolved_to_file: &str,
73        resolved_to_symbol: &str,
74    ) {
75        if let Some(existing) = self
76            .patterns
77            .iter_mut()
78            .find(|p| p.source_file == source_file && p.call_name == call_name)
79        {
80            existing.resolved_to_file = resolved_to_file.to_string();
81            existing.resolved_to_symbol = resolved_to_symbol.to_string();
82            existing.confidence = (existing.confidence + 0.1).min(1.0);
83            existing.last_updated = epoch_now();
84        } else {
85            self.patterns.push(LearnedPattern {
86                source_file: source_file.to_string(),
87                call_name: call_name.to_string(),
88                resolved_to_file: resolved_to_file.to_string(),
89                resolved_to_symbol: resolved_to_symbol.to_string(),
90                confidence: 0.5,
91                source: "scip".to_string(),
92                last_updated: epoch_now(),
93            });
94        }
95    }
96
97    /// Look up a learned resolution for a call site.
98    /// Only returns patterns with confidence >= 0.3.
99    pub fn lookup(&self, source_file: &str, call_name: &str) -> Option<&LearnedPattern> {
100        self.patterns.iter().find(|p| {
101            p.source_file == source_file && p.call_name == call_name && p.confidence >= 0.3
102        })
103    }
104
105    /// Remove patterns pointing to files that no longer exist in the index.
106    pub fn prune_stale(&mut self, existing_files: &std::collections::HashSet<String>) {
107        self.patterns
108            .retain(|p| existing_files.contains(&p.resolved_to_file));
109    }
110
111    /// Clear all learned data.
112    pub fn clear(&mut self) {
113        self.patterns.clear();
114    }
115
116    /// Number of stored patterns.
117    pub fn len(&self) -> usize {
118        self.patterns.len()
119    }
120
121    /// Whether the store is empty.
122    pub fn is_empty(&self) -> bool {
123        self.patterns.is_empty()
124    }
125}
126
127/// Simple epoch-seconds timestamp without pulling in the `chrono` crate.
128fn epoch_now() -> String {
129    use std::time::SystemTime;
130    let duration = SystemTime::now()
131        .duration_since(SystemTime::UNIX_EPOCH)
132        .unwrap_or_default();
133    format!("{}", duration.as_secs())
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use tempfile::TempDir;
140
141    #[test]
142    fn test_record_and_lookup() {
143        let mut store = LearnedStore::default();
144        store.record_correction(
145            "main.py",
146            "authenticate",
147            "auth.py",
148            "auth.py::authenticate",
149        );
150
151        let found = store.lookup("main.py", "authenticate");
152        assert!(found.is_some());
153        let p = found.unwrap();
154        assert_eq!(p.resolved_to_symbol, "auth.py::authenticate");
155        assert!((p.confidence - 0.5).abs() < f32::EPSILON);
156    }
157
158    #[test]
159    fn test_repeated_correction_increases_confidence() {
160        let mut store = LearnedStore::default();
161        store.record_correction(
162            "main.py",
163            "authenticate",
164            "auth.py",
165            "auth.py::authenticate",
166        );
167        store.record_correction(
168            "main.py",
169            "authenticate",
170            "auth.py",
171            "auth.py::authenticate",
172        );
173        store.record_correction(
174            "main.py",
175            "authenticate",
176            "auth.py",
177            "auth.py::authenticate",
178        );
179
180        let p = store.lookup("main.py", "authenticate").unwrap();
181        assert!((p.confidence - 0.7).abs() < f32::EPSILON);
182    }
183
184    #[test]
185    fn test_confidence_caps_at_one() {
186        let mut store = LearnedStore::default();
187        for _ in 0..20 {
188            store.record_correction("a.py", "foo", "b.py", "b.py::foo");
189        }
190        let p = store.lookup("a.py", "foo").unwrap();
191        assert!((p.confidence - 1.0).abs() < f32::EPSILON);
192    }
193
194    #[test]
195    fn test_low_confidence_not_returned() {
196        let mut store = LearnedStore::default();
197        store.patterns.push(LearnedPattern {
198            source_file: "a.py".into(),
199            call_name: "foo".into(),
200            resolved_to_file: "b.py".into(),
201            resolved_to_symbol: "b.py::foo".into(),
202            confidence: 0.1,
203            source: "scip".into(),
204            last_updated: "0".into(),
205        });
206        assert!(store.lookup("a.py", "foo").is_none());
207    }
208
209    #[test]
210    fn test_save_and_load() {
211        let tmp = TempDir::new().unwrap();
212        let root = tmp.path();
213        let mut store = LearnedStore::default();
214        store.record_correction("main.py", "auth", "auth.py", "auth.py::auth");
215        store.save(root).unwrap();
216
217        let loaded = LearnedStore::load(root);
218        assert_eq!(loaded.len(), 1);
219        assert_eq!(loaded.patterns[0].call_name, "auth");
220    }
221
222    #[test]
223    fn test_prune_stale() {
224        let mut store = LearnedStore::default();
225        store.record_correction("a.py", "foo", "b.py", "b.py::foo");
226        store.record_correction("a.py", "bar", "c.py", "c.py::bar");
227
228        let mut existing = std::collections::HashSet::new();
229        existing.insert("b.py".to_string());
230        store.prune_stale(&existing);
231
232        assert_eq!(store.len(), 1);
233        assert_eq!(store.patterns[0].call_name, "foo");
234    }
235
236    #[test]
237    fn test_clear() {
238        let mut store = LearnedStore::default();
239        store.record_correction("a.py", "foo", "b.py", "b.py::foo");
240        assert_eq!(store.len(), 1);
241        store.clear();
242        assert_eq!(store.len(), 0);
243    }
244}