use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::types::world::LearnedState;
use super::CorrectionPattern;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegulatorState {
pub user_id: String,
pub learned: LearnedState,
#[serde(default)]
pub correction_patterns: HashMap<String, CorrectionPattern>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_state_round_trips() {
let s = RegulatorState {
user_id: "alice".into(),
learned: LearnedState::default(),
correction_patterns: HashMap::new(),
};
let json = serde_json::to_string(&s).expect("serialise");
let back: RegulatorState = serde_json::from_str(&json).expect("deserialise");
assert_eq!(back.user_id, "alice");
assert!(back.correction_patterns.is_empty());
}
#[test]
fn pre_session_20_snapshot_deserialises_with_empty_patterns() {
let legacy_json = r#"{
"user_id": "legacy_user",
"learned": {
"gain_mode": "neutral",
"tick": 0,
"response_success": {},
"response_strategies": {}
}
}"#;
let state: RegulatorState =
serde_json::from_str(legacy_json).expect("legacy snapshot must load");
assert_eq!(state.user_id, "legacy_user");
assert!(
state.correction_patterns.is_empty(),
"missing field must default to empty map"
);
}
#[test]
fn correction_patterns_survive_round_trip() {
let mut patterns = HashMap::new();
patterns.insert(
"refactor+async".to_string(),
CorrectionPattern {
user_id: "alice".into(),
topic_cluster: "refactor+async".into(),
pattern_name: "corrections_on_refactor+async".into(),
learned_from_turns: 3,
confidence: 0.15,
example_corrections: vec![
"no more logs".into(),
"stop adding logging".into(),
"don't add logging".into(),
],
},
);
let s = RegulatorState {
user_id: "alice".into(),
learned: LearnedState::default(),
correction_patterns: patterns,
};
let json = serde_json::to_string(&s).expect("serialise");
let back: RegulatorState = serde_json::from_str(&json).expect("deserialise");
let restored = back
.correction_patterns
.get("refactor+async")
.expect("pattern should round-trip");
assert_eq!(restored.learned_from_turns, 3);
assert_eq!(restored.example_corrections.len(), 3);
assert_eq!(restored.example_corrections[0], "no more logs");
}
}