Skip to main content

jamjet_state/
budget.rs

1//! Budget state accumulator — tracks token and cost usage per execution.
2//!
3//! `BudgetState` is stored as a JSON sub-object under the `__budget` key
4//! in the execution's workflow state (inside `Snapshot.state`). This avoids
5//! a separate table or schema migration.
6
7use serde::{Deserialize, Serialize};
8
9/// Accumulated runtime budget for a single workflow execution.
10///
11/// Embedded in `Snapshot.state` under `"__budget"`. Updated after every
12/// `NodeCompleted` event that carries token / cost telemetry.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct BudgetState {
15    pub total_input_tokens: u64,
16    pub total_output_tokens: u64,
17    pub total_cost_usd: f64,
18    /// Number of agent reasoning iterations across all agent nodes.
19    pub iteration_count: u32,
20    /// Number of tool calls across all nodes.
21    pub tool_call_count: u32,
22    /// Consecutive error count for circuit-breaker logic.
23    pub consecutive_error_count: u32,
24    /// True once the circuit breaker has fired for this execution.
25    pub circuit_breaker_tripped: bool,
26}
27
28impl BudgetState {
29    /// Total tokens (input + output).
30    pub fn total_tokens(&self) -> u64 {
31        self.total_input_tokens + self.total_output_tokens
32    }
33
34    /// Load `BudgetState` from the `__budget` key of a snapshot state value.
35    pub fn from_snapshot_state(state: &serde_json::Value) -> Self {
36        state
37            .get("__budget")
38            .and_then(|v| serde_json::from_value(v.clone()).ok())
39            .unwrap_or_default()
40    }
41
42    /// Merge this `BudgetState` back into the snapshot state JSON as `__budget`.
43    pub fn patch_into_snapshot_state(&self, state: &mut serde_json::Value) {
44        if let Some(obj) = state.as_object_mut() {
45            obj.insert(
46                "__budget".to_string(),
47                serde_json::to_value(self).unwrap_or(serde_json::Value::Null),
48            );
49        }
50    }
51
52    /// Accumulate token / cost values from a completed model node.
53    pub fn accumulate(
54        &mut self,
55        input_tokens: Option<u64>,
56        output_tokens: Option<u64>,
57        cost_usd: Option<f64>,
58    ) {
59        self.total_input_tokens += input_tokens.unwrap_or(0);
60        self.total_output_tokens += output_tokens.unwrap_or(0);
61        self.total_cost_usd += cost_usd.unwrap_or(0.0);
62    }
63
64    /// Record a successful node — resets the consecutive error counter.
65    pub fn record_success(&mut self) {
66        self.consecutive_error_count = 0;
67    }
68
69    /// Record a failed node — increments the consecutive error counter.
70    pub fn record_error(&mut self) {
71        self.consecutive_error_count += 1;
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use serde_json::json;
79
80    #[test]
81    fn roundtrip_via_snapshot_state() {
82        let mut budget = BudgetState::default();
83        budget.accumulate(Some(100), Some(50), Some(0.002));
84        budget.iteration_count = 2;
85
86        let mut state = json!({ "answer": "hello" });
87        budget.patch_into_snapshot_state(&mut state);
88
89        let loaded = BudgetState::from_snapshot_state(&state);
90        assert_eq!(loaded.total_input_tokens, 100);
91        assert_eq!(loaded.total_output_tokens, 50);
92        assert!((loaded.total_cost_usd - 0.002).abs() < 1e-9);
93        assert_eq!(loaded.iteration_count, 2);
94
95        // The original state key is preserved.
96        assert_eq!(state["answer"], "hello");
97    }
98
99    #[test]
100    fn total_tokens() {
101        let mut b = BudgetState::default();
102        b.accumulate(Some(300), Some(200), None);
103        assert_eq!(b.total_tokens(), 500);
104    }
105}