1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::OnceLock;
8
9use serde::Serialize;
10
11use crate::core::roles::{self, RoleLimits};
12
13static TRACKER: OnceLock<BudgetTracker> = OnceLock::new();
14
15pub struct BudgetTracker {
16 context_tokens: AtomicU64,
17 shell_invocations: AtomicUsize,
18 cost_millicents: AtomicU64,
19}
20
21impl BudgetTracker {
22 fn new() -> Self {
23 Self {
24 context_tokens: AtomicU64::new(0),
25 shell_invocations: AtomicUsize::new(0),
26 cost_millicents: AtomicU64::new(0),
27 }
28 }
29
30 pub fn global() -> &'static BudgetTracker {
31 TRACKER.get_or_init(BudgetTracker::new)
32 }
33
34 pub fn record_tokens(&self, tokens: u64) {
35 self.context_tokens.fetch_add(tokens, Ordering::Relaxed);
36 }
37
38 pub fn record_shell(&self) {
39 self.shell_invocations.fetch_add(1, Ordering::Relaxed);
40 }
41
42 pub fn record_cost_usd(&self, usd: f64) {
43 let mc = (usd * 100_000.0) as u64;
44 self.cost_millicents.fetch_add(mc, Ordering::Relaxed);
45 }
46
47 pub fn tokens_used(&self) -> u64 {
48 self.context_tokens.load(Ordering::Relaxed)
49 }
50
51 pub fn shell_used(&self) -> usize {
52 self.shell_invocations.load(Ordering::Relaxed)
53 }
54
55 pub fn cost_usd(&self) -> f64 {
56 self.cost_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
57 }
58
59 pub fn reset(&self) {
60 self.context_tokens.store(0, Ordering::Relaxed);
61 self.shell_invocations.store(0, Ordering::Relaxed);
62 self.cost_millicents.store(0, Ordering::Relaxed);
63 }
64
65 pub fn check(&self) -> BudgetSnapshot {
66 let limits = roles::active_role().limits;
67 let role_name = roles::active_role_name();
68
69 let tokens = self.tokens_used();
70 let shell = self.shell_used();
71 let cost = self.cost_usd();
72
73 BudgetSnapshot {
74 role: role_name,
75 tokens: DimensionStatus::evaluate(tokens as usize, limits.max_context_tokens, &limits),
76 shell: DimensionStatus::evaluate(shell, limits.max_shell_invocations, &limits),
77 cost: CostStatus::evaluate(cost, limits.max_cost_usd, &limits),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
83pub enum BudgetLevel {
84 Ok,
85 Warning,
86 Exhausted,
87}
88
89impl std::fmt::Display for BudgetLevel {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 match self {
92 Self::Ok => write!(f, "OK"),
93 Self::Warning => write!(f, "WARNING"),
94 Self::Exhausted => write!(f, "EXHAUSTED"),
95 }
96 }
97}
98
99#[derive(Debug, Clone, Serialize)]
100pub struct DimensionStatus {
101 pub used: usize,
102 pub limit: usize,
103 pub percent: u8,
104 pub level: BudgetLevel,
105}
106
107impl DimensionStatus {
108 fn evaluate(used: usize, limit: usize, limits: &RoleLimits) -> Self {
109 if limit == 0 {
110 return Self {
111 used,
112 limit,
113 percent: 0,
114 level: if used > 0 {
115 BudgetLevel::Exhausted
116 } else {
117 BudgetLevel::Ok
118 },
119 };
120 }
121 let percent = ((used as f64 / limit as f64) * 100.0).min(255.0) as u8;
122 let level = if percent >= limits.block_at_percent {
123 BudgetLevel::Exhausted
124 } else if percent >= limits.warn_at_percent {
125 BudgetLevel::Warning
126 } else {
127 BudgetLevel::Ok
128 };
129 Self {
130 used,
131 limit,
132 percent,
133 level,
134 }
135 }
136}
137
138#[derive(Debug, Clone, Serialize)]
139pub struct CostStatus {
140 pub used_usd: f64,
141 pub limit_usd: f64,
142 pub percent: u8,
143 pub level: BudgetLevel,
144}
145
146impl CostStatus {
147 fn evaluate(used: f64, limit: f64, limits: &RoleLimits) -> Self {
148 if limit <= 0.0 {
149 return Self {
150 used_usd: used,
151 limit_usd: limit,
152 percent: 0,
153 level: if used > 0.0 {
154 BudgetLevel::Exhausted
155 } else {
156 BudgetLevel::Ok
157 },
158 };
159 }
160 let pct = ((used / limit) * 100.0).min(255.0) as u8;
161 let level = if pct >= limits.block_at_percent {
162 BudgetLevel::Exhausted
163 } else if pct >= limits.warn_at_percent {
164 BudgetLevel::Warning
165 } else {
166 BudgetLevel::Ok
167 };
168 Self {
169 used_usd: used,
170 limit_usd: limit,
171 percent: pct,
172 level,
173 }
174 }
175}
176
177#[derive(Debug, Clone, Serialize)]
178pub struct BudgetSnapshot {
179 pub role: String,
180 pub tokens: DimensionStatus,
181 pub shell: DimensionStatus,
182 pub cost: CostStatus,
183}
184
185impl BudgetSnapshot {
186 pub fn worst_level(&self) -> &BudgetLevel {
187 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
188 if *level == BudgetLevel::Exhausted {
189 return level;
190 }
191 }
192 for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
193 if *level == BudgetLevel::Warning {
194 return level;
195 }
196 }
197 &BudgetLevel::Ok
198 }
199
200 pub fn format_compact(&self) -> String {
201 format!(
202 "Budget[{}]: tokens {}/{} ({}%) | shell {}/{} ({}%) | cost ${:.2}/${:.2} ({}%) → {}",
203 self.role,
204 self.tokens.used,
205 self.tokens.limit,
206 self.tokens.percent,
207 self.shell.used,
208 self.shell.limit,
209 self.shell.percent,
210 self.cost.used_usd,
211 self.cost.limit_usd,
212 self.cost.percent,
213 self.worst_level(),
214 )
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn tracker_starts_at_zero() {
224 let t = BudgetTracker::new();
225 assert_eq!(t.tokens_used(), 0);
226 assert_eq!(t.shell_used(), 0);
227 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
228 }
229
230 #[test]
231 fn record_and_read() {
232 let t = BudgetTracker::new();
233 t.record_tokens(5000);
234 t.record_tokens(3000);
235 t.record_shell();
236 t.record_shell();
237 t.record_cost_usd(0.50);
238 assert_eq!(t.tokens_used(), 8000);
239 assert_eq!(t.shell_used(), 2);
240 assert!((t.cost_usd() - 0.50).abs() < 0.001);
241 }
242
243 #[test]
244 fn reset_clears_all() {
245 let t = BudgetTracker::new();
246 t.record_tokens(10_000);
247 t.record_shell();
248 t.record_cost_usd(1.0);
249 t.reset();
250 assert_eq!(t.tokens_used(), 0);
251 assert_eq!(t.shell_used(), 0);
252 assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
253 }
254
255 #[test]
256 fn dimension_status_ok() {
257 let limits = RoleLimits::default();
258 let s = DimensionStatus::evaluate(50_000, 200_000, &limits);
259 assert_eq!(s.level, BudgetLevel::Ok);
260 assert_eq!(s.percent, 25);
261 }
262
263 #[test]
264 fn dimension_status_warning() {
265 let limits = RoleLimits::default();
266 let s = DimensionStatus::evaluate(170_000, 200_000, &limits);
267 assert_eq!(s.level, BudgetLevel::Warning);
268 assert_eq!(s.percent, 85);
269 }
270
271 #[test]
272 fn dimension_status_exhausted() {
273 let limits = RoleLimits::default();
274 let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
275 assert_eq!(s.level, BudgetLevel::Exhausted);
276 assert_eq!(s.percent, 100);
277 }
278
279 #[test]
280 fn zero_limit_blocks_usage() {
281 let limits = RoleLimits::default();
282 let s = DimensionStatus::evaluate(1, 0, &limits);
283 assert_eq!(s.level, BudgetLevel::Exhausted);
284 }
285
286 #[test]
287 fn cost_status_warning() {
288 let limits = RoleLimits::default();
289 let s = CostStatus::evaluate(4.5, 5.0, &limits);
290 assert_eq!(s.level, BudgetLevel::Warning);
291 }
292
293 #[test]
294 fn snapshot_worst_level() {
295 let limits = RoleLimits::default();
296 let snap = BudgetSnapshot {
297 role: "test".into(),
298 tokens: DimensionStatus::evaluate(50_000, 200_000, &limits),
299 shell: DimensionStatus::evaluate(90, 100, &limits),
300 cost: CostStatus::evaluate(1.0, 5.0, &limits),
301 };
302 assert_eq!(*snap.worst_level(), BudgetLevel::Warning);
303 }
304
305 #[test]
306 fn format_compact_includes_all() {
307 let s = BudgetSnapshot {
308 role: "coder".into(),
309 tokens: DimensionStatus {
310 used: 1000,
311 limit: 200_000,
312 percent: 0,
313 level: BudgetLevel::Ok,
314 },
315 shell: DimensionStatus {
316 used: 5,
317 limit: 100,
318 percent: 5,
319 level: BudgetLevel::Ok,
320 },
321 cost: CostStatus {
322 used_usd: 0.25,
323 limit_usd: 5.0,
324 percent: 5,
325 level: BudgetLevel::Ok,
326 },
327 };
328 let out = s.format_compact();
329 assert!(out.contains("coder"));
330 assert!(out.contains("tokens"));
331 assert!(out.contains("shell"));
332 assert!(out.contains("cost"));
333 assert!(out.contains("OK"));
334 }
335}