pub mod correction;
pub mod cost;
pub mod scope;
pub mod state;
pub mod token_stats;
use serde::{Deserialize, Serialize};
use crate::session::CognitiveSession;
use self::correction::CorrectionStore;
use self::cost::{
normalize_cost, CostAccumulator, POOR_QUALITY_MEAN, QUALITY_DECLINE_MIN_DELTA,
QUALITY_DECLINE_WINDOW,
};
use self::scope::{ScopeTracker, DRIFT_WARN_THRESHOLD};
use self::token_stats::{confidence_with_fallback, TokenStatsAccumulator};
pub use self::state::RegulatorState;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum LLMEvent {
TurnStart {
user_message: String,
},
Token {
token: String,
logprob: f64,
index: usize,
},
TurnComplete {
full_response: String,
},
Cost {
tokens_in: u32,
tokens_out: u32,
wallclock_ms: u32,
provider: Option<String>,
},
UserCorrection {
correction_message: String,
corrects_last: bool,
},
QualityFeedback {
quality: f64,
fragment_spans: Option<Vec<(usize, usize)>>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[must_use]
pub enum Decision {
Continue,
CircuitBreak {
reason: CircuitBreakReason,
suggestion: String,
},
ScopeDriftWarn {
drift_tokens: Vec<String>,
drift_score: f64,
task_tokens: Vec<String>,
},
LowConfidenceSpans {
spans: Vec<ConfidenceSpan>,
},
ProceduralWarning {
patterns: Vec<CorrectionPattern>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum CircuitBreakReason {
CostCapReached {
tokens_spent: u32,
tokens_cap: u32,
mean_quality_last_n: f64,
},
QualityDeclineNoRecovery {
turns: usize,
mean_delta: f64,
},
RepeatedFailurePattern {
cluster: String,
failure_count: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceSpan {
pub start_char: usize,
pub end_char: usize,
pub confidence: f64,
pub mean_token_logprob: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CorrectionPattern {
pub user_id: String,
pub topic_cluster: String,
pub pattern_name: String,
pub learned_from_turns: usize,
pub confidence: f64,
#[serde(default)]
pub example_corrections: Vec<String>,
}
pub struct Regulator {
session: CognitiveSession,
user_id: String,
pending_response: Option<String>,
token_stats: TokenStatsAccumulator,
scope: ScopeTracker,
cost: CostAccumulator,
correction: CorrectionStore,
current_topic_cluster: String,
}
impl Regulator {
pub fn for_user(user_id: impl Into<String>) -> Self {
Self {
session: CognitiveSession::new(),
user_id: user_id.into(),
pending_response: None,
token_stats: TokenStatsAccumulator::new(),
scope: ScopeTracker::new(),
cost: CostAccumulator::new(),
correction: CorrectionStore::new(),
current_topic_cluster: String::new(),
}
}
pub fn with_cost_cap(mut self, cap_tokens: u32) -> Self {
self.cost.set_cap(cap_tokens);
self
}
pub fn on_event(&mut self, event: LLMEvent) {
match event {
LLMEvent::TurnStart { user_message } => {
self.token_stats.begin_turn();
self.scope.set_task(&user_message);
self.current_topic_cluster = crate::cognition::detector::build_topic_cluster(
self.scope.task_tokens(),
);
let _ = self.session.process_message(&user_message);
}
LLMEvent::Token { logprob, .. } => {
self.token_stats.on_token(logprob);
}
LLMEvent::TurnComplete { full_response } => {
self.scope.set_response(&full_response);
self.pending_response = Some(full_response);
}
LLMEvent::Cost {
tokens_in,
tokens_out,
wallclock_ms,
provider: _,
} => {
self.cost.record_cost(tokens_in, tokens_out, wallclock_ms);
let normalised = normalize_cost(tokens_out, wallclock_ms);
self.session.track_cost(normalised);
}
LLMEvent::UserCorrection {
correction_message,
corrects_last,
} => {
if !corrects_last {
return;
}
if self.current_topic_cluster.is_empty() {
return;
}
self.correction
.record_correction(&self.current_topic_cluster, correction_message);
}
LLMEvent::QualityFeedback { quality, .. } => {
self.cost.record_quality(quality);
if let Some(response) = self.pending_response.take() {
self.session.process_response(&response, quality);
}
}
}
}
pub fn confidence(&self) -> f64 {
confidence_with_fallback(&self.token_stats, self.pending_response.as_deref())
}
pub fn logprob_coverage(&self) -> f64 {
self.token_stats.logprob_coverage()
}
pub fn total_tokens_out(&self) -> u32 {
self.cost.total_tokens_out()
}
pub fn cost_cap_tokens(&self) -> u32 {
self.cost.cap_tokens()
}
pub fn decide(&self) -> Decision {
if self.cost.cap_reached() {
let mean_quality_last_n = self
.cost
.mean_quality_last_n(QUALITY_DECLINE_WINDOW)
.unwrap_or(1.0);
if mean_quality_last_n < POOR_QUALITY_MEAN {
return Decision::CircuitBreak {
reason: CircuitBreakReason::CostCapReached {
tokens_spent: self.cost.total_tokens_out(),
tokens_cap: self.cost.cap_tokens(),
mean_quality_last_n,
},
suggestion:
"Cost cap reached with poor recent quality. Ask the user to clarify scope or abandon this task."
.into(),
};
}
}
if let Some(delta) = self
.cost
.quality_decline_over_n(QUALITY_DECLINE_WINDOW, QUALITY_DECLINE_MIN_DELTA)
{
let mean = self
.cost
.mean_quality_last_n(QUALITY_DECLINE_WINDOW)
.unwrap_or(1.0);
if mean < POOR_QUALITY_MEAN {
return Decision::CircuitBreak {
reason: CircuitBreakReason::QualityDeclineNoRecovery {
turns: QUALITY_DECLINE_WINDOW,
mean_delta: delta,
},
suggestion:
"Response quality is declining without recovery. Consider redirecting or simplifying the task."
.into(),
};
}
}
if let Some(drift) = self.scope.drift_score() {
if drift >= DRIFT_WARN_THRESHOLD {
return Decision::ScopeDriftWarn {
drift_tokens: self.scope.drift_tokens(),
drift_score: drift,
task_tokens: self.scope.task_tokens().to_vec(),
};
}
}
if !self.current_topic_cluster.is_empty() {
if let Some(pattern) = self
.correction
.pattern_for(&self.user_id, &self.current_topic_cluster)
{
return Decision::ProceduralWarning {
patterns: vec![pattern],
};
}
}
Decision::Continue
}
pub fn export(&self) -> RegulatorState {
let correction_patterns = self
.correction
.all_patterns(&self.user_id)
.into_iter()
.map(|p| (p.topic_cluster.clone(), p))
.collect();
RegulatorState {
user_id: self.user_id.clone(),
learned: self.session.export_learned(),
correction_patterns,
}
}
pub fn import(state: RegulatorState) -> Self {
let mut session = CognitiveSession::new();
session.import_learned(state.learned);
let mut correction = CorrectionStore::new();
for (cluster, pattern) in &state.correction_patterns {
for text in pattern.example_corrections.iter().rev() {
correction.record_correction(cluster, text.clone());
}
}
Self {
session,
user_id: state.user_id,
pending_response: None,
token_stats: TokenStatsAccumulator::new(),
scope: ScopeTracker::new(),
cost: CostAccumulator::new(),
correction,
current_topic_cluster: String::new(),
}
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub fn session(&self) -> &CognitiveSession {
&self.session
}
pub fn session_mut(&mut self) -> &mut CognitiveSession {
&mut self.session
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn for_user_starts_with_fresh_session() {
let reg = Regulator::for_user("user_42");
assert_eq!(reg.user_id(), "user_42");
assert_eq!(reg.session().turn_count(), 0);
assert!(reg.session().world_model().last_response_strategy.is_none());
}
#[test]
fn turn_start_runs_cognitive_pipeline() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "Explain async Rust".into(),
});
assert_eq!(reg.session().turn_count(), 1);
}
#[test]
fn turn_complete_without_feedback_does_not_learn() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to use async?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's a step-by-step guide:\n1. Add tokio\n2. Write async fn\n3. Await"
.into(),
});
assert!(reg.session().world_model().last_response_strategy.is_none());
}
#[test]
fn quality_feedback_consolidates_buffered_response() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How do I use async in Rust?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's a step-by-step guide:\n1. Add tokio\n2. Write async fn\n3. Await"
.into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
assert!(reg
.session()
.world_model()
.last_response_strategy
.is_some());
}
#[test]
fn second_quality_feedback_after_drain_is_noop() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to async?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: Add tokio\nStep 2: async fn\nStep 3: await".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
let learned_after_first = serde_json::to_string(
®.session().world_model().learned,
)
.expect("serialise LearnedState");
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.1,
fragment_spans: None,
});
let learned_after_second = serde_json::to_string(
®.session().world_model().learned,
)
.expect("serialise LearnedState");
assert_eq!(
learned_after_first, learned_after_second,
"drained buffer must not be re-consolidated by a stray feedback"
);
}
#[test]
fn quality_feedback_without_turn_complete_is_noop() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to async?".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.5,
fragment_spans: None,
});
assert!(reg.session().world_model().last_response_strategy.is_none());
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: First\nStep 2: Second\nStep 3: Third".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.8,
fragment_spans: None,
});
assert!(reg.session().world_model().last_response_strategy.is_some());
}
#[test]
fn inert_events_do_not_panic_or_mutate_turn_count() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "hi".into() });
let before = reg.session().turn_count();
reg.on_event(LLMEvent::Token {
token: "hello".into(),
logprob: -0.5,
index: 0,
});
reg.on_event(LLMEvent::Cost {
tokens_in: 10,
tokens_out: 20,
wallclock_ms: 500,
provider: Some("anthropic".into()),
});
reg.on_event(LLMEvent::UserCorrection {
correction_message: "don't add docstrings".into(),
corrects_last: true,
});
assert_eq!(reg.session().turn_count(), before);
}
#[test]
fn decide_returns_continue_by_default() {
let reg = Regulator::for_user("user_a");
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn export_import_roundtrip_preserves_learning() {
let mut reg = Regulator::for_user("user_persist");
for i in 0..10 {
reg.on_event(LLMEvent::TurnStart {
user_message: format!("Rust question {i}"),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: First\nStep 2: Second\nStep 3: Third".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
}
let snapshot = reg.export();
assert_eq!(snapshot.user_id, "user_persist");
assert!(
!snapshot.learned.response_strategies.is_empty(),
"training loop should have populated strategy EMA"
);
let restored = Regulator::import(snapshot.clone());
assert_eq!(restored.user_id(), "user_persist");
assert_eq!(
restored.session().world_model().learned.response_strategies.len(),
snapshot.learned.response_strategies.len(),
);
}
#[test]
fn roundtrip_via_serde_json() {
let mut reg = Regulator::for_user("user_json");
reg.on_event(LLMEvent::TurnStart { user_message: "hi".into() });
reg.on_event(LLMEvent::TurnComplete { full_response: "hello".into() });
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.7,
fragment_spans: None,
});
let snapshot = reg.export();
let json = serde_json::to_string(&snapshot).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
assert_eq!(decoded.user_id, snapshot.user_id);
assert_eq!(decoded.learned.tick, snapshot.learned.tick);
}
#[test]
fn session_mut_exposes_path1_escape_hatch() {
let mut reg = Regulator::for_user("user_a");
let initial = reg.session().world_model().body_budget;
reg.session_mut().track_cost(1.0);
let after = reg.session().world_model().body_budget;
assert!(after < initial, "track_cost via session_mut should deplete budget");
}
#[test]
fn confidence_starts_neutral() {
let reg = Regulator::for_user("user_a");
assert!((reg.confidence() - 0.5).abs() < 1e-9);
assert_eq!(reg.logprob_coverage(), 0.0);
}
#[test]
fn confident_token_stream_raises_confidence() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "explain async".into() });
for i in 0..15 {
reg.on_event(LLMEvent::Token {
token: "tok".into(),
logprob: -0.2,
index: i,
});
}
assert!(
reg.confidence() > 0.8,
"confident token stream should drive confidence >0.8 (got {})",
reg.confidence()
);
assert!((reg.logprob_coverage() - 1.0).abs() < 1e-9);
}
#[test]
fn gibberish_token_stream_lowers_confidence() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "asdfkjh qwer zxcvb".into(),
});
for i in 0..15 {
reg.on_event(LLMEvent::Token {
token: "?".into(),
logprob: -6.5,
index: i,
});
}
assert!(
reg.confidence() < 0.2,
"high-NLL stream should drive confidence <0.2 (got {})",
reg.confidence()
);
}
#[test]
fn turn_start_resets_token_window() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "q1".into() });
for i in 0..10 {
reg.on_event(LLMEvent::Token {
token: "t".into(),
logprob: -0.1,
index: i,
});
}
let confident = reg.confidence();
assert!(confident > 0.8);
reg.on_event(LLMEvent::TurnStart { user_message: "q2".into() });
assert!(
(reg.confidence() - 0.5).abs() < 1e-9,
"new turn should reset confidence to neutral (got {})",
reg.confidence()
);
}
#[test]
fn unavailable_logprobs_fall_back_to_structural_signal() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "how to refactor?".into(),
});
for i in 0..5 {
reg.on_event(LLMEvent::Token {
token: "t".into(),
logprob: 0.0, index: i,
});
}
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's the refactored function. It preserves the original signature."
.into(),
});
assert!(
(reg.confidence() - 0.7).abs() < 0.01,
"structural fallback on unremarkable response should be ~0.7 (got {})",
reg.confidence()
);
assert_eq!(reg.logprob_coverage(), 0.0);
}
#[test]
fn unavailable_logprobs_short_response_fallback_is_low() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "can you do X?".into(),
});
reg.on_event(LLMEvent::Token {
token: "No.".into(),
logprob: 0.0,
index: 0,
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "No.".into(),
});
assert!(
reg.confidence() < 0.5,
"short unremarkable response should fall in low band (got {})",
reg.confidence()
);
}
#[test]
fn decide_emits_scope_drift_warn_on_plan_example() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
match reg.decide() {
Decision::ScopeDriftWarn {
drift_score,
drift_tokens,
task_tokens,
} => {
assert!(
drift_score > 0.3,
"plan target requires drift > 0.3 (got {drift_score})"
);
assert!(!drift_tokens.is_empty(), "drift_tokens must be populated");
assert!(!task_tokens.is_empty(), "task_tokens must be populated");
}
other => panic!(
"plan example must emit ScopeDriftWarn, got {other:?}"
),
}
}
#[test]
fn decide_continues_when_response_stays_on_task() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor the async function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "refactor async function".into(),
});
assert!(
matches!(reg.decide(), Decision::Continue),
"on-task response must not emit ScopeDriftWarn"
);
}
#[test]
fn decide_continues_before_turn_complete() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn turn_start_resets_scope_state() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and telemetry".into(),
});
assert!(matches!(reg.decide(), Decision::ScopeDriftWarn { .. }));
reg.on_event(LLMEvent::TurnStart {
user_message: "explain tokio runtime".into(),
});
assert!(
matches!(reg.decide(), Decision::Continue),
"new turn must clear stale response before new response arrives"
);
}
#[test]
fn cost_event_accumulates_totals() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::Cost {
tokens_in: 100,
tokens_out: 400,
wallclock_ms: 2_000,
provider: None,
});
reg.on_event(LLMEvent::Cost {
tokens_in: 50,
tokens_out: 200,
wallclock_ms: 1_500,
provider: Some("anthropic".into()),
});
assert_eq!(reg.total_tokens_out(), 600);
}
#[test]
fn cost_event_feeds_track_cost_via_normalisation() {
let mut reg = Regulator::for_user("user_a");
let initial_budget = reg.session().world_model().body_budget;
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: cost::TYPICAL_TURN_TOKENS_OUT,
wallclock_ms: cost::TYPICAL_TURN_WALLCLOCK_MS,
provider: None,
});
let after_budget = reg.session().world_model().body_budget;
assert!(
after_budget < initial_budget,
"Cost event must deplete body_budget via track_cost (before {initial_budget}, after {after_budget})"
);
}
#[test]
fn quality_feedback_records_in_cost_accumulator() {
let mut reg = Regulator::for_user("user_a");
for q in [0.9, 0.8, 0.7] {
reg.on_event(LLMEvent::QualityFeedback {
quality: q,
fragment_spans: None,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_emits_circuit_break_on_cost_cap_with_poor_quality() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(1_000);
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 1_000,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.3, fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak { reason, .. } => match reason {
CircuitBreakReason::CostCapReached {
tokens_spent,
tokens_cap,
mean_quality_last_n,
} => {
assert_eq!(tokens_spent, 1_200);
assert_eq!(tokens_cap, 1_000);
assert!(mean_quality_last_n < 0.5);
}
other => panic!("expected CostCapReached, got {other:?}"),
},
other => panic!("expected CircuitBreak, got {other:?}"),
}
}
#[test]
fn decide_does_not_emit_circuit_break_when_quality_recovers() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(1_000);
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 1_000,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.9, fragment_spans: None,
});
}
assert!(
matches!(reg.decide(), Decision::Continue),
"high-quality agent at over-cap must not be halted"
);
}
#[test]
fn decide_emits_circuit_break_on_quality_decline_no_recovery() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(u32::MAX);
for q in [0.8, 0.5, 0.3, 0.2, 0.15] {
reg.on_event(LLMEvent::QualityFeedback {
quality: q,
fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak {
reason:
CircuitBreakReason::QualityDeclineNoRecovery { turns, mean_delta },
..
} => {
assert_eq!(turns, cost::QUALITY_DECLINE_WINDOW);
assert!(mean_delta >= cost::QUALITY_DECLINE_MIN_DELTA);
}
other => panic!("expected QualityDeclineNoRecovery, got {other:?}"),
}
}
#[test]
fn decide_priority_circuit_break_dominates_scope_drift() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(500);
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 0,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.2,
fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak {
reason: CircuitBreakReason::CostCapReached { .. },
..
} => {}
other => panic!(
"priority rule must emit CircuitBreak CostCapReached, got {other:?}"
),
}
}
#[test]
fn with_cost_cap_preserves_prior_accumulation() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::Cost {
tokens_in: 100,
tokens_out: 200,
wallclock_ms: 500,
provider: None,
});
assert_eq!(reg.total_tokens_out(), 200);
let reg = reg.with_cost_cap(10_000);
assert_eq!(reg.total_tokens_out(), 200, "cap change must not reset counters");
assert_eq!(reg.cost_cap_tokens(), 10_000);
}
fn drive_three_corrections_on(reg: &mut Regulator, task_message: &str) {
reg.on_event(LLMEvent::TurnStart {
user_message: task_message.into(),
});
for msg in [
"don't add logging",
"stop adding logging please",
"no more logs",
] {
reg.on_event(LLMEvent::UserCorrection {
correction_message: msg.into(),
corrects_last: true,
});
}
}
#[test]
fn user_correction_requires_corrects_last_true() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::UserCorrection {
correction_message: "something entirely different".into(),
corrects_last: false,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn user_correction_dropped_when_no_active_cluster() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "is it ok?".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::UserCorrection {
correction_message: "never do X".into(),
corrects_last: true,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_emits_procedural_warning_at_pattern_threshold() {
let mut reg = Regulator::for_user("user_42");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
match reg.decide() {
Decision::ProceduralWarning { patterns } => {
assert_eq!(patterns.len(), 1);
let pattern = &patterns[0];
assert_eq!(pattern.user_id, "user_42");
assert_eq!(pattern.learned_from_turns, 3);
assert!(pattern
.pattern_name
.starts_with("corrections_on_"));
assert_eq!(pattern.example_corrections.len(), 3);
assert_eq!(pattern.example_corrections[0], "no more logs");
}
other => panic!("expected ProceduralWarning, got {other:?}"),
}
}
#[test]
fn decide_does_not_emit_procedural_warning_below_threshold() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
for msg in ["don't add logging", "no more logs"] {
reg.on_event(LLMEvent::UserCorrection {
correction_message: msg.into(),
corrects_last: true,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn procedural_warning_fires_only_on_matching_cluster() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
assert!(matches!(reg.decide(), Decision::ProceduralWarning { .. }));
reg.on_event(LLMEvent::TurnStart {
user_message: "explain docker containers to me".into(),
});
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_priority_scope_drift_dominates_procedural_warning() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
assert!(matches!(reg.decide(), Decision::ScopeDriftWarn { .. }));
}
#[test]
fn export_includes_correction_patterns() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
let snapshot = reg.export();
assert!(
!snapshot.correction_patterns.is_empty(),
"at-threshold cluster must appear in exported patterns"
);
let cluster_key = snapshot
.correction_patterns
.keys()
.next()
.expect("one pattern");
let pattern = &snapshot.correction_patterns[cluster_key];
assert_eq!(pattern.learned_from_turns, 3);
assert_eq!(pattern.example_corrections.len(), 3);
}
#[test]
fn import_restores_patterns_via_example_replay() {
let mut source = Regulator::for_user("user_persist");
drive_three_corrections_on(&mut source, "refactor this function to be async");
let snapshot = source.export();
let json = serde_json::to_string(&snapshot).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
let mut restored = Regulator::import(decoded);
restored.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
match restored.decide() {
Decision::ProceduralWarning { patterns } => {
assert_eq!(patterns.len(), 1);
assert_eq!(patterns[0].learned_from_turns, 3);
assert_eq!(patterns[0].example_corrections.len(), 3);
}
other => panic!("restored regulator must fire ProceduralWarning, got {other:?}"),
}
}
#[test]
fn import_preserves_example_corrections_order() {
let mut source = Regulator::for_user("user_order");
drive_three_corrections_on(&mut source, "refactor this function to be async");
let snapshot_before = source.export();
let (cluster, pattern_before) = snapshot_before
.correction_patterns
.iter()
.next()
.expect("one pattern after 3 corrections");
let before = pattern_before.example_corrections.clone();
let json = serde_json::to_string(&snapshot_before).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
let restored = Regulator::import(decoded);
let snapshot_after = restored.export();
let pattern_after = snapshot_after
.correction_patterns
.get(cluster)
.expect("pattern restored under same cluster key");
assert_eq!(pattern_after.example_corrections, before);
}
#[test]
fn legacy_snapshot_loads_without_correction_patterns() {
let legacy = r#"{
"user_id": "legacy",
"learned": {
"gain_mode": "neutral",
"tick": 0,
"response_success": {},
"response_strategies": {}
}
}"#;
let state: RegulatorState =
serde_json::from_str(legacy).expect("legacy snapshot must load");
let reg = Regulator::import(state);
assert_eq!(reg.user_id(), "legacy");
assert!(matches!(reg.decide(), Decision::Continue));
}
}