1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::OnceLock;
8
9use serde::Serialize;
10
11use crate::core::roles::{self, RoleLimits};
12
13static TRACKER: OnceLock<BudgetTracker> = OnceLock::new();
14
15pub struct BudgetTracker {
16 context_tokens: AtomicU64,
17 shell_invocations: AtomicUsize,
18 cost_millicents: AtomicU64,
19}
20
21impl BudgetTracker {
22 fn new() -> Self {
23 Self {
24 context_tokens: AtomicU64::new(0),
25 shell_invocations: AtomicUsize::new(0),
26 cost_millicents: AtomicU64::new(0),
27 }
28 }
29
30 pub fn global() -> &'static BudgetTracker {
31 TRACKER.get_or_init(BudgetTracker::new)
32 }
33
34 pub fn record_tokens(&self, tokens: u64) {
35 self.context_tokens.fetch_add(tokens, Ordering::Relaxed);
36 }
37
38 pub fn record_shell(&self) {
39 self.shell_invocations.fetch_add(1, Ordering::Relaxed);
40 }
41
42 pub fn record_cost_usd(&self, usd: f64) {
43 let mc = (usd * 100_000.0) as u64;
44 self.cost_millicents.fetch_add(mc, Ordering::Relaxed);
45 }
46
47 pub fn tokens_used(&self) -> u64 {
48 self.context_tokens.load(Ordering::Relaxed)
49 }
50
51 pub fn shell_used(&self) -> usize {
52 self.shell_invocations.load(Ordering::Relaxed)
53 }
54
55 pub fn cost_usd(&self) -> f64 {
56 self.cost_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
57 }
58
59 pub fn reset(&self) {
60 self.context_tokens.store(0, Ordering::Relaxed);
61 self.shell_invocations.store(0, Ordering::Relaxed);
62 self.cost_millicents.store(0, Ordering::Relaxed);
63 }
64
65 pub fn check(&self) -> BudgetSnapshot {
66 let limits = roles::active_role().limits;
67 let role_name = roles::active_role_name();
68
69 let tokens = self.tokens_used();
70 let shell = self.shell_used();
71 let cost = self.cost_usd();
72
73 BudgetSnapshot {
74 role: role_name,
75 tokens: DimensionStatus::evaluate(tokens as usize, limits.max_context_tokens, &limits),
76 shell: DimensionStatus::evaluate(shell, limits.max_shell_invocations, &limits),
77 cost: CostStatus::evaluate(cost, limits.max_cost_usd, &limits),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
83pub enum BudgetLevel {
84 Ok,
85 Warning,
86 Exhausted,
87}
88
89impl std::fmt::Display for BudgetLevel {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 Self::Ok => write!(f, "OK"),
93 Self::Warning => write!(f, "WARNING"),
94 Self::Exhausted => write!(f, "EXHAUSTED"),
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize)]
100pub struct DimensionStatus {
101 pub used: usize,
102 pub limit: usize,
103 pub percent: u8,
104 pub level: BudgetLevel,
105}
106
107impl DimensionStatus {
108 fn evaluate(used: usize, limit: usize, limits: &RoleLimits) -> Self {
109 if limit == 0 {
110 return Self {
112 used,
113 limit,
114 percent: 0,
115 level: if used > 0 {
116 BudgetLevel::Warning
117 } else {
118 BudgetLevel::Ok
119 },
120 };
121 }
122 let percent = ((used as f64 / limit as f64) * 100.0).min(254.0) as u8;
123 let level = if limits.block_at_percent < 255 && percent >= limits.block_at_percent {
125 BudgetLevel::Exhausted
126 } else if percent >= limits.warn_at_percent {
127 BudgetLevel::Warning
128 } else {
129 BudgetLevel::Ok
130 };
131 Self {
132 used,
133 limit,
134 percent,
135 level,
136 }
137 }
138}
139
140#[derive(Debug, Clone, Serialize)]
141pub struct CostStatus {
142 pub used_usd: f64,
143 pub limit_usd: f64,
144 pub percent: u8,
145 pub level: BudgetLevel,
146}
147
148impl CostStatus {
149 fn evaluate(used: f64, limit: f64, limits: &RoleLimits) -> Self {
150 if limit <= 0.0 {
151 return Self {
153 used_usd: used,
154 limit_usd: limit,
155 percent: 0,
156 level: if used > 0.0 {
157 BudgetLevel::Warning
158 } else {
159 BudgetLevel::Ok
160 },
161 };
162 }
163 let pct = ((used / limit) * 100.0).min(254.0) as u8;
164 let level = if limits.block_at_percent < 255 && pct >= limits.block_at_percent {
166 BudgetLevel::Exhausted
167 } else if pct >= limits.warn_at_percent {
168 BudgetLevel::Warning
169 } else {
170 BudgetLevel::Ok
171 };
172 Self {
173 used_usd: used,
174 limit_usd: limit,
175 percent: pct,
176 level,
177 }
178 }
179}
180
181#[derive(Debug, Clone, Serialize)]
182pub struct BudgetSnapshot {
183 pub role: String,
184 pub tokens: DimensionStatus,
185 pub shell: DimensionStatus,
186 pub cost: CostStatus,
187}
188
189impl BudgetSnapshot {
190 pub fn worst_level(&self) -> &BudgetLevel {
191 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
192 if *level == BudgetLevel::Exhausted {
193 return level;
194 }
195 }
196 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
197 if *level == BudgetLevel::Warning {
198 return level;
199 }
200 }
201 &BudgetLevel::Ok
202 }
203
204 pub fn format_compact(&self) -> String {
205 format!(
206 "Budget[{}]: tokens {}/{} ({}%) | shell {}/{} ({}%) | cost ${:.2}/${:.2} ({}%) → {}",
207 self.role,
208 self.tokens.used,
209 self.tokens.limit,
210 self.tokens.percent,
211 self.shell.used,
212 self.shell.limit,
213 self.shell.percent,
214 self.cost.used_usd,
215 self.cost.limit_usd,
216 self.cost.percent,
217 self.worst_level(),
218 )
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn tracker_starts_at_zero() {
228 let t = BudgetTracker::new();
229 assert_eq!(t.tokens_used(), 0);
230 assert_eq!(t.shell_used(), 0);
231 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
232 }
233
234 #[test]
235 fn record_and_read() {
236 let t = BudgetTracker::new();
237 t.record_tokens(5000);
238 t.record_tokens(3000);
239 t.record_shell();
240 t.record_shell();
241 t.record_cost_usd(0.50);
242 assert_eq!(t.tokens_used(), 8000);
243 assert_eq!(t.shell_used(), 2);
244 assert!((t.cost_usd() - 0.50).abs() < 0.001);
245 }
246
247 #[test]
248 fn reset_clears_all() {
249 let t = BudgetTracker::new();
250 t.record_tokens(10_000);
251 t.record_shell();
252 t.record_cost_usd(1.0);
253 t.reset();
254 assert_eq!(t.tokens_used(), 0);
255 assert_eq!(t.shell_used(), 0);
256 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
257 }
258
259 #[test]
260 fn dimension_status_ok() {
261 let limits = RoleLimits::default();
262 let s = DimensionStatus::evaluate(50_000, 200_000, &limits);
263 assert_eq!(s.level, BudgetLevel::Ok);
264 assert_eq!(s.percent, 25);
265 }
266
267 #[test]
268 fn dimension_status_warning() {
269 let limits = RoleLimits::default();
270 let s = DimensionStatus::evaluate(170_000, 200_000, &limits);
271 assert_eq!(s.level, BudgetLevel::Warning);
272 assert_eq!(s.percent, 85);
273 }
274
275 #[test]
276 fn dimension_status_at_100_percent_is_warning_by_default() {
277 let limits = RoleLimits::default();
279 assert_eq!(limits.block_at_percent, 255); let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
281 assert_eq!(s.level, BudgetLevel::Warning);
282 assert_eq!(s.percent, 100);
283 }
284
285 #[test]
286 fn dimension_status_exhausted_when_blocking_enabled() {
287 let limits = RoleLimits {
289 block_at_percent: 100,
290 ..Default::default()
291 };
292 let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
293 assert_eq!(s.level, BudgetLevel::Exhausted);
294 }
295
296 #[test]
297 fn zero_limit_warns_usage() {
298 let limits = RoleLimits::default();
300 let s = DimensionStatus::evaluate(1, 0, &limits);
301 assert_eq!(s.level, BudgetLevel::Warning);
302 }
303
304 #[test]
305 fn cost_status_warning() {
306 let limits = RoleLimits::default();
307 let s = CostStatus::evaluate(4.5, 5.0, &limits);
308 assert_eq!(s.level, BudgetLevel::Warning);
309 }
310
311 #[test]
312 fn snapshot_worst_level() {
313 let limits = RoleLimits::default();
314 let snap = BudgetSnapshot {
315 role: "test".into(),
316 tokens: DimensionStatus::evaluate(50_000, 200_000, &limits),
317 shell: DimensionStatus::evaluate(90, 100, &limits),
318 cost: CostStatus::evaluate(1.0, 5.0, &limits),
319 };
320 assert_eq!(*snap.worst_level(), BudgetLevel::Warning);
321 }
322
323 #[test]
324 fn format_compact_includes_all() {
325 let s = BudgetSnapshot {
326 role: "coder".into(),
327 tokens: DimensionStatus {
328 used: 1000,
329 limit: 200_000,
330 percent: 0,
331 level: BudgetLevel::Ok,
332 },
333 shell: DimensionStatus {
334 used: 5,
335 limit: 100,
336 percent: 5,
337 level: BudgetLevel::Ok,
338 },
339 cost: CostStatus {
340 used_usd: 0.25,
341 limit_usd: 5.0,
342 percent: 5,
343 level: BudgetLevel::Ok,
344 },
345 };
346 let out = s.format_compact();
347 assert!(out.contains("coder"));
348 assert!(out.contains("tokens"));
349 assert!(out.contains("shell"));
350 assert!(out.contains("cost"));
351 assert!(out.contains("OK"));
352 }
353}