1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct BudgetState {
15 pub total_input_tokens: u64,
16 pub total_output_tokens: u64,
17 pub total_cost_usd: f64,
18 pub iteration_count: u32,
20 pub tool_call_count: u32,
22 pub consecutive_error_count: u32,
24 pub circuit_breaker_tripped: bool,
26}
27
28impl BudgetState {
29 pub fn total_tokens(&self) -> u64 {
31 self.total_input_tokens + self.total_output_tokens
32 }
33
34 pub fn from_snapshot_state(state: &serde_json::Value) -> Self {
36 state
37 .get("__budget")
38 .and_then(|v| serde_json::from_value(v.clone()).ok())
39 .unwrap_or_default()
40 }
41
42 pub fn patch_into_snapshot_state(&self, state: &mut serde_json::Value) {
44 if let Some(obj) = state.as_object_mut() {
45 obj.insert(
46 "__budget".to_string(),
47 serde_json::to_value(self).unwrap_or(serde_json::Value::Null),
48 );
49 }
50 }
51
52 pub fn accumulate(
54 &mut self,
55 input_tokens: Option<u64>,
56 output_tokens: Option<u64>,
57 cost_usd: Option<f64>,
58 ) {
59 self.total_input_tokens += input_tokens.unwrap_or(0);
60 self.total_output_tokens += output_tokens.unwrap_or(0);
61 self.total_cost_usd += cost_usd.unwrap_or(0.0);
62 }
63
64 pub fn record_success(&mut self) {
66 self.consecutive_error_count = 0;
67 }
68
69 pub fn record_error(&mut self) {
71 self.consecutive_error_count += 1;
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use serde_json::json;
79
80 #[test]
81 fn roundtrip_via_snapshot_state() {
82 let mut budget = BudgetState::default();
83 budget.accumulate(Some(100), Some(50), Some(0.002));
84 budget.iteration_count = 2;
85
86 let mut state = json!({ "answer": "hello" });
87 budget.patch_into_snapshot_state(&mut state);
88
89 let loaded = BudgetState::from_snapshot_state(&state);
90 assert_eq!(loaded.total_input_tokens, 100);
91 assert_eq!(loaded.total_output_tokens, 50);
92 assert!((loaded.total_cost_usd - 0.002).abs() < 1e-9);
93 assert_eq!(loaded.iteration_count, 2);
94
95 assert_eq!(state["answer"], "hello");
97 }
98
99 #[test]
100 fn total_tokens() {
101 let mut b = BudgetState::default();
102 b.accumulate(Some(300), Some(200), None);
103 assert_eq!(b.total_tokens(), 500);
104 }
105}