Skip to main content

claude_agent/agent/
state.rs

1//! Agent state management.
2
3use rust_decimal::Decimal;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum AgentState {
9    #[default]
10    Initializing,
11    Running,
12    WaitingForToolResults,
13    WaitingForUserInput,
14    PlanMode,
15    Completed,
16    Failed,
17    Cancelled,
18}
19
20impl AgentState {
21    pub fn is_terminal(&self) -> bool {
22        matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
23    }
24
25    pub fn is_waiting(&self) -> bool {
26        matches!(
27            self,
28            Self::WaitingForToolResults | Self::WaitingForUserInput | Self::PlanMode
29        )
30    }
31
32    pub fn can_continue(&self) -> bool {
33        matches!(self, Self::Running | Self::Initializing)
34    }
35}
36
37use crate::types::{ModelUsage, PermissionDenial, ServerToolUse, ServerToolUseUsage, Usage};
38
39#[derive(Debug, Clone, Default)]
40pub struct AgentMetrics {
41    pub iterations: usize,
42    pub tool_calls: usize,
43    pub input_tokens: u32,
44    pub output_tokens: u32,
45    pub cache_read_tokens: u32,
46    pub cache_creation_tokens: u32,
47    pub execution_time_ms: u64,
48    pub errors: usize,
49    pub compactions: usize,
50    pub api_calls: usize,
51    pub total_cost_usd: Decimal,
52    pub tool_stats: std::collections::HashMap<String, ToolStats>,
53    pub tool_call_records: Vec<ToolCallRecord>,
54    pub model_usage: std::collections::HashMap<String, ModelUsage>,
55    pub server_tool_use: ServerToolUse,
56    pub permission_denials: Vec<PermissionDenial>,
57    pub api_time_ms: u64,
58}
59
60#[derive(Debug, Clone, Default)]
61pub struct ToolStats {
62    pub calls: usize,
63    pub total_time_ms: u64,
64    pub errors: usize,
65}
66
67#[derive(Debug, Clone)]
68pub struct ToolCallRecord {
69    pub tool_use_id: String,
70    pub tool_name: String,
71    pub duration_ms: u64,
72    pub is_error: bool,
73}
74
75impl AgentMetrics {
76    pub fn total_tokens(&self) -> u32 {
77        self.input_tokens.saturating_add(self.output_tokens)
78    }
79
80    pub fn add_usage_with_cache(&mut self, usage: &Usage) {
81        self.input_tokens = self.input_tokens.saturating_add(usage.input_tokens);
82        self.output_tokens = self.output_tokens.saturating_add(usage.output_tokens);
83        self.cache_read_tokens = self
84            .cache_read_tokens
85            .saturating_add(usage.cache_read_input_tokens.unwrap_or(0));
86        self.cache_creation_tokens = self
87            .cache_creation_tokens
88            .saturating_add(usage.cache_creation_input_tokens.unwrap_or(0));
89    }
90
91    /// Calculate cache hit rate as a proportion of input tokens.
92    ///
93    /// Returns the ratio of cache_read_tokens to input_tokens.
94    /// A higher value means more tokens were served from cache.
95    pub fn cache_hit_rate(&self) -> f64 {
96        if self.input_tokens == 0 {
97            return 0.0;
98        }
99        self.cache_read_tokens as f64 / self.input_tokens as f64
100    }
101
102    /// Calculate cache efficiency (reads vs total cache operations).
103    ///
104    /// Returns 1.0 for perfect cache reuse (all reads, no writes),
105    /// and 0.0 when there's no cache activity.
106    ///
107    /// Per Anthropic pricing:
108    /// - Cache reads cost 10% of input tokens
109    /// - Cache writes cost 125% of input tokens
110    ///
111    /// Higher efficiency = better cost savings.
112    pub fn cache_efficiency(&self) -> f64 {
113        let total = self.cache_read_tokens + self.cache_creation_tokens;
114        if total == 0 {
115            return 0.0;
116        }
117        self.cache_read_tokens as f64 / total as f64
118    }
119
120    /// Estimate tokens saved through caching.
121    ///
122    /// Cache reads are billed at 10%, so 90% of read tokens are "saved".
123    pub fn cache_tokens_saved(&self) -> u32 {
124        (self.cache_read_tokens as f64 * 0.9) as u32
125    }
126
127    /// Calculate estimated cost savings from caching in USD.
128    ///
129    /// Per Anthropic pricing:
130    /// - Normal input: full price
131    /// - Cache read: 10% of normal price (90% savings)
132    /// - Cache write: 125% of normal price (25% overhead)
133    ///
134    /// Net savings = (cache_read * 0.9 * price) - (cache_write * 0.25 * price)
135    pub fn cache_cost_savings(&self, input_price_per_mtok: Decimal) -> Decimal {
136        let mtok_divisor = Decimal::from(1_000_000);
137        let read_tokens = Decimal::from(self.cache_read_tokens) / mtok_divisor;
138        let write_tokens = Decimal::from(self.cache_creation_tokens) / mtok_divisor;
139
140        // Savings from reading cached content (90% discount)
141        let read_savings = read_tokens * input_price_per_mtok * Decimal::new(9, 1); // 0.9
142        // Overhead from writing to cache (25% extra cost)
143        let write_overhead = write_tokens * input_price_per_mtok * Decimal::new(25, 2); // 0.25
144
145        read_savings - write_overhead
146    }
147
148    pub fn add_cost(&mut self, cost: Decimal) {
149        self.total_cost_usd += cost;
150    }
151
152    pub fn record_tool(&mut self, tool_use_id: &str, name: &str, duration_ms: u64, is_error: bool) {
153        self.tool_calls += 1;
154        let stats = self.tool_stats.entry(name.to_string()).or_default();
155        stats.calls += 1;
156        stats.total_time_ms += duration_ms;
157        if is_error {
158            stats.errors += 1;
159            self.errors += 1;
160        }
161        self.tool_call_records.push(ToolCallRecord {
162            tool_use_id: tool_use_id.to_string(),
163            tool_name: name.to_string(),
164            duration_ms,
165            is_error,
166        });
167    }
168
169    pub fn record_api_call(&mut self) {
170        self.api_calls += 1;
171    }
172
173    pub fn record_compaction(&mut self) {
174        self.compactions += 1;
175    }
176
177    pub fn avg_tool_time_ms(&self) -> f64 {
178        if self.tool_calls == 0 {
179            return 0.0;
180        }
181        let total: u64 = self.tool_stats.values().map(|s| s.total_time_ms).sum();
182        total as f64 / self.tool_calls as f64
183    }
184
185    /// Record usage for a specific model.
186    ///
187    /// This enables per-model cost tracking like CLI's modelUsage field.
188    pub fn record_model_usage(&mut self, model: &str, usage: &Usage) {
189        let entry = self.model_usage.entry(model.to_string()).or_default();
190        entry.add_usage(usage, model);
191    }
192
193    /// Record an API call with timing information.
194    pub fn record_api_call_with_timing(&mut self, duration_ms: u64) {
195        self.api_calls += 1;
196        self.api_time_ms += duration_ms;
197    }
198
199    /// Update server_tool_use from API response.
200    ///
201    /// This is for server-side tools executed by the API (e.g., Anthropic's
202    /// server-side RAG). Not to be confused with local tool usage.
203    pub fn update_server_tool_use(&mut self, server_tool_use: &ServerToolUse) {
204        self.server_tool_use.web_search_requests += server_tool_use.web_search_requests;
205        self.server_tool_use.web_fetch_requests += server_tool_use.web_fetch_requests;
206    }
207
208    /// Update server_tool_use from API response's usage.server_tool_use field.
209    ///
210    /// This parses the server tool usage directly from the API response.
211    pub fn update_server_tool_use_from_api(&mut self, usage: &ServerToolUseUsage) {
212        self.server_tool_use.add_from_usage(usage);
213    }
214
215    /// Record a permission denial.
216    pub fn record_permission_denial(&mut self, denial: PermissionDenial) {
217        self.permission_denials.push(denial);
218    }
219
220    /// Get the total cost across all models.
221    pub fn total_model_cost(&self) -> Decimal {
222        self.model_usage.values().map(|m| m.cost_usd).sum()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_agent_state() {
232        assert!(AgentState::Completed.is_terminal());
233        assert!(AgentState::Failed.is_terminal());
234        assert!(!AgentState::Running.is_terminal());
235
236        assert!(AgentState::WaitingForUserInput.is_waiting());
237        assert!(AgentState::PlanMode.is_waiting());
238        assert!(!AgentState::Running.is_waiting());
239
240        assert!(AgentState::Running.can_continue());
241        assert!(!AgentState::Completed.can_continue());
242    }
243
244    #[test]
245    fn test_agent_metrics() {
246        let mut metrics = AgentMetrics::default();
247        metrics.add_usage_with_cache(&Usage {
248            input_tokens: 100,
249            output_tokens: 50,
250            ..Default::default()
251        });
252        metrics.add_usage_with_cache(&Usage {
253            input_tokens: 200,
254            output_tokens: 100,
255            cache_read_input_tokens: Some(30),
256            ..Default::default()
257        });
258
259        assert_eq!(metrics.input_tokens, 300);
260        assert_eq!(metrics.output_tokens, 150);
261        assert_eq!(metrics.total_tokens(), 450);
262        assert_eq!(metrics.cache_read_tokens, 30);
263    }
264
265    #[test]
266    fn test_agent_metrics_tool_recording() {
267        let mut metrics = AgentMetrics::default();
268        metrics.record_tool("tu_1", "Read", 50, false);
269        metrics.record_tool("tu_2", "Read", 30, false);
270        metrics.record_tool("tu_3", "Bash", 100, true);
271
272        assert_eq!(metrics.tool_calls, 3);
273        assert_eq!(metrics.errors, 1);
274        assert_eq!(metrics.tool_stats.get("Read").unwrap().calls, 2);
275        assert_eq!(metrics.tool_stats.get("Read").unwrap().total_time_ms, 80);
276        assert_eq!(metrics.tool_stats.get("Bash").unwrap().errors, 1);
277        assert_eq!(metrics.tool_call_records.len(), 3);
278        assert_eq!(metrics.tool_call_records[0].tool_use_id, "tu_1");
279        assert!(metrics.tool_call_records[2].is_error);
280    }
281
282    #[test]
283    fn test_agent_metrics_avg_time() {
284        let mut metrics = AgentMetrics::default();
285        assert_eq!(metrics.avg_tool_time_ms(), 0.0);
286
287        metrics.record_tool("tu_1", "Read", 100, false);
288        metrics.record_tool("tu_2", "Write", 200, false);
289        assert!((metrics.avg_tool_time_ms() - 150.0).abs() < 0.1);
290    }
291
292    #[test]
293    fn test_cache_efficiency_no_activity() {
294        let metrics = AgentMetrics::default();
295        assert_eq!(metrics.cache_efficiency(), 0.0);
296    }
297
298    #[test]
299    fn test_cache_efficiency_all_reads() {
300        let metrics = AgentMetrics {
301            cache_read_tokens: 1000,
302            cache_creation_tokens: 0,
303            ..Default::default()
304        };
305
306        assert!((metrics.cache_efficiency() - 1.0).abs() < 0.001);
307    }
308
309    #[test]
310    fn test_cache_efficiency_mixed() {
311        let metrics = AgentMetrics {
312            cache_read_tokens: 900,
313            cache_creation_tokens: 100,
314            ..Default::default()
315        };
316
317        // 900 / (900 + 100) = 0.9
318        assert!((metrics.cache_efficiency() - 0.9).abs() < 0.001);
319    }
320
321    #[test]
322    fn test_cache_cost_savings() {
323        use rust_decimal_macros::dec;
324
325        let metrics = AgentMetrics {
326            cache_read_tokens: 1_000_000,   // 1M tokens
327            cache_creation_tokens: 100_000, // 100K tokens
328            ..Default::default()
329        };
330
331        let price_per_mtok = dec!(3); // $3 per MTok
332
333        // Read savings: 1.0 * 3.0 * 0.9 = $2.70
334        // Write overhead: 0.1 * 3.0 * 0.25 = $0.075
335        // Net savings: $2.70 - $0.075 = $2.625
336        let savings = metrics.cache_cost_savings(price_per_mtok);
337        assert_eq!(savings, dec!(2.625));
338    }
339
340    #[test]
341    fn test_cache_hit_rate() {
342        let metrics = AgentMetrics {
343            input_tokens: 1000,
344            cache_read_tokens: 800,
345            ..Default::default()
346        };
347
348        // 800 / 1000 = 0.8
349        assert!((metrics.cache_hit_rate() - 0.8).abs() < 0.001);
350    }
351
352    #[test]
353    fn test_cache_tokens_saved() {
354        let metrics = AgentMetrics {
355            cache_read_tokens: 1000,
356            ..Default::default()
357        };
358
359        // 1000 * 0.9 = 900
360        assert_eq!(metrics.cache_tokens_saved(), 900);
361    }
362}