use crate::dual_branch::{BranchLabel, PredictionReason, PredictionReasonKind};
use crate::models::{Finding, Severity};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabeledFinding {
pub finding_id: String,
pub detector: String,
#[serde(deserialize_with = "deserialize_severity_compat")]
pub severity: Severity,
pub title: String,
pub description: String,
pub file_path: String,
pub line_start: Option<u32>,
pub is_true_positive: bool,
pub reason: Option<String>,
pub timestamp: String,
#[serde(default)]
pub had_alternative_branch: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub predicted_label: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub alternative_severity: Option<Severity>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub prediction_reason_kinds: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub original_severity: Option<Severity>,
}
pub(crate) fn reason_kind(r: &PredictionReason) -> &'static str {
match &r.kind {
PredictionReasonKind::BundledCode => "BundledCode",
PredictionReasonKind::NonProductionPath => "NonProductionPath",
PredictionReasonKind::MultiDetectorAgreement { .. } => "MultiDetectorAgreement",
PredictionReasonKind::TestFixtureFile => "TestFixtureFile",
PredictionReasonKind::HierarchicalLevel { .. } => "HierarchicalLevel",
PredictionReasonKind::KeywordArgument { .. } => "KeywordArgument",
PredictionReasonKind::FirstArgIdentifier { .. } => "FirstArgIdentifier",
PredictionReasonKind::EnclosingScope { .. } => "EnclosingScope",
PredictionReasonKind::ImportPresence { .. } => "ImportPresence",
PredictionReasonKind::FilePath { .. } => "FilePath",
PredictionReasonKind::StructuralPattern { .. } => "StructuralPattern",
PredictionReasonKind::Custom { .. } => "Custom",
}
}
pub(crate) fn predicted_label_from_alt(alt_label: BranchLabel) -> &'static str {
match alt_label.opposite() {
BranchLabel::RealBug => "real_bug",
BranchLabel::Benign => "benign",
}
}
impl LabeledFinding {
pub fn from_finding(finding: &Finding, is_tp: bool, reason: Option<String>) -> Self {
let had_alternative_branch = finding.alternative_branch.is_some();
let predicted_label = finding
.alternative_branch
.as_ref()
.map(|alt| predicted_label_from_alt(alt.label).to_string());
let alternative_severity = finding.alternative_branch.as_ref().map(|alt| alt.severity);
let prediction_reason_kinds: Vec<String> = finding
.prediction_reasons
.iter()
.map(|r| reason_kind(r).to_string())
.collect();
Self {
finding_id: finding.id.clone(),
detector: finding.detector.clone(),
severity: finding.severity,
title: finding.title.clone(),
description: finding.description.chars().take(500).collect(),
file_path: finding
.affected_files
.first()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_default(),
line_start: finding.line_start,
is_true_positive: is_tp,
reason,
timestamp: chrono::Utc::now().to_rfc3339(),
had_alternative_branch,
predicted_label,
alternative_severity,
prediction_reason_kinds,
original_severity: finding.original_severity,
}
}
}
pub struct FeedbackCollector {
data_path: PathBuf,
}
impl FeedbackCollector {
pub fn new() -> Self {
let data_path = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("repotoire")
.join("training_data.jsonl");
Self { data_path }
}
pub fn with_path(path: impl Into<PathBuf>) -> Self {
Self {
data_path: path.into(),
}
}
pub fn record(
&self,
finding: &Finding,
is_tp: bool,
reason: Option<String>,
) -> std::io::Result<()> {
if let Some(parent) = self.data_path.parent() {
std::fs::create_dir_all(parent)?;
}
let labeled = LabeledFinding::from_finding(finding, is_tp, reason);
let json = serde_json::to_string(&labeled)?;
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&self.data_path)?;
writeln!(file, "{}", json)?;
Ok(())
}
pub fn record_batch(&self, findings: &[Finding], is_tp: bool) -> std::io::Result<usize> {
let mut count = 0;
for finding in findings {
self.record(finding, is_tp, None)?;
count += 1;
}
Ok(count)
}
pub fn load_all(&self) -> std::io::Result<Vec<LabeledFinding>> {
if !self.data_path.exists() {
return Ok(Vec::new());
}
let file = File::open(&self.data_path)?;
let reader = BufReader::new(file);
let mut examples = Vec::new();
for line in reader.lines() {
let line = line?;
if line.trim().is_empty() {
continue;
}
if let Ok(labeled) = serde_json::from_str::<LabeledFinding>(&line) {
examples.push(labeled);
}
}
Ok(examples)
}
pub fn load_label_map(&self) -> HashMap<String, bool> {
let entries = match self.load_all() {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to load feedback labels: {}", e);
return HashMap::new();
}
};
let mut map = HashMap::new();
for entry in entries {
map.insert(entry.finding_id, entry.is_true_positive);
}
map
}
pub fn stats(&self) -> std::io::Result<TrainingStats> {
let examples = self.load_all()?;
let tp_count = examples.iter().filter(|e| e.is_true_positive).count();
let fp_count = examples.iter().filter(|e| !e.is_true_positive).count();
let mut by_detector: std::collections::HashMap<String, (usize, usize)> =
std::collections::HashMap::new();
for ex in &examples {
let entry = by_detector.entry(ex.detector.clone()).or_insert((0, 0));
if ex.is_true_positive {
entry.0 += 1;
} else {
entry.1 += 1;
}
}
Ok(TrainingStats {
total: examples.len(),
true_positives: tp_count,
false_positives: fp_count,
by_detector,
})
}
pub fn data_path(&self) -> &Path {
&self.data_path
}
}
impl Default for FeedbackCollector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TrainingStats {
pub total: usize,
pub true_positives: usize,
pub false_positives: usize,
pub by_detector: std::collections::HashMap<String, (usize, usize)>,
}
impl std::fmt::Display for TrainingStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Training Data Statistics:")?;
writeln!(f, " Total examples: {}", self.total)?;
writeln!(
f,
" True positives: {} ({:.1}%)",
self.true_positives,
if self.total > 0 {
self.true_positives as f64 / self.total as f64 * 100.0
} else {
0.0
}
)?;
writeln!(
f,
" False positives: {} ({:.1}%)",
self.false_positives,
if self.total > 0 {
self.false_positives as f64 / self.total as f64 * 100.0
} else {
0.0
}
)?;
writeln!(f, "\n By detector:")?;
let mut detectors: Vec<_> = self.by_detector.iter().collect();
detectors.sort_by_key(|item| std::cmp::Reverse(item.1 .0 + item.1 .1));
for (detector, (tp, fp)) in detectors.iter().take(10) {
writeln!(f, " {}: {} TP, {} FP", detector, tp, fp)?;
}
Ok(())
}
}
fn deserialize_severity_compat<'de, D>(deserializer: D) -> Result<Severity, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse::<Severity>().map_err(serde::de::Error::custom)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_record_and_load() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("test_feedback.jsonl");
let collector = FeedbackCollector::with_path(&path);
let finding = Finding {
id: "test-123".into(),
detector: "TestDetector".into(),
severity: crate::models::Severity::High,
title: "Test finding".into(),
description: "A test finding for testing".into(),
..Default::default()
};
collector
.record(&finding, true, Some("Real issue".into()))
.expect("record true positive");
collector
.record(&finding, false, Some("Not a problem".into()))
.expect("record false positive");
let loaded = collector.load_all().expect("load feedback records");
assert_eq!(loaded.len(), 2);
assert!(loaded[0].is_true_positive);
assert!(!loaded[1].is_true_positive);
}
#[test]
fn test_load_label_map_last_writer_wins() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("test_labels.jsonl");
let collector = FeedbackCollector::with_path(&path);
let finding = Finding {
id: "abc-123".into(),
detector: "TestDetector".into(),
severity: crate::models::Severity::High,
title: "Test".into(),
..Default::default()
};
collector.record(&finding, true, None).unwrap();
collector
.record(&finding, false, Some("Actually not a bug".into()))
.unwrap();
let map = collector.load_label_map();
assert_eq!(map.len(), 1);
assert_eq!(
map.get("abc-123"),
Some(&false),
"Last entry (FP) should win"
);
}
#[test]
fn test_load_label_map_empty_file() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("nonexistent.jsonl");
let collector = FeedbackCollector::with_path(&path);
let map = collector.load_label_map();
assert!(map.is_empty());
}
use crate::dual_branch::{
AlternativeBranch, BranchLabel, PredictionReason, PredictionReasonKind,
};
fn make_reason(kind: PredictionReasonKind) -> PredictionReason {
PredictionReason {
kind,
weight: 0.0,
note: String::new(),
}
}
#[test]
fn dual_branch_fields_roundtrip_through_jsonl() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("dual.jsonl");
let collector = FeedbackCollector::with_path(&path);
let finding = Finding {
id: "dual-1".into(),
detector: "JwtWeakDetector".into(),
severity: Severity::High,
title: "Test dual-branch".into(),
description: "Test".into(),
alternative_branch: Some(AlternativeBranch {
label: BranchLabel::Benign,
severity: Severity::Info,
title: "Hardened JWT".into(),
description: "Algorithms allowlist present".into(),
suggested_fix: None,
}),
prediction_reasons: vec![
make_reason(PredictionReasonKind::EnclosingScope {
scope_kind: "function".into(),
name: "decode_token".into(),
}),
make_reason(PredictionReasonKind::KeywordArgument {
name: "algorithms".into(),
value: "[\"RS256\"]".into(),
}),
],
original_severity: Some(Severity::Critical),
..Default::default()
};
collector
.record(&finding, false, Some("predictor mistake".into()))
.expect("record");
let loaded = collector.load_all().expect("load");
assert_eq!(loaded.len(), 1);
let row = &loaded[0];
assert!(row.had_alternative_branch, "dual-branch flag set");
assert_eq!(row.predicted_label.as_deref(), Some("real_bug"));
assert_eq!(row.alternative_severity, Some(Severity::Info));
assert_eq!(
row.prediction_reason_kinds,
vec!["EnclosingScope".to_string(), "KeywordArgument".to_string()]
);
assert_eq!(row.original_severity, Some(Severity::Critical));
}
#[test]
fn predicted_label_inverts_alt_realbug_to_benign() {
let finding = Finding {
id: "x".into(),
detector: "D".into(),
severity: Severity::Info,
title: "t".into(),
description: String::new(),
alternative_branch: Some(AlternativeBranch {
label: BranchLabel::RealBug,
severity: Severity::Critical,
title: "alt".into(),
description: String::new(),
suggested_fix: None,
}),
..Default::default()
};
let labeled = LabeledFinding::from_finding(&finding, true, None);
assert_eq!(labeled.predicted_label.as_deref(), Some("benign"));
}
#[test]
fn predicted_label_inverts_alt_benign_to_realbug() {
let finding = Finding {
id: "x".into(),
detector: "D".into(),
severity: Severity::High,
title: "t".into(),
description: String::new(),
alternative_branch: Some(AlternativeBranch {
label: BranchLabel::Benign,
severity: Severity::Info,
title: "alt".into(),
description: String::new(),
suggested_fix: None,
}),
..Default::default()
};
let labeled = LabeledFinding::from_finding(&finding, true, None);
assert_eq!(labeled.predicted_label.as_deref(), Some("real_bug"));
}
#[test]
fn single_branch_finding_has_no_predicted_label() {
let finding = Finding {
id: "x".into(),
detector: "D".into(),
severity: Severity::High,
title: "t".into(),
description: String::new(),
..Default::default()
};
let labeled = LabeledFinding::from_finding(&finding, false, None);
assert!(!labeled.had_alternative_branch);
assert!(labeled.predicted_label.is_none());
assert!(labeled.alternative_severity.is_none());
assert!(labeled.prediction_reason_kinds.is_empty());
assert!(labeled.original_severity.is_none());
}
#[test]
fn reason_kind_covers_every_variant_exhaustively() {
let all_variants: Vec<PredictionReasonKind> = vec![
PredictionReasonKind::BundledCode,
PredictionReasonKind::NonProductionPath,
PredictionReasonKind::MultiDetectorAgreement { count: 2 },
PredictionReasonKind::TestFixtureFile,
PredictionReasonKind::HierarchicalLevel {
level_name: "L1 Token".into(),
z_score: 0.0,
},
PredictionReasonKind::KeywordArgument {
name: "verify".into(),
value: "False".into(),
},
PredictionReasonKind::FirstArgIdentifier {
name: "password".into(),
},
PredictionReasonKind::EnclosingScope {
scope_kind: "function".into(),
name: "f".into(),
},
PredictionReasonKind::ImportPresence {
module: "jwt".into(),
},
PredictionReasonKind::FilePath {
hint: "/scripts".into(),
},
PredictionReasonKind::StructuralPattern {
description: "x[:N]".into(),
},
PredictionReasonKind::Custom {
description: "legacy".into(),
},
];
let reasons: Vec<PredictionReason> = all_variants.into_iter().map(make_reason).collect();
let n_variants = reasons.len();
let finding = Finding {
id: "x".into(),
detector: "D".into(),
severity: Severity::Info,
title: "t".into(),
description: String::new(),
prediction_reasons: reasons,
..Default::default()
};
let labeled = LabeledFinding::from_finding(&finding, true, None);
assert_eq!(
labeled.prediction_reason_kinds.len(),
n_variants,
"every variant should map to a discriminant string"
);
let mut sorted = labeled.prediction_reason_kinds.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
n_variants,
"discriminant strings should be unique per variant; got {:?}",
labeled.prediction_reason_kinds
);
}
#[test]
fn legacy_jsonl_without_dual_branch_fields_deserializes() {
let legacy_json = r#"{
"finding_id": "old-1",
"detector": "TestDetector",
"severity": "high",
"title": "Old finding",
"description": "From before Phase 3 prep",
"file_path": "/tmp/x.py",
"line_start": 10,
"is_true_positive": true,
"reason": null,
"timestamp": "2026-01-01T00:00:00Z"
}"#;
let parsed: LabeledFinding =
serde_json::from_str(legacy_json).expect("legacy entry must parse");
assert!(!parsed.had_alternative_branch);
assert!(parsed.predicted_label.is_none());
assert!(parsed.alternative_severity.is_none());
assert!(parsed.prediction_reason_kinds.is_empty());
assert!(parsed.original_severity.is_none());
}
}