claude_agent/agent/
state.rs

1//! Agent state management.
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum AgentState {
8    #[default]
9    Initializing,
10    Running,
11    WaitingForToolResults,
12    WaitingForUserInput,
13    PlanMode,
14    Completed,
15    Failed,
16    Cancelled,
17}
18
19impl AgentState {
20    pub fn is_terminal(&self) -> bool {
21        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
22    }
23
24    pub fn is_waiting(&self) -> bool {
25        matches!(
26            self,
27            Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
28        )
29    }
30
31    pub fn can_continue(&self) -> bool {
32        matches!(self, Self::Running | Self::Initializing)
33    }
34}
35
36use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
37
38#[derive(Debug, Clone, Default)]
39pub struct AgentMetrics {
40    pub iterations: usize,
41    pub tool_calls: usize,
42    pub input_tokens: u32,
43    pub output_tokens: u32,
44    pub cache_read_tokens: u32,
45    pub cache_creation_tokens: u32,
46    pub execution_time_ms: u64,
47    pub errors: usize,
48    pub compactions: usize,
49    pub api_calls: usize,
50    pub total_cost_usd: f64,
51    pub tool_stats: std::collections::HashMap<String, ToolStats>,
52    pub tool_call_records: Vec<ToolCallRecord>,
53    pub model_usage: std::collections::HashMap<String, ModelUsage>,
54    pub server_tool_use: ServerToolUse,
55    pub permission_denials: Vec<PermissionDenial>,
56    pub api_time_ms: u64,
57}
58
59#[derive(Debug, Clone, Default)]
60pub struct ToolStats {
61    pub calls: usize,
62    pub total_time_ms: u64,
63    pub errors: usize,
64}
65
66#[derive(Debug, Clone)]
67pub struct ToolCallRecord {
68    pub tool_use_id: String,
69    pub tool_name: String,
70    pub duration_ms: u64,
71    pub is_error: bool,
72}
73
74impl AgentMetrics {
75    pub fn total_tokens(&self) -> u32 {
76        self.input_tokens + self.output_tokens
77    }
78
79    pub fn add_usage(&mut self, input: u32, output: u32) {
80        self.input_tokens += input;
81        self.output_tokens += output;
82    }
83
84    pub fn add_usage_with_cache(&mut self, usage: &Usage) {
85        self.input_tokens += usage.input_tokens;
86        self.output_tokens += usage.output_tokens;
87        self.cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
88        self.cache_creation_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
89    }
90
91    pub fn cache_hit_rate(&self) -> f64 {
92        if self.input_tokens == 0 {
93            return 0.0;
94        }
95        self.cache_read_tokens as f64 / self.input_tokens as f64
96    }
97
98    pub fn cache_tokens_saved(&self) -> u32 {
99        (self.cache_read_tokens as f64 * 0.9) as u32
100    }
101
102    pub fn add_cost(&mut self, cost: f64) {
103        self.total_cost_usd += cost;
104    }
105
106    pub fn record_tool(&mut self, tool_use_id: &str, name: &str, duration_ms: u64, is_error: bool) {
107        self.tool_calls += 1;
108        let stats = self.tool_stats.entry(name.to_string()).or_default();
109        stats.calls += 1;
110        stats.total_time_ms += duration_ms;
111        if is_error {
112            stats.errors += 1;
113            self.errors += 1;
114        }
115        self.tool_call_records.push(ToolCallRecord {
116            tool_use_id: tool_use_id.to_string(),
117            tool_name: name.to_string(),
118            duration_ms,
119            is_error,
120        });
121    }
122
123    pub fn record_api_call(&mut self) {
124        self.api_calls += 1;
125    }
126
127    pub fn record_compaction(&mut self) {
128        self.compactions += 1;
129    }
130
131    pub fn avg_tool_time_ms(&self) -> f64 {
132        if self.tool_calls == 0 {
133            return 0.0;
134        }
135        let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
136        total as f64 / self.tool_calls as f64
137    }
138
139    /// Record usage for a specific model.
140    ///
141    /// This enables per-model cost tracking like CLI's modelUsage field.
142    pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
143        let entry = self.model_usage.entry(model.to_string()).or_default();
144        entry.add_usage(usage, model);
145    }
146
147    /// Record an API call with timing information.
148    pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
149        self.api_calls += 1;
150        self.api_time_ms += duration_ms;
151    }
152
153    /// Update server_tool_use from API response.
154    ///
155    /// This is for server-side tools executed by the API (e.g., Anthropic's
156    /// server-side RAG). Not to be confused with local tool usage.
157    pub fn update_server_tool_use(&mut self, server_tool_use: ServerToolUse) {
158        self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
159        self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
160    }
161
162    /// Update server_tool_use from API response's usage.server_tool_use field.
163    ///
164    /// This parses the server tool usage directly from the API response.
165    pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
166        self.server_tool_use.add_from_usage(usage);
167    }
168
169    /// Record a permission denial.
170    pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
171        self.permission_denials.push(denial);
172    }
173
174    /// Get the total cost across all models.
175    pub fn total_model_cost(&self) -> f64 {
176        self.model_usage.values().map(|m| m.cost_usd).sum()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_agent_state() {
186        assert!(AgentState::Completed.is_terminal());
187        assert!(AgentState::Failed.is_terminal());
188        assert!(!AgentState::Running.is_terminal());
189
190        assert!(AgentState::WaitingForUserInput.is_waiting());
191        assert!(AgentState::PlanMode.is_waiting());
192        assert!(!AgentState::Running.is_waiting());
193
194        assert!(AgentState::Running.can_continue());
195        assert!(!AgentState::Completed.can_continue());
196    }
197
198    #[test]
199    fn test_agent_metrics() {
200        let mut metrics = AgentMetrics::default();
201        metrics.add_usage(100, 50);
202        metrics.add_usage(200, 100);
203
204        assert_eq!(metrics.input_tokens, 300);
205        assert_eq!(metrics.output_tokens, 150);
206        assert_eq!(metrics.total_tokens(), 450);
207    }
208
209    #[test]
210    fn test_agent_metrics_tool_recording() {
211        let mut metrics = AgentMetrics::default();
212        metrics.record_tool("tu_1", "Read", 50, false);
213        metrics.record_tool("tu_2", "Read", 30, false);
214        metrics.record_tool("tu_3", "Bash", 100, true);
215
216        assert_eq!(metrics.tool_calls, 3);
217        assert_eq!(metrics.errors, 1);
218        assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
219        assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
220        assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
221        assert_eq!(metrics.tool_call_records.len(), 3);
222        assert_eq!(metrics.tool_call_records[0].tool_use_id, "tu_1");
223        assert!(metrics.tool_call_records[2].is_error);
224    }
225
226    #[test]
227    fn test_agent_metrics_avg_time() {
228        let mut metrics = AgentMetrics::default();
229        assert_eq!(metrics.avg_tool_time_ms(), 0.0);
230
231        metrics.record_tool("tu_1", "Read", 100, false);
232        metrics.record_tool("tu_2", "Write", 200, false);
233        assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
234    }
235}