claude_agent/agent/
state.rs

1//! Agent state management.
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum AgentState {
8    #[default]
9    Initializing,
10    Running,
11    WaitingForToolResults,
12    WaitingForUserInput,
13    PlanMode,
14    Completed,
15    Failed,
16    Cancelled,
17}
18
19impl AgentState {
20    pub fn is_terminal(&self) -> bool {
21        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
22    }
23
24    pub fn is_waiting(&self) -> bool {
25        matches!(
26            self,
27            Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
28        )
29    }
30
31    pub fn can_continue(&self) -> bool {
32        matches!(self, Self::Running | Self::Initializing)
33    }
34}
35
36use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
37
38#[derive(Debug, Clone, Default)]
39pub struct AgentMetrics {
40    pub iterations: usize,
41    pub tool_calls: usize,
42    pub input_tokens: u32,
43    pub output_tokens: u32,
44    pub execution_time_ms: u64,
45    pub errors: usize,
46    pub compactions: usize,
47    pub api_calls: usize,
48    pub total_cost_usd: f64,
49    pub tool_stats: std::collections::HashMap<String, ToolStats>,
50    /// Per-model usage tracking (like CLI's modelUsage).
51    pub model_usage: std::collections::HashMap<String, ModelUsage>,
52    /// Server-side tool usage statistics.
53    pub server_tool_use: ServerToolUse,
54    /// List of permission denials that occurred.
55    pub permission_denials: Vec<PermissionDenial>,
56    /// Total API call time in milliseconds (duration_api_ms).
57    pub api_time_ms: u64,
58}
59
60#[derive(Debug, Clone, Default)]
61pub struct ToolStats {
62    pub calls: usize,
63    pub total_time_ms: u64,
64    pub errors: usize,
65}
66
67impl AgentMetrics {
68    pub fn total_tokens(&self) -> u32 {
69        self.input_tokens + self.output_tokens
70    }
71
72    pub fn add_usage(&mut self, input: u32, output: u32) {
73        self.input_tokens += input;
74        self.output_tokens += output;
75    }
76
77    pub fn add_cost(&mut self, cost: f64) {
78        self.total_cost_usd += cost;
79    }
80
81    pub fn record_tool(&mut self, name: &str, duration_ms: u64, is_error: bool) {
82        self.tool_calls += 1;
83        let stats = self.tool_stats.entry(name.to_string()).or_default();
84        stats.calls += 1;
85        stats.total_time_ms += duration_ms;
86        if is_error {
87            stats.errors += 1;
88            self.errors += 1;
89        }
90    }
91
92    pub fn record_api_call(&mut self) {
93        self.api_calls += 1;
94    }
95
96    pub fn record_compaction(&mut self) {
97        self.compactions += 1;
98    }
99
100    pub fn avg_tool_time_ms(&self) -> f64 {
101        if self.tool_calls == 0 {
102            return 0.0;
103        }
104        let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
105        total as f64 / self.tool_calls as f64
106    }
107
108    /// Record usage for a specific model.
109    ///
110    /// This enables per-model cost tracking like CLI's modelUsage field.
111    pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
112        let entry = self.model_usage.entry(model.to_string()).or_default();
113        entry.add_usage(usage, model);
114    }
115
116    /// Record an API call with timing information.
117    pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
118        self.api_calls += 1;
119        self.api_time_ms += duration_ms;
120    }
121
122    /// Update server_tool_use from API response.
123    ///
124    /// This is for server-side tools executed by the API (e.g., Anthropic's
125    /// server-side RAG). Not to be confused with local tool usage.
126    pub fn update_server_tool_use(&mut self, server_tool_use: ServerToolUse) {
127        self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
128        self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
129    }
130
131    /// Update server_tool_use from API response's usage.server_tool_use field.
132    ///
133    /// This parses the server tool usage directly from the API response.
134    pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
135        self.server_tool_use.add_from_usage(usage);
136    }
137
138    /// Record a permission denial.
139    pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
140        self.permission_denials.push(denial);
141    }
142
143    /// Get the total cost across all models.
144    pub fn total_model_cost(&self) -> f64 {
145        self.model_usage.values().map(|m| m.cost_usd).sum()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_agent_state() {
155        assert!(AgentState::Completed.is_terminal());
156        assert!(AgentState::Failed.is_terminal());
157        assert!(!AgentState::Running.is_terminal());
158
159        assert!(AgentState::WaitingForUserInput.is_waiting());
160        assert!(AgentState::PlanMode.is_waiting());
161        assert!(!AgentState::Running.is_waiting());
162
163        assert!(AgentState::Running.can_continue());
164        assert!(!AgentState::Completed.can_continue());
165    }
166
167    #[test]
168    fn test_agent_metrics() {
169        let mut metrics = AgentMetrics::default();
170        metrics.add_usage(100, 50);
171        metrics.add_usage(200, 100);
172
173        assert_eq!(metrics.input_tokens, 300);
174        assert_eq!(metrics.output_tokens, 150);
175        assert_eq!(metrics.total_tokens(), 450);
176    }
177
178    #[test]
179    fn test_agent_metrics_tool_recording() {
180        let mut metrics = AgentMetrics::default();
181        metrics.record_tool("Read", 50, false);
182        metrics.record_tool("Read", 30, false);
183        metrics.record_tool("Bash", 100, true);
184
185        assert_eq!(metrics.tool_calls, 3);
186        assert_eq!(metrics.errors, 1);
187        assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
188        assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
189        assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
190    }
191
192    #[test]
193    fn test_agent_metrics_avg_time() {
194        let mut metrics = AgentMetrics::default();
195        assert_eq!(metrics.avg_tool_time_ms(), 0.0);
196
197        metrics.record_tool("Read", 100, false);
198        metrics.record_tool("Write", 200, false);
199        assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
200    }
201}