1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use parking_lot::RwLock;
14
15#[derive(Debug)]
20pub struct AgentTokenBudget {
21 agent_id: String,
23
24 total_tokens: u64,
26
27 used_tokens: AtomicU64,
29
30 operation_costs: HashMap<String, u64>,
32
33 period: Duration,
35
36 last_reset: RwLock<Instant>,
38
39 warning_threshold: f64,
41
42 hard_limit: bool,
44}
45
46impl AgentTokenBudget {
47 pub fn daily(agent_id: impl Into<String>, tokens: u64) -> Self {
49 Self::new(agent_id, tokens, Duration::from_secs(86400))
50 }
51
52 pub fn hourly(agent_id: impl Into<String>, tokens: u64) -> Self {
54 Self::new(agent_id, tokens, Duration::from_secs(3600))
55 }
56
57 pub fn new(agent_id: impl Into<String>, tokens: u64, period: Duration) -> Self {
59 Self {
60 agent_id: agent_id.into(),
61 total_tokens: tokens,
62 used_tokens: AtomicU64::new(0),
63 operation_costs: Self::default_operation_costs(),
64 period,
65 last_reset: RwLock::new(Instant::now()),
66 warning_threshold: 0.8,
67 hard_limit: true,
68 }
69 }
70
71 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
73 self.warning_threshold = threshold.clamp(0.0, 1.0);
74 self
75 }
76
77 pub fn with_hard_limit(mut self, hard: bool) -> Self {
79 self.hard_limit = hard;
80 self
81 }
82
83 pub fn with_operation_costs(mut self, costs: HashMap<String, u64>) -> Self {
85 self.operation_costs = costs;
86 self
87 }
88
89 pub fn add_operation_cost(&mut self, operation: impl Into<String>, cost: u64) {
91 self.operation_costs.insert(operation.into(), cost);
92 }
93
94 pub fn consume(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
96 self.maybe_reset();
97
98 let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
99 let total_cost = cost.saturating_mul(estimated_tokens);
100
101 let used = self.used_tokens.fetch_add(total_cost, Ordering::SeqCst);
102
103 if self.hard_limit && used + total_cost > self.total_tokens {
104 self.used_tokens.fetch_sub(total_cost, Ordering::SeqCst);
106
107 return Err(BudgetExceeded {
108 agent_id: self.agent_id.clone(),
109 requested: total_cost,
110 remaining: self.total_tokens.saturating_sub(used),
111 total: self.total_tokens,
112 resets_in: self.time_until_reset(),
113 });
114 }
115
116 Ok(())
117 }
118
119 pub fn check(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
121 self.maybe_reset();
122
123 let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
124 let total_cost = cost.saturating_mul(estimated_tokens);
125 let used = self.used_tokens.load(Ordering::SeqCst);
126
127 if used + total_cost > self.total_tokens {
128 return Err(BudgetExceeded {
129 agent_id: self.agent_id.clone(),
130 requested: total_cost,
131 remaining: self.total_tokens.saturating_sub(used),
132 total: self.total_tokens,
133 resets_in: self.time_until_reset(),
134 });
135 }
136
137 Ok(())
138 }
139
140 pub fn remaining(&self) -> u64 {
142 self.maybe_reset();
143 let used = self.used_tokens.load(Ordering::SeqCst);
144 self.total_tokens.saturating_sub(used)
145 }
146
147 pub fn used(&self) -> u64 {
149 self.maybe_reset();
150 self.used_tokens.load(Ordering::SeqCst)
151 }
152
153 pub fn usage_percentage(&self) -> f64 {
155 self.maybe_reset();
156 let used = self.used_tokens.load(Ordering::SeqCst);
157 used as f64 / self.total_tokens as f64
158 }
159
160 pub fn is_warning(&self) -> bool {
162 self.usage_percentage() >= self.warning_threshold
163 }
164
165 pub fn time_until_reset(&self) -> Duration {
167 let last = *self.last_reset.read();
168 let elapsed = last.elapsed();
169
170 if elapsed >= self.period {
171 Duration::ZERO
172 } else {
173 self.period - elapsed
174 }
175 }
176
177 pub fn reset(&self) {
179 self.used_tokens.store(0, Ordering::SeqCst);
180 *self.last_reset.write() = Instant::now();
181 }
182
183 fn maybe_reset(&self) {
185 let last = *self.last_reset.read();
186 if last.elapsed() >= self.period {
187 self.reset();
188 }
189 }
190
191 fn default_operation_costs() -> HashMap<String, u64> {
192 let mut costs = HashMap::new();
193 costs.insert("query".to_string(), 1);
194 costs.insert("embedding".to_string(), 5);
195 costs.insert("vector_search".to_string(), 10);
196 costs.insert("write".to_string(), 2);
197 costs.insert("transaction".to_string(), 3);
198 costs
199 }
200}
201
202impl Clone for AgentTokenBudget {
203 fn clone(&self) -> Self {
204 Self {
205 agent_id: self.agent_id.clone(),
206 total_tokens: self.total_tokens,
207 used_tokens: AtomicU64::new(self.used_tokens.load(Ordering::Relaxed)),
208 operation_costs: self.operation_costs.clone(),
209 period: self.period,
210 last_reset: RwLock::new(*self.last_reset.read()),
211 warning_threshold: self.warning_threshold,
212 hard_limit: self.hard_limit,
213 }
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct BudgetExceeded {
220 pub agent_id: String,
222
223 pub requested: u64,
225
226 pub remaining: u64,
228
229 pub total: u64,
231
232 pub resets_in: Duration,
234}
235
236impl std::fmt::Display for BudgetExceeded {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(
239 f,
240 "Token budget exceeded for agent '{}': requested {} tokens, {} remaining of {} total, resets in {}s",
241 self.agent_id,
242 self.requested,
243 self.remaining,
244 self.total,
245 self.resets_in.as_secs()
246 )
247 }
248}
249
250impl std::error::Error for BudgetExceeded {}
251
252impl BudgetExceeded {
253 pub fn to_llm_message(&self) -> String {
255 format!(
256 "{{\"error\": \"budget_exceeded\", \"message\": \"Token budget exceeded\", \
257 \"details\": {{\"agent_id\": \"{}\", \"requested\": {}, \"remaining\": {}, \
258 \"total\": {}, \"resets_in_seconds\": {}}}, \
259 \"suggestion\": \"Wait for budget reset or request a higher allocation\"}}",
260 self.agent_id,
261 self.requested,
262 self.remaining,
263 self.total,
264 self.resets_in.as_secs()
265 )
266 }
267}
268
269#[derive(Debug)]
273pub struct WorkflowQuota {
274 max_workflows: u32,
276
277 max_steps: u32,
279
280 workflow_count: AtomicU32,
282
283 period: Duration,
285
286 last_reset: RwLock<Instant>,
288
289 active_workflows: DashMap<String, WorkflowToken>,
291}
292
293impl WorkflowQuota {
294 pub fn hourly(max_workflows: u32, max_steps: u32) -> Self {
296 Self::new(max_workflows, max_steps, Duration::from_secs(3600))
297 }
298
299 pub fn new(max_workflows: u32, max_steps: u32, period: Duration) -> Self {
301 Self {
302 max_workflows,
303 max_steps,
304 workflow_count: AtomicU32::new(0),
305 period,
306 last_reset: RwLock::new(Instant::now()),
307 active_workflows: DashMap::new(),
308 }
309 }
310
311 pub fn begin_workflow(
313 &self,
314 workflow_id: impl Into<String>,
315 ) -> Result<WorkflowToken, QuotaExceeded> {
316 self.maybe_reset();
317
318 let count = self.workflow_count.fetch_add(1, Ordering::SeqCst);
319 if count >= self.max_workflows {
320 self.workflow_count.fetch_sub(1, Ordering::SeqCst);
321 return Err(QuotaExceeded::HourlyLimit {
322 current: count,
323 limit: self.max_workflows,
324 resets_in: self.time_until_reset(),
325 });
326 }
327
328 let id = workflow_id.into();
329 let token = WorkflowToken::new(id.clone(), self.max_steps);
330 self.active_workflows.insert(id, token.clone());
331
332 Ok(token)
333 }
334
335 pub fn end_workflow(&self, workflow_id: &str) {
337 self.active_workflows.remove(workflow_id);
338 }
339
340 pub fn active_count(&self) -> usize {
342 self.active_workflows.len()
343 }
344
345 pub fn period_count(&self) -> u32 {
347 self.maybe_reset();
348 self.workflow_count.load(Ordering::SeqCst)
349 }
350
351 pub fn remaining(&self) -> u32 {
353 self.maybe_reset();
354 let count = self.workflow_count.load(Ordering::SeqCst);
355 self.max_workflows.saturating_sub(count)
356 }
357
358 pub fn time_until_reset(&self) -> Duration {
360 let last = *self.last_reset.read();
361 let elapsed = last.elapsed();
362
363 if elapsed >= self.period {
364 Duration::ZERO
365 } else {
366 self.period - elapsed
367 }
368 }
369
370 pub fn reset(&self) {
372 self.workflow_count.store(0, Ordering::SeqCst);
373 *self.last_reset.write() = Instant::now();
374 }
375
376 fn maybe_reset(&self) {
377 let last = *self.last_reset.read();
378 if last.elapsed() >= self.period {
379 self.reset();
380 }
381 }
382}
383
384impl Clone for WorkflowQuota {
385 fn clone(&self) -> Self {
386 Self {
387 max_workflows: self.max_workflows,
388 max_steps: self.max_steps,
389 workflow_count: AtomicU32::new(self.workflow_count.load(Ordering::Relaxed)),
390 period: self.period,
391 last_reset: RwLock::new(*self.last_reset.read()),
392 active_workflows: DashMap::new(),
393 }
394 }
395}
396
397#[derive(Debug)]
401pub struct WorkflowToken {
402 pub id: String,
404
405 remaining_steps: AtomicU32,
407
408 max_steps: u32,
410
411 steps_executed: AtomicU32,
413
414 created_at: Instant,
416}
417
418impl Clone for WorkflowToken {
419 fn clone(&self) -> Self {
420 Self {
421 id: self.id.clone(),
422 remaining_steps: AtomicU32::new(self.remaining_steps.load(Ordering::Relaxed)),
423 max_steps: self.max_steps,
424 steps_executed: AtomicU32::new(self.steps_executed.load(Ordering::Relaxed)),
425 created_at: self.created_at,
426 }
427 }
428}
429
430impl WorkflowToken {
431 fn new(id: String, max_steps: u32) -> Self {
432 Self {
433 id,
434 remaining_steps: AtomicU32::new(max_steps),
435 max_steps,
436 steps_executed: AtomicU32::new(0),
437 created_at: Instant::now(),
438 }
439 }
440
441 pub fn execute_step(&self) -> Result<(), QuotaExceeded> {
443 let remaining = self.remaining_steps.fetch_sub(1, Ordering::SeqCst);
444
445 if remaining == 0 {
446 self.remaining_steps.fetch_add(1, Ordering::SeqCst); return Err(QuotaExceeded::StepLimit {
448 workflow_id: self.id.clone(),
449 steps_executed: self.steps_executed.load(Ordering::SeqCst),
450 max_steps: self.max_steps,
451 });
452 }
453
454 self.steps_executed.fetch_add(1, Ordering::SeqCst);
455 Ok(())
456 }
457
458 pub fn remaining_steps(&self) -> u32 {
460 self.remaining_steps.load(Ordering::SeqCst)
461 }
462
463 pub fn steps_executed(&self) -> u32 {
465 self.steps_executed.load(Ordering::SeqCst)
466 }
467
468 pub fn duration(&self) -> Duration {
470 self.created_at.elapsed()
471 }
472
473 pub fn can_continue(&self) -> bool {
475 self.remaining_steps.load(Ordering::SeqCst) > 0
476 }
477}
478
479#[derive(Debug, Clone)]
481pub enum QuotaExceeded {
482 HourlyLimit {
484 current: u32,
485 limit: u32,
486 resets_in: Duration,
487 },
488
489 StepLimit {
491 workflow_id: String,
492 steps_executed: u32,
493 max_steps: u32,
494 },
495}
496
497impl std::fmt::Display for QuotaExceeded {
498 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499 match self {
500 QuotaExceeded::HourlyLimit {
501 current,
502 limit,
503 resets_in,
504 } => {
505 write!(
506 f,
507 "Hourly workflow limit exceeded: {}/{} workflows, resets in {}s",
508 current,
509 limit,
510 resets_in.as_secs()
511 )
512 }
513 QuotaExceeded::StepLimit {
514 workflow_id,
515 steps_executed,
516 max_steps,
517 } => {
518 write!(
519 f,
520 "Workflow '{}' step limit exceeded: {}/{} steps",
521 workflow_id, steps_executed, max_steps
522 )
523 }
524 }
525 }
526}
527
528impl std::error::Error for QuotaExceeded {}
529
530impl QuotaExceeded {
531 pub fn to_llm_message(&self) -> String {
533 match self {
534 QuotaExceeded::HourlyLimit {
535 current,
536 limit,
537 resets_in,
538 } => {
539 format!(
540 "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"hourly_limit\", \
541 \"current\": {}, \"limit\": {}, \"resets_in_seconds\": {}, \
542 \"suggestion\": \"Wait for quota reset or optimize workflow count\"}}",
543 current,
544 limit,
545 resets_in.as_secs()
546 )
547 }
548 QuotaExceeded::StepLimit {
549 workflow_id,
550 steps_executed,
551 max_steps,
552 } => {
553 format!(
554 "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"step_limit\", \
555 \"workflow_id\": \"{}\", \"steps_executed\": {}, \"max_steps\": {}, \
556 \"suggestion\": \"Complete current workflow before starting more steps\"}}",
557 workflow_id, steps_executed, max_steps
558 )
559 }
560 }
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_token_budget_creation() {
570 let budget = AgentTokenBudget::daily("agent-1", 10000);
571 assert_eq!(budget.remaining(), 10000);
572 assert_eq!(budget.used(), 0);
573 }
574
575 #[test]
576 fn test_token_budget_consume() {
577 let budget = AgentTokenBudget::daily("agent-1", 100);
578
579 assert!(budget.consume("query", 10).is_ok());
580 assert_eq!(budget.used(), 10);
581 assert_eq!(budget.remaining(), 90);
582 }
583
584 #[test]
585 fn test_token_budget_exceeded() {
586 let budget = AgentTokenBudget::daily("agent-1", 10);
587
588 assert!(budget.consume("query", 5).is_ok());
589 assert!(budget.consume("query", 5).is_ok());
590
591 let result = budget.consume("query", 1);
592 assert!(result.is_err());
593
594 let err = result.unwrap_err();
595 assert_eq!(err.agent_id, "agent-1");
596 assert_eq!(err.remaining, 0);
597 }
598
599 #[test]
600 fn test_token_budget_operation_costs() {
601 let budget = AgentTokenBudget::daily("agent-1", 1000);
602
603 assert!(budget.consume("embedding", 10).is_ok());
605 assert_eq!(budget.used(), 50); }
607
608 #[test]
609 fn test_token_budget_warning() {
610 let budget = AgentTokenBudget::daily("agent-1", 100).with_warning_threshold(0.8);
611
612 assert!(!budget.is_warning());
613
614 assert!(budget.consume("query", 85).is_ok());
615 assert!(budget.is_warning());
616 }
617
618 #[test]
619 fn test_token_budget_reset() {
620 let budget = AgentTokenBudget::new("agent-1", 100, Duration::from_millis(50));
621
622 assert!(budget.consume("query", 100).is_ok());
623 assert_eq!(budget.remaining(), 0);
624
625 std::thread::sleep(Duration::from_millis(60));
626
627 assert_eq!(budget.remaining(), 100);
629 }
630
631 #[test]
632 fn test_budget_exceeded_llm_message() {
633 let err = BudgetExceeded {
634 agent_id: "agent-1".to_string(),
635 requested: 100,
636 remaining: 50,
637 total: 1000,
638 resets_in: Duration::from_secs(3600),
639 };
640
641 let msg = err.to_llm_message();
642 assert!(msg.contains("budget_exceeded"));
643 assert!(msg.contains("agent-1"));
644 }
645
646 #[test]
647 fn test_workflow_quota_creation() {
648 let quota = WorkflowQuota::hourly(10, 100);
649 assert_eq!(quota.remaining(), 10);
650 }
651
652 #[test]
653 fn test_workflow_quota_begin() {
654 let quota = WorkflowQuota::hourly(10, 100);
655
656 let token = quota.begin_workflow("wf-1").unwrap();
657 assert_eq!(token.remaining_steps(), 100);
658 assert_eq!(quota.remaining(), 9);
659 }
660
661 #[test]
662 fn test_workflow_quota_exceeded() {
663 let quota = WorkflowQuota::hourly(2, 100);
664
665 assert!(quota.begin_workflow("wf-1").is_ok());
666 assert!(quota.begin_workflow("wf-2").is_ok());
667
668 let result = quota.begin_workflow("wf-3");
669 assert!(result.is_err());
670 }
671
672 #[test]
673 fn test_workflow_token_steps() {
674 let quota = WorkflowQuota::hourly(10, 5);
675 let token = quota.begin_workflow("wf-1").unwrap();
676
677 for _ in 0..5 {
678 assert!(token.execute_step().is_ok());
679 }
680
681 let result = token.execute_step();
682 assert!(result.is_err());
683 }
684
685 #[test]
686 fn test_workflow_token_can_continue() {
687 let quota = WorkflowQuota::hourly(10, 2);
688 let token = quota.begin_workflow("wf-1").unwrap();
689
690 assert!(token.can_continue());
691
692 assert!(token.execute_step().is_ok());
693 assert!(token.can_continue());
694
695 assert!(token.execute_step().is_ok());
696 assert!(!token.can_continue());
697 }
698
699 #[test]
700 fn test_quota_exceeded_llm_message() {
701 let err = QuotaExceeded::HourlyLimit {
702 current: 10,
703 limit: 10,
704 resets_in: Duration::from_secs(1800),
705 };
706
707 let msg = err.to_llm_message();
708 assert!(msg.contains("workflow_quota_exceeded"));
709 assert!(msg.contains("hourly_limit"));
710
711 let err2 = QuotaExceeded::StepLimit {
712 workflow_id: "wf-1".to_string(),
713 steps_executed: 100,
714 max_steps: 100,
715 };
716
717 let msg2 = err2.to_llm_message();
718 assert!(msg2.contains("step_limit"));
719 }
720
721 #[test]
722 fn test_workflow_end() {
723 let quota = WorkflowQuota::hourly(10, 100);
724
725 let _token = quota.begin_workflow("wf-1").unwrap();
726 assert_eq!(quota.active_count(), 1);
727
728 quota.end_workflow("wf-1");
729 assert_eq!(quota.active_count(), 0);
730 }
731}