pub mod correction;
pub mod cost;
pub mod otel;
pub mod scope;
pub mod state;
pub mod token_stats;
pub mod tools;
use serde::{Deserialize, Serialize};
use crate::session::CognitiveSession;
use self::correction::CorrectionStore;
use self::cost::{
normalize_cost, CostAccumulator, POOR_QUALITY_MEAN, QUALITY_DECLINE_MIN_DELTA,
QUALITY_DECLINE_WINDOW,
};
use self::scope::{ScopeTracker, DRIFT_WARN_THRESHOLD};
use self::token_stats::{confidence_with_fallback, TokenStatsAccumulator};
use self::tools::ToolStatsAccumulator;
pub use self::state::RegulatorState;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum LLMEvent {
TurnStart {
user_message: String,
},
Token {
token: String,
logprob: f64,
index: usize,
},
TurnComplete {
full_response: String,
},
Cost {
tokens_in: u32,
tokens_out: u32,
wallclock_ms: u32,
provider: Option<String>,
},
UserCorrection {
correction_message: String,
corrects_last: bool,
},
QualityFeedback {
quality: f64,
fragment_spans: Option<Vec<(usize, usize)>>,
},
ToolCall {
tool_name: String,
args_json: Option<String>,
},
ToolResult {
tool_name: String,
success: bool,
duration_ms: u64,
error_summary: Option<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
#[must_use]
pub enum Decision {
Continue,
CircuitBreak {
reason: CircuitBreakReason,
suggestion: String,
},
ScopeDriftWarn {
drift_tokens: Vec<String>,
drift_score: f64,
task_tokens: Vec<String>,
},
LowConfidenceSpans {
spans: Vec<ConfidenceSpan>,
},
ProceduralWarning {
patterns: Vec<CorrectionPattern>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum CircuitBreakReason {
CostCapReached {
tokens_spent: u32,
tokens_cap: u32,
mean_quality_last_n: f64,
},
QualityDeclineNoRecovery {
turns: usize,
mean_delta: f64,
},
RepeatedFailurePattern {
cluster: String,
failure_count: usize,
},
RepeatedToolCallLoop {
tool_name: String,
consecutive_count: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfidenceSpan {
pub start_char: usize,
pub end_char: usize,
pub confidence: f64,
pub mean_token_logprob: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CorrectionPattern {
pub user_id: String,
pub topic_cluster: String,
pub pattern_name: String,
pub learned_from_turns: usize,
pub confidence: f64,
#[serde(default)]
pub example_corrections: Vec<String>,
}
pub struct Regulator {
session: CognitiveSession,
user_id: String,
pending_response: Option<String>,
token_stats: TokenStatsAccumulator,
scope: ScopeTracker,
cost: CostAccumulator,
correction: CorrectionStore,
current_topic_cluster: String,
tools: ToolStatsAccumulator,
implicit_correction_window: Option<std::time::Duration>,
last_turn_complete_at: Option<std::time::Instant>,
implicit_corrections_count: usize,
}
impl Regulator {
pub fn for_user(user_id: impl Into<String>) -> Self {
Self {
session: CognitiveSession::new(),
user_id: user_id.into(),
pending_response: None,
token_stats: TokenStatsAccumulator::new(),
scope: ScopeTracker::new(),
cost: CostAccumulator::new(),
correction: CorrectionStore::new(),
current_topic_cluster: String::new(),
tools: ToolStatsAccumulator::new(),
implicit_correction_window: None,
last_turn_complete_at: None,
implicit_corrections_count: 0,
}
}
pub fn with_cost_cap(mut self, cap_tokens: u32) -> Self {
self.cost.set_cap(cap_tokens);
self
}
pub fn with_implicit_correction_window(mut self, window: std::time::Duration) -> Self {
self.implicit_correction_window = Some(window);
self
}
pub fn on_event(&mut self, event: LLMEvent) {
match event {
LLMEvent::TurnStart { user_message } => {
let previous_cluster = self.current_topic_cluster.clone();
let within_window = match (
self.implicit_correction_window,
self.last_turn_complete_at,
) {
(Some(window), Some(last_complete_at)) => {
last_complete_at.elapsed() <= window
}
_ => false,
};
self.token_stats.begin_turn();
self.tools.reset_turn();
self.scope.set_task(&user_message);
self.current_topic_cluster = crate::cognition::detector::build_topic_cluster(
self.scope.task_tokens(),
);
if within_window
&& !self.current_topic_cluster.is_empty()
&& self.current_topic_cluster == previous_cluster
{
self.correction.record_correction(
&self.current_topic_cluster,
user_message.clone(),
);
self.implicit_corrections_count += 1;
}
let _ = self.session.process_message(&user_message);
}
LLMEvent::Token { logprob, .. } => {
self.token_stats.on_token(logprob);
}
LLMEvent::TurnComplete { full_response } => {
self.scope.set_response(&full_response);
self.pending_response = Some(full_response);
if self.implicit_correction_window.is_some() {
self.last_turn_complete_at = Some(std::time::Instant::now());
}
}
LLMEvent::Cost {
tokens_in,
tokens_out,
wallclock_ms,
provider: _,
} => {
self.cost.record_cost(tokens_in, tokens_out, wallclock_ms);
let normalised = normalize_cost(tokens_out, wallclock_ms);
self.session.track_cost(normalised);
}
LLMEvent::UserCorrection {
correction_message,
corrects_last,
} => {
if !corrects_last {
return;
}
if self.current_topic_cluster.is_empty() {
return;
}
self.correction
.record_correction(&self.current_topic_cluster, correction_message);
}
LLMEvent::QualityFeedback { quality, .. } => {
self.cost.record_quality(quality);
if let Some(response) = self.pending_response.take() {
self.session.process_response(&response, quality);
}
}
LLMEvent::ToolCall {
tool_name,
args_json,
} => {
self.tools.record_call(tool_name, args_json);
}
LLMEvent::ToolResult {
tool_name,
success,
duration_ms,
error_summary,
} => {
self.tools
.record_result(tool_name, success, duration_ms, error_summary);
}
}
}
pub fn confidence(&self) -> f64 {
confidence_with_fallback(&self.token_stats, self.pending_response.as_deref())
}
pub fn logprob_coverage(&self) -> f64 {
self.token_stats.logprob_coverage()
}
pub fn total_tokens_out(&self) -> u32 {
self.cost.total_tokens_out()
}
pub fn cost_cap_tokens(&self) -> u32 {
self.cost.cap_tokens()
}
pub fn tool_total_calls(&self) -> usize {
self.tools.total_calls()
}
pub fn tool_counts_by_name(&self) -> std::collections::HashMap<String, usize> {
self.tools.counts_by_tool()
}
pub fn tool_total_duration_ms(&self) -> u64 {
self.tools.total_duration_ms()
}
pub fn tool_failure_count(&self) -> usize {
self.tools.failure_count()
}
pub fn implicit_corrections_count(&self) -> usize {
self.implicit_corrections_count
}
pub fn metrics_snapshot(&self) -> std::collections::HashMap<String, f64> {
let mut m = std::collections::HashMap::new();
m.insert("noos.confidence".into(), self.confidence());
m.insert("noos.logprob_coverage".into(), self.logprob_coverage());
m.insert(
"noos.total_tokens_out".into(),
self.total_tokens_out() as f64,
);
m.insert(
"noos.cost_cap_tokens".into(),
self.cost_cap_tokens() as f64,
);
m.insert(
"noos.tool_total_calls".into(),
self.tool_total_calls() as f64,
);
m.insert(
"noos.tool_total_duration_ms".into(),
self.tool_total_duration_ms() as f64,
);
m.insert(
"noos.tool_failure_count".into(),
self.tool_failure_count() as f64,
);
m.insert(
"noos.implicit_corrections_count".into(),
self.implicit_corrections_count() as f64,
);
m
}
pub fn decide(&self) -> Decision {
if self.cost.cap_reached() {
let mean_quality_last_n = self
.cost
.mean_quality_last_n(QUALITY_DECLINE_WINDOW)
.unwrap_or(1.0);
if mean_quality_last_n < POOR_QUALITY_MEAN {
return Decision::CircuitBreak {
reason: CircuitBreakReason::CostCapReached {
tokens_spent: self.cost.total_tokens_out(),
tokens_cap: self.cost.cap_tokens(),
mean_quality_last_n,
},
suggestion:
"Cost cap reached with poor recent quality. Ask the user to clarify scope or abandon this task."
.into(),
};
}
}
if let Some(delta) = self
.cost
.quality_decline_over_n(QUALITY_DECLINE_WINDOW, QUALITY_DECLINE_MIN_DELTA)
{
let mean = self
.cost
.mean_quality_last_n(QUALITY_DECLINE_WINDOW)
.unwrap_or(1.0);
if mean < POOR_QUALITY_MEAN {
return Decision::CircuitBreak {
reason: CircuitBreakReason::QualityDeclineNoRecovery {
turns: QUALITY_DECLINE_WINDOW,
mean_delta: delta,
},
suggestion:
"Response quality is declining without recovery. Consider redirecting or simplifying the task."
.into(),
};
}
}
if let Some((tool_name, consecutive_count)) = self.tools.detected_loop() {
return Decision::CircuitBreak {
reason: CircuitBreakReason::RepeatedToolCallLoop {
tool_name,
consecutive_count,
},
suggestion: "Agent repeatedly called the same tool without progress. \
Break the loop: re-prompt with a different approach, \
escalate to the user, or mark the task unresolved."
.into(),
};
}
if let Some(drift) = self.scope.drift_score() {
if drift >= DRIFT_WARN_THRESHOLD {
return Decision::ScopeDriftWarn {
drift_tokens: self.scope.drift_tokens(),
drift_score: drift,
task_tokens: self.scope.task_tokens().to_vec(),
};
}
}
if !self.current_topic_cluster.is_empty() {
if let Some(pattern) = self
.correction
.pattern_for(&self.user_id, &self.current_topic_cluster)
{
return Decision::ProceduralWarning {
patterns: vec![pattern],
};
}
}
Decision::Continue
}
pub fn export(&self) -> RegulatorState {
let correction_patterns = self
.correction
.all_patterns(&self.user_id)
.into_iter()
.map(|p| (p.topic_cluster.clone(), p))
.collect();
RegulatorState {
user_id: self.user_id.clone(),
learned: self.session.export_learned(),
correction_patterns,
}
}
pub fn import(state: RegulatorState) -> Self {
let mut session = CognitiveSession::new();
session.import_learned(state.learned);
let mut correction = CorrectionStore::new();
for (cluster, pattern) in &state.correction_patterns {
for text in pattern.example_corrections.iter().rev() {
correction.record_correction(cluster, text.clone());
}
}
Self {
session,
user_id: state.user_id,
pending_response: None,
token_stats: TokenStatsAccumulator::new(),
scope: ScopeTracker::new(),
cost: CostAccumulator::new(),
correction,
current_topic_cluster: String::new(),
tools: ToolStatsAccumulator::new(),
implicit_correction_window: None,
last_turn_complete_at: None,
implicit_corrections_count: 0,
}
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub fn session(&self) -> &CognitiveSession {
&self.session
}
pub fn session_mut(&mut self) -> &mut CognitiveSession {
&mut self.session
}
#[must_use]
pub fn corrections_prelude(&self) -> Option<String> {
let patterns = match self.decide() {
Decision::ProceduralWarning { patterns } => patterns,
_ => return None,
};
let lines: Vec<String> = patterns
.iter()
.flat_map(|p| &p.example_corrections)
.map(|ex| format!("- {ex}"))
.collect();
if lines.is_empty() {
None
} else {
Some(format!(
"User has previously corrected responses on this topic with:\n{}",
lines.join("\n")
))
}
}
#[must_use]
pub fn inject_corrections(&self, user_prompt: &str) -> String {
match self.corrections_prelude() {
Some(prelude) => format!("{prelude}\n\nCurrent request: {user_prompt}"),
None => user_prompt.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn for_user_starts_with_fresh_session() {
let reg = Regulator::for_user("user_42");
assert_eq!(reg.user_id(), "user_42");
assert_eq!(reg.session().turn_count(), 0);
assert!(reg.session().world_model().last_response_strategy.is_none());
}
#[test]
fn turn_start_runs_cognitive_pipeline() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "Explain async Rust".into(),
});
assert_eq!(reg.session().turn_count(), 1);
}
#[test]
fn turn_complete_without_feedback_does_not_learn() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to use async?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's a step-by-step guide:\n1. Add tokio\n2. Write async fn\n3. Await"
.into(),
});
assert!(reg.session().world_model().last_response_strategy.is_none());
}
#[test]
fn quality_feedback_consolidates_buffered_response() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How do I use async in Rust?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's a step-by-step guide:\n1. Add tokio\n2. Write async fn\n3. Await"
.into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
assert!(reg
.session()
.world_model()
.last_response_strategy
.is_some());
}
#[test]
fn second_quality_feedback_after_drain_is_noop() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to async?".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: Add tokio\nStep 2: async fn\nStep 3: await".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
let learned_after_first = serde_json::to_string(
®.session().world_model().learned,
)
.expect("serialise LearnedState");
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.1,
fragment_spans: None,
});
let learned_after_second = serde_json::to_string(
®.session().world_model().learned,
)
.expect("serialise LearnedState");
assert_eq!(
learned_after_first, learned_after_second,
"drained buffer must not be re-consolidated by a stray feedback"
);
}
#[test]
fn quality_feedback_without_turn_complete_is_noop() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "How to async?".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.5,
fragment_spans: None,
});
assert!(reg.session().world_model().last_response_strategy.is_none());
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: First\nStep 2: Second\nStep 3: Third".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.8,
fragment_spans: None,
});
assert!(reg.session().world_model().last_response_strategy.is_some());
}
#[test]
fn inert_events_do_not_panic_or_mutate_turn_count() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "hi".into() });
let before = reg.session().turn_count();
reg.on_event(LLMEvent::Token {
token: "hello".into(),
logprob: -0.5,
index: 0,
});
reg.on_event(LLMEvent::Cost {
tokens_in: 10,
tokens_out: 20,
wallclock_ms: 500,
provider: Some("anthropic".into()),
});
reg.on_event(LLMEvent::UserCorrection {
correction_message: "don't add docstrings".into(),
corrects_last: true,
});
assert_eq!(reg.session().turn_count(), before);
}
#[test]
fn decide_returns_continue_by_default() {
let reg = Regulator::for_user("user_a");
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn export_import_roundtrip_preserves_learning() {
let mut reg = Regulator::for_user("user_persist");
for i in 0..10 {
reg.on_event(LLMEvent::TurnStart {
user_message: format!("Rust question {i}"),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "Step 1: First\nStep 2: Second\nStep 3: Third".into(),
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.85,
fragment_spans: None,
});
}
let snapshot = reg.export();
assert_eq!(snapshot.user_id, "user_persist");
assert!(
!snapshot.learned.response_strategies.is_empty(),
"training loop should have populated strategy EMA"
);
let restored = Regulator::import(snapshot.clone());
assert_eq!(restored.user_id(), "user_persist");
assert_eq!(
restored.session().world_model().learned.response_strategies.len(),
snapshot.learned.response_strategies.len(),
);
}
#[test]
fn roundtrip_via_serde_json() {
let mut reg = Regulator::for_user("user_json");
reg.on_event(LLMEvent::TurnStart { user_message: "hi".into() });
reg.on_event(LLMEvent::TurnComplete { full_response: "hello".into() });
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.7,
fragment_spans: None,
});
let snapshot = reg.export();
let json = serde_json::to_string(&snapshot).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
assert_eq!(decoded.user_id, snapshot.user_id);
assert_eq!(decoded.learned.tick, snapshot.learned.tick);
}
#[test]
fn session_mut_exposes_path1_escape_hatch() {
let mut reg = Regulator::for_user("user_a");
let initial = reg.session().world_model().body_budget;
reg.session_mut().track_cost(1.0);
let after = reg.session().world_model().body_budget;
assert!(after < initial, "track_cost via session_mut should deplete budget");
}
#[test]
fn confidence_starts_neutral() {
let reg = Regulator::for_user("user_a");
assert!((reg.confidence() - 0.5).abs() < 1e-9);
assert_eq!(reg.logprob_coverage(), 0.0);
}
#[test]
fn confident_token_stream_raises_confidence() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "explain async".into() });
for i in 0..15 {
reg.on_event(LLMEvent::Token {
token: "tok".into(),
logprob: -0.2,
index: i,
});
}
assert!(
reg.confidence() > 0.8,
"confident token stream should drive confidence >0.8 (got {})",
reg.confidence()
);
assert!((reg.logprob_coverage() - 1.0).abs() < 1e-9);
}
#[test]
fn gibberish_token_stream_lowers_confidence() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "asdfkjh qwer zxcvb".into(),
});
for i in 0..15 {
reg.on_event(LLMEvent::Token {
token: "?".into(),
logprob: -6.5,
index: i,
});
}
assert!(
reg.confidence() < 0.2,
"high-NLL stream should drive confidence <0.2 (got {})",
reg.confidence()
);
}
#[test]
fn turn_start_resets_token_window() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart { user_message: "q1".into() });
for i in 0..10 {
reg.on_event(LLMEvent::Token {
token: "t".into(),
logprob: -0.1,
index: i,
});
}
let confident = reg.confidence();
assert!(confident > 0.8);
reg.on_event(LLMEvent::TurnStart { user_message: "q2".into() });
assert!(
(reg.confidence() - 0.5).abs() < 1e-9,
"new turn should reset confidence to neutral (got {})",
reg.confidence()
);
}
#[test]
fn unavailable_logprobs_fall_back_to_structural_signal() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "how to refactor?".into(),
});
for i in 0..5 {
reg.on_event(LLMEvent::Token {
token: "t".into(),
logprob: 0.0, index: i,
});
}
reg.on_event(LLMEvent::TurnComplete {
full_response:
"Here's the refactored function. It preserves the original signature."
.into(),
});
assert!(
(reg.confidence() - 0.7).abs() < 0.01,
"structural fallback on unremarkable response should be ~0.7 (got {})",
reg.confidence()
);
assert_eq!(reg.logprob_coverage(), 0.0);
}
#[test]
fn unavailable_logprobs_short_response_fallback_is_low() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "can you do X?".into(),
});
reg.on_event(LLMEvent::Token {
token: "No.".into(),
logprob: 0.0,
index: 0,
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "No.".into(),
});
assert!(
reg.confidence() < 0.5,
"short unremarkable response should fall in low band (got {})",
reg.confidence()
);
}
#[test]
fn decide_emits_scope_drift_warn_on_plan_example() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
match reg.decide() {
Decision::ScopeDriftWarn {
drift_score,
drift_tokens,
task_tokens,
} => {
assert!(
drift_score > 0.3,
"plan target requires drift > 0.3 (got {drift_score})"
);
assert!(!drift_tokens.is_empty(), "drift_tokens must be populated");
assert!(!task_tokens.is_empty(), "task_tokens must be populated");
}
other => panic!(
"plan example must emit ScopeDriftWarn, got {other:?}"
),
}
}
#[test]
fn decide_continues_when_response_stays_on_task() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor the async function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "refactor async function".into(),
});
assert!(
matches!(reg.decide(), Decision::Continue),
"on-task response must not emit ScopeDriftWarn"
);
}
#[test]
fn decide_continues_before_turn_complete() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn turn_start_resets_scope_state() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and telemetry".into(),
});
assert!(matches!(reg.decide(), Decision::ScopeDriftWarn { .. }));
reg.on_event(LLMEvent::TurnStart {
user_message: "explain tokio runtime".into(),
});
assert!(
matches!(reg.decide(), Decision::Continue),
"new turn must clear stale response before new response arrives"
);
}
#[test]
fn cost_event_accumulates_totals() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::Cost {
tokens_in: 100,
tokens_out: 400,
wallclock_ms: 2_000,
provider: None,
});
reg.on_event(LLMEvent::Cost {
tokens_in: 50,
tokens_out: 200,
wallclock_ms: 1_500,
provider: Some("anthropic".into()),
});
assert_eq!(reg.total_tokens_out(), 600);
}
#[test]
fn cost_event_feeds_track_cost_via_normalisation() {
let mut reg = Regulator::for_user("user_a");
let initial_budget = reg.session().world_model().body_budget;
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: cost::TYPICAL_TURN_TOKENS_OUT,
wallclock_ms: cost::TYPICAL_TURN_WALLCLOCK_MS,
provider: None,
});
let after_budget = reg.session().world_model().body_budget;
assert!(
after_budget < initial_budget,
"Cost event must deplete body_budget via track_cost (before {initial_budget}, after {after_budget})"
);
}
#[test]
fn quality_feedback_records_in_cost_accumulator() {
let mut reg = Regulator::for_user("user_a");
for q in [0.9, 0.8, 0.7] {
reg.on_event(LLMEvent::QualityFeedback {
quality: q,
fragment_spans: None,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_emits_circuit_break_on_cost_cap_with_poor_quality() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(1_000);
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 1_000,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.3, fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak { reason, .. } => match reason {
CircuitBreakReason::CostCapReached {
tokens_spent,
tokens_cap,
mean_quality_last_n,
} => {
assert_eq!(tokens_spent, 1_200);
assert_eq!(tokens_cap, 1_000);
assert!(mean_quality_last_n < 0.5);
}
other => panic!("expected CostCapReached, got {other:?}"),
},
other => panic!("expected CircuitBreak, got {other:?}"),
}
}
#[test]
fn decide_does_not_emit_circuit_break_when_quality_recovers() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(1_000);
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 1_000,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.9, fragment_spans: None,
});
}
assert!(
matches!(reg.decide(), Decision::Continue),
"high-quality agent at over-cap must not be halted"
);
}
#[test]
fn decide_emits_circuit_break_on_quality_decline_no_recovery() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(u32::MAX);
for q in [0.8, 0.5, 0.3, 0.2, 0.15] {
reg.on_event(LLMEvent::QualityFeedback {
quality: q,
fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak {
reason:
CircuitBreakReason::QualityDeclineNoRecovery { turns, mean_delta },
..
} => {
assert_eq!(turns, cost::QUALITY_DECLINE_WINDOW);
assert!(mean_delta >= cost::QUALITY_DECLINE_MIN_DELTA);
}
other => panic!("expected QualityDeclineNoRecovery, got {other:?}"),
}
}
#[test]
fn decide_priority_circuit_break_dominates_scope_drift() {
let mut reg = Regulator::for_user("user_a").with_cost_cap(500);
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
for _ in 0..3 {
reg.on_event(LLMEvent::Cost {
tokens_in: 0,
tokens_out: 400,
wallclock_ms: 0,
provider: None,
});
reg.on_event(LLMEvent::QualityFeedback {
quality: 0.2,
fragment_spans: None,
});
}
match reg.decide() {
Decision::CircuitBreak {
reason: CircuitBreakReason::CostCapReached { .. },
..
} => {}
other => panic!(
"priority rule must emit CircuitBreak CostCapReached, got {other:?}"
),
}
}
#[test]
fn with_cost_cap_preserves_prior_accumulation() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::Cost {
tokens_in: 100,
tokens_out: 200,
wallclock_ms: 500,
provider: None,
});
assert_eq!(reg.total_tokens_out(), 200);
let reg = reg.with_cost_cap(10_000);
assert_eq!(reg.total_tokens_out(), 200, "cap change must not reset counters");
assert_eq!(reg.cost_cap_tokens(), 10_000);
}
fn drive_three_corrections_on(reg: &mut Regulator, task_message: &str) {
reg.on_event(LLMEvent::TurnStart {
user_message: task_message.into(),
});
for msg in [
"don't add logging",
"stop adding logging please",
"no more logs",
] {
reg.on_event(LLMEvent::UserCorrection {
correction_message: msg.into(),
corrects_last: true,
});
}
}
#[test]
fn user_correction_requires_corrects_last_true() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::UserCorrection {
correction_message: "something entirely different".into(),
corrects_last: false,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn user_correction_dropped_when_no_active_cluster() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "is it ok?".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::UserCorrection {
correction_message: "never do X".into(),
corrects_last: true,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_emits_procedural_warning_at_pattern_threshold() {
let mut reg = Regulator::for_user("user_42");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
match reg.decide() {
Decision::ProceduralWarning { patterns } => {
assert_eq!(patterns.len(), 1);
let pattern = &patterns[0];
assert_eq!(pattern.user_id, "user_42");
assert_eq!(pattern.learned_from_turns, 3);
assert!(pattern
.pattern_name
.starts_with("corrections_on_"));
assert_eq!(pattern.example_corrections.len(), 3);
assert_eq!(pattern.example_corrections[0], "no more logs");
}
other => panic!("expected ProceduralWarning, got {other:?}"),
}
}
#[test]
fn decide_does_not_emit_procedural_warning_below_threshold() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
for msg in ["don't add logging", "no more logs"] {
reg.on_event(LLMEvent::UserCorrection {
correction_message: msg.into(),
corrects_last: true,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn procedural_warning_fires_only_on_matching_cluster() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
assert!(matches!(reg.decide(), Decision::ProceduralWarning { .. }));
reg.on_event(LLMEvent::TurnStart {
user_message: "explain docker containers to me".into(),
});
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn decide_priority_scope_drift_dominates_procedural_warning() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging and error handling".into(),
});
assert!(matches!(reg.decide(), Decision::ScopeDriftWarn { .. }));
}
#[test]
fn corrections_prelude_none_when_no_pattern() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
assert!(reg.corrections_prelude().is_none());
}
#[test]
fn corrections_prelude_returns_formatted_block_after_threshold() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
let prelude = reg.corrections_prelude().expect("pattern should apply");
assert!(
prelude.starts_with("User has previously corrected responses on this topic with:"),
"unexpected header: {prelude}"
);
for expected in &["no more logs", "stop adding logging please", "don't add logging"] {
assert!(
prelude.contains(&format!("- {expected}")),
"prelude missing correction {expected:?}: {prelude}"
);
}
}
#[test]
fn inject_corrections_returns_unchanged_when_no_pattern() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
let prompt = "write me a hello world";
assert_eq!(reg.inject_corrections(prompt), prompt);
}
#[test]
fn inject_corrections_wraps_prompt_after_threshold() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
let prompt = "refactor this function to be async";
let injected = reg.inject_corrections(prompt);
assert!(
injected.contains("User has previously corrected responses"),
"injected prompt missing prelude header"
);
assert!(
injected.contains(&format!("Current request: {prompt}")),
"injected prompt missing current-request marker"
);
assert!(injected.len() > prompt.len(), "expected expansion");
}
#[test]
fn tool_call_event_accumulates_stats() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "do a search".into(),
});
for _ in 0..3 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
assert_eq!(reg.tool_total_calls(), 3);
let counts = reg.tool_counts_by_name();
assert_eq!(counts.get("search"), Some(&3));
}
#[test]
fn tool_call_loop_fires_circuit_break() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "do a search".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
match reg.decide() {
Decision::CircuitBreak {
reason:
CircuitBreakReason::RepeatedToolCallLoop {
tool_name,
consecutive_count,
},
..
} => {
assert_eq!(tool_name, "search");
assert_eq!(consecutive_count, 5);
}
other => panic!("expected RepeatedToolCallLoop, got {other:?}"),
}
}
#[test]
fn tool_call_interleaved_does_not_fire_loop() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "do a search".into(),
});
for _ in 0..4 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
reg.on_event(LLMEvent::ToolCall {
tool_name: "db.query".into(),
args_json: None,
});
for _ in 0..4 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn turn_start_resets_tool_stats() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "turn 1".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
assert!(matches!(reg.decide(), Decision::CircuitBreak { .. }));
reg.on_event(LLMEvent::TurnStart {
user_message: "turn 2".into(),
});
assert_eq!(reg.tool_total_calls(), 0);
assert!(!matches!(reg.decide(), Decision::CircuitBreak { .. }));
}
#[test]
fn tool_result_observability_counters() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "do a search".into(),
});
reg.on_event(LLMEvent::ToolResult {
tool_name: "search".into(),
success: true,
duration_ms: 120,
error_summary: None,
});
reg.on_event(LLMEvent::ToolResult {
tool_name: "db.query".into(),
success: false,
duration_ms: 250,
error_summary: Some("timeout".into()),
});
assert_eq!(reg.tool_total_duration_ms(), 370);
assert_eq!(reg.tool_failure_count(), 1);
}
#[test]
fn tool_loop_has_higher_priority_than_scope_drift() {
let mut reg = Regulator::for_user("user_a");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this async function".into(),
});
for _ in 0..5 {
reg.on_event(LLMEvent::ToolCall {
tool_name: "search".into(),
args_json: None,
});
}
reg.on_event(LLMEvent::TurnComplete {
full_response: "added logging database migration new schema totally unrelated".into(),
});
assert!(matches!(
reg.decide(),
Decision::CircuitBreak {
reason: CircuitBreakReason::RepeatedToolCallLoop { .. },
..
}
));
}
#[test]
fn corrections_prelude_none_when_higher_priority_decision_dominates() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
reg.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
reg.on_event(LLMEvent::TurnComplete {
full_response: "add logging plus database migration new schema".into(),
});
assert!(matches!(reg.decide(), Decision::ScopeDriftWarn { .. }));
assert!(reg.corrections_prelude().is_none());
}
#[test]
fn export_includes_correction_patterns() {
let mut reg = Regulator::for_user("user_a");
drive_three_corrections_on(&mut reg, "refactor this function to be async");
let snapshot = reg.export();
assert!(
!snapshot.correction_patterns.is_empty(),
"at-threshold cluster must appear in exported patterns"
);
let cluster_key = snapshot
.correction_patterns
.keys()
.next()
.expect("one pattern");
let pattern = &snapshot.correction_patterns[cluster_key];
assert_eq!(pattern.learned_from_turns, 3);
assert_eq!(pattern.example_corrections.len(), 3);
}
#[test]
fn import_restores_patterns_via_example_replay() {
let mut source = Regulator::for_user("user_persist");
drive_three_corrections_on(&mut source, "refactor this function to be async");
let snapshot = source.export();
let json = serde_json::to_string(&snapshot).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
let mut restored = Regulator::import(decoded);
restored.on_event(LLMEvent::TurnStart {
user_message: "refactor this function to be async".into(),
});
match restored.decide() {
Decision::ProceduralWarning { patterns } => {
assert_eq!(patterns.len(), 1);
assert_eq!(patterns[0].learned_from_turns, 3);
assert_eq!(patterns[0].example_corrections.len(), 3);
}
other => panic!("restored regulator must fire ProceduralWarning, got {other:?}"),
}
}
#[test]
fn import_preserves_example_corrections_order() {
let mut source = Regulator::for_user("user_order");
drive_three_corrections_on(&mut source, "refactor this function to be async");
let snapshot_before = source.export();
let (cluster, pattern_before) = snapshot_before
.correction_patterns
.iter()
.next()
.expect("one pattern after 3 corrections");
let before = pattern_before.example_corrections.clone();
let json = serde_json::to_string(&snapshot_before).expect("serialise");
let decoded: RegulatorState =
serde_json::from_str(&json).expect("deserialise");
let restored = Regulator::import(decoded);
let snapshot_after = restored.export();
let pattern_after = snapshot_after
.correction_patterns
.get(cluster)
.expect("pattern restored under same cluster key");
assert_eq!(pattern_after.example_corrections, before);
}
#[test]
fn legacy_snapshot_loads_without_correction_patterns() {
let legacy = r#"{
"user_id": "legacy",
"learned": {
"gain_mode": "neutral",
"tick": 0,
"response_success": {},
"response_strategies": {}
}
}"#;
let state: RegulatorState =
serde_json::from_str(legacy).expect("legacy snapshot must load");
let reg = Regulator::import(state);
assert_eq!(reg.user_id(), "legacy");
assert!(matches!(reg.decide(), Decision::Continue));
}
#[test]
fn implicit_correction_off_by_default_no_synthetic_records() {
let mut r = Regulator::for_user("u");
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "resp".into(),
});
r.on_event(LLMEvent::TurnStart {
user_message: "Fix the fetch_user async refactoring".into(),
});
assert_eq!(r.implicit_corrections_count(), 0);
}
#[test]
fn implicit_correction_fires_on_fast_same_cluster_retry() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "(unsatisfactory response)".into(),
});
std::thread::sleep(std::time::Duration::from_millis(20));
r.on_event(LLMEvent::TurnStart {
user_message: "Fix the fetch_user async refactoring".into(),
});
assert_eq!(r.implicit_corrections_count(), 1);
}
#[test]
fn implicit_correction_skips_when_window_expires() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(50));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "resp".into(),
});
std::thread::sleep(std::time::Duration::from_millis(150));
r.on_event(LLMEvent::TurnStart {
user_message: "Fix the fetch_user async refactoring".into(),
});
assert_eq!(
r.implicit_corrections_count(),
0,
"retry outside window must not synthesise a correction"
);
}
#[test]
fn implicit_correction_skips_on_different_cluster() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "resp".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "Explain tokio channels and scheduling".into(),
});
assert_eq!(
r.implicit_corrections_count(),
0,
"different-topic follow-up must not count as a correction"
);
}
#[test]
fn implicit_correction_skips_on_first_ever_turn() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
assert_eq!(r.implicit_corrections_count(), 0);
}
#[test]
fn implicit_correction_skips_when_cluster_is_empty() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "ok hi".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "greetings".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "hi ok".into(),
});
assert_eq!(r.implicit_corrections_count(), 0);
}
#[test]
fn implicit_correction_accumulates_to_pattern_at_threshold() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "try 1".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "Fix the fetch_user async refactoring".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "try 2".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "Make fetch_user async properly".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "try 3".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "Update fetch_user to async version".into(),
});
assert_eq!(
r.implicit_corrections_count(),
3,
"3 fast same-cluster retries → 3 synthesised corrections"
);
let d = r.decide();
match d {
Decision::ProceduralWarning { patterns } => {
assert!(!patterns.is_empty());
assert_eq!(patterns[0].learned_from_turns, 3);
}
other => panic!(
"expected ProceduralWarning after 3 implicit corrections, got {other:?}"
),
}
}
#[test]
fn metrics_snapshot_exposes_stable_keys() {
let r = Regulator::for_user("m").with_cost_cap(5_000);
let snap = r.metrics_snapshot();
for key in [
"noos.confidence",
"noos.logprob_coverage",
"noos.total_tokens_out",
"noos.cost_cap_tokens",
"noos.tool_total_calls",
"noos.tool_total_duration_ms",
"noos.tool_failure_count",
"noos.implicit_corrections_count",
] {
assert!(snap.contains_key(key), "missing metric key {key:?}");
let v = snap[key];
assert!(v.is_finite(), "metric {key} produced non-finite {v}");
}
assert!((snap["noos.cost_cap_tokens"] - 5_000.0).abs() < 1e-9);
}
#[test]
fn metrics_snapshot_tracks_state_changes() {
let mut r = Regulator::for_user("m").with_cost_cap(5_000);
let before = r.metrics_snapshot();
assert!((before["noos.total_tokens_out"] - 0.0).abs() < 1e-9);
r.on_event(LLMEvent::TurnStart {
user_message: "hello".into(),
});
r.on_event(LLMEvent::Cost {
tokens_in: 10,
tokens_out: 150,
wallclock_ms: 500,
provider: None,
});
let after = r.metrics_snapshot();
assert!((after["noos.total_tokens_out"] - 150.0).abs() < 1e-9);
}
#[test]
fn implicit_counter_not_persisted_across_import() {
let mut r = Regulator::for_user("u")
.with_implicit_correction_window(std::time::Duration::from_millis(500));
r.on_event(LLMEvent::TurnStart {
user_message: "Refactor fetch_user to be async".into(),
});
r.on_event(LLMEvent::TurnComplete {
full_response: "r1".into(),
});
std::thread::sleep(std::time::Duration::from_millis(10));
r.on_event(LLMEvent::TurnStart {
user_message: "Fix the fetch_user async refactoring".into(),
});
assert_eq!(r.implicit_corrections_count(), 1);
let snapshot = r.export();
let restored = Regulator::import(snapshot);
assert_eq!(
restored.implicit_corrections_count(),
0,
"counter is per-process; not in RegulatorState"
);
}
}