1use chrono::{DateTime, Utc};
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use std::time::Duration;
13
14use crate::types::AgentId;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BudgetLimit {
19 pub agent_id: AgentId,
21 pub token_budget: u64,
23 pub calls_budget: u64,
25 pub window_secs: u64,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Usage {
32 pub tokens_used: u64,
34 pub calls_used: u64,
36 pub window_start: DateTime<Utc>,
38}
39
40#[derive(Debug, Clone)]
42pub struct BudgetInfo {
43 pub tokens_remaining: u64,
45 pub calls_remaining: u64,
47 pub window_remaining_secs: u64,
49 pub is_exhausted: bool,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct FullBudgetInfo {
57 pub agent_id: AgentId,
59 pub token_limit: u64,
61 pub tokens_used: u64,
63 pub tokens_remaining: u64,
65 pub calls_limit: u64,
67 pub calls_used: u64,
69 pub calls_remaining: u64,
71 pub window_secs: u64,
73 pub window_remaining_secs: u64,
75 pub is_exhausted: bool,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum BudgetKind {
82 Token,
84 Call,
86}
87
88#[derive(Debug, Clone)]
90pub struct BudgetExceeded {
91 pub agent_id: AgentId,
93 pub kind: BudgetKind,
95 pub message: String,
97}
98
99impl std::fmt::Display for BudgetExceeded {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 write!(f, "{}", self.message)
102 }
103}
104
105impl std::error::Error for BudgetExceeded {}
106
107pub struct BudgetManager {
109 budgets: RwLock<HashMap<AgentId, BudgetLimit>>,
110 usage: RwLock<HashMap<AgentId, Usage>>,
111}
112
113impl BudgetManager {
114 pub fn new() -> Self {
116 Self {
117 budgets: RwLock::new(HashMap::new()),
118 usage: RwLock::new(HashMap::new()),
119 }
120 }
121
122 pub fn set_budget(&self, limit: BudgetLimit) {
124 let agent_id = limit.agent_id;
125 let now = Utc::now();
126
127 {
128 let mut budgets = self.budgets.write();
129 budgets.insert(agent_id, limit);
130 }
131
132 let mut usage = self.usage.write();
134 usage.entry(agent_id).or_insert(Usage {
135 tokens_used: 0,
136 calls_used: 0,
137 window_start: now,
138 });
139 }
140
141 pub fn remove_budget(&self, agent_id: &AgentId) {
143 let mut budgets = self.budgets.write();
144 let mut usage = self.usage.write();
145 budgets.remove(agent_id);
146 usage.remove(agent_id);
147 }
148
149 pub fn reserve(&self, agent_id: &AgentId, tokens: u64) -> Result<(), BudgetExceeded> {
156 let limit = {
157 let budgets = self.budgets.read();
158 budgets.get(agent_id).cloned()
159 };
160
161 let limit = match limit {
162 Some(l) => l,
163 None => {
164 return Err(BudgetExceeded {
165 agent_id: *agent_id,
166 kind: BudgetKind::Token,
167 message: format!("No budget configured for agent {agent_id}"),
168 });
169 }
170 };
171
172 {
173 let mut usage = self.usage.write();
174 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
175 tokens_used: 0,
176 calls_used: 0,
177 window_start: Utc::now(),
178 });
179
180 reset_if_expired(usage_entry, limit.window_secs);
181
182 if usage_entry.tokens_used + tokens > limit.token_budget {
183 return Err(BudgetExceeded {
184 agent_id: *agent_id,
185 kind: BudgetKind::Token,
186 message: format!(
187 "Token budget exceeded: requested {} but only {} remaining",
188 tokens,
189 limit.token_budget.saturating_sub(usage_entry.tokens_used)
190 ),
191 });
192 }
193
194 usage_entry.tokens_used += tokens;
195 }
196
197 Ok(())
198 }
199
200 pub fn release(&self, agent_id: &AgentId, tokens_used: u64) {
204 let mut usage = self.usage.write();
205 if let Some(entry) = usage.get_mut(agent_id) {
206 entry.tokens_used = entry.tokens_used.saturating_sub(tokens_used);
207 }
208 }
209
210 pub fn track_call(&self, agent_id: &AgentId) -> Result<(), BudgetExceeded> {
214 let limit = {
215 let budgets = self.budgets.read();
216 budgets.get(agent_id).cloned()
217 };
218
219 let limit = match limit {
220 Some(l) => l,
221 None => {
222 return Err(BudgetExceeded {
223 agent_id: *agent_id,
224 kind: BudgetKind::Call,
225 message: format!("No budget configured for agent {agent_id}"),
226 });
227 }
228 };
229
230 {
231 let mut usage = self.usage.write();
232 let usage_entry = usage.entry(*agent_id).or_insert_with(|| Usage {
233 tokens_used: 0,
234 calls_used: 0,
235 window_start: Utc::now(),
236 });
237
238 reset_if_expired(usage_entry, limit.window_secs);
239
240 if usage_entry.calls_used >= limit.calls_budget {
241 return Err(BudgetExceeded {
242 agent_id: *agent_id,
243 kind: BudgetKind::Call,
244 message: format!(
245 "Call budget exceeded: {} calls used, limit is {}",
246 usage_entry.calls_used, limit.calls_budget
247 ),
248 });
249 }
250
251 usage_entry.calls_used += 1;
252 }
253
254 Ok(())
255 }
256
257 pub fn remaining(&self, agent_id: &AgentId) -> BudgetInfo {
259 let limit = {
260 let budgets = self.budgets.read();
261 budgets.get(agent_id).cloned()
262 };
263
264 match limit {
265 Some(limit) => {
266 let usage = self.usage.read();
267 let usage_entry = usage.get(agent_id);
268
269 if let Some(entry) = usage_entry {
270 let elapsed = Utc::now()
271 .signed_duration_since(entry.window_start)
272 .to_std()
273 .unwrap_or(Duration::ZERO);
274 let window_remaining = Duration::from_secs(limit.window_secs)
275 .saturating_sub(elapsed)
276 .as_secs();
277
278 let tokens_remaining = limit.token_budget.saturating_sub(entry.tokens_used);
279 let calls_remaining = limit.calls_budget.saturating_sub(entry.calls_used);
280 let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
281
282 BudgetInfo {
283 tokens_remaining,
284 calls_remaining,
285 window_remaining_secs: window_remaining,
286 is_exhausted,
287 }
288 } else {
289 BudgetInfo {
290 tokens_remaining: limit.token_budget,
291 calls_remaining: limit.calls_budget,
292 window_remaining_secs: limit.window_secs,
293 is_exhausted: false,
294 }
295 }
296 }
297 None => BudgetInfo {
298 tokens_remaining: 0,
299 calls_remaining: 0,
300 window_remaining_secs: 0,
301 is_exhausted: true,
302 },
303 }
304 }
305
306 pub fn can_schedule(&self, agent_id: &AgentId) -> bool {
308 !self.remaining(agent_id).is_exhausted
309 }
310
311 pub fn reset_window(&self, agent_id: &AgentId) {
313 let mut usage = self.usage.write();
314 if let Some(entry) = usage.get_mut(agent_id) {
315 entry.tokens_used = 0;
316 entry.calls_used = 0;
317 entry.window_start = Utc::now();
318 }
319 }
320
321 pub fn full_info(&self, agent_id: &AgentId) -> Option<FullBudgetInfo> {
325 let limit = self.budgets.read().get(agent_id).cloned()?;
326
327 let usage = self.usage.read().get(agent_id).cloned();
328 let (tokens_used, calls_used, window_remaining_secs) = if let Some(entry) = usage {
329 let elapsed = Utc::now()
330 .signed_duration_since(entry.window_start)
331 .to_std()
332 .unwrap_or(Duration::ZERO);
333 let window_duration = Duration::from_secs(limit.window_secs);
334 let window_remaining = window_duration.saturating_sub(elapsed).as_secs();
335 let elapsed_secs = elapsed.as_secs();
336
337 if window_remaining == 0 && elapsed_secs >= limit.window_secs {
338 (0u64, 0u64, 0u64)
339 } else {
340 (entry.tokens_used, entry.calls_used, window_remaining)
341 }
342 } else {
343 (0u64, 0u64, limit.window_secs)
344 };
345
346 let tokens_remaining = limit.token_budget.saturating_sub(tokens_used);
347 let calls_remaining = limit.calls_budget.saturating_sub(calls_used);
348 let is_exhausted = tokens_remaining == 0 || calls_remaining == 0;
349
350 Some(FullBudgetInfo {
351 agent_id: *agent_id,
352 token_limit: limit.token_budget,
353 tokens_used,
354 tokens_remaining,
355 calls_limit: limit.calls_budget,
356 calls_used,
357 calls_remaining,
358 window_secs: limit.window_secs,
359 window_remaining_secs,
360 is_exhausted,
361 })
362 }
363
364 pub fn all_full_info(&self) -> Vec<FullBudgetInfo> {
366 let budgets = self.budgets.read();
367 budgets.keys().filter_map(|id| self.full_info(id)).collect()
368 }
369
370 pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
372 let budgets = self.budgets.read();
373 let usage = self.usage.read();
374 let data = PersistedBudgets {
375 budgets: budgets.clone(),
376 usage: usage.clone(),
377 };
378 if let Some(parent) = path.parent() {
379 std::fs::create_dir_all(parent)?;
380 }
381 let json = serde_json::to_string_pretty(&data)?;
382 std::fs::write(path, json)?;
383 Ok(())
384 }
385
386 pub fn restore(&self, path: &Path) -> anyhow::Result<()> {
391 if !path.exists() {
392 return Ok(());
393 }
394 let json = std::fs::read_to_string(path)?;
395 let data: PersistedBudgets = serde_json::from_str(&json)?;
396 {
397 let mut budgets = self.budgets.write();
398 *budgets = data.budgets;
399 }
400 {
401 let mut usage = self.usage.write();
402 *usage = data.usage;
403 }
404 Ok(())
405 }
406}
407
408#[derive(Serialize, Deserialize)]
410struct PersistedBudgets {
411 budgets: HashMap<AgentId, BudgetLimit>,
412 usage: HashMap<AgentId, Usage>,
413}
414
415fn reset_if_expired(usage: &mut Usage, window_secs: u64) {
417 let window_duration = chrono::Duration::seconds(window_secs as i64);
418 let elapsed = Utc::now().signed_duration_since(usage.window_start);
419 if elapsed >= window_duration {
420 usage.tokens_used = 0;
421 usage.calls_used = 0;
422 usage.window_start = Utc::now();
423 }
424}
425
426impl Default for BudgetManager {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use std::thread;
436
437 fn test_agent_id() -> AgentId {
438 uuid::Uuid::new_v4()
439 }
440
441 #[test]
442 fn test_budget_creation() {
443 let manager = BudgetManager::new();
444 let agent_id = test_agent_id();
445
446 let limit = BudgetLimit {
447 agent_id,
448 token_budget: 1000,
449 calls_budget: 10,
450 window_secs: 60,
451 };
452
453 manager.set_budget(limit.clone());
454
455 let info = manager.remaining(&agent_id);
456 assert_eq!(info.tokens_remaining, 1000);
457 assert_eq!(info.calls_remaining, 10);
458 assert!(!info.is_exhausted);
459 }
460
461 #[test]
462 fn test_reserve_success() {
463 let manager = BudgetManager::new();
464 let agent_id = test_agent_id();
465
466 let limit = BudgetLimit {
467 agent_id,
468 token_budget: 1000,
469 calls_budget: 10,
470 window_secs: 60,
471 };
472
473 manager.set_budget(limit);
474
475 let result = manager.reserve(&agent_id, 500);
477 assert!(result.is_ok());
478
479 let info = manager.remaining(&agent_id);
480 assert_eq!(info.tokens_remaining, 500);
481 }
482
483 #[test]
484 fn test_exhaust_tokens() {
485 let manager = BudgetManager::new();
486 let agent_id = test_agent_id();
487
488 let limit = BudgetLimit {
489 agent_id,
490 token_budget: 1000,
491 calls_budget: 10,
492 window_secs: 60,
493 };
494
495 manager.set_budget(limit);
496
497 let result = manager.reserve(&agent_id, 1000);
499 assert!(result.is_ok());
500
501 let result = manager.reserve(&agent_id, 1);
503 assert!(result.is_err());
504
505 let err = result.unwrap_err();
506 assert_eq!(err.agent_id, agent_id);
507 assert_eq!(err.kind, BudgetKind::Token);
508 }
509
510 #[test]
511 fn test_exhaust_calls() {
512 let manager = BudgetManager::new();
513 let agent_id = test_agent_id();
514
515 let limit = BudgetLimit {
516 agent_id,
517 token_budget: 1000,
518 calls_budget: 3,
519 window_secs: 60,
520 };
521
522 manager.set_budget(limit);
523
524 assert!(manager.track_call(&agent_id).is_ok());
526 assert!(manager.track_call(&agent_id).is_ok());
527 assert!(manager.track_call(&agent_id).is_ok());
528
529 let result = manager.track_call(&agent_id);
531 assert!(result.is_err());
532
533 let err = result.unwrap_err();
534 assert_eq!(err.agent_id, agent_id);
535 assert_eq!(err.kind, BudgetKind::Call);
536 }
537
538 #[test]
539 fn test_window_reset() {
540 let manager = BudgetManager::new();
541 let agent_id = test_agent_id();
542
543 let limit = BudgetLimit {
545 agent_id,
546 token_budget: 100,
547 calls_budget: 5,
548 window_secs: 1,
549 };
550
551 manager.set_budget(limit);
552
553 manager.reserve(&agent_id, 100).unwrap();
555 assert!(manager.reserve(&agent_id, 1).is_err());
556
557 thread::sleep(Duration::from_millis(1_100));
559
560 let result = manager.reserve(&agent_id, 50);
562 assert!(result.is_ok());
563
564 let info = manager.remaining(&agent_id);
565 assert_eq!(info.tokens_remaining, 50);
566 }
567
568 #[test]
569 fn test_can_schedule() {
570 let manager = BudgetManager::new();
571 let agent_id = test_agent_id();
572
573 let limit = BudgetLimit {
574 agent_id,
575 token_budget: 1000,
576 calls_budget: 10,
577 window_secs: 60,
578 };
579
580 manager.set_budget(limit);
581
582 assert!(manager.can_schedule(&agent_id));
584
585 for _ in 0..10 {
587 manager.track_call(&agent_id).unwrap();
588 }
589
590 assert!(!manager.can_schedule(&agent_id));
592 }
593
594 #[test]
595 fn test_no_budget_configured() {
596 let manager = BudgetManager::new();
597 let agent_id = test_agent_id();
598
599 let result = manager.reserve(&agent_id, 100);
601 assert!(result.is_err());
602
603 let err = result.unwrap_err();
604 assert!(err.message.contains("No budget configured"));
605
606 let result = manager.track_call(&agent_id);
608 assert!(result.is_err());
609 }
610
611 #[test]
612 fn test_remove_budget() {
613 let manager = BudgetManager::new();
614 let agent_id = test_agent_id();
615
616 let limit = BudgetLimit {
617 agent_id,
618 token_budget: 1000,
619 calls_budget: 10,
620 window_secs: 60,
621 };
622
623 manager.set_budget(limit);
624 manager.reserve(&agent_id, 100).unwrap();
625
626 manager.remove_budget(&agent_id);
627
628 let result = manager.reserve(&agent_id, 100);
630 assert!(result.is_err());
631 }
632
633 #[test]
634 fn test_release_tokens() {
635 let manager = BudgetManager::new();
636 let agent_id = test_agent_id();
637
638 let limit = BudgetLimit {
639 agent_id,
640 token_budget: 1000,
641 calls_budget: 10,
642 window_secs: 60,
643 };
644
645 manager.set_budget(limit);
646 manager.reserve(&agent_id, 500).unwrap();
647
648 let info_before = manager.remaining(&agent_id);
649 assert_eq!(info_before.tokens_remaining, 500);
650
651 manager.release(&agent_id, 200);
653
654 let info_after = manager.remaining(&agent_id);
655 assert_eq!(info_after.tokens_remaining, 700);
656 }
657
658 #[test]
659 fn test_reset_window() {
660 let manager = BudgetManager::new();
661 let agent_id = test_agent_id();
662
663 let limit = BudgetLimit {
664 agent_id,
665 token_budget: 1000,
666 calls_budget: 10,
667 window_secs: 60,
668 };
669
670 manager.set_budget(limit);
671 manager.reserve(&agent_id, 500).unwrap();
672
673 let info_before = manager.remaining(&agent_id);
674 assert_eq!(info_before.tokens_remaining, 500);
675
676 manager.reset_window(&agent_id);
678
679 let info_after = manager.remaining(&agent_id);
680 assert_eq!(info_after.tokens_remaining, 1000);
681 assert_eq!(info_after.calls_remaining, 10);
682 }
683
684 #[test]
685 fn test_multiple_agents() {
686 let manager = BudgetManager::new();
687 let agent1 = test_agent_id();
688 let agent2 = test_agent_id();
689
690 manager.set_budget(BudgetLimit {
691 agent_id: agent1,
692 token_budget: 1000,
693 calls_budget: 10,
694 window_secs: 60,
695 });
696
697 manager.set_budget(BudgetLimit {
698 agent_id: agent2,
699 token_budget: 500,
700 calls_budget: 5,
701 window_secs: 60,
702 });
703
704 manager.reserve(&agent1, 300).unwrap();
706
707 manager.reserve(&agent2, 200).unwrap();
709
710 let info1 = manager.remaining(&agent1);
711 let info2 = manager.remaining(&agent2);
712
713 assert_eq!(info1.tokens_remaining, 700);
714 assert_eq!(info2.tokens_remaining, 300);
715 }
716}