1use chrono::{DateTime, Utc};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use std::time::Duration;
13
14use crate::types::AgentId;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BudgetLimit {
19 pub agent_id: AgentId,
21 pub token_budget: u64,
23 pub calls_budget: u64,
25 pub window_secs: u64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Usage {
32 pub tokens_used: u64,
34 pub calls_used: u64,
36 pub window_start: DateTime<Utc>,
38}
39
40#[derive(Debug, Clone)]
42pub struct BudgetInfo {
43 pub tokens_remaining: u64,
45 pub calls_remaining: u64,
47 pub window_remaining_secs: u64,
49 pub is_exhausted: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct FullBudgetInfo {
57 pub agent_id: AgentId,
59 pub token_limit: u64,
61 pub tokens_used: u64,
63 pub tokens_remaining: u64,
65 pub calls_limit: u64,
67 pub calls_used: u64,
69 pub calls_remaining: u64,
71 pub window_secs: u64,
73 pub window_remaining_secs: u64,
75 pub is_exhausted: bool,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum BudgetKind {
82 Token,
84 Call,
86}
87
88#[derive(Debug, Clone)]
90pub struct BudgetExceeded {
91 pub agent_id: AgentId,
93 pub kind: BudgetKind,
95 pub message: String,
97}
98
99impl std::fmt::Display for BudgetExceeded {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "{}", self.message)
102 }
103}
104
105impl std::error::Error for BudgetExceeded {}
106
107pub struct BudgetManager {
109 budgets: RwLock<HashMap<AgentId, BudgetLimit>>,
110 usage: RwLock<HashMap<AgentId, Usage>>,
111}
112
113impl BudgetManager {
114 pub fn new() -> Self {
116 Self {
117 budgets: RwLock::new(HashMap::new()),
118 usage: RwLock::new(HashMap::new()),
119 }
120 }
121
122 pub fn set_budget(&self, limit: BudgetLimit) {
124 let agent_id = limit.agent_id;
125 let now = Utc::now();
126
127 {
128 let mut budgets = self.budgets.write();
129 budgets.insert(agent_id, limit);
130 }
131
132 let mut usage = self.usage.write();
134 usage.entry(agent_id).or_insert(Usage {
135 tokens_used: 0,
136 calls_used: 0,
137 window_start: now,
138 });
139 }
140
141 pub fn remove_budget(&self, agent_id: &AgentId) {
143 let mut budgets = self.budgets.write();
144 let mut usage = self.usage.write();
145 budgets.remove(agent_id);
146 usage.remove(agent_id);
147 }
148
149 pub fn reserve(&self, agent_id: &AgentId, tokens: u64) -> Result<(), BudgetExceeded> {
156 let limit = {
157 let budgets = self.budgets.read();
158 budgets.get(agent_id).cloned()
159 };
160
161 let limit = match limit {
162 Some(l) => l,
163 None => {
164 return Err(BudgetExceeded {
165 agent_id: *agent_id,
166 kind: BudgetKind::Token,
167 message: format!("No budget configured for agent {agent_id}"),
168 });
169 }
170 };
171
172 {
173 let mut usage = self.usage.write();
174 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
175 tokens_used: 0,
176 calls_used: 0,
177 window_start: Utc::now(),
178 });
179
180 reset_if_expired(usage_entry, limit.window_secs);
181
182 if usage_entry.tokens_used + tokens > limit.token_budget {
183 return Err(BudgetExceeded {
184 agent_id: *agent_id,
185 kind: BudgetKind::Token,
186 message: format!(
187 "Token budget exceeded: requested {} but only {} remaining",
188 tokens,
189 limit.token_budget.saturating_sub(usage_entry.tokens_used)
190 ),
191 });
192 }
193
194 usage_entry.tokens_used += tokens;
195 }
196
197 Ok(())
198 }
199
200 pub fn release(&self, agent_id: &AgentId, tokens_used: u64) {
204 let mut usage = self.usage.write();
205 if let Some(entry) = usage.get_mut(agent_id) {
206 entry.tokens_used = entry.tokens_used.saturating_sub(tokens_used);
207 }
208 }
209
210 pub fn track_call(&self, agent_id: &AgentId) -> Result<(), BudgetExceeded> {
214 let limit = {
215 let budgets = self.budgets.read();
216 budgets.get(agent_id).cloned()
217 };
218
219 let limit = match limit {
220 Some(l) => l,
221 None => {
222 return Err(BudgetExceeded {
223 agent_id: *agent_id,
224 kind: BudgetKind::Call,
225 message: format!("No budget configured for agent {agent_id}"),
226 });
227 }
228 };
229
230 {
231 let mut usage = self.usage.write();
232 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
233 tokens_used: 0,
234 calls_used: 0,
235 window_start: Utc::now(),
236 });
237
238 reset_if_expired(usage_entry, limit.window_secs);
239
240 if usage_entry.calls_used >= limit.calls_budget {
241 return Err(BudgetExceeded {
242 agent_id: *agent_id,
243 kind: BudgetKind::Call,
244 message: format!(
245 "Call budget exceeded: {} calls used, limit is {}",
246 usage_entry.calls_used, limit.calls_budget
247 ),
248 });
249 }
250
251 usage_entry.calls_used += 1;
252 }
253
254 Ok(())
255 }
256
257 pub fn remaining(&self, agent_id: &AgentId) -> BudgetInfo {
259 let limit = {
260 let budgets = self.budgets.read();
261 budgets.get(agent_id).cloned()
262 };
263
264 match limit {
265 Some(limit) => {
266 let usage = self.usage.read();
267 let usage_entry = usage.get(agent_id);
268
269 if let Some(entry) = usage_entry {
270 let elapsed = Utc::now()
271 .signed_duration_since(entry.window_start)
272 .to_std()
273 .unwrap_or(Duration::ZERO);
274 let window_expired = elapsed.as_secs() >= limit.window_secs;
275 let window_remaining = Duration::from_secs(limit.window_secs)
276 .saturating_sub(elapsed)
277 .as_secs();
278
279 let (tokens_remaining, calls_remaining) = if window_expired {
284 (limit.token_budget, limit.calls_budget)
285 } else {
286 (
287 limit.token_budget.saturating_sub(entry.tokens_used),
288 limit.calls_budget.saturating_sub(entry.calls_used),
289 )
290 };
291 let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
292
293 BudgetInfo {
294 tokens_remaining,
295 calls_remaining,
296 window_remaining_secs: window_remaining,
297 is_exhausted,
298 }
299 } else {
300 BudgetInfo {
301 tokens_remaining: limit.token_budget,
302 calls_remaining: limit.calls_budget,
303 window_remaining_secs: limit.window_secs,
304 is_exhausted: false,
305 }
306 }
307 }
308 None => BudgetInfo {
309 tokens_remaining: 0,
310 calls_remaining: 0,
311 window_remaining_secs: 0,
312 is_exhausted: true,
313 },
314 }
315 }
316
317 pub fn can_schedule(&self, agent_id: &AgentId) -> bool {
319 !self.remaining(agent_id).is_exhausted
320 }
321
322 pub fn reset_window(&self, agent_id: &AgentId) {
324 let mut usage = self.usage.write();
325 if let Some(entry) = usage.get_mut(agent_id) {
326 entry.tokens_used = 0;
327 entry.calls_used = 0;
328 entry.window_start = Utc::now();
329 }
330 }
331
332 pub fn full_info(&self, agent_id: &AgentId) -> Option<FullBudgetInfo> {
336 let limit = self.budgets.read().get(agent_id).cloned()?;
337
338 let usage = self.usage.read().get(agent_id).cloned();
339 let (tokens_used, calls_used, window_remaining_secs) = if let Some(entry) = usage {
340 let elapsed = Utc::now()
341 .signed_duration_since(entry.window_start)
342 .to_std()
343 .unwrap_or(Duration::ZERO);
344 let window_duration = Duration::from_secs(limit.window_secs);
345 let window_remaining = window_duration.saturating_sub(elapsed).as_secs();
346 let elapsed_secs = elapsed.as_secs();
347
348 if window_remaining == 0 && elapsed_secs >= limit.window_secs {
349 (0u64, 0u64, 0u64)
350 } else {
351 (entry.tokens_used, entry.calls_used, window_remaining)
352 }
353 } else {
354 (0u64, 0u64, limit.window_secs)
355 };
356
357 let tokens_remaining = limit.token_budget.saturating_sub(tokens_used);
358 let calls_remaining = limit.calls_budget.saturating_sub(calls_used);
359 let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
360
361 Some(FullBudgetInfo {
362 agent_id: *agent_id,
363 token_limit: limit.token_budget,
364 tokens_used,
365 tokens_remaining,
366 calls_limit: limit.calls_budget,
367 calls_used,
368 calls_remaining,
369 window_secs: limit.window_secs,
370 window_remaining_secs,
371 is_exhausted,
372 })
373 }
374
375 pub fn all_full_info(&self) -> Vec<FullBudgetInfo> {
377 let budgets = self.budgets.read();
378 budgets.keys().filter_map(|id| self.full_info(id)).collect()
379 }
380
381 pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
383 let budgets = self.budgets.read();
384 let usage = self.usage.read();
385 let data = PersistedBudgets {
386 budgets: budgets.clone(),
387 usage: usage.clone(),
388 };
389 if let Some(parent) = path.parent() {
390 std::fs::create_dir_all(parent)?;
391 }
392 let json = serde_json::to_string_pretty(&data)?;
393 std::fs::write(path, json)?;
394 Ok(())
395 }
396
397 pub fn restore(&self, path: &Path) -> anyhow::Result<()> {
402 if !path.exists() {
403 return Ok(());
404 }
405 let json = std::fs::read_to_string(path)?;
406 let data: PersistedBudgets = serde_json::from_str(&json)?;
407 {
408 let mut budgets = self.budgets.write();
409 *budgets = data.budgets;
410 }
411 {
412 let mut usage = self.usage.write();
413 *usage = data.usage;
414 }
415 Ok(())
416 }
417}
418
419#[derive(Serialize, Deserialize)]
421struct PersistedBudgets {
422 budgets: HashMap<AgentId, BudgetLimit>,
423 usage: HashMap<AgentId, Usage>,
424}
425
426fn reset_if_expired(usage: &mut Usage, window_secs: u64) {
428 let window_duration = chrono::Duration::seconds(window_secs as i64);
429 let elapsed = Utc::now().signed_duration_since(usage.window_start);
430 if elapsed >= window_duration {
431 usage.tokens_used = 0;
432 usage.calls_used = 0;
433 usage.window_start = Utc::now();
434 }
435}
436
437impl Default for BudgetManager {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use std::thread;
447
448 fn test_agent_id() -> AgentId {
449 uuid::Uuid::new_v4()
450 }
451
452 #[test]
453 fn test_budget_creation() {
454 let manager = BudgetManager::new();
455 let agent_id = test_agent_id();
456
457 let limit = BudgetLimit {
458 agent_id,
459 token_budget: 1000,
460 calls_budget: 10,
461 window_secs: 60,
462 };
463
464 manager.set_budget(limit.clone());
465
466 let info = manager.remaining(&agent_id);
467 assert_eq!(info.tokens_remaining, 1000);
468 assert_eq!(info.calls_remaining, 10);
469 assert!(!info.is_exhausted);
470 }
471
472 #[test]
473 fn test_reserve_success() {
474 let manager = BudgetManager::new();
475 let agent_id = test_agent_id();
476
477 let limit = BudgetLimit {
478 agent_id,
479 token_budget: 1000,
480 calls_budget: 10,
481 window_secs: 60,
482 };
483
484 manager.set_budget(limit);
485
486 let result = manager.reserve(&agent_id, 500);
488 assert!(result.is_ok());
489
490 let info = manager.remaining(&agent_id);
491 assert_eq!(info.tokens_remaining, 500);
492 }
493
494 #[test]
495 fn test_exhaust_tokens() {
496 let manager = BudgetManager::new();
497 let agent_id = test_agent_id();
498
499 let limit = BudgetLimit {
500 agent_id,
501 token_budget: 1000,
502 calls_budget: 10,
503 window_secs: 60,
504 };
505
506 manager.set_budget(limit);
507
508 let result = manager.reserve(&agent_id, 1000);
510 assert!(result.is_ok());
511
512 let result = manager.reserve(&agent_id, 1);
514 assert!(result.is_err());
515
516 let err = result.unwrap_err();
517 assert_eq!(err.agent_id, agent_id);
518 assert_eq!(err.kind, BudgetKind::Token);
519 }
520
521 #[test]
522 fn test_exhaust_calls() {
523 let manager = BudgetManager::new();
524 let agent_id = test_agent_id();
525
526 let limit = BudgetLimit {
527 agent_id,
528 token_budget: 1000,
529 calls_budget: 3,
530 window_secs: 60,
531 };
532
533 manager.set_budget(limit);
534
535 assert!(manager.track_call(&agent_id).is_ok());
537 assert!(manager.track_call(&agent_id).is_ok());
538 assert!(manager.track_call(&agent_id).is_ok());
539
540 let result = manager.track_call(&agent_id);
542 assert!(result.is_err());
543
544 let err = result.unwrap_err();
545 assert_eq!(err.agent_id, agent_id);
546 assert_eq!(err.kind, BudgetKind::Call);
547 }
548
549 #[test]
550 fn test_window_reset() {
551 let manager = BudgetManager::new();
552 let agent_id = test_agent_id();
553
554 let limit = BudgetLimit {
556 agent_id,
557 token_budget: 100,
558 calls_budget: 5,
559 window_secs: 1,
560 };
561
562 manager.set_budget(limit);
563
564 manager.reserve(&agent_id, 100).unwrap();
566 assert!(manager.reserve(&agent_id, 1).is_err());
567
568 thread::sleep(Duration::from_millis(1_100));
570
571 let result = manager.reserve(&agent_id, 50);
573 assert!(result.is_ok());
574
575 let info = manager.remaining(&agent_id);
576 assert_eq!(info.tokens_remaining, 50);
577 }
578
579 #[test]
580 fn test_can_schedule() {
581 let manager = BudgetManager::new();
582 let agent_id = test_agent_id();
583
584 let limit = BudgetLimit {
585 agent_id,
586 token_budget: 1000,
587 calls_budget: 10,
588 window_secs: 60,
589 };
590
591 manager.set_budget(limit);
592
593 assert!(manager.can_schedule(&agent_id));
595
596 for _ in 0..10 {
598 manager.track_call(&agent_id).unwrap();
599 }
600
601 assert!(!manager.can_schedule(&agent_id));
603 }
604
605 #[test]
606 fn test_no_budget_configured() {
607 let manager = BudgetManager::new();
608 let agent_id = test_agent_id();
609
610 let result = manager.reserve(&agent_id, 100);
612 assert!(result.is_err());
613
614 let err = result.unwrap_err();
615 assert!(err.message.contains("No budget configured"));
616
617 let result = manager.track_call(&agent_id);
619 assert!(result.is_err());
620 }
621
622 #[test]
623 fn test_remove_budget() {
624 let manager = BudgetManager::new();
625 let agent_id = test_agent_id();
626
627 let limit = BudgetLimit {
628 agent_id,
629 token_budget: 1000,
630 calls_budget: 10,
631 window_secs: 60,
632 };
633
634 manager.set_budget(limit);
635 manager.reserve(&agent_id, 100).unwrap();
636
637 manager.remove_budget(&agent_id);
638
639 let result = manager.reserve(&agent_id, 100);
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_release_tokens() {
646 let manager = BudgetManager::new();
647 let agent_id = test_agent_id();
648
649 let limit = BudgetLimit {
650 agent_id,
651 token_budget: 1000,
652 calls_budget: 10,
653 window_secs: 60,
654 };
655
656 manager.set_budget(limit);
657 manager.reserve(&agent_id, 500).unwrap();
658
659 let info_before = manager.remaining(&agent_id);
660 assert_eq!(info_before.tokens_remaining, 500);
661
662 manager.release(&agent_id, 200);
664
665 let info_after = manager.remaining(&agent_id);
666 assert_eq!(info_after.tokens_remaining, 700);
667 }
668
669 #[test]
670 fn test_reset_window() {
671 let manager = BudgetManager::new();
672 let agent_id = test_agent_id();
673
674 let limit = BudgetLimit {
675 agent_id,
676 token_budget: 1000,
677 calls_budget: 10,
678 window_secs: 60,
679 };
680
681 manager.set_budget(limit);
682 manager.reserve(&agent_id, 500).unwrap();
683
684 let info_before = manager.remaining(&agent_id);
685 assert_eq!(info_before.tokens_remaining, 500);
686
687 manager.reset_window(&agent_id);
689
690 let info_after = manager.remaining(&agent_id);
691 assert_eq!(info_after.tokens_remaining, 1000);
692 assert_eq!(info_after.calls_remaining, 10);
693 }
694
695 #[test]
696 fn test_multiple_agents() {
697 let manager = BudgetManager::new();
698 let agent1 = test_agent_id();
699 let agent2 = test_agent_id();
700
701 manager.set_budget(BudgetLimit {
702 agent_id: agent1,
703 token_budget: 1000,
704 calls_budget: 10,
705 window_secs: 60,
706 });
707
708 manager.set_budget(BudgetLimit {
709 agent_id: agent2,
710 token_budget: 500,
711 calls_budget: 5,
712 window_secs: 60,
713 });
714
715 manager.reserve(&agent1, 300).unwrap();
717
718 manager.reserve(&agent2, 200).unwrap();
720
721 let info1 = manager.remaining(&agent1);
722 let info2 = manager.remaining(&agent2);
723
724 assert_eq!(info1.tokens_remaining, 700);
725 assert_eq!(info2.tokens_remaining, 300);
726 }
727}