claude_agent/agent/
state.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum AgentState {
8 #[default]
9 Initializing,
10 Running,
11 WaitingForToolResults,
12 WaitingForUserInput,
13 PlanMode,
14 Completed,
15 Failed,
16 Cancelled,
17}
18
19impl AgentState {
20 pub fn is_terminal(&self) -> bool {
21 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
22 }
23
24 pub fn is_waiting(&self) -> bool {
25 matches!(
26 self,
27 Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
28 )
29 }
30
31 pub fn can_continue(&self) -> bool {
32 matches!(self, Self::Running | Self::Initializing)
33 }
34}
35
36use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
37
38#[derive(Debug, Clone, Default)]
39pub struct AgentMetrics {
40 pub iterations: usize,
41 pub tool_calls: usize,
42 pub input_tokens: u32,
43 pub output_tokens: u32,
44 pub execution_time_ms: u64,
45 pub errors: usize,
46 pub compactions: usize,
47 pub api_calls: usize,
48 pub total_cost_usd: f64,
49 pub tool_stats: std::collections::HashMap<String, ToolStats>,
50 pub model_usage: std::collections::HashMap<String, ModelUsage>,
52 pub server_tool_use: ServerToolUse,
54 pub permission_denials: Vec<PermissionDenial>,
56 pub api_time_ms: u64,
58}
59
60#[derive(Debug, Clone, Default)]
61pub struct ToolStats {
62 pub calls: usize,
63 pub total_time_ms: u64,
64 pub errors: usize,
65}
66
67impl AgentMetrics {
68 pub fn total_tokens(&self) -> u32 {
69 self.input_tokens + self.output_tokens
70 }
71
72 pub fn add_usage(&mut self, input: u32, output: u32) {
73 self.input_tokens += input;
74 self.output_tokens += output;
75 }
76
77 pub fn add_cost(&mut self, cost: f64) {
78 self.total_cost_usd += cost;
79 }
80
81 pub fn record_tool(&mut self, name: &str, duration_ms: u64, is_error: bool) {
82 self.tool_calls += 1;
83 let stats = self.tool_stats.entry(name.to_string()).or_default();
84 stats.calls += 1;
85 stats.total_time_ms += duration_ms;
86 if is_error {
87 stats.errors += 1;
88 self.errors += 1;
89 }
90 }
91
92 pub fn record_api_call(&mut self) {
93 self.api_calls += 1;
94 }
95
96 pub fn record_compaction(&mut self) {
97 self.compactions += 1;
98 }
99
100 pub fn avg_tool_time_ms(&self) -> f64 {
101 if self.tool_calls == 0 {
102 return 0.0;
103 }
104 let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
105 total as f64 / self.tool_calls as f64
106 }
107
108 pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
112 let entry = self.model_usage.entry(model.to_string()).or_default();
113 entry.add_usage(usage, model);
114 }
115
116 pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
118 self.api_calls += 1;
119 self.api_time_ms += duration_ms;
120 }
121
122 pub fn update_server_tool_use(&mut self, server_tool_use: ServerToolUse) {
127 self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
128 self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
129 }
130
131 pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
135 self.server_tool_use.add_from_usage(usage);
136 }
137
138 pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
140 self.permission_denials.push(denial);
141 }
142
143 pub fn total_model_cost(&self) -> f64 {
145 self.model_usage.values().map(|m| m.cost_usd).sum()
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_agent_state() {
155 assert!(AgentState::Completed.is_terminal());
156 assert!(AgentState::Failed.is_terminal());
157 assert!(!AgentState::Running.is_terminal());
158
159 assert!(AgentState::WaitingForUserInput.is_waiting());
160 assert!(AgentState::PlanMode.is_waiting());
161 assert!(!AgentState::Running.is_waiting());
162
163 assert!(AgentState::Running.can_continue());
164 assert!(!AgentState::Completed.can_continue());
165 }
166
167 #[test]
168 fn test_agent_metrics() {
169 let mut metrics = AgentMetrics::default();
170 metrics.add_usage(100, 50);
171 metrics.add_usage(200, 100);
172
173 assert_eq!(metrics.input_tokens, 300);
174 assert_eq!(metrics.output_tokens, 150);
175 assert_eq!(metrics.total_tokens(), 450);
176 }
177
178 #[test]
179 fn test_agent_metrics_tool_recording() {
180 let mut metrics = AgentMetrics::default();
181 metrics.record_tool("Read", 50, false);
182 metrics.record_tool("Read", 30, false);
183 metrics.record_tool("Bash", 100, true);
184
185 assert_eq!(metrics.tool_calls, 3);
186 assert_eq!(metrics.errors, 1);
187 assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
188 assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
189 assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
190 }
191
192 #[test]
193 fn test_agent_metrics_avg_time() {
194 let mut metrics = AgentMetrics::default();
195 assert_eq!(metrics.avg_tool_time_ms(), 0.0);
196
197 metrics.record_tool("Read", 100, false);
198 metrics.record_tool("Write", 200, false);
199 assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
200 }
201}