1use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::warn;
9
10#[derive(Debug, Clone, Default)]
12pub struct TokenUsageTracker {
13 pub total_input_tokens: u64,
15 pub total_output_tokens: u64,
17 pub total_cost_usd: f64,
19 pub session_count: usize,
21}
22
23impl TokenUsageTracker {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn total_tokens(&self) -> u64 {
31 self.total_input_tokens + self.total_output_tokens
32 }
33
34 pub fn avg_tokens_per_session(&self) -> f64 {
36 if self.session_count == 0 {
37 0.0
38 } else {
39 self.total_tokens() as f64 / self.session_count as f64
40 }
41 }
42
43 pub fn avg_cost_per_session(&self) -> f64 {
45 if self.session_count == 0 {
46 0.0
47 } else {
48 self.total_cost_usd / self.session_count as f64
49 }
50 }
51
52 pub fn update(&mut self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
54 self.total_input_tokens += input_tokens;
55 self.total_output_tokens += output_tokens;
56 self.total_cost_usd += cost_usd;
57 self.session_count += 1;
58 }
59
60 pub fn reset(&mut self) {
62 self.total_input_tokens = 0;
63 self.total_output_tokens = 0;
64 self.total_cost_usd = 0.0;
65 self.session_count = 0;
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct BudgetLimit {
72 pub max_cost_usd: Option<f64>,
74 pub max_tokens: Option<u64>,
76 pub warning_threshold: f64,
78}
79
80impl Default for BudgetLimit {
81 fn default() -> Self {
82 Self {
83 max_cost_usd: None,
84 max_tokens: None,
85 warning_threshold: 0.8,
86 }
87 }
88}
89
90impl BudgetLimit {
91 pub fn with_cost(max_cost_usd: f64) -> Self {
93 Self {
94 max_cost_usd: Some(max_cost_usd),
95 ..Default::default()
96 }
97 }
98
99 pub fn with_tokens(max_tokens: u64) -> Self {
101 Self {
102 max_tokens: Some(max_tokens),
103 ..Default::default()
104 }
105 }
106
107 pub fn with_both(max_cost_usd: f64, max_tokens: u64) -> Self {
109 Self {
110 max_cost_usd: Some(max_cost_usd),
111 max_tokens: Some(max_tokens),
112 warning_threshold: 0.8,
113 }
114 }
115
116 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
118 self.warning_threshold = threshold.clamp(0.0, 1.0);
119 self
120 }
121
122 pub fn check_limits(&self, usage: &TokenUsageTracker) -> BudgetStatus {
124 let mut status = BudgetStatus::Ok;
125
126 if let Some(max_cost) = self.max_cost_usd {
128 let cost_ratio = usage.total_cost_usd / max_cost;
129
130 if cost_ratio >= 1.0 {
131 status = BudgetStatus::Exceeded;
132 } else if cost_ratio >= self.warning_threshold {
133 status = BudgetStatus::Warning {
134 current_ratio: cost_ratio,
135 message: format!(
136 "Cost usage at {:.1}% (${:.2}/${:.2})",
137 cost_ratio * 100.0,
138 usage.total_cost_usd,
139 max_cost
140 ),
141 };
142 }
143 }
144
145 if let Some(max_tokens) = self.max_tokens {
147 let token_ratio = usage.total_tokens() as f64 / max_tokens as f64;
148
149 if token_ratio >= 1.0 {
150 status = BudgetStatus::Exceeded;
151 } else if token_ratio >= self.warning_threshold {
152 if !matches!(status, BudgetStatus::Exceeded) {
154 status = BudgetStatus::Warning {
155 current_ratio: token_ratio,
156 message: format!(
157 "Token usage at {:.1}% ({}/{})",
158 token_ratio * 100.0,
159 usage.total_tokens(),
160 max_tokens
161 ),
162 };
163 }
164 }
165 }
166
167 status
168 }
169}
170
171#[derive(Debug, Clone, PartialEq)]
173pub enum BudgetStatus {
174 Ok,
176 Warning {
178 current_ratio: f64,
180 message: String,
182 },
183 Exceeded,
185}
186
187pub type BudgetWarningCallback = Arc<dyn Fn(&str) + Send + Sync>;
189
190#[derive(Clone)]
192pub struct BudgetManager {
193 tracker: Arc<RwLock<TokenUsageTracker>>,
194 limit: Arc<RwLock<Option<BudgetLimit>>>,
195 on_warning: Arc<RwLock<Option<BudgetWarningCallback>>>,
196 warning_fired: Arc<RwLock<bool>>,
197}
198
199impl BudgetManager {
200 pub fn new() -> Self {
202 Self {
203 tracker: Arc::new(RwLock::new(TokenUsageTracker::new())),
204 limit: Arc::new(RwLock::new(None)),
205 on_warning: Arc::new(RwLock::new(None)),
206 warning_fired: Arc::new(RwLock::new(false)),
207 }
208 }
209
210 pub async fn set_limit(&self, limit: BudgetLimit) {
212 *self.limit.write().await = Some(limit);
213 *self.warning_fired.write().await = false;
214 }
215
216 pub async fn set_warning_callback(&self, callback: BudgetWarningCallback) {
218 *self.on_warning.write().await = Some(callback);
219 }
220
221 pub async fn clear_limit(&self) {
223 *self.limit.write().await = None;
224 *self.warning_fired.write().await = false;
225 }
226
227 pub async fn get_usage(&self) -> TokenUsageTracker {
229 self.tracker.read().await.clone()
230 }
231
232 pub async fn update_usage(&self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
234 self.tracker.write().await.update(input_tokens, output_tokens, cost_usd);
236
237 if let Some(limit) = self.limit.read().await.as_ref() {
239 let usage = self.tracker.read().await.clone();
240 let status = limit.check_limits(&usage);
241
242 match status {
243 BudgetStatus::Warning { message, .. } => {
244 let mut fired = self.warning_fired.write().await;
245 if !*fired {
246 *fired = true;
247 warn!("Budget warning: {}", message);
248
249 if let Some(callback) = self.on_warning.read().await.as_ref() {
250 callback(&message);
251 }
252 }
253 }
254 BudgetStatus::Exceeded => {
255 warn!("Budget exceeded! Current usage: {} tokens, ${:.2}",
256 usage.total_tokens(), usage.total_cost_usd);
257
258 if let Some(callback) = self.on_warning.read().await.as_ref() {
259 callback("Budget limit exceeded");
260 }
261 }
262 BudgetStatus::Ok => {
263 *self.warning_fired.write().await = false;
265 }
266 }
267 }
268 }
269
270 pub async fn reset_usage(&self) {
272 self.tracker.write().await.reset();
273 *self.warning_fired.write().await = false;
274 }
275
276 pub async fn is_exceeded(&self) -> bool {
278 if let Some(limit) = self.limit.read().await.as_ref() {
279 let usage = self.tracker.read().await.clone();
280 matches!(limit.check_limits(&usage), BudgetStatus::Exceeded)
281 } else {
282 false
283 }
284 }
285}
286
287impl Default for BudgetManager {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_tracker_basics() {
299 let mut tracker = TokenUsageTracker::new();
300 assert_eq!(tracker.total_tokens(), 0);
301 assert_eq!(tracker.total_cost_usd, 0.0);
302
303 tracker.update(100, 200, 0.05);
304 assert_eq!(tracker.total_input_tokens, 100);
305 assert_eq!(tracker.total_output_tokens, 200);
306 assert_eq!(tracker.total_tokens(), 300);
307 assert_eq!(tracker.total_cost_usd, 0.05);
308 assert_eq!(tracker.session_count, 1);
309
310 tracker.update(50, 100, 0.02);
311 assert_eq!(tracker.total_tokens(), 450);
312 assert_eq!(tracker.total_cost_usd, 0.07);
313 assert_eq!(tracker.session_count, 2);
314 }
315
316 #[test]
317 fn test_budget_limits() {
318 let limit = BudgetLimit::with_cost(1.0).with_warning_threshold(0.8);
319
320 let mut tracker = TokenUsageTracker::new();
321 tracker.update(100, 200, 0.5);
322 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Ok));
323
324 tracker.update(100, 200, 0.35);
325 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Warning { .. }));
326
327 tracker.update(100, 200, 0.2);
328 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Exceeded));
329 }
330
331 #[tokio::test]
332 async fn test_budget_manager() {
333 let manager = BudgetManager::new();
334
335 manager.set_limit(BudgetLimit::with_tokens(1000)).await;
336 manager.update_usage(300, 200, 0.05).await;
337
338 let usage = manager.get_usage().await;
339 assert_eq!(usage.total_tokens(), 500);
340
341 assert!(!manager.is_exceeded().await);
342
343 manager.update_usage(300, 300, 0.05).await;
344 assert!(manager.is_exceeded().await);
345 }
346}