Skip to main content

lean_ctx/core/
budget_tracker.rs

1//! Runtime budget tracking against role limits.
2//!
3//! Compares accumulated session counters with the active role's `RoleLimits`
4//! and produces `BudgetStatus` verdicts (Ok / Warning / Exhausted).
5
6use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::sync::OnceLock;
8
9use serde::Serialize;
10
11use crate::core::roles::{self, RoleLimits};
12
13static TRACKER: OnceLock<BudgetTracker> = OnceLock::new();
14
15pub struct BudgetTracker {
16    context_tokens: AtomicU64,
17    shell_invocations: AtomicUsize,
18    cost_millicents: AtomicU64,
19}
20
21impl BudgetTracker {
22    fn new() -> Self {
23        Self {
24            context_tokens: AtomicU64::new(0),
25            shell_invocations: AtomicUsize::new(0),
26            cost_millicents: AtomicU64::new(0),
27        }
28    }
29
30    pub fn global() -> &'static BudgetTracker {
31        TRACKER.get_or_init(BudgetTracker::new)
32    }
33
34    pub fn record_tokens(&self, tokens: u64) {
35        self.context_tokens.fetch_add(tokens, Ordering::Relaxed);
36    }
37
38    pub fn record_shell(&self) {
39        self.shell_invocations.fetch_add(1, Ordering::Relaxed);
40    }
41
42    pub fn record_cost_usd(&self, usd: f64) {
43        let mc = (usd * 100_000.0) as u64;
44        self.cost_millicents.fetch_add(mc, Ordering::Relaxed);
45    }
46
47    pub fn tokens_used(&self) -> u64 {
48        self.context_tokens.load(Ordering::Relaxed)
49    }
50
51    pub fn shell_used(&self) -> usize {
52        self.shell_invocations.load(Ordering::Relaxed)
53    }
54
55    pub fn cost_usd(&self) -> f64 {
56        self.cost_millicents.load(Ordering::Relaxed) as f64 / 100_000.0
57    }
58
59    pub fn reset(&self) {
60        self.context_tokens.store(0, Ordering::Relaxed);
61        self.shell_invocations.store(0, Ordering::Relaxed);
62        self.cost_millicents.store(0, Ordering::Relaxed);
63    }
64
65    pub fn check(&self) -> BudgetSnapshot {
66        let limits = roles::active_role().limits;
67        let role_name = roles::active_role_name();
68
69        let tokens = self.tokens_used();
70        let shell = self.shell_used();
71        let cost = self.cost_usd();
72
73        BudgetSnapshot {
74            role: role_name,
75            tokens: DimensionStatus::evaluate(tokens as usize, limits.max_context_tokens, &limits),
76            shell: DimensionStatus::evaluate(shell, limits.max_shell_invocations, &limits),
77            cost: CostStatus::evaluate(cost, limits.max_cost_usd, &limits),
78        }
79    }
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
83pub enum BudgetLevel {
84    Ok,
85    Warning,
86    Exhausted,
87}
88
89impl std::fmt::Display for BudgetLevel {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        match self {
92            Self::Ok => write!(f, "OK"),
93            Self::Warning => write!(f, "WARNING"),
94            Self::Exhausted => write!(f, "EXHAUSTED"),
95        }
96    }
97}
98
99#[derive(Debug, Clone, Serialize)]
100pub struct DimensionStatus {
101    pub used: usize,
102    pub limit: usize,
103    pub percent: u8,
104    pub level: BudgetLevel,
105}
106
107impl DimensionStatus {
108    fn evaluate(used: usize, limit: usize, limits: &RoleLimits) -> Self {
109        if limit == 0 {
110            return Self {
111                used,
112                limit,
113                percent: 0,
114                level: if used > 0 {
115                    BudgetLevel::Exhausted
116                } else {
117                    BudgetLevel::Ok
118                },
119            };
120        }
121        let percent = ((used as f64 / limit as f64) * 100.0).min(255.0) as u8;
122        let level = if percent >= limits.block_at_percent {
123            BudgetLevel::Exhausted
124        } else if percent >= limits.warn_at_percent {
125            BudgetLevel::Warning
126        } else {
127            BudgetLevel::Ok
128        };
129        Self {
130            used,
131            limit,
132            percent,
133            level,
134        }
135    }
136}
137
138#[derive(Debug, Clone, Serialize)]
139pub struct CostStatus {
140    pub used_usd: f64,
141    pub limit_usd: f64,
142    pub percent: u8,
143    pub level: BudgetLevel,
144}
145
146impl CostStatus {
147    fn evaluate(used: f64, limit: f64, limits: &RoleLimits) -> Self {
148        if limit <= 0.0 {
149            return Self {
150                used_usd: used,
151                limit_usd: limit,
152                percent: 0,
153                level: if used > 0.0 {
154                    BudgetLevel::Exhausted
155                } else {
156                    BudgetLevel::Ok
157                },
158            };
159        }
160        let pct = ((used / limit) * 100.0).min(255.0) as u8;
161        let level = if pct >= limits.block_at_percent {
162            BudgetLevel::Exhausted
163        } else if pct >= limits.warn_at_percent {
164            BudgetLevel::Warning
165        } else {
166            BudgetLevel::Ok
167        };
168        Self {
169            used_usd: used,
170            limit_usd: limit,
171            percent: pct,
172            level,
173        }
174    }
175}
176
177#[derive(Debug, Clone, Serialize)]
178pub struct BudgetSnapshot {
179    pub role: String,
180    pub tokens: DimensionStatus,
181    pub shell: DimensionStatus,
182    pub cost: CostStatus,
183}
184
185impl BudgetSnapshot {
186    pub fn worst_level(&self) -> &BudgetLevel {
187        for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
188            if *level == BudgetLevel::Exhausted {
189                return level;
190            }
191        }
192        for level in [&self.tokens.level, &self.shell.level, &self.cost.level] {
193            if *level == BudgetLevel::Warning {
194                return level;
195            }
196        }
197        &BudgetLevel::Ok
198    }
199
200    pub fn format_compact(&self) -> String {
201        format!(
202            "Budget[{}]: tokens {}/{} ({}%) | shell {}/{} ({}%) | cost ${:.2}/${:.2} ({}%) → {}",
203            self.role,
204            self.tokens.used,
205            self.tokens.limit,
206            self.tokens.percent,
207            self.shell.used,
208            self.shell.limit,
209            self.shell.percent,
210            self.cost.used_usd,
211            self.cost.limit_usd,
212            self.cost.percent,
213            self.worst_level(),
214        )
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn tracker_starts_at_zero() {
224        let t = BudgetTracker::new();
225        assert_eq!(t.tokens_used(), 0);
226        assert_eq!(t.shell_used(), 0);
227        assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
228    }
229
230    #[test]
231    fn record_and_read() {
232        let t = BudgetTracker::new();
233        t.record_tokens(5000);
234        t.record_tokens(3000);
235        t.record_shell();
236        t.record_shell();
237        t.record_cost_usd(0.50);
238        assert_eq!(t.tokens_used(), 8000);
239        assert_eq!(t.shell_used(), 2);
240        assert!((t.cost_usd() - 0.50).abs() < 0.001);
241    }
242
243    #[test]
244    fn reset_clears_all() {
245        let t = BudgetTracker::new();
246        t.record_tokens(10_000);
247        t.record_shell();
248        t.record_cost_usd(1.0);
249        t.reset();
250        assert_eq!(t.tokens_used(), 0);
251        assert_eq!(t.shell_used(), 0);
252        assert!((t.cost_usd() - 0.0).abs() < f64::EPSILON);
253    }
254
255    #[test]
256    fn dimension_status_ok() {
257        let limits = RoleLimits::default();
258        let s = DimensionStatus::evaluate(50_000, 200_000, &limits);
259        assert_eq!(s.level, BudgetLevel::Ok);
260        assert_eq!(s.percent, 25);
261    }
262
263    #[test]
264    fn dimension_status_warning() {
265        let limits = RoleLimits::default();
266        let s = DimensionStatus::evaluate(170_000, 200_000, &limits);
267        assert_eq!(s.level, BudgetLevel::Warning);
268        assert_eq!(s.percent, 85);
269    }
270
271    #[test]
272    fn dimension_status_exhausted() {
273        let limits = RoleLimits::default();
274        let s = DimensionStatus::evaluate(200_000, 200_000, &limits);
275        assert_eq!(s.level, BudgetLevel::Exhausted);
276        assert_eq!(s.percent, 100);
277    }
278
279    #[test]
280    fn zero_limit_blocks_usage() {
281        let limits = RoleLimits::default();
282        let s = DimensionStatus::evaluate(1, 0, &limits);
283        assert_eq!(s.level, BudgetLevel::Exhausted);
284    }
285
286    #[test]
287    fn cost_status_warning() {
288        let limits = RoleLimits::default();
289        let s = CostStatus::evaluate(4.5, 5.0, &limits);
290        assert_eq!(s.level, BudgetLevel::Warning);
291    }
292
293    #[test]
294    fn snapshot_worst_level() {
295        let limits = RoleLimits::default();
296        let snap = BudgetSnapshot {
297            role: "test".into(),
298            tokens: DimensionStatus::evaluate(50_000, 200_000, &limits),
299            shell: DimensionStatus::evaluate(90, 100, &limits),
300            cost: CostStatus::evaluate(1.0, 5.0, &limits),
301        };
302        assert_eq!(*snap.worst_level(), BudgetLevel::Warning);
303    }
304
305    #[test]
306    fn format_compact_includes_all() {
307        let s = BudgetSnapshot {
308            role: "coder".into(),
309            tokens: DimensionStatus {
310                used: 1000,
311                limit: 200_000,
312                percent: 0,
313                level: BudgetLevel::Ok,
314            },
315            shell: DimensionStatus {
316                used: 5,
317                limit: 100,
318                percent: 5,
319                level: BudgetLevel::Ok,
320            },
321            cost: CostStatus {
322                used_usd: 0.25,
323                limit_usd: 5.0,
324                percent: 5,
325                level: BudgetLevel::Ok,
326            },
327        };
328        let out = s.format_compact();
329        assert!(out.contains("coder"));
330        assert!(out.contains("tokens"));
331        assert!(out.contains("shell"));
332        assert!(out.contains("cost"));
333        assert!(out.contains("OK"));
334    }
335}