1use chrono::{DateTime, Utc};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use std::time::Duration;
13
14use crate::types::AgentId;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BudgetLimit {
19 pub agent_id: AgentId,
21 pub token_budget: u64,
23 pub calls_budget: u64,
25 pub window_secs: u64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Usage {
32 pub tokens_used: u64,
34 pub calls_used: u64,
36 pub window_start: DateTime<Utc>,
38}
39
40#[derive(Debug, Clone)]
42pub struct BudgetInfo {
43 pub tokens_remaining: u64,
45 pub calls_remaining: u64,
47 pub window_remaining_secs: u64,
49 pub is_exhausted: bool,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum BudgetKind {
56 Token,
58 Call,
60}
61
62#[derive(Debug, Clone)]
64pub struct BudgetExceeded {
65 pub agent_id: AgentId,
67 pub kind: BudgetKind,
69 pub message: String,
71}
72
73impl std::fmt::Display for BudgetExceeded {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 write!(f, "{}", self.message)
76 }
77}
78
79impl std::error::Error for BudgetExceeded {}
80
81pub struct BudgetManager {
83 budgets: RwLock<HashMap<AgentId, BudgetLimit>>,
84 usage: RwLock<HashMap<AgentId, Usage>>,
85}
86
87impl BudgetManager {
88 pub fn new() -> Self {
90 Self {
91 budgets: RwLock::new(HashMap::new()),
92 usage: RwLock::new(HashMap::new()),
93 }
94 }
95
96 pub fn set_budget(&self, limit: BudgetLimit) {
98 let agent_id = limit.agent_id;
99 let now = Utc::now();
100
101 {
102 let mut budgets = self.budgets.write();
103 budgets.insert(agent_id, limit);
104 }
105
106 let mut usage = self.usage.write();
108 usage.entry(agent_id).or_insert(Usage {
109 tokens_used: 0,
110 calls_used: 0,
111 window_start: now,
112 });
113 }
114
115 pub fn remove_budget(&self, agent_id: &AgentId) {
117 let mut budgets = self.budgets.write();
118 let mut usage = self.usage.write();
119 budgets.remove(agent_id);
120 usage.remove(agent_id);
121 }
122
123 pub fn reserve(&self, agent_id: &AgentId, tokens: u64) -> Result<(), BudgetExceeded> {
130 let limit = {
131 let budgets = self.budgets.read();
132 budgets.get(agent_id).cloned()
133 };
134
135 let limit = match limit {
136 Some(l) => l,
137 None => {
138 return Err(BudgetExceeded {
139 agent_id: *agent_id,
140 kind: BudgetKind::Token,
141 message: format!("No budget configured for agent {}", agent_id),
142 });
143 }
144 };
145
146 {
147 let mut usage = self.usage.write();
148 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
149 tokens_used: 0,
150 calls_used: 0,
151 window_start: Utc::now(),
152 });
153
154 reset_if_expired(usage_entry, limit.window_secs);
155
156 if usage_entry.tokens_used + tokens > limit.token_budget {
157 return Err(BudgetExceeded {
158 agent_id: *agent_id,
159 kind: BudgetKind::Token,
160 message: format!(
161 "Token budget exceeded: requested {} but only {} remaining",
162 tokens,
163 limit.token_budget.saturating_sub(usage_entry.tokens_used)
164 ),
165 });
166 }
167
168 usage_entry.tokens_used += tokens;
169 }
170
171 Ok(())
172 }
173
174 pub fn release(&self, agent_id: &AgentId, tokens_used: u64) {
178 let mut usage = self.usage.write();
179 if let Some(entry) = usage.get_mut(agent_id) {
180 entry.tokens_used = entry.tokens_used.saturating_sub(tokens_used);
181 }
182 }
183
184 pub fn track_call(&self, agent_id: &AgentId) -> Result<(), BudgetExceeded> {
188 let limit = {
189 let budgets = self.budgets.read();
190 budgets.get(agent_id).cloned()
191 };
192
193 let limit = match limit {
194 Some(l) => l,
195 None => {
196 return Err(BudgetExceeded {
197 agent_id: *agent_id,
198 kind: BudgetKind::Call,
199 message: format!("No budget configured for agent {}", agent_id),
200 });
201 }
202 };
203
204 {
205 let mut usage = self.usage.write();
206 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
207 tokens_used: 0,
208 calls_used: 0,
209 window_start: Utc::now(),
210 });
211
212 reset_if_expired(usage_entry, limit.window_secs);
213
214 if usage_entry.calls_used >= limit.calls_budget {
215 return Err(BudgetExceeded {
216 agent_id: *agent_id,
217 kind: BudgetKind::Call,
218 message: format!(
219 "Call budget exceeded: {} calls used, limit is {}",
220 usage_entry.calls_used, limit.calls_budget
221 ),
222 });
223 }
224
225 usage_entry.calls_used += 1;
226 }
227
228 Ok(())
229 }
230
231 pub fn remaining(&self, agent_id: &AgentId) -> BudgetInfo {
233 let limit = {
234 let budgets = self.budgets.read();
235 budgets.get(agent_id).cloned()
236 };
237
238 match limit {
239 Some(limit) => {
240 let usage = self.usage.read();
241 let usage_entry = usage.get(agent_id);
242
243 if let Some(entry) = usage_entry {
244 let elapsed = Utc::now()
245 .signed_duration_since(entry.window_start)
246 .to_std()
247 .unwrap_or(Duration::ZERO);
248 let window_remaining = Duration::from_secs(limit.window_secs)
249 .saturating_sub(elapsed)
250 .as_secs();
251
252 let tokens_remaining = limit.token_budget.saturating_sub(entry.tokens_used);
253 let calls_remaining = limit.calls_budget.saturating_sub(entry.calls_used);
254 let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
255
256 BudgetInfo {
257 tokens_remaining,
258 calls_remaining,
259 window_remaining_secs: window_remaining,
260 is_exhausted,
261 }
262 } else {
263 BudgetInfo {
264 tokens_remaining: limit.token_budget,
265 calls_remaining: limit.calls_budget,
266 window_remaining_secs: limit.window_secs,
267 is_exhausted: false,
268 }
269 }
270 }
271 None => BudgetInfo {
272 tokens_remaining: 0,
273 calls_remaining: 0,
274 window_remaining_secs: 0,
275 is_exhausted: true,
276 },
277 }
278 }
279
280 pub fn can_schedule(&self, agent_id: &AgentId) -> bool {
282 !self.remaining(agent_id).is_exhausted
283 }
284
285 pub fn reset_window(&self, agent_id: &AgentId) {
287 let mut usage = self.usage.write();
288 if let Some(entry) = usage.get_mut(agent_id) {
289 entry.tokens_used = 0;
290 entry.calls_used = 0;
291 entry.window_start = Utc::now();
292 }
293 }
294
295 pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
297 let budgets = self.budgets.read();
298 let usage = self.usage.read();
299 let data = PersistedBudgets {
300 budgets: budgets.clone(),
301 usage: usage.clone(),
302 };
303 if let Some(parent) = path.parent() {
304 std::fs::create_dir_all(parent)?;
305 }
306 let json = serde_json::to_string_pretty(&data)?;
307 std::fs::write(path, json)?;
308 Ok(())
309 }
310
311 pub fn restore(&self, path: &Path) -> anyhow::Result<()> {
316 if !path.exists() {
317 return Ok(());
318 }
319 let json = std::fs::read_to_string(path)?;
320 let data: PersistedBudgets = serde_json::from_str(&json)?;
321 {
322 let mut budgets = self.budgets.write();
323 *budgets = data.budgets;
324 }
325 {
326 let mut usage = self.usage.write();
327 *usage = data.usage;
328 }
329 Ok(())
330 }
331}
332
333#[derive(Serialize, Deserialize)]
335struct PersistedBudgets {
336 budgets: HashMap<AgentId, BudgetLimit>,
337 usage: HashMap<AgentId, Usage>,
338}
339
340fn reset_if_expired(usage: &mut Usage, window_secs: u64) {
342 let window_duration = chrono::Duration::seconds(window_secs as i64);
343 let elapsed = Utc::now().signed_duration_since(usage.window_start);
344 if elapsed >= window_duration {
345 usage.tokens_used = 0;
346 usage.calls_used = 0;
347 usage.window_start = Utc::now();
348 }
349}
350
351impl Default for BudgetManager {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use std::thread;
361
362 fn test_agent_id() -> AgentId {
363 uuid::Uuid::new_v4()
364 }
365
366 #[test]
367 fn test_budget_creation() {
368 let manager = BudgetManager::new();
369 let agent_id = test_agent_id();
370
371 let limit = BudgetLimit {
372 agent_id,
373 token_budget: 1000,
374 calls_budget: 10,
375 window_secs: 60,
376 };
377
378 manager.set_budget(limit.clone());
379
380 let info = manager.remaining(&agent_id);
381 assert_eq!(info.tokens_remaining, 1000);
382 assert_eq!(info.calls_remaining, 10);
383 assert!(!info.is_exhausted);
384 }
385
386 #[test]
387 fn test_reserve_success() {
388 let manager = BudgetManager::new();
389 let agent_id = test_agent_id();
390
391 let limit = BudgetLimit {
392 agent_id,
393 token_budget: 1000,
394 calls_budget: 10,
395 window_secs: 60,
396 };
397
398 manager.set_budget(limit);
399
400 let result = manager.reserve(&agent_id, 500);
402 assert!(result.is_ok());
403
404 let info = manager.remaining(&agent_id);
405 assert_eq!(info.tokens_remaining, 500);
406 }
407
408 #[test]
409 fn test_exhaust_tokens() {
410 let manager = BudgetManager::new();
411 let agent_id = test_agent_id();
412
413 let limit = BudgetLimit {
414 agent_id,
415 token_budget: 1000,
416 calls_budget: 10,
417 window_secs: 60,
418 };
419
420 manager.set_budget(limit);
421
422 let result = manager.reserve(&agent_id, 1000);
424 assert!(result.is_ok());
425
426 let result = manager.reserve(&agent_id, 1);
428 assert!(result.is_err());
429
430 let err = result.unwrap_err();
431 assert_eq!(err.agent_id, agent_id);
432 assert_eq!(err.kind, BudgetKind::Token);
433 }
434
435 #[test]
436 fn test_exhaust_calls() {
437 let manager = BudgetManager::new();
438 let agent_id = test_agent_id();
439
440 let limit = BudgetLimit {
441 agent_id,
442 token_budget: 1000,
443 calls_budget: 3,
444 window_secs: 60,
445 };
446
447 manager.set_budget(limit);
448
449 assert!(manager.track_call(&agent_id).is_ok());
451 assert!(manager.track_call(&agent_id).is_ok());
452 assert!(manager.track_call(&agent_id).is_ok());
453
454 let result = manager.track_call(&agent_id);
456 assert!(result.is_err());
457
458 let err = result.unwrap_err();
459 assert_eq!(err.agent_id, agent_id);
460 assert_eq!(err.kind, BudgetKind::Call);
461 }
462
463 #[test]
464 fn test_window_reset() {
465 let manager = BudgetManager::new();
466 let agent_id = test_agent_id();
467
468 let limit = BudgetLimit {
470 agent_id,
471 token_budget: 100,
472 calls_budget: 5,
473 window_secs: 1,
474 };
475
476 manager.set_budget(limit);
477
478 manager.reserve(&agent_id, 100).unwrap();
480 assert!(manager.reserve(&agent_id, 1).is_err());
481
482 thread::sleep(Duration::from_secs(2));
484
485 let result = manager.reserve(&agent_id, 50);
487 assert!(result.is_ok());
488
489 let info = manager.remaining(&agent_id);
490 assert_eq!(info.tokens_remaining, 50);
491 }
492
493 #[test]
494 fn test_can_schedule() {
495 let manager = BudgetManager::new();
496 let agent_id = test_agent_id();
497
498 let limit = BudgetLimit {
499 agent_id,
500 token_budget: 1000,
501 calls_budget: 10,
502 window_secs: 60,
503 };
504
505 manager.set_budget(limit);
506
507 assert!(manager.can_schedule(&agent_id));
509
510 for _ in 0..10 {
512 manager.track_call(&agent_id).unwrap();
513 }
514
515 assert!(!manager.can_schedule(&agent_id));
517 }
518
519 #[test]
520 fn test_no_budget_configured() {
521 let manager = BudgetManager::new();
522 let agent_id = test_agent_id();
523
524 let result = manager.reserve(&agent_id, 100);
526 assert!(result.is_err());
527
528 let err = result.unwrap_err();
529 assert!(err.message.contains("No budget configured"));
530
531 let result = manager.track_call(&agent_id);
533 assert!(result.is_err());
534 }
535
536 #[test]
537 fn test_remove_budget() {
538 let manager = BudgetManager::new();
539 let agent_id = test_agent_id();
540
541 let limit = BudgetLimit {
542 agent_id,
543 token_budget: 1000,
544 calls_budget: 10,
545 window_secs: 60,
546 };
547
548 manager.set_budget(limit);
549 manager.reserve(&agent_id, 100).unwrap();
550
551 manager.remove_budget(&agent_id);
552
553 let result = manager.reserve(&agent_id, 100);
555 assert!(result.is_err());
556 }
557
558 #[test]
559 fn test_release_tokens() {
560 let manager = BudgetManager::new();
561 let agent_id = test_agent_id();
562
563 let limit = BudgetLimit {
564 agent_id,
565 token_budget: 1000,
566 calls_budget: 10,
567 window_secs: 60,
568 };
569
570 manager.set_budget(limit);
571 manager.reserve(&agent_id, 500).unwrap();
572
573 let info_before = manager.remaining(&agent_id);
574 assert_eq!(info_before.tokens_remaining, 500);
575
576 manager.release(&agent_id, 200);
578
579 let info_after = manager.remaining(&agent_id);
580 assert_eq!(info_after.tokens_remaining, 700);
581 }
582
583 #[test]
584 fn test_reset_window() {
585 let manager = BudgetManager::new();
586 let agent_id = test_agent_id();
587
588 let limit = BudgetLimit {
589 agent_id,
590 token_budget: 1000,
591 calls_budget: 10,
592 window_secs: 60,
593 };
594
595 manager.set_budget(limit);
596 manager.reserve(&agent_id, 500).unwrap();
597
598 let info_before = manager.remaining(&agent_id);
599 assert_eq!(info_before.tokens_remaining, 500);
600
601 manager.reset_window(&agent_id);
603
604 let info_after = manager.remaining(&agent_id);
605 assert_eq!(info_after.tokens_remaining, 1000);
606 assert_eq!(info_after.calls_remaining, 10);
607 }
608
609 #[test]
610 fn test_multiple_agents() {
611 let manager = BudgetManager::new();
612 let agent1 = test_agent_id();
613 let agent2 = test_agent_id();
614
615 manager.set_budget(BudgetLimit {
616 agent_id: agent1,
617 token_budget: 1000,
618 calls_budget: 10,
619 window_secs: 60,
620 });
621
622 manager.set_budget(BudgetLimit {
623 agent_id: agent2,
624 token_budget: 500,
625 calls_budget: 5,
626 window_secs: 60,
627 });
628
629 manager.reserve(&agent1, 300).unwrap();
631
632 manager.reserve(&agent2, 200).unwrap();
634
635 let info1 = manager.remaining(&agent1);
636 let info2 = manager.remaining(&agent2);
637
638 assert_eq!(info1.tokens_remaining, 700);
639 assert_eq!(info2.tokens_remaining, 300);
640 }
641}