Skip to main content

atomr_agents_core/
budget.rs

1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4use crate::error::{AgentError, Result};
5
6/// Token budget threaded through every strategy resolution. Strategies
7/// `consume` from a shared budget; the `ContextAssembler` honors the
8/// final cap when packing fragments.
9#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
10pub struct TokenBudget {
11    pub remaining: u32,
12    pub reserved: u32,
13}
14
15impl TokenBudget {
16    pub fn new(total: u32) -> Self {
17        Self {
18            remaining: total,
19            reserved: 0,
20        }
21    }
22
23    pub fn consume(&mut self, n: u32) -> Result<()> {
24        if n > self.remaining {
25            return Err(AgentError::BudgetExceeded("tokens"));
26        }
27        self.remaining -= n;
28        Ok(())
29    }
30
31    pub fn reserve(&mut self, n: u32) -> Result<()> {
32        if n > self.remaining {
33            return Err(AgentError::BudgetExceeded("tokens"));
34        }
35        self.remaining -= n;
36        self.reserved += n;
37        Ok(())
38    }
39
40    pub fn release(&mut self, n: u32) {
41        let n = n.min(self.reserved);
42        self.reserved -= n;
43        self.remaining += n;
44    }
45
46    /// Split the *current* remaining budget into `n` equal slices for
47    /// cooperative parallel resolution. Each slice is independent;
48    /// after the parallel join, the caller sums what was actually used
49    /// and updates the parent.
50    pub fn split(&self, n: u32) -> Vec<TokenBudget> {
51        if n == 0 {
52            return Vec::new();
53        }
54        let per = self.remaining / n;
55        (0..n).map(|_| TokenBudget::new(per)).collect()
56    }
57}
58
59#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
60pub struct TimeBudget {
61    /// Wall-clock budget remaining as milliseconds.
62    pub remaining_ms: u64,
63}
64
65impl TimeBudget {
66    pub fn new(d: Duration) -> Self {
67        Self {
68            remaining_ms: d.as_millis().min(u64::MAX as u128) as u64,
69        }
70    }
71
72    pub fn consume(&mut self, d: Duration) -> Result<()> {
73        let ms = d.as_millis() as u64;
74        if ms > self.remaining_ms {
75            return Err(AgentError::BudgetExceeded("time"));
76        }
77        self.remaining_ms -= ms;
78        Ok(())
79    }
80}
81
82/// Money budget. Stored as integer micro-USD to avoid float drift.
83#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
84pub struct MoneyBudget {
85    pub remaining_micro_usd: u64,
86}
87
88impl MoneyBudget {
89    pub fn from_usd(usd: f64) -> Self {
90        Self {
91            remaining_micro_usd: (usd * 1_000_000.0) as u64,
92        }
93    }
94
95    pub fn consume_micro(&mut self, micro: u64) -> Result<()> {
96        if micro > self.remaining_micro_usd {
97            return Err(AgentError::BudgetExceeded("money"));
98        }
99        self.remaining_micro_usd -= micro;
100        Ok(())
101    }
102}
103
104#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
105pub struct IterationBudget {
106    pub remaining: u32,
107}
108
109impl IterationBudget {
110    pub fn new(n: u32) -> Self {
111        Self { remaining: n }
112    }
113
114    pub fn consume_one(&mut self) -> Result<()> {
115        if self.remaining == 0 {
116            return Err(AgentError::BudgetExceeded("iterations"));
117        }
118        self.remaining -= 1;
119        Ok(())
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn token_budget_consume_and_split() {
129        let mut b = TokenBudget::new(1000);
130        b.consume(100).unwrap();
131        assert_eq!(b.remaining, 900);
132        let parts = b.split(3);
133        assert_eq!(parts.len(), 3);
134        assert_eq!(parts[0].remaining, 300);
135    }
136
137    #[test]
138    fn budget_exceeded() {
139        let mut b = TokenBudget::new(10);
140        assert!(b.consume(11).is_err());
141    }
142
143    #[test]
144    fn iteration_budget() {
145        let mut b = IterationBudget::new(2);
146        assert!(b.consume_one().is_ok());
147        assert!(b.consume_one().is_ok());
148        assert!(b.consume_one().is_err());
149    }
150}