use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::CorrectionPattern;
use crate::math::clamp;
pub const MIN_CORRECTIONS_FOR_PATTERN: usize = 3;
pub const MAX_EXAMPLE_CORRECTIONS: usize = 3;
pub const MAX_CORRECTIONS_PER_CLUSTER: usize = 20;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CorrectionStore {
records: HashMap<String, Vec<String>>,
}
impl CorrectionStore {
pub fn new() -> Self {
Self::default()
}
pub fn record_correction(&mut self, cluster: &str, correction_text: impl Into<String>) {
if cluster.is_empty() {
return;
}
let list = self.records.entry(cluster.to_string()).or_default();
list.push(correction_text.into());
if list.len() > MAX_CORRECTIONS_PER_CLUSTER {
list.remove(0);
}
}
pub fn count_for(&self, cluster: &str) -> usize {
self.records.get(cluster).map(|v| v.len()).unwrap_or(0)
}
pub fn pattern_for(&self, user_id: &str, cluster: &str) -> Option<CorrectionPattern> {
let list = self.records.get(cluster)?;
if list.len() < MIN_CORRECTIONS_FOR_PATTERN {
return None;
}
let example_corrections: Vec<String> = list
.iter()
.rev()
.take(MAX_EXAMPLE_CORRECTIONS)
.cloned()
.collect();
let confidence = clamp(
list.len() as f64 / MAX_CORRECTIONS_PER_CLUSTER as f64,
0.0,
1.0,
);
Some(CorrectionPattern {
user_id: user_id.to_string(),
topic_cluster: cluster.to_string(),
pattern_name: format!("corrections_on_{cluster}"),
learned_from_turns: list.len(),
confidence,
example_corrections,
})
}
pub fn all_patterns(&self, user_id: &str) -> Vec<CorrectionPattern> {
let mut clusters: Vec<&String> = self.records.keys().collect();
clusters.sort();
clusters
.into_iter()
.filter_map(|c| self.pattern_for(user_id, c))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_store_has_no_patterns() {
let store = CorrectionStore::new();
assert_eq!(store.count_for("any"), 0);
assert!(store.pattern_for("u1", "any").is_none());
assert!(store.all_patterns("u1").is_empty());
}
#[test]
fn record_correction_accumulates_count() {
let mut store = CorrectionStore::new();
store.record_correction("refactor+async", "don't add logging");
store.record_correction("refactor+async", "stop adding logging");
assert_eq!(store.count_for("refactor+async"), 2);
assert_eq!(store.count_for("other+cluster"), 0);
}
#[test]
fn empty_cluster_drops_correction() {
let mut store = CorrectionStore::new();
store.record_correction("", "some correction text");
assert_eq!(store.count_for(""), 0);
assert!(store.all_patterns("u1").is_empty());
}
#[test]
fn pattern_below_threshold_is_none() {
let mut store = CorrectionStore::new();
store.record_correction("cluster_a", "first correction");
store.record_correction("cluster_a", "second correction");
assert!(store.pattern_for("u1", "cluster_a").is_none());
}
#[test]
fn pattern_at_threshold_emerges() {
let mut store = CorrectionStore::new();
store.record_correction("refactor+async", "don't add logging");
store.record_correction("refactor+async", "stop adding logging please");
store.record_correction("refactor+async", "no more logs");
let pattern = store
.pattern_for("user_42", "refactor+async")
.expect("three corrections on same cluster must emerge as pattern");
assert_eq!(pattern.user_id, "user_42");
assert_eq!(pattern.topic_cluster, "refactor+async");
assert_eq!(pattern.pattern_name, "corrections_on_refactor+async");
assert_eq!(pattern.learned_from_turns, 3);
assert!((pattern.confidence - 3.0 / MAX_CORRECTIONS_PER_CLUSTER as f64).abs() < 1e-9);
assert_eq!(pattern.example_corrections.len(), 3);
assert_eq!(pattern.example_corrections[0], "no more logs");
}
#[test]
fn example_corrections_capped_at_max() {
let mut store = CorrectionStore::new();
for i in 0..10 {
store.record_correction("cluster", format!("correction {i}"));
}
let pattern = store
.pattern_for("u1", "cluster")
.expect("10 corrections exceed threshold");
assert_eq!(pattern.example_corrections.len(), MAX_EXAMPLE_CORRECTIONS);
assert_eq!(pattern.example_corrections[0], "correction 9");
}
#[test]
fn confidence_saturates_at_max_records() {
let mut store = CorrectionStore::new();
for i in 0..(MAX_CORRECTIONS_PER_CLUSTER + 5) {
store.record_correction("cluster", format!("correction {i}"));
}
let pattern = store.pattern_for("u1", "cluster").expect("above threshold");
assert!((pattern.confidence - 1.0).abs() < 1e-9);
assert_eq!(pattern.learned_from_turns, MAX_CORRECTIONS_PER_CLUSTER);
}
#[test]
fn all_patterns_returns_sorted_stable() {
let mut store = CorrectionStore::new();
for cluster in ["zulu", "alpha", "mike"] {
for i in 0..3 {
store.record_correction(cluster, format!("{cluster} correction {i}"));
}
}
store.record_correction("below", "only once");
let patterns = store.all_patterns("u1");
let clusters: Vec<&str> = patterns.iter().map(|p| p.topic_cluster.as_str()).collect();
assert_eq!(clusters, vec!["alpha", "mike", "zulu"]);
}
#[test]
fn records_per_cluster_bounded() {
let mut store = CorrectionStore::new();
for i in 0..(MAX_CORRECTIONS_PER_CLUSTER * 2) {
store.record_correction("cluster", format!("correction {i}"));
}
assert_eq!(store.count_for("cluster"), MAX_CORRECTIONS_PER_CLUSTER);
let pattern = store.pattern_for("u1", "cluster").unwrap();
assert_eq!(
pattern.example_corrections[0],
format!("correction {}", MAX_CORRECTIONS_PER_CLUSTER * 2 - 1)
);
}
#[test]
fn round_trip_via_serde_json() {
let mut store = CorrectionStore::new();
store.record_correction("cluster_a", "c1");
store.record_correction("cluster_a", "c2");
store.record_correction("cluster_a", "c3");
let json = serde_json::to_string(&store).expect("serialise");
let restored: CorrectionStore =
serde_json::from_str(&json).expect("deserialise");
assert_eq!(restored.count_for("cluster_a"), 3);
assert!(restored.pattern_for("u1", "cluster_a").is_some());
}
}