use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BudgetState {
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cost_usd: f64,
pub iteration_count: u32,
pub tool_call_count: u32,
pub consecutive_error_count: u32,
pub circuit_breaker_tripped: bool,
}
impl BudgetState {
pub fn total_tokens(&self) -> u64 {
self.total_input_tokens + self.total_output_tokens
}
pub fn from_snapshot_state(state: &serde_json::Value) -> Self {
state
.get("__budget")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default()
}
pub fn patch_into_snapshot_state(&self, state: &mut serde_json::Value) {
if let Some(obj) = state.as_object_mut() {
obj.insert(
"__budget".to_string(),
serde_json::to_value(self).unwrap_or(serde_json::Value::Null),
);
}
}
pub fn accumulate(
&mut self,
input_tokens: Option<u64>,
output_tokens: Option<u64>,
cost_usd: Option<f64>,
) {
self.total_input_tokens += input_tokens.unwrap_or(0);
self.total_output_tokens += output_tokens.unwrap_or(0);
self.total_cost_usd += cost_usd.unwrap_or(0.0);
}
pub fn record_success(&mut self) {
self.consecutive_error_count = 0;
}
pub fn record_error(&mut self) {
self.consecutive_error_count += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn roundtrip_via_snapshot_state() {
let mut budget = BudgetState::default();
budget.accumulate(Some(100), Some(50), Some(0.002));
budget.iteration_count = 2;
let mut state = json!({ "answer": "hello" });
budget.patch_into_snapshot_state(&mut state);
let loaded = BudgetState::from_snapshot_state(&state);
assert_eq!(loaded.total_input_tokens, 100);
assert_eq!(loaded.total_output_tokens, 50);
assert!((loaded.total_cost_usd - 0.002).abs() < 1e-9);
assert_eq!(loaded.iteration_count, 2);
assert_eq!(state["answer"], "hello");
}
#[test]
fn total_tokens() {
let mut b = BudgetState::default();
b.accumulate(Some(300), Some(200), None);
assert_eq!(b.total_tokens(), 500);
}
}