spec-ai 0.7.0

A framework for building AI agents with structured outputs, policy enforcement, and execution tracking
Documentation
use crate::spec_ai_core::agent::model::{GenerationConfig, TokenUsage};
use crate::spec_ai_core::config::SafetyConfig;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

/// A concrete safety limit that stopped or constrained a run.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SafetyLimitHit {
    pub limit: String,
    pub configured: u64,
    pub observed: u64,
    pub message: String,
}

impl SafetyLimitHit {
    fn new(limit: impl Into<String>, configured: u64, observed: u64) -> Self {
        let limit = limit.into();
        Self {
            message: format!(
                "Stopped because recursion/cost safety limit was reached: {} (configured {}, observed {}).",
                limit, configured, observed
            ),
            limit,
            configured,
            observed,
        }
    }

    pub fn finish_reason(&self) -> String {
        format!("safety_limit:{}", self.limit)
    }
}

/// Safety counters reported with each agent output.
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct SafetyStats {
    pub enabled: bool,
    pub model_calls: usize,
    pub tool_calls: usize,
    pub loop_iterations: usize,
    pub repeated_tool_calls: usize,
    pub prompt_bytes: usize,
    pub tool_output_bytes: usize,
    pub total_tokens: u64,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub limit_hit: Option<SafetyLimitHit>,
}

#[derive(Debug)]
struct SafetyState {
    config: SafetyConfig,
    stats: SafetyStats,
    repeated_tool_calls: HashMap<String, usize>,
}

/// Shared, cloneable safety budget for all cost-bearing work in one run.
#[derive(Debug, Clone)]
pub struct RunSafetyBudget {
    inner: Arc<Mutex<SafetyState>>,
}

impl RunSafetyBudget {
    pub fn new(config: SafetyConfig) -> Self {
        let stats = SafetyStats {
            enabled: config.enabled,
            ..SafetyStats::default()
        };
        Self {
            inner: Arc::new(Mutex::new(SafetyState {
                config,
                stats,
                repeated_tool_calls: HashMap::new(),
            })),
        }
    }

    pub fn config(&self) -> SafetyConfig {
        self.inner
            .lock()
            .expect("safety budget poisoned")
            .config
            .clone()
    }

    pub fn stats(&self) -> SafetyStats {
        self.inner
            .lock()
            .expect("safety budget poisoned")
            .stats
            .clone()
    }

    pub fn limit_hit(&self) -> Option<SafetyLimitHit> {
        self.inner
            .lock()
            .expect("safety budget poisoned")
            .stats
            .limit_hit
            .clone()
    }

    pub fn safety_response(hit: &SafetyLimitHit) -> String {
        hit.message.clone()
    }

    pub fn clamp_generation_config(&self, config: &GenerationConfig) -> GenerationConfig {
        let state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return config.clone();
        }

        let mut clamped = config.clone();
        let max = state.config.max_output_tokens_per_call;
        clamped.max_tokens = Some(config.max_tokens.map(|v| v.min(max)).unwrap_or(max));
        clamped
    }

    pub fn record_loop_iteration(&self) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        state.stats.loop_iterations += 1;
        let observed = state.stats.loop_iterations;
        let configured = state.config.max_tool_loop_iterations as u64;
        if observed as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_tool_loop_iterations",
                configured,
                observed as u64,
            ));
        }
        Ok(())
    }

    pub fn record_model_call(&self, prompt: &str) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        state.stats.model_calls += 1;
        let model_calls = state.stats.model_calls;
        let configured = state.config.max_model_calls_per_run as u64;
        if model_calls as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_model_calls_per_run",
                configured,
                model_calls as u64,
            ));
        }

        state.stats.prompt_bytes = state.stats.prompt_bytes.saturating_add(prompt.len());
        let prompt_bytes = state.stats.prompt_bytes;
        let configured = state.config.max_prompt_bytes_per_run as u64;
        if prompt_bytes as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_prompt_bytes_per_run",
                configured,
                prompt_bytes as u64,
            ));
        }

        Ok(())
    }

    pub fn record_token_usage(&self, usage: &TokenUsage) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        state.stats.total_tokens = state
            .stats
            .total_tokens
            .saturating_add(usage.total_tokens as u64);
        let total_tokens = state.stats.total_tokens;
        let configured = state.config.max_total_tokens_per_run;
        if total_tokens > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_total_tokens_per_run",
                configured,
                total_tokens,
            ));
        }

        Ok(())
    }

    pub fn record_tool_call(&self, tool_name: &str, args: &Value) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        state.stats.tool_calls += 1;
        let tool_calls = state.stats.tool_calls;
        let configured = state.config.max_tool_calls_per_run as u64;
        if tool_calls as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_tool_calls_per_run",
                configured,
                tool_calls as u64,
            ));
        }

        let key = tool_fingerprint(tool_name, args);
        let repeat_count = {
            let count = state.repeated_tool_calls.entry(key).or_insert(0);
            *count += 1;
            *count
        };
        state.stats.repeated_tool_calls = state.stats.repeated_tool_calls.max(repeat_count);
        let configured = state.config.max_repeated_tool_calls as u64;
        if repeat_count as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_repeated_tool_calls",
                configured,
                repeat_count as u64,
            ));
        }

        Ok(())
    }

    pub fn record_tool_output(&self, tool_output: &str) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        state.stats.tool_output_bytes = state
            .stats
            .tool_output_bytes
            .saturating_add(tool_output.len());
        let output_bytes = state.stats.tool_output_bytes;
        let configured = state.config.max_tool_output_bytes as u64;
        if output_bytes as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_tool_output_bytes",
                configured,
                output_bytes as u64,
            ));
        }

        Ok(())
    }

    pub fn check_delegation_depth(&self, depth: usize) -> Result<(), SafetyLimitHit> {
        let mut state = self.inner.lock().expect("safety budget poisoned");
        if !state.config.enabled {
            return Ok(());
        }

        let configured = state.config.max_delegation_depth as u64;
        if depth as u64 > configured {
            return Err(Self::set_hit(
                &mut state,
                "max_delegation_depth",
                configured,
                depth as u64,
            ));
        }

        Ok(())
    }

    fn set_hit(
        state: &mut SafetyState,
        limit: &str,
        configured: u64,
        observed: u64,
    ) -> SafetyLimitHit {
        let hit = SafetyLimitHit::new(limit, configured, observed);
        if state.stats.limit_hit.is_none() {
            state.stats.limit_hit = Some(hit.clone());
        }
        hit
    }
}

fn tool_fingerprint(tool_name: &str, args: &Value) -> String {
    let args_json = serde_json::to_string(args).unwrap_or_else(|_| args.to_string());
    let hash = blake3::hash(args_json.as_bytes()).to_hex().to_string();
    format!("{}:{}", tool_name, hash)
}