1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::OnceLock;
8
9use serde::Serialize;
10
11use crate::core::roles::{self, RoleLimits};
12
13static TRACKER: OnceLock<BudgetTracker> = OnceLock::new();
14
15pub struct BudgetTracker {
16 context_tokens: AtomicU64,
17 shell_invocations: AtomicUsize,
18 cost_millicents: AtomicU64,
19 tool_calls: AtomicUsize,
20}
21
22impl BudgetTracker {
23 fn new() -> Self {
24 Self {
25 context_tokens: AtomicU64::new(0),
26 shell_invocations: AtomicUsize::new(0),
27 cost_millicents: AtomicU64::new(0),
28 tool_calls: AtomicUsize::new(0),
29 }
30 }
31
32 pub fn global() -> &'static BudgetTracker {
33 TRACKER.get_or_init(BudgetTracker::new)
34 }
35
36 pub fn record_tokens(&self, tokens: u64) {
37 self.context_tokens.fetch_add(tokens, Ordering::Relaxed);
38 }
39
40 pub fn record_shell(&self) {
41 self.shell_invocations.fetch_add(1, Ordering::Relaxed);
42 }
43
44 pub fn record_tool_call(&self) {
45 self.tool_calls.fetch_add(1, Ordering::Relaxed);
46 }
47
48 pub fn tool_calls_count(&self) -> usize {
49 self.tool_calls.load(Ordering::Relaxed)
50 }
51
52 pub fn record_cost_usd(&self, usd: f64) {
53 let mc = (usd * 100_000.0) as u64;
54 self.cost_millicents.fetch_add(mc, Ordering::Relaxed);
55 }
56
57 pub fn tokens_used(&self) -> u64 {
58 self.context_tokens.load(Ordering::Relaxed)
59 }
60
61 pub fn shell_used(&self) -> usize {
62 self.shell_invocations.load(Ordering::Relaxed)
63 }
64
65 pub fn cost_usd(&self) -> f64 {
66 self.cost_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
67 }
68
69 pub fn reset(&self) {
70 self.context_tokens.store(0, Ordering::Relaxed);
71 self.shell_invocations.store(0, Ordering::Relaxed);
72 self.cost_millicents.store(0, Ordering::Relaxed);
73 self.tool_calls.store(0, Ordering::Relaxed);
74 }
75
76 pub fn check(&self) -> BudgetSnapshot {
77 let limits = roles::active_role().limits;
78 let role_name = roles::active_role_name();
79
80 let tokens = self.tokens_used();
81 let shell = self.shell_used();
82 let cost = self.cost_usd();
83
84 BudgetSnapshot {
85 role: role_name,
86 tokens: DimensionStatus::evaluate(tokens as usize, limits.max_context_tokens, &limits),
87 shell: DimensionStatus::evaluate(shell, limits.max_shell_invocations, &limits),
88 cost: CostStatus::evaluate(cost, limits.max_cost_usd, &limits),
89 }
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
94pub enum BudgetLevel {
95 Ok,
96 Warning,
97 Exhausted,
98}
99
100impl std::fmt::Display for BudgetLevel {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 match self {
103 Self::Ok => write!(f, "OK"),
104 Self::Warning => write!(f, "WARNING"),
105 Self::Exhausted => write!(f, "EXHAUSTED"),
106 }
107 }
108}
109
110#[derive(Debug, Clone, Serialize)]
111pub struct DimensionStatus {
112 pub used: usize,
113 pub limit: usize,
114 pub percent: u8,
115 pub level: BudgetLevel,
116}
117
118impl DimensionStatus {
119 fn evaluate(used: usize, limit: usize, limits: &RoleLimits) -> Self {
120 if limit == 0 {
121 return Self {
123 used,
124 limit,
125 percent: 0,
126 level: if used > 0 {
127 BudgetLevel::Warning
128 } else {
129 BudgetLevel::Ok
130 },
131 };
132 }
133 let percent = ((used as f64 / limit as f64) * 100.0).min(254.0) as u8;
134 let level = if limits.block_at_percent < 255 && percent >= limits.block_at_percent {
136 BudgetLevel::Exhausted
137 } else if percent >= limits.warn_at_percent {
138 BudgetLevel::Warning
139 } else {
140 BudgetLevel::Ok
141 };
142 Self {
143 used,
144 limit,
145 percent,
146 level,
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize)]
152pub struct CostStatus {
153 pub used_usd: f64,
154 pub limit_usd: f64,
155 pub percent: u8,
156 pub level: BudgetLevel,
157}
158
159impl CostStatus {
160 fn evaluate(used: f64, limit: f64, limits: &RoleLimits) -> Self {
161 if limit <= 0.0 {
162 return Self {
164 used_usd: used,
165 limit_usd: limit,
166 percent: 0,
167 level: if used > 0.0 {
168 BudgetLevel::Warning
169 } else {
170 BudgetLevel::Ok
171 },
172 };
173 }
174 let pct = ((used / limit) * 100.0).min(254.0) as u8;
175 let level = if limits.block_at_percent < 255 && pct >= limits.block_at_percent {
177 BudgetLevel::Exhausted
178 } else if pct >= limits.warn_at_percent {
179 BudgetLevel::Warning
180 } else {
181 BudgetLevel::Ok
182 };
183 Self {
184 used_usd: used,
185 limit_usd: limit,
186 percent: pct,
187 level,
188 }
189 }
190}
191
192#[derive(Debug, Clone, Serialize)]
193pub struct BudgetSnapshot {
194 pub role: String,
195 pub tokens: DimensionStatus,
196 pub shell: DimensionStatus,
197 pub cost: CostStatus,
198}
199
200impl BudgetSnapshot {
201 pub fn worst_level(&self) -> &BudgetLevel {
202 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
203 if *level == BudgetLevel::Exhausted {
204 return level;
205 }
206 }
207 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
208 if *level == BudgetLevel::Warning {
209 return level;
210 }
211 }
212 &BudgetLevel::Ok
213 }
214
215 pub fn format_compact(&self) -> String {
216 format!(
217 "Budget[{}]: tokens {}/{} ({}%) | shell {}/{} ({}%) | cost ${:.2}/${:.2} ({}%) → {}",
218 self.role,
219 self.tokens.used,
220 self.tokens.limit,
221 self.tokens.percent,
222 self.shell.used,
223 self.shell.limit,
224 self.shell.percent,
225 self.cost.used_usd,
226 self.cost.limit_usd,
227 self.cost.percent,
228 self.worst_level(),
229 )
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn tracker_starts_at_zero() {
239 let t = BudgetTracker::new();
240 assert_eq!(t.tokens_used(), 0);
241 assert_eq!(t.shell_used(), 0);
242 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
243 }
244
245 #[test]
246 fn record_and_read() {
247 let t = BudgetTracker::new();
248 t.record_tokens(5000);
249 t.record_tokens(3000);
250 t.record_shell();
251 t.record_shell();
252 t.record_cost_usd(0.50);
253 assert_eq!(t.tokens_used(), 8000);
254 assert_eq!(t.shell_used(), 2);
255 assert!((t.cost_usd() - 0.50).abs() < 0.001);
256 }
257
258 #[test]
259 fn reset_clears_all() {
260 let t = BudgetTracker::new();
261 t.record_tokens(10_000);
262 t.record_shell();
263 t.record_cost_usd(1.0);
264 t.reset();
265 assert_eq!(t.tokens_used(), 0);
266 assert_eq!(t.shell_used(), 0);
267 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
268 }
269
270 #[test]
271 fn dimension_status_ok() {
272 let limits = RoleLimits::default();
273 let s = DimensionStatus::evaluate(50_000, 200_000, &limits);
274 assert_eq!(s.level, BudgetLevel::Ok);
275 assert_eq!(s.percent, 25);
276 }
277
278 #[test]
279 fn dimension_status_warning() {
280 let limits = RoleLimits::default();
281 let s = DimensionStatus::evaluate(170_000, 200_000, &limits);
282 assert_eq!(s.level, BudgetLevel::Warning);
283 assert_eq!(s.percent, 85);
284 }
285
286 #[test]
287 fn dimension_status_at_100_percent_is_warning_by_default() {
288 let limits = RoleLimits::default();
290 assert_eq!(limits.block_at_percent, 255); let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
292 assert_eq!(s.level, BudgetLevel::Warning);
293 assert_eq!(s.percent, 100);
294 }
295
296 #[test]
297 fn dimension_status_exhausted_when_blocking_enabled() {
298 let limits = RoleLimits {
300 block_at_percent: 100,
301 ..Default::default()
302 };
303 let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
304 assert_eq!(s.level, BudgetLevel::Exhausted);
305 }
306
307 #[test]
308 fn zero_limit_warns_usage() {
309 let limits = RoleLimits::default();
311 let s = DimensionStatus::evaluate(1, 0, &limits);
312 assert_eq!(s.level, BudgetLevel::Warning);
313 }
314
315 #[test]
316 fn cost_status_warning() {
317 let limits = RoleLimits::default();
318 let s = CostStatus::evaluate(4.5, 5.0, &limits);
319 assert_eq!(s.level, BudgetLevel::Warning);
320 }
321
322 #[test]
323 fn snapshot_worst_level() {
324 let limits = RoleLimits::default();
325 let snap = BudgetSnapshot {
326 role: "test".into(),
327 tokens: DimensionStatus::evaluate(50_000, 200_000, &limits),
328 shell: DimensionStatus::evaluate(90, 100, &limits),
329 cost: CostStatus::evaluate(1.0, 5.0, &limits),
330 };
331 assert_eq!(*snap.worst_level(), BudgetLevel::Warning);
332 }
333
334 #[test]
335 fn format_compact_includes_all() {
336 let s = BudgetSnapshot {
337 role: "coder".into(),
338 tokens: DimensionStatus {
339 used: 1000,
340 limit: 200_000,
341 percent: 0,
342 level: BudgetLevel::Ok,
343 },
344 shell: DimensionStatus {
345 used: 5,
346 limit: 100,
347 percent: 5,
348 level: BudgetLevel::Ok,
349 },
350 cost: CostStatus {
351 used_usd: 0.25,
352 limit_usd: 5.0,
353 percent: 5,
354 level: BudgetLevel::Ok,
355 },
356 };
357 let out = s.format_compact();
358 assert!(out.contains("coder"));
359 assert!(out.contains("tokens"));
360 assert!(out.contains("shell"));
361 assert!(out.contains("cost"));
362 assert!(out.contains("OK"));
363 }
364}