use crate::spec_ai_core::agent::model::{GenerationConfig, TokenUsage};
use crate::spec_ai_core::config::SafetyConfig;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SafetyLimitHit {
pub limit: String,
pub configured: u64,
pub observed: u64,
pub message: String,
}
impl SafetyLimitHit {
fn new(limit: impl Into<String>, configured: u64, observed: u64) -> Self {
let limit = limit.into();
Self {
message: format!(
"Stopped because recursion/cost safety limit was reached: {} (configured {}, observed {}).",
limit, configured, observed
),
limit,
configured,
observed,
}
}
pub fn finish_reason(&self) -> String {
format!("safety_limit:{}", self.limit)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct SafetyStats {
pub enabled: bool,
pub model_calls: usize,
pub tool_calls: usize,
pub loop_iterations: usize,
pub repeated_tool_calls: usize,
pub prompt_bytes: usize,
pub tool_output_bytes: usize,
pub total_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit_hit: Option<SafetyLimitHit>,
}
#[derive(Debug)]
struct SafetyState {
config: SafetyConfig,
stats: SafetyStats,
repeated_tool_calls: HashMap<String, usize>,
}
#[derive(Debug, Clone)]
pub struct RunSafetyBudget {
inner: Arc<Mutex<SafetyState>>,
}
impl RunSafetyBudget {
pub fn new(config: SafetyConfig) -> Self {
let stats = SafetyStats {
enabled: config.enabled,
..SafetyStats::default()
};
Self {
inner: Arc::new(Mutex::new(SafetyState {
config,
stats,
repeated_tool_calls: HashMap::new(),
})),
}
}
pub fn config(&self) -> SafetyConfig {
self.inner
.lock()
.expect("safety budget poisoned")
.config
.clone()
}
pub fn stats(&self) -> SafetyStats {
self.inner
.lock()
.expect("safety budget poisoned")
.stats
.clone()
}
pub fn limit_hit(&self) -> Option<SafetyLimitHit> {
self.inner
.lock()
.expect("safety budget poisoned")
.stats
.limit_hit
.clone()
}
pub fn safety_response(hit: &SafetyLimitHit) -> String {
hit.message.clone()
}
pub fn clamp_generation_config(&self, config: &GenerationConfig) -> GenerationConfig {
let state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return config.clone();
}
let mut clamped = config.clone();
let max = state.config.max_output_tokens_per_call;
clamped.max_tokens = Some(config.max_tokens.map(|v| v.min(max)).unwrap_or(max));
clamped
}
pub fn record_loop_iteration(&self) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
state.stats.loop_iterations += 1;
let observed = state.stats.loop_iterations;
let configured = state.config.max_tool_loop_iterations as u64;
if observed as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_tool_loop_iterations",
configured,
observed as u64,
));
}
Ok(())
}
pub fn record_model_call(&self, prompt: &str) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
state.stats.model_calls += 1;
let model_calls = state.stats.model_calls;
let configured = state.config.max_model_calls_per_run as u64;
if model_calls as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_model_calls_per_run",
configured,
model_calls as u64,
));
}
state.stats.prompt_bytes = state.stats.prompt_bytes.saturating_add(prompt.len());
let prompt_bytes = state.stats.prompt_bytes;
let configured = state.config.max_prompt_bytes_per_run as u64;
if prompt_bytes as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_prompt_bytes_per_run",
configured,
prompt_bytes as u64,
));
}
Ok(())
}
pub fn record_token_usage(&self, usage: &TokenUsage) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
state.stats.total_tokens = state
.stats
.total_tokens
.saturating_add(usage.total_tokens as u64);
let total_tokens = state.stats.total_tokens;
let configured = state.config.max_total_tokens_per_run;
if total_tokens > configured {
return Err(Self::set_hit(
&mut state,
"max_total_tokens_per_run",
configured,
total_tokens,
));
}
Ok(())
}
pub fn record_tool_call(&self, tool_name: &str, args: &Value) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
state.stats.tool_calls += 1;
let tool_calls = state.stats.tool_calls;
let configured = state.config.max_tool_calls_per_run as u64;
if tool_calls as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_tool_calls_per_run",
configured,
tool_calls as u64,
));
}
let key = tool_fingerprint(tool_name, args);
let repeat_count = {
let count = state.repeated_tool_calls.entry(key).or_insert(0);
*count += 1;
*count
};
state.stats.repeated_tool_calls = state.stats.repeated_tool_calls.max(repeat_count);
let configured = state.config.max_repeated_tool_calls as u64;
if repeat_count as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_repeated_tool_calls",
configured,
repeat_count as u64,
));
}
Ok(())
}
pub fn record_tool_output(&self, tool_output: &str) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
state.stats.tool_output_bytes = state
.stats
.tool_output_bytes
.saturating_add(tool_output.len());
let output_bytes = state.stats.tool_output_bytes;
let configured = state.config.max_tool_output_bytes as u64;
if output_bytes as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_tool_output_bytes",
configured,
output_bytes as u64,
));
}
Ok(())
}
pub fn check_delegation_depth(&self, depth: usize) -> Result<(), SafetyLimitHit> {
let mut state = self.inner.lock().expect("safety budget poisoned");
if !state.config.enabled {
return Ok(());
}
let configured = state.config.max_delegation_depth as u64;
if depth as u64 > configured {
return Err(Self::set_hit(
&mut state,
"max_delegation_depth",
configured,
depth as u64,
));
}
Ok(())
}
fn set_hit(
state: &mut SafetyState,
limit: &str,
configured: u64,
observed: u64,
) -> SafetyLimitHit {
let hit = SafetyLimitHit::new(limit, configured, observed);
if state.stats.limit_hit.is_none() {
state.stats.limit_hit = Some(hit.clone());
}
hit
}
}
fn tool_fingerprint(tool_name: &str, args: &Value) -> String {
let args_json = serde_json::to_string(args).unwrap_or_else(|_| args.to_string());
let hash = blake3::hash(args_json.as_bytes()).to_hex().to_string();
format!("{}:{}", tool_name, hash)
}