1use rust_decimal::Decimal;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum AgentState {
9 #[default]
10 Initializing,
11 Running,
12 WaitingForToolResults,
13 WaitingForUserInput,
14 PlanMode,
15 Completed,
16 Failed,
17 Cancelled,
18}
19
20impl AgentState {
21 pub fn is_terminal(&self) -> bool {
22 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
23 }
24
25 pub fn is_waiting(&self) -> bool {
26 matches!(
27 self,
28 Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
29 )
30 }
31
32 pub fn can_continue(&self) -> bool {
33 matches!(self, Self::Running | Self::Initializing)
34 }
35}
36
37use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
38
39#[derive(Debug, Clone, Default)]
40pub struct AgentMetrics {
41 pub iterations: usize,
42 pub tool_calls: usize,
43 pub input_tokens: u32,
44 pub output_tokens: u32,
45 pub cache_read_tokens: u32,
46 pub cache_creation_tokens: u32,
47 pub execution_time_ms: u64,
48 pub errors: usize,
49 pub compactions: usize,
50 pub api_calls: usize,
51 pub total_cost_usd: Decimal,
52 pub tool_stats: std::collections::HashMap<String, ToolStats>,
53 pub tool_call_records: Vec<ToolCallRecord>,
54 pub model_usage: std::collections::HashMap<String, ModelUsage>,
55 pub server_tool_use: ServerToolUse,
56 pub permission_denials: Vec<PermissionDenial>,
57 pub api_time_ms: u64,
58}
59
60#[derive(Debug, Clone, Default)]
61pub struct ToolStats {
62 pub calls: usize,
63 pub total_time_ms: u64,
64 pub errors: usize,
65}
66
67#[derive(Debug, Clone)]
68pub struct ToolCallRecord {
69 pub tool_use_id: String,
70 pub tool_name: String,
71 pub duration_ms: u64,
72 pub is_error: bool,
73}
74
75impl AgentMetrics {
76 pub fn total_tokens(&self) -> u32 {
77 self.input_tokens.saturating_add(self.output_tokens)
78 }
79
80 pub fn add_usage_with_cache(&mut self, usage: &Usage) {
81 self.input_tokens = self.input_tokens.saturating_add(usage.input_tokens);
82 self.output_tokens = self.output_tokens.saturating_add(usage.output_tokens);
83 self.cache_read_tokens = self
84 .cache_read_tokens
85 .saturating_add(usage.cache_read_input_tokens.unwrap_or(0));
86 self.cache_creation_tokens = self
87 .cache_creation_tokens
88 .saturating_add(usage.cache_creation_input_tokens.unwrap_or(0));
89 }
90
91 pub fn cache_hit_rate(&self) -> f64 {
96 if self.input_tokens == 0 {
97 return 0.0;
98 }
99 self.cache_read_tokens as f64 / self.input_tokens as f64
100 }
101
102 pub fn cache_efficiency(&self) -> f64 {
113 let total = self.cache_read_tokens + self.cache_creation_tokens;
114 if total == 0 {
115 return 0.0;
116 }
117 self.cache_read_tokens as f64 / total as f64
118 }
119
120 pub fn cache_tokens_saved(&self) -> u32 {
124 (self.cache_read_tokens as f64 * 0.9) as u32
125 }
126
127 pub fn cache_cost_savings(&self, input_price_per_mtok: Decimal) -> Decimal {
136 let mtok_divisor = Decimal::from(1_000_000);
137 let read_tokens = Decimal::from(self.cache_read_tokens) / mtok_divisor;
138 let write_tokens = Decimal::from(self.cache_creation_tokens) / mtok_divisor;
139
140 let read_savings = read_tokens * input_price_per_mtok * Decimal::new(9, 1); let write_overhead = write_tokens * input_price_per_mtok * Decimal::new(25, 2); read_savings - write_overhead
146 }
147
148 pub fn add_cost(&mut self, cost: Decimal) {
149 self.total_cost_usd += cost;
150 }
151
152 pub fn record_tool(&mut self, tool_use_id: &str, name: &str, duration_ms: u64, is_error: bool) {
153 self.tool_calls += 1;
154 let stats = self.tool_stats.entry(name.to_string()).or_default();
155 stats.calls += 1;
156 stats.total_time_ms += duration_ms;
157 if is_error {
158 stats.errors += 1;
159 self.errors += 1;
160 }
161 self.tool_call_records.push(ToolCallRecord {
162 tool_use_id: tool_use_id.to_string(),
163 tool_name: name.to_string(),
164 duration_ms,
165 is_error,
166 });
167 }
168
169 pub fn record_api_call(&mut self) {
170 self.api_calls += 1;
171 }
172
173 pub fn record_compaction(&mut self) {
174 self.compactions += 1;
175 }
176
177 pub fn avg_tool_time_ms(&self) -> f64 {
178 if self.tool_calls == 0 {
179 return 0.0;
180 }
181 let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
182 total as f64 / self.tool_calls as f64
183 }
184
185 pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
189 let entry = self.model_usage.entry(model.to_string()).or_default();
190 entry.add_usage(usage, model);
191 }
192
193 pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
195 self.api_calls += 1;
196 self.api_time_ms += duration_ms;
197 }
198
199 pub fn update_server_tool_use(&mut self, server_tool_use: &ServerToolUse) {
204 self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
205 self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
206 }
207
208 pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
212 self.server_tool_use.add_from_usage(usage);
213 }
214
215 pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
217 self.permission_denials.push(denial);
218 }
219
220 pub fn total_model_cost(&self) -> Decimal {
222 self.model_usage.values().map(|m| m.cost_usd).sum()
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_agent_state() {
232 assert!(AgentState::Completed.is_terminal());
233 assert!(AgentState::Failed.is_terminal());
234 assert!(!AgentState::Running.is_terminal());
235
236 assert!(AgentState::WaitingForUserInput.is_waiting());
237 assert!(AgentState::PlanMode.is_waiting());
238 assert!(!AgentState::Running.is_waiting());
239
240 assert!(AgentState::Running.can_continue());
241 assert!(!AgentState::Completed.can_continue());
242 }
243
244 #[test]
245 fn test_agent_metrics() {
246 let mut metrics = AgentMetrics::default();
247 metrics.add_usage_with_cache(&Usage {
248 input_tokens: 100,
249 output_tokens: 50,
250 ..Default::default()
251 });
252 metrics.add_usage_with_cache(&Usage {
253 input_tokens: 200,
254 output_tokens: 100,
255 cache_read_input_tokens: Some(30),
256 ..Default::default()
257 });
258
259 assert_eq!(metrics.input_tokens, 300);
260 assert_eq!(metrics.output_tokens, 150);
261 assert_eq!(metrics.total_tokens(), 450);
262 assert_eq!(metrics.cache_read_tokens, 30);
263 }
264
265 #[test]
266 fn test_agent_metrics_tool_recording() {
267 let mut metrics = AgentMetrics::default();
268 metrics.record_tool("tu_1", "Read", 50, false);
269 metrics.record_tool("tu_2", "Read", 30, false);
270 metrics.record_tool("tu_3", "Bash", 100, true);
271
272 assert_eq!(metrics.tool_calls, 3);
273 assert_eq!(metrics.errors, 1);
274 assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
275 assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
276 assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
277 assert_eq!(metrics.tool_call_records.len(), 3);
278 assert_eq!(metrics.tool_call_records[0].tool_use_id, "tu_1");
279 assert!(metrics.tool_call_records[2].is_error);
280 }
281
282 #[test]
283 fn test_agent_metrics_avg_time() {
284 let mut metrics = AgentMetrics::default();
285 assert_eq!(metrics.avg_tool_time_ms(), 0.0);
286
287 metrics.record_tool("tu_1", "Read", 100, false);
288 metrics.record_tool("tu_2", "Write", 200, false);
289 assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
290 }
291
292 #[test]
293 fn test_cache_efficiency_no_activity() {
294 let metrics = AgentMetrics::default();
295 assert_eq!(metrics.cache_efficiency(), 0.0);
296 }
297
298 #[test]
299 fn test_cache_efficiency_all_reads() {
300 let metrics = AgentMetrics {
301 cache_read_tokens: 1000,
302 cache_creation_tokens: 0,
303 ..Default::default()
304 };
305
306 assert!((metrics.cache_efficiency() - 1.0).abs() < 0.001);
307 }
308
309 #[test]
310 fn test_cache_efficiency_mixed() {
311 let metrics = AgentMetrics {
312 cache_read_tokens: 900,
313 cache_creation_tokens: 100,
314 ..Default::default()
315 };
316
317 assert!((metrics.cache_efficiency() - 0.9).abs() < 0.001);
319 }
320
321 #[test]
322 fn test_cache_cost_savings() {
323 use rust_decimal_macros::dec;
324
325 let metrics = AgentMetrics {
326 cache_read_tokens: 1_000_000, cache_creation_tokens: 100_000, ..Default::default()
329 };
330
331 let price_per_mtok = dec!(3); let savings = metrics.cache_cost_savings(price_per_mtok);
337 assert_eq!(savings, dec!(2.625));
338 }
339
340 #[test]
341 fn test_cache_hit_rate() {
342 let metrics = AgentMetrics {
343 input_tokens: 1000,
344 cache_read_tokens: 800,
345 ..Default::default()
346 };
347
348 assert!((metrics.cache_hit_rate() - 0.8).abs() < 0.001);
350 }
351
352 #[test]
353 fn test_cache_tokens_saved() {
354 let metrics = AgentMetrics {
355 cache_read_tokens: 1000,
356 ..Default::default()
357 };
358
359 assert_eq!(metrics.cache_tokens_saved(), 900);
361 }
362}