use serde::{Deserialize, Serialize};
use crate::cognition::detector;
pub const DRIFT_WARN_THRESHOLD: f64 = 0.5;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScopeTracker {
task_keywords: Vec<String>,
response_keywords: Vec<String>,
}
impl ScopeTracker {
pub fn new() -> Self {
Self::default()
}
pub fn set_task(&mut self, user_message: &str) {
self.task_keywords = detector::extract_topics(user_message);
self.response_keywords.clear();
}
pub fn set_response(&mut self, full_response: &str) {
self.response_keywords = detector::extract_topics(full_response);
}
pub fn task_tokens(&self) -> &[String] {
&self.task_keywords
}
pub fn response_tokens(&self) -> &[String] {
&self.response_keywords
}
pub fn drift_score(&self) -> Option<f64> {
if self.task_keywords.is_empty() || self.response_keywords.is_empty() {
return None;
}
let task_set = detector::to_topic_set(&self.task_keywords);
let overlap = detector::count_topic_overlap(&self.response_keywords, &task_set);
let non_task = self.response_keywords.len() - overlap;
Some(non_task as f64 / self.response_keywords.len() as f64)
}
pub fn drift_tokens(&self) -> Vec<String> {
let task_set = detector::to_topic_set(&self.task_keywords);
self.response_keywords
.iter()
.filter(|k| !task_set.contains(&k.to_lowercase()))
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_tracker_has_no_drift_score() {
let tracker = ScopeTracker::new();
assert!(tracker.drift_score().is_none());
assert!(tracker.task_tokens().is_empty());
assert!(tracker.response_tokens().is_empty());
assert!(tracker.drift_tokens().is_empty());
}
#[test]
fn task_only_has_no_drift_score() {
let mut tracker = ScopeTracker::new();
tracker.set_task("refactor this function to be async");
assert!(!tracker.task_tokens().is_empty());
assert!(tracker.drift_score().is_none());
}
#[test]
fn response_only_has_no_drift_score() {
let mut tracker = ScopeTracker::new();
tracker.set_response("Here is the refactored async function.");
assert!(!tracker.response_tokens().is_empty());
assert!(tracker.drift_score().is_none());
}
#[test]
fn plan_example_flags_high_drift() {
let mut tracker = ScopeTracker::new();
tracker.set_task("refactor this function to be async");
tracker.set_response("add logging and error handling");
let drift = tracker.drift_score().expect("both sides populated");
assert!(
drift > 0.3,
"plan test target must produce drift > 0.3 (got {drift})"
);
assert!(drift >= DRIFT_WARN_THRESHOLD);
}
#[test]
fn on_task_response_keeps_drift_low() {
let mut tracker = ScopeTracker::new();
tracker.set_task("refactor the async function");
tracker.set_response("refactor async function");
let drift = tracker.drift_score().expect("both sides populated");
assert!(
drift < DRIFT_WARN_THRESHOLD,
"minimal on-task response should stay under warning threshold (got {drift})"
);
}
#[test]
fn drift_tokens_only_contains_non_task_keywords() {
let mut tracker = ScopeTracker::new();
tracker.set_task("refactor the async function");
tracker.set_response("add logging telemetry for the async function");
let drift_tokens = tracker.drift_tokens();
for token in tracker.task_tokens() {
assert!(
!drift_tokens.contains(token),
"drift_tokens must not contain task keyword {token:?}"
);
}
assert!(!drift_tokens.is_empty());
}
#[test]
fn set_task_resets_previous_response() {
let mut tracker = ScopeTracker::new();
tracker.set_task("refactor the async function");
tracker.set_response("add logging and error handling");
assert!(tracker.drift_score().is_some());
tracker.set_task("explain tokio runtime");
assert!(tracker.response_tokens().is_empty());
assert!(
tracker.drift_score().is_none(),
"new task must clear stale response"
);
}
#[test]
fn drift_score_bounded_zero_to_one() {
let mut tracker = ScopeTracker::new();
tracker.set_task("task keyword");
tracker.set_response("completely different response content");
let drift = tracker.drift_score().expect("both sides populated");
assert!((0.0..=1.0).contains(&drift));
}
#[test]
fn decision_checkpoint_fpr_on_hand_crafted_cases() {
let cases: &[(&str, &str, bool)] = &[
(
"refactor this function to be async",
"add logging and error handling",
true, ),
(
"explain tokio runtime",
"here is a recipe for chocolate cake with frosting",
true, ),
(
"help me with SQL queries",
"JavaScript frameworks overview: React, Vue, Angular",
true, ),
(
"fix the authentication bug",
"my thoughts on microservice architecture patterns",
true, ),
(
"explain docker containers",
"chocolate cake baking instructions with butter",
true, ),
(
"refactor async function",
"refactor async function",
false, ),
(
"refactor the async function",
"refactored async function returned",
false, ),
(
"tokio async runtime rust",
"tokio async runtime rust futures scheduling",
false, ),
(
"fix error authentication rust",
"fix error authentication rust verify",
false, ),
(
"jwt token format explain",
"jwt token format explain signature",
false, ),
];
let mut mis_classifications = 0usize;
let mut report = String::new();
for (task, response, should_flag) in cases {
let mut tracker = ScopeTracker::new();
tracker.set_task(task);
tracker.set_response(response);
let drift = tracker.drift_score().unwrap_or(0.0);
let flagged = drift >= DRIFT_WARN_THRESHOLD;
let correct = flagged == *should_flag;
if !correct {
mis_classifications += 1;
}
report.push_str(&format!(
" [{}] task={:?} resp={:?} drift={:.2} flag={} expected={} {}\n",
if correct { "✓" } else { "✗" },
task,
response,
drift,
flagged,
should_flag,
if correct { "" } else { "← MIS" },
));
}
let total = cases.len();
let error_rate = mis_classifications as f64 / total as f64;
assert!(
error_rate <= 0.2,
"decision checkpoint failed: {mis_classifications}/{total} \
mis-classified ({:.0}% error rate, bar ≤ 20%). Report:\n{report}",
error_rate * 100.0
);
}
}