1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum AgentState {
8 #[default]
9 Initializing,
10 Running,
11 WaitingForToolResults,
12 WaitingForUserInput,
13 PlanMode,
14 Completed,
15 Failed,
16 Cancelled,
17}
18
19impl AgentState {
20 pub fn is_terminal(&self) -> bool {
21 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
22 }
23
24 pub fn is_waiting(&self) -> bool {
25 matches!(
26 self,
27 Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
28 )
29 }
30
31 pub fn can_continue(&self) -> bool {
32 matches!(self, Self::Running | Self::Initializing)
33 }
34}
35
36use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
37
38#[derive(Debug, Clone, Default)]
39pub struct AgentMetrics {
40 pub iterations: usize,
41 pub tool_calls: usize,
42 pub input_tokens: u32,
43 pub output_tokens: u32,
44 pub cache_read_tokens: u32,
45 pub cache_creation_tokens: u32,
46 pub execution_time_ms: u64,
47 pub errors: usize,
48 pub compactions: usize,
49 pub api_calls: usize,
50 pub total_cost_usd: f64,
51 pub tool_stats: std::collections::HashMap<String, ToolStats>,
52 pub tool_call_records: Vec<ToolCallRecord>,
53 pub model_usage: std::collections::HashMap<String, ModelUsage>,
54 pub server_tool_use: ServerToolUse,
55 pub permission_denials: Vec<PermissionDenial>,
56 pub api_time_ms: u64,
57}
58
59#[derive(Debug, Clone, Default)]
60pub struct ToolStats {
61 pub calls: usize,
62 pub total_time_ms: u64,
63 pub errors: usize,
64}
65
66#[derive(Debug, Clone)]
67pub struct ToolCallRecord {
68 pub tool_use_id: String,
69 pub tool_name: String,
70 pub duration_ms: u64,
71 pub is_error: bool,
72}
73
74impl AgentMetrics {
75 pub fn total_tokens(&self) -> u32 {
76 self.input_tokens + self.output_tokens
77 }
78
79 pub fn add_usage(&mut self, input: u32, output: u32) {
80 self.input_tokens += input;
81 self.output_tokens += output;
82 }
83
84 pub fn add_usage_with_cache(&mut self, usage: &Usage) {
85 self.input_tokens += usage.input_tokens;
86 self.output_tokens += usage.output_tokens;
87 self.cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
88 self.cache_creation_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
89 }
90
91 pub fn cache_hit_rate(&self) -> f64 {
96 if self.input_tokens == 0 {
97 return 0.0;
98 }
99 self.cache_read_tokens as f64 / self.input_tokens as f64
100 }
101
102 pub fn cache_efficiency(&self) -> f64 {
113 let total = self.cache_read_tokens + self.cache_creation_tokens;
114 if total == 0 {
115 return 0.0;
116 }
117 self.cache_read_tokens as f64 / total as f64
118 }
119
120 pub fn cache_tokens_saved(&self) -> u32 {
124 (self.cache_read_tokens as f64 * 0.9) as u32
125 }
126
127 pub fn cache_cost_savings(&self, input_price_per_mtok: f64) -> f64 {
136 let read_tokens = self.cache_read_tokens as f64 / 1_000_000.0;
137 let write_tokens = self.cache_creation_tokens as f64 / 1_000_000.0;
138
139 let read_savings = read_tokens * input_price_per_mtok * 0.9;
141 let write_overhead = write_tokens * input_price_per_mtok * 0.25;
143
144 read_savings - write_overhead
145 }
146
147 pub fn add_cost(&mut self, cost: f64) {
148 self.total_cost_usd += cost;
149 }
150
151 pub fn record_tool(&mut self, tool_use_id: &str, name: &str, duration_ms: u64, is_error: bool) {
152 self.tool_calls += 1;
153 let stats = self.tool_stats.entry(name.to_string()).or_default();
154 stats.calls += 1;
155 stats.total_time_ms += duration_ms;
156 if is_error {
157 stats.errors += 1;
158 self.errors += 1;
159 }
160 self.tool_call_records.push(ToolCallRecord {
161 tool_use_id: tool_use_id.to_string(),
162 tool_name: name.to_string(),
163 duration_ms,
164 is_error,
165 });
166 }
167
168 pub fn record_api_call(&mut self) {
169 self.api_calls += 1;
170 }
171
172 pub fn record_compaction(&mut self) {
173 self.compactions += 1;
174 }
175
176 pub fn avg_tool_time_ms(&self) -> f64 {
177 if self.tool_calls == 0 {
178 return 0.0;
179 }
180 let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
181 total as f64 / self.tool_calls as f64
182 }
183
184 pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
188 let entry = self.model_usage.entry(model.to_string()).or_default();
189 entry.add_usage(usage, model);
190 }
191
192 pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
194 self.api_calls += 1;
195 self.api_time_ms += duration_ms;
196 }
197
198 pub fn update_server_tool_use(&mut self, server_tool_use: ServerToolUse) {
203 self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
204 self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
205 }
206
207 pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
211 self.server_tool_use.add_from_usage(usage);
212 }
213
214 pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
216 self.permission_denials.push(denial);
217 }
218
219 pub fn total_model_cost(&self) -> f64 {
221 self.model_usage.values().map(|m| m.cost_usd).sum()
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_agent_state() {
231 assert!(AgentState::Completed.is_terminal());
232 assert!(AgentState::Failed.is_terminal());
233 assert!(!AgentState::Running.is_terminal());
234
235 assert!(AgentState::WaitingForUserInput.is_waiting());
236 assert!(AgentState::PlanMode.is_waiting());
237 assert!(!AgentState::Running.is_waiting());
238
239 assert!(AgentState::Running.can_continue());
240 assert!(!AgentState::Completed.can_continue());
241 }
242
243 #[test]
244 fn test_agent_metrics() {
245 let mut metrics = AgentMetrics::default();
246 metrics.add_usage(100, 50);
247 metrics.add_usage(200, 100);
248
249 assert_eq!(metrics.input_tokens, 300);
250 assert_eq!(metrics.output_tokens, 150);
251 assert_eq!(metrics.total_tokens(), 450);
252 }
253
254 #[test]
255 fn test_agent_metrics_tool_recording() {
256 let mut metrics = AgentMetrics::default();
257 metrics.record_tool("tu_1", "Read", 50, false);
258 metrics.record_tool("tu_2", "Read", 30, false);
259 metrics.record_tool("tu_3", "Bash", 100, true);
260
261 assert_eq!(metrics.tool_calls, 3);
262 assert_eq!(metrics.errors, 1);
263 assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
264 assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
265 assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
266 assert_eq!(metrics.tool_call_records.len(), 3);
267 assert_eq!(metrics.tool_call_records[0].tool_use_id, "tu_1");
268 assert!(metrics.tool_call_records[2].is_error);
269 }
270
271 #[test]
272 fn test_agent_metrics_avg_time() {
273 let mut metrics = AgentMetrics::default();
274 assert_eq!(metrics.avg_tool_time_ms(), 0.0);
275
276 metrics.record_tool("tu_1", "Read", 100, false);
277 metrics.record_tool("tu_2", "Write", 200, false);
278 assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
279 }
280
281 #[test]
282 fn test_cache_efficiency_no_activity() {
283 let metrics = AgentMetrics::default();
284 assert_eq!(metrics.cache_efficiency(), 0.0);
285 }
286
287 #[test]
288 fn test_cache_efficiency_all_reads() {
289 let metrics = AgentMetrics {
290 cache_read_tokens: 1000,
291 cache_creation_tokens: 0,
292 ..Default::default()
293 };
294
295 assert!((metrics.cache_efficiency() - 1.0).abs() < 0.001);
296 }
297
298 #[test]
299 fn test_cache_efficiency_mixed() {
300 let metrics = AgentMetrics {
301 cache_read_tokens: 900,
302 cache_creation_tokens: 100,
303 ..Default::default()
304 };
305
306 assert!((metrics.cache_efficiency() - 0.9).abs() < 0.001);
308 }
309
310 #[test]
311 fn test_cache_cost_savings() {
312 let metrics = AgentMetrics {
313 cache_read_tokens: 1_000_000, cache_creation_tokens: 100_000, ..Default::default()
316 };
317
318 let price_per_mtok = 3.0; let savings = metrics.cache_cost_savings(price_per_mtok);
324 assert!((savings - 2.625).abs() < 0.001);
325 }
326
327 #[test]
328 fn test_cache_hit_rate() {
329 let metrics = AgentMetrics {
330 input_tokens: 1000,
331 cache_read_tokens: 800,
332 ..Default::default()
333 };
334
335 assert!((metrics.cache_hit_rate() - 0.8).abs() < 0.001);
337 }
338
339 #[test]
340 fn test_cache_tokens_saved() {
341 let metrics = AgentMetrics {
342 cache_read_tokens: 1000,
343 ..Default::default()
344 };
345
346 assert_eq!(metrics.cache_tokens_saved(), 900);
348 }
349}