use crate::models::{Finding, Severity};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabeledFinding {
pub finding_id: String,
pub detector: String,
#[serde(deserialize_with = "deserialize_severity_compat")]
pub severity: Severity,
pub title: String,
pub description: String,
pub file_path: String,
pub line_start: Option<u32>,
pub is_true_positive: bool,
pub reason: Option<String>,
pub timestamp: String,
}
impl LabeledFinding {
pub fn from_finding(finding: &Finding, is_tp: bool, reason: Option<String>) -> Self {
Self {
finding_id: finding.id.clone(),
detector: finding.detector.clone(),
severity: finding.severity,
title: finding.title.clone(),
description: finding.description.chars().take(500).collect(),
file_path: finding
.affected_files
.first()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_default(),
line_start: finding.line_start,
is_true_positive: is_tp,
reason,
timestamp: chrono::Utc::now().to_rfc3339(),
}
}
}
pub struct FeedbackCollector {
data_path: PathBuf,
}
impl FeedbackCollector {
pub fn new() -> Self {
let data_path = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("repotoire")
.join("training_data.jsonl");
Self { data_path }
}
pub fn with_path(path: impl Into<PathBuf>) -> Self {
Self {
data_path: path.into(),
}
}
pub fn record(
&self,
finding: &Finding,
is_tp: bool,
reason: Option<String>,
) -> std::io::Result<()> {
if let Some(parent) = self.data_path.parent() {
std::fs::create_dir_all(parent)?;
}
let labeled = LabeledFinding::from_finding(finding, is_tp, reason);
let json = serde_json::to_string(&labeled)?;
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.data_path)?;
writeln!(file, "{}", json)?;
Ok(())
}
pub fn record_batch(&self, findings: &[Finding], is_tp: bool) -> std::io::Result<usize> {
let mut count = 0;
for finding in findings {
self.record(finding, is_tp, None)?;
count += 1;
}
Ok(count)
}
pub fn load_all(&self) -> std::io::Result<Vec<LabeledFinding>> {
if !self.data_path.exists() {
return Ok(Vec::new());
}
let file = File::open(&self.data_path)?;
let reader = BufReader::new(file);
let mut examples = Vec::new();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
if let Ok(labeled) = serde_json::from_str::<LabeledFinding>(&line) {
examples.push(labeled);
}
}
Ok(examples)
}
pub fn load_label_map(&self) -> HashMap<String, bool> {
let entries = match self.load_all() {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to load feedback labels: {}", e);
return HashMap::new();
}
};
let mut map = HashMap::new();
for entry in entries {
map.insert(entry.finding_id, entry.is_true_positive);
}
map
}
pub fn stats(&self) -> std::io::Result<TrainingStats> {
let examples = self.load_all()?;
let tp_count = examples.iter().filter(|e| e.is_true_positive).count();
let fp_count = examples.iter().filter(|e| !e.is_true_positive).count();
let mut by_detector: std::collections::HashMap<String, (usize, usize)> =
std::collections::HashMap::new();
for ex in &examples {
let entry = by_detector.entry(ex.detector.clone()).or_insert((0, 0));
if ex.is_true_positive {
entry.0 += 1;
} else {
entry.1 += 1;
}
}
Ok(TrainingStats {
total: examples.len(),
true_positives: tp_count,
false_positives: fp_count,
by_detector,
})
}
pub fn data_path(&self) -> &Path {
&self.data_path
}
}
impl Default for FeedbackCollector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TrainingStats {
pub total: usize,
pub true_positives: usize,
pub false_positives: usize,
pub by_detector: std::collections::HashMap<String, (usize, usize)>,
}
impl std::fmt::Display for TrainingStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Training Data Statistics:")?;
writeln!(f, " Total examples: {}", self.total)?;
writeln!(
f,
" True positives: {} ({:.1}%)",
self.true_positives,
if self.total > 0 {
self.true_positives as f64 / self.total as f64 * 100.0
} else {
0.0
}
)?;
writeln!(
f,
" False positives: {} ({:.1}%)",
self.false_positives,
if self.total > 0 {
self.false_positives as f64 / self.total as f64 * 100.0
} else {
0.0
}
)?;
writeln!(f, "\n By detector:")?;
let mut detectors: Vec<_> = self.by_detector.iter().collect();
detectors.sort_by_key(|item| std::cmp::Reverse(item.1 .0 + item.1 .1));
for (detector, (tp, fp)) in detectors.iter().take(10) {
writeln!(f, " {}: {} TP, {} FP", detector, tp, fp)?;
}
Ok(())
}
}
fn deserialize_severity_compat<'de, D>(deserializer: D) -> Result<Severity, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse::<Severity>().map_err(serde::de::Error::custom)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_record_and_load() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("test_feedback.jsonl");
let collector = FeedbackCollector::with_path(&path);
let finding = Finding {
id: "test-123".into(),
detector: "TestDetector".into(),
severity: crate::models::Severity::High,
title: "Test finding".into(),
description: "A test finding for testing".into(),
..Default::default()
};
collector
.record(&finding, true, Some("Real issue".into()))
.expect("record true positive");
collector
.record(&finding, false, Some("Not a problem".into()))
.expect("record false positive");
let loaded = collector.load_all().expect("load feedback records");
assert_eq!(loaded.len(), 2);
assert!(loaded[0].is_true_positive);
assert!(!loaded[1].is_true_positive);
}
#[test]
fn test_load_label_map_last_writer_wins() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("test_labels.jsonl");
let collector = FeedbackCollector::with_path(&path);
let finding = Finding {
id: "abc-123".into(),
detector: "TestDetector".into(),
severity: crate::models::Severity::High,
title: "Test".into(),
..Default::default()
};
collector.record(&finding, true, None).unwrap();
collector
.record(&finding, false, Some("Actually not a bug".into()))
.unwrap();
let map = collector.load_label_map();
assert_eq!(map.len(), 1);
assert_eq!(
map.get("abc-123"),
Some(&false),
"Last entry (FP) should win"
);
}
#[test]
fn test_load_label_map_empty_file() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("nonexistent.jsonl");
let collector = FeedbackCollector::with_path(&path);
let map = collector.load_label_map();
assert!(map.is_empty());
}
}