use std::collections::HashMap;
use std::sync::Mutex;
pub const NAG_THRESHOLD_ENV: &str = "AI_MEMORY_CAPTURE_NAG_THRESHOLD";
pub const NAG_ESCALATE_THRESHOLD_ENV: &str = "AI_MEMORY_CAPTURE_NAG_ESCALATE_THRESHOLD";
pub const DEFAULT_NAG_THRESHOLD: u32 = 5;
pub const DEFAULT_NAG_ESCALATE_THRESHOLD: u32 = 20;
#[derive(Debug, Clone, Copy, Default)]
struct SessionCounter {
non_store_streak: u32,
primary_warned: bool,
escalation_warned: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NagAction {
None,
Warn,
WarnAndEscalate,
}
pub struct CaptureNagWatcher {
inner: Mutex<HashMap<(String, String), SessionCounter>>,
primary_threshold: u32,
escalation_threshold: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolKind {
MemoryWrite,
Other,
}
impl CaptureNagWatcher {
#[must_use]
pub fn new_from_env() -> Self {
let primary = parse_threshold_env(NAG_THRESHOLD_ENV, DEFAULT_NAG_THRESHOLD);
let escalation =
parse_threshold_env(NAG_ESCALATE_THRESHOLD_ENV, DEFAULT_NAG_ESCALATE_THRESHOLD);
Self::new(primary, escalation)
}
#[must_use]
pub fn new(primary_threshold: u32, escalation_threshold: u32) -> Self {
Self {
inner: Mutex::new(HashMap::new()),
primary_threshold,
escalation_threshold,
}
}
pub fn observe_tool_call(
&self,
agent_id: &str,
session_id: &str,
tool_kind: ToolKind,
) -> NagAction {
let key = (agent_id.to_string(), session_id.to_string());
let Ok(mut state) = self.inner.lock() else {
return NagAction::None;
};
let entry = state.entry(key).or_default();
match tool_kind {
ToolKind::MemoryWrite => {
*entry = SessionCounter::default();
NagAction::None
}
ToolKind::Other => {
entry.non_store_streak = entry.non_store_streak.saturating_add(1);
if self.escalation_threshold > 0
&& entry.non_store_streak >= self.escalation_threshold
&& !entry.escalation_warned
{
entry.escalation_warned = true;
return NagAction::WarnAndEscalate;
}
if self.primary_threshold > 0
&& entry.non_store_streak >= self.primary_threshold
&& !entry.primary_warned
{
entry.primary_warned = true;
return NagAction::Warn;
}
NagAction::None
}
}
}
#[must_use]
pub fn streak_for(&self, agent_id: &str, session_id: &str) -> u32 {
let key = (agent_id.to_string(), session_id.to_string());
let Ok(state) = self.inner.lock() else {
return 0;
};
state.get(&key).map_or(0, |c| c.non_store_streak)
}
pub fn drop_session(&self, agent_id: &str, session_id: &str) {
let key = (agent_id.to_string(), session_id.to_string());
if let Ok(mut state) = self.inner.lock() {
state.remove(&key);
}
}
#[must_use]
pub fn primary_threshold(&self) -> u32 {
self.primary_threshold
}
#[must_use]
pub fn escalation_threshold(&self) -> u32 {
self.escalation_threshold
}
}
impl Default for CaptureNagWatcher {
fn default() -> Self {
Self::new_from_env()
}
}
#[must_use]
pub fn classify_tool(tool_name: &str) -> ToolKind {
use crate::mcp::registry::tool_names as tn;
match tool_name {
tn::MEMORY_STORE
| tn::MEMORY_UPDATE
| tn::MEMORY_LINK
| tn::MEMORY_ATOMISE
| tn::MEMORY_INGEST_MULTISTEP
| tn::MEMORY_CONSOLIDATE
| tn::MEMORY_PROMOTE
| tn::MEMORY_REFLECT
| tn::MEMORY_PERSONA_GENERATE
| tn::MEMORY_ENTITY_REGISTER
| tn::MEMORY_SHARE
| tn::MEMORY_SUBSCRIBE
| tn::MEMORY_NOTIFY
| tn::MEMORY_SKILL_REGISTER
| tn::MEMORY_SKILL_PROMOTE_FROM_REFLECTION
| tn::MEMORY_NAMESPACE_SET_STANDARD
| tn::MEMORY_KG_INVALIDATE
| tn::MEMORY_CAPTURE_TURN => ToolKind::MemoryWrite,
_ => ToolKind::Other,
}
}
fn parse_threshold_env(name: &str, default_value: u32) -> u32 {
std::env::var(name)
.ok()
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(default_value)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_tool_recognizes_writes() {
assert_eq!(classify_tool("memory_store"), ToolKind::MemoryWrite);
assert_eq!(classify_tool("memory_update"), ToolKind::MemoryWrite);
assert_eq!(classify_tool("memory_link"), ToolKind::MemoryWrite);
assert_eq!(classify_tool("memory_atomise"), ToolKind::MemoryWrite);
assert_eq!(
classify_tool("memory_capture_turn"),
ToolKind::MemoryWrite,
"L4 surface MUST reset the nag counter"
);
}
#[test]
fn classify_tool_defaults_to_other() {
assert_eq!(classify_tool("memory_recall"), ToolKind::Other);
assert_eq!(classify_tool("memory_get"), ToolKind::Other);
assert_eq!(classify_tool("bash"), ToolKind::Other);
assert_eq!(classify_tool("unknown_future_tool"), ToolKind::Other);
}
#[test]
fn primary_threshold_fires_exactly_once() {
let w = CaptureNagWatcher::new(3, 10);
for _ in 0..2 {
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::None
);
}
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::Warn
);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::None
);
}
#[test]
fn memory_write_resets_streak() {
let w = CaptureNagWatcher::new(3, 10);
for _ in 0..2 {
w.observe_tool_call("agent", "session", ToolKind::Other);
}
assert_eq!(w.streak_for("agent", "session"), 2);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::MemoryWrite),
NagAction::None
);
assert_eq!(w.streak_for("agent", "session"), 0);
for _ in 0..2 {
w.observe_tool_call("agent", "session", ToolKind::Other);
}
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::Warn,
"re-armed WARN after reset"
);
}
#[test]
fn escalation_threshold_fires_after_sustained_drift() {
let w = CaptureNagWatcher::new(2, 4);
w.observe_tool_call("agent", "session", ToolKind::Other);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::Warn
);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::None
);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::WarnAndEscalate
);
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::None
);
}
#[test]
fn per_session_counters_are_independent() {
let w = CaptureNagWatcher::new(2, 10);
w.observe_tool_call("agent", "session-a", ToolKind::Other);
assert_eq!(w.streak_for("agent", "session-b"), 0);
assert_eq!(
w.observe_tool_call("agent", "session-a", ToolKind::Other),
NagAction::Warn
);
assert_eq!(
w.observe_tool_call("agent", "session-b", ToolKind::Other),
NagAction::None
);
}
#[test]
fn per_agent_counters_are_independent() {
let w = CaptureNagWatcher::new(2, 10);
w.observe_tool_call("agent-a", "session", ToolKind::Other);
assert_eq!(w.streak_for("agent-b", "session"), 0);
}
#[test]
fn drop_session_clears_counter() {
let w = CaptureNagWatcher::new(2, 10);
w.observe_tool_call("agent", "session", ToolKind::Other);
assert_eq!(w.streak_for("agent", "session"), 1);
w.drop_session("agent", "session");
assert_eq!(w.streak_for("agent", "session"), 0);
}
#[test]
fn disabled_thresholds_never_fire() {
let w = CaptureNagWatcher::new(0, 0);
for _ in 0..100 {
assert_eq!(
w.observe_tool_call("agent", "session", ToolKind::Other),
NagAction::None
);
}
}
#[test]
fn streak_saturates_instead_of_overflowing() {
let w = CaptureNagWatcher::new(5, 10);
let key = ("agent".to_string(), "session".to_string());
{
let mut state = w.inner.lock().unwrap();
state.insert(
key,
SessionCounter {
non_store_streak: u32::MAX - 1,
primary_warned: true,
escalation_warned: true,
},
);
}
w.observe_tool_call("agent", "session", ToolKind::Other);
w.observe_tool_call("agent", "session", ToolKind::Other);
assert_eq!(w.streak_for("agent", "session"), u32::MAX);
}
}