use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use anyhow::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnedPattern {
pub source_file: String,
pub call_name: String,
pub resolved_to_file: String,
pub resolved_to_symbol: String,
pub confidence: f32,
pub source: String,
pub last_updated: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LearnedStore {
pub patterns: Vec<LearnedPattern>,
}
impl LearnedStore {
pub fn load(root: &Path) -> Self {
let path = Self::path(root);
match std::fs::read_to_string(&path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
Err(_) => Self::default(),
}
}
pub fn save(&self, root: &Path) -> Result<()> {
let path = Self::path(root);
std::fs::create_dir_all(path.parent().unwrap())?;
let json = serde_json::to_string_pretty(self)?;
std::fs::write(&path, json)?;
Ok(())
}
fn path(root: &Path) -> PathBuf {
root.join(".infigraph")
.join("learned")
.join("patterns.json")
}
pub fn record_correction(
&mut self,
source_file: &str,
call_name: &str,
resolved_to_file: &str,
resolved_to_symbol: &str,
) {
if let Some(existing) = self
.patterns
.iter_mut()
.find(|p| p.source_file == source_file && p.call_name == call_name)
{
existing.resolved_to_file = resolved_to_file.to_string();
existing.resolved_to_symbol = resolved_to_symbol.to_string();
existing.confidence = (existing.confidence + 0.1).min(1.0);
existing.last_updated = epoch_now();
} else {
self.patterns.push(LearnedPattern {
source_file: source_file.to_string(),
call_name: call_name.to_string(),
resolved_to_file: resolved_to_file.to_string(),
resolved_to_symbol: resolved_to_symbol.to_string(),
confidence: 0.5,
source: "scip".to_string(),
last_updated: epoch_now(),
});
}
}
pub fn lookup(&self, source_file: &str, call_name: &str) -> Option<&LearnedPattern> {
self.patterns.iter().find(|p| {
p.source_file == source_file && p.call_name == call_name && p.confidence >= 0.3
})
}
pub fn prune_stale(&mut self, existing_files: &std::collections::HashSet<String>) {
self.patterns
.retain(|p| existing_files.contains(&p.resolved_to_file));
}
pub fn clear(&mut self) {
self.patterns.clear();
}
pub fn len(&self) -> usize {
self.patterns.len()
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
}
fn epoch_now() -> String {
use std::time::SystemTime;
let duration = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_record_and_lookup() {
let mut store = LearnedStore::default();
store.record_correction(
"main.py",
"authenticate",
"auth.py",
"auth.py::authenticate",
);
let found = store.lookup("main.py", "authenticate");
assert!(found.is_some());
let p = found.unwrap();
assert_eq!(p.resolved_to_symbol, "auth.py::authenticate");
assert!((p.confidence - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_repeated_correction_increases_confidence() {
let mut store = LearnedStore::default();
store.record_correction(
"main.py",
"authenticate",
"auth.py",
"auth.py::authenticate",
);
store.record_correction(
"main.py",
"authenticate",
"auth.py",
"auth.py::authenticate",
);
store.record_correction(
"main.py",
"authenticate",
"auth.py",
"auth.py::authenticate",
);
let p = store.lookup("main.py", "authenticate").unwrap();
assert!((p.confidence - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_confidence_caps_at_one() {
let mut store = LearnedStore::default();
for _ in 0..20 {
store.record_correction("a.py", "foo", "b.py", "b.py::foo");
}
let p = store.lookup("a.py", "foo").unwrap();
assert!((p.confidence - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_low_confidence_not_returned() {
let mut store = LearnedStore::default();
store.patterns.push(LearnedPattern {
source_file: "a.py".into(),
call_name: "foo".into(),
resolved_to_file: "b.py".into(),
resolved_to_symbol: "b.py::foo".into(),
confidence: 0.1,
source: "scip".into(),
last_updated: "0".into(),
});
assert!(store.lookup("a.py", "foo").is_none());
}
#[test]
fn test_save_and_load() {
let tmp = TempDir::new().unwrap();
let root = tmp.path();
let mut store = LearnedStore::default();
store.record_correction("main.py", "auth", "auth.py", "auth.py::auth");
store.save(root).unwrap();
let loaded = LearnedStore::load(root);
assert_eq!(loaded.len(), 1);
assert_eq!(loaded.patterns[0].call_name, "auth");
}
#[test]
fn test_prune_stale() {
let mut store = LearnedStore::default();
store.record_correction("a.py", "foo", "b.py", "b.py::foo");
store.record_correction("a.py", "bar", "c.py", "c.py::bar");
let mut existing = std::collections::HashSet::new();
existing.insert("b.py".to_string());
store.prune_stale(&existing);
assert_eq!(store.len(), 1);
assert_eq!(store.patterns[0].call_name, "foo");
}
#[test]
fn test_clear() {
let mut store = LearnedStore::default();
store.record_correction("a.py", "foo", "b.py", "b.py::foo");
assert_eq!(store.len(), 1);
store.clear();
assert_eq!(store.len(), 0);
}
}