1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use parking_lot::RwLock;
14
15#[derive(Debug)]
20pub struct AgentTokenBudget {
21 agent_id: String,
23
24 total_tokens: u64,
26
27 used_tokens: AtomicU64,
29
30 operation_costs: HashMap<String, u64>,
32
33 period: Duration,
35
36 last_reset: RwLock<Instant>,
38
39 warning_threshold: f64,
41
42 hard_limit: bool,
44}
45
46impl AgentTokenBudget {
47 pub fn daily(agent_id: impl Into<String>, tokens: u64) -> Self {
49 Self::new(agent_id, tokens, Duration::from_secs(86400))
50 }
51
52 pub fn hourly(agent_id: impl Into<String>, tokens: u64) -> Self {
54 Self::new(agent_id, tokens, Duration::from_secs(3600))
55 }
56
57 pub fn new(agent_id: impl Into<String>, tokens: u64, period: Duration) -> Self {
59 Self {
60 agent_id: agent_id.into(),
61 total_tokens: tokens,
62 used_tokens: AtomicU64::new(0),
63 operation_costs: Self::default_operation_costs(),
64 period,
65 last_reset: RwLock::new(Instant::now()),
66 warning_threshold: 0.8,
67 hard_limit: true,
68 }
69 }
70
71 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
73 self.warning_threshold = threshold.clamp(0.0, 1.0);
74 self
75 }
76
77 pub fn with_hard_limit(mut self, hard: bool) -> Self {
79 self.hard_limit = hard;
80 self
81 }
82
83 pub fn with_operation_costs(mut self, costs: HashMap<String, u64>) -> Self {
85 self.operation_costs = costs;
86 self
87 }
88
89 pub fn add_operation_cost(&mut self, operation: impl Into<String>, cost: u64) {
91 self.operation_costs.insert(operation.into(), cost);
92 }
93
94 pub fn consume(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
96 self.maybe_reset();
97
98 let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
99 let total_cost = cost.saturating_mul(estimated_tokens);
100
101 let used = self.used_tokens.fetch_add(total_cost, Ordering::SeqCst);
102
103 if self.hard_limit && used + total_cost > self.total_tokens {
104 self.used_tokens.fetch_sub(total_cost, Ordering::SeqCst);
106
107 return Err(BudgetExceeded {
108 agent_id: self.agent_id.clone(),
109 requested: total_cost,
110 remaining: self.total_tokens.saturating_sub(used),
111 total: self.total_tokens,
112 resets_in: self.time_until_reset(),
113 });
114 }
115
116 Ok(())
117 }
118
119 pub fn check(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
121 self.maybe_reset();
122
123 let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
124 let total_cost = cost.saturating_mul(estimated_tokens);
125 let used = self.used_tokens.load(Ordering::SeqCst);
126
127 if used + total_cost > self.total_tokens {
128 return Err(BudgetExceeded {
129 agent_id: self.agent_id.clone(),
130 requested: total_cost,
131 remaining: self.total_tokens.saturating_sub(used),
132 total: self.total_tokens,
133 resets_in: self.time_until_reset(),
134 });
135 }
136
137 Ok(())
138 }
139
140 pub fn remaining(&self) -> u64 {
142 self.maybe_reset();
143 let used = self.used_tokens.load(Ordering::SeqCst);
144 self.total_tokens.saturating_sub(used)
145 }
146
147 pub fn used(&self) -> u64 {
149 self.maybe_reset();
150 self.used_tokens.load(Ordering::SeqCst)
151 }
152
153 pub fn usage_percentage(&self) -> f64 {
155 self.maybe_reset();
156 let used = self.used_tokens.load(Ordering::SeqCst);
157 used as f64 / self.total_tokens as f64
158 }
159
160 pub fn is_warning(&self) -> bool {
162 self.usage_percentage() >= self.warning_threshold
163 }
164
165 pub fn time_until_reset(&self) -> Duration {
167 let last = *self.last_reset.read();
168 let elapsed = last.elapsed();
169
170 if elapsed >= self.period {
171 Duration::ZERO
172 } else {
173 self.period - elapsed
174 }
175 }
176
177 pub fn reset(&self) {
179 self.used_tokens.store(0, Ordering::SeqCst);
180 *self.last_reset.write() = Instant::now();
181 }
182
183 fn maybe_reset(&self) {
185 let last = *self.last_reset.read();
186 if last.elapsed() >= self.period {
187 self.reset();
188 }
189 }
190
191 fn default_operation_costs() -> HashMap<String, u64> {
192 let mut costs = HashMap::new();
193 costs.insert("query".to_string(), 1);
194 costs.insert("embedding".to_string(), 5);
195 costs.insert("vector_search".to_string(), 10);
196 costs.insert("write".to_string(), 2);
197 costs.insert("transaction".to_string(), 3);
198 costs
199 }
200}
201
202impl Clone for AgentTokenBudget {
203 fn clone(&self) -> Self {
204 Self {
205 agent_id: self.agent_id.clone(),
206 total_tokens: self.total_tokens,
207 used_tokens: AtomicU64::new(self.used_tokens.load(Ordering::Relaxed)),
208 operation_costs: self.operation_costs.clone(),
209 period: self.period,
210 last_reset: RwLock::new(*self.last_reset.read()),
211 warning_threshold: self.warning_threshold,
212 hard_limit: self.hard_limit,
213 }
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct BudgetExceeded {
220 pub agent_id: String,
222
223 pub requested: u64,
225
226 pub remaining: u64,
228
229 pub total: u64,
231
232 pub resets_in: Duration,
234}
235
236impl std::fmt::Display for BudgetExceeded {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(
239 f,
240 "Token budget exceeded for agent '{}': requested {} tokens, {} remaining of {} total, resets in {}s",
241 self.agent_id,
242 self.requested,
243 self.remaining,
244 self.total,
245 self.resets_in.as_secs()
246 )
247 }
248}
249
250impl std::error::Error for BudgetExceeded {}
251
252impl BudgetExceeded {
253 pub fn to_llm_message(&self) -> String {
255 format!(
256 "{{\"error\": \"budget_exceeded\", \"message\": \"Token budget exceeded\", \
257 \"details\": {{\"agent_id\": \"{}\", \"requested\": {}, \"remaining\": {}, \
258 \"total\": {}, \"resets_in_seconds\": {}}}, \
259 \"suggestion\": \"Wait for budget reset or request a higher allocation\"}}",
260 self.agent_id,
261 self.requested,
262 self.remaining,
263 self.total,
264 self.resets_in.as_secs()
265 )
266 }
267}
268
269#[derive(Debug)]
273pub struct WorkflowQuota {
274 max_workflows: u32,
276
277 max_steps: u32,
279
280 workflow_count: AtomicU32,
282
283 period: Duration,
285
286 last_reset: RwLock<Instant>,
288
289 active_workflows: DashMap<String, WorkflowToken>,
291}
292
293impl WorkflowQuota {
294 pub fn hourly(max_workflows: u32, max_steps: u32) -> Self {
296 Self::new(max_workflows, max_steps, Duration::from_secs(3600))
297 }
298
299 pub fn new(max_workflows: u32, max_steps: u32, period: Duration) -> Self {
301 Self {
302 max_workflows,
303 max_steps,
304 workflow_count: AtomicU32::new(0),
305 period,
306 last_reset: RwLock::new(Instant::now()),
307 active_workflows: DashMap::new(),
308 }
309 }
310
311 pub fn begin_workflow(&self, workflow_id: impl Into<String>) -> Result<WorkflowToken, QuotaExceeded> {
313 self.maybe_reset();
314
315 let count = self.workflow_count.fetch_add(1, Ordering::SeqCst);
316 if count >= self.max_workflows {
317 self.workflow_count.fetch_sub(1, Ordering::SeqCst);
318 return Err(QuotaExceeded::HourlyLimit {
319 current: count,
320 limit: self.max_workflows,
321 resets_in: self.time_until_reset(),
322 });
323 }
324
325 let id = workflow_id.into();
326 let token = WorkflowToken::new(id.clone(), self.max_steps);
327 self.active_workflows.insert(id, token.clone());
328
329 Ok(token)
330 }
331
332 pub fn end_workflow(&self, workflow_id: &str) {
334 self.active_workflows.remove(workflow_id);
335 }
336
337 pub fn active_count(&self) -> usize {
339 self.active_workflows.len()
340 }
341
342 pub fn period_count(&self) -> u32 {
344 self.maybe_reset();
345 self.workflow_count.load(Ordering::SeqCst)
346 }
347
348 pub fn remaining(&self) -> u32 {
350 self.maybe_reset();
351 let count = self.workflow_count.load(Ordering::SeqCst);
352 self.max_workflows.saturating_sub(count)
353 }
354
355 pub fn time_until_reset(&self) -> Duration {
357 let last = *self.last_reset.read();
358 let elapsed = last.elapsed();
359
360 if elapsed >= self.period {
361 Duration::ZERO
362 } else {
363 self.period - elapsed
364 }
365 }
366
367 pub fn reset(&self) {
369 self.workflow_count.store(0, Ordering::SeqCst);
370 *self.last_reset.write() = Instant::now();
371 }
372
373 fn maybe_reset(&self) {
374 let last = *self.last_reset.read();
375 if last.elapsed() >= self.period {
376 self.reset();
377 }
378 }
379}
380
381impl Clone for WorkflowQuota {
382 fn clone(&self) -> Self {
383 Self {
384 max_workflows: self.max_workflows,
385 max_steps: self.max_steps,
386 workflow_count: AtomicU32::new(self.workflow_count.load(Ordering::Relaxed)),
387 period: self.period,
388 last_reset: RwLock::new(*self.last_reset.read()),
389 active_workflows: DashMap::new(),
390 }
391 }
392}
393
394#[derive(Debug)]
398pub struct WorkflowToken {
399 pub id: String,
401
402 remaining_steps: AtomicU32,
404
405 max_steps: u32,
407
408 steps_executed: AtomicU32,
410
411 created_at: Instant,
413}
414
415impl Clone for WorkflowToken {
416 fn clone(&self) -> Self {
417 Self {
418 id: self.id.clone(),
419 remaining_steps: AtomicU32::new(self.remaining_steps.load(Ordering::Relaxed)),
420 max_steps: self.max_steps,
421 steps_executed: AtomicU32::new(self.steps_executed.load(Ordering::Relaxed)),
422 created_at: self.created_at,
423 }
424 }
425}
426
427impl WorkflowToken {
428 fn new(id: String, max_steps: u32) -> Self {
429 Self {
430 id,
431 remaining_steps: AtomicU32::new(max_steps),
432 max_steps,
433 steps_executed: AtomicU32::new(0),
434 created_at: Instant::now(),
435 }
436 }
437
438 pub fn execute_step(&self) -> Result<(), QuotaExceeded> {
440 let remaining = self.remaining_steps.fetch_sub(1, Ordering::SeqCst);
441
442 if remaining == 0 {
443 self.remaining_steps.fetch_add(1, Ordering::SeqCst); return Err(QuotaExceeded::StepLimit {
445 workflow_id: self.id.clone(),
446 steps_executed: self.steps_executed.load(Ordering::SeqCst),
447 max_steps: self.max_steps,
448 });
449 }
450
451 self.steps_executed.fetch_add(1, Ordering::SeqCst);
452 Ok(())
453 }
454
455 pub fn remaining_steps(&self) -> u32 {
457 self.remaining_steps.load(Ordering::SeqCst)
458 }
459
460 pub fn steps_executed(&self) -> u32 {
462 self.steps_executed.load(Ordering::SeqCst)
463 }
464
465 pub fn duration(&self) -> Duration {
467 self.created_at.elapsed()
468 }
469
470 pub fn can_continue(&self) -> bool {
472 self.remaining_steps.load(Ordering::SeqCst) > 0
473 }
474}
475
476#[derive(Debug, Clone)]
478pub enum QuotaExceeded {
479 HourlyLimit {
481 current: u32,
482 limit: u32,
483 resets_in: Duration,
484 },
485
486 StepLimit {
488 workflow_id: String,
489 steps_executed: u32,
490 max_steps: u32,
491 },
492}
493
494impl std::fmt::Display for QuotaExceeded {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 match self {
497 QuotaExceeded::HourlyLimit { current, limit, resets_in } => {
498 write!(
499 f,
500 "Hourly workflow limit exceeded: {}/{} workflows, resets in {}s",
501 current, limit, resets_in.as_secs()
502 )
503 }
504 QuotaExceeded::StepLimit { workflow_id, steps_executed, max_steps } => {
505 write!(
506 f,
507 "Workflow '{}' step limit exceeded: {}/{} steps",
508 workflow_id, steps_executed, max_steps
509 )
510 }
511 }
512 }
513}
514
515impl std::error::Error for QuotaExceeded {}
516
517impl QuotaExceeded {
518 pub fn to_llm_message(&self) -> String {
520 match self {
521 QuotaExceeded::HourlyLimit { current, limit, resets_in } => {
522 format!(
523 "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"hourly_limit\", \
524 \"current\": {}, \"limit\": {}, \"resets_in_seconds\": {}, \
525 \"suggestion\": \"Wait for quota reset or optimize workflow count\"}}",
526 current, limit, resets_in.as_secs()
527 )
528 }
529 QuotaExceeded::StepLimit { workflow_id, steps_executed, max_steps } => {
530 format!(
531 "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"step_limit\", \
532 \"workflow_id\": \"{}\", \"steps_executed\": {}, \"max_steps\": {}, \
533 \"suggestion\": \"Complete current workflow before starting more steps\"}}",
534 workflow_id, steps_executed, max_steps
535 )
536 }
537 }
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_token_budget_creation() {
547 let budget = AgentTokenBudget::daily("agent-1", 10000);
548 assert_eq!(budget.remaining(), 10000);
549 assert_eq!(budget.used(), 0);
550 }
551
552 #[test]
553 fn test_token_budget_consume() {
554 let budget = AgentTokenBudget::daily("agent-1", 100);
555
556 assert!(budget.consume("query", 10).is_ok());
557 assert_eq!(budget.used(), 10);
558 assert_eq!(budget.remaining(), 90);
559 }
560
561 #[test]
562 fn test_token_budget_exceeded() {
563 let budget = AgentTokenBudget::daily("agent-1", 10);
564
565 assert!(budget.consume("query", 5).is_ok());
566 assert!(budget.consume("query", 5).is_ok());
567
568 let result = budget.consume("query", 1);
569 assert!(result.is_err());
570
571 let err = result.unwrap_err();
572 assert_eq!(err.agent_id, "agent-1");
573 assert_eq!(err.remaining, 0);
574 }
575
576 #[test]
577 fn test_token_budget_operation_costs() {
578 let budget = AgentTokenBudget::daily("agent-1", 1000);
579
580 assert!(budget.consume("embedding", 10).is_ok());
582 assert_eq!(budget.used(), 50); }
584
585 #[test]
586 fn test_token_budget_warning() {
587 let budget = AgentTokenBudget::daily("agent-1", 100)
588 .with_warning_threshold(0.8);
589
590 assert!(!budget.is_warning());
591
592 assert!(budget.consume("query", 85).is_ok());
593 assert!(budget.is_warning());
594 }
595
596 #[test]
597 fn test_token_budget_reset() {
598 let budget = AgentTokenBudget::new("agent-1", 100, Duration::from_millis(50));
599
600 assert!(budget.consume("query", 100).is_ok());
601 assert_eq!(budget.remaining(), 0);
602
603 std::thread::sleep(Duration::from_millis(60));
604
605 assert_eq!(budget.remaining(), 100);
607 }
608
609 #[test]
610 fn test_budget_exceeded_llm_message() {
611 let err = BudgetExceeded {
612 agent_id: "agent-1".to_string(),
613 requested: 100,
614 remaining: 50,
615 total: 1000,
616 resets_in: Duration::from_secs(3600),
617 };
618
619 let msg = err.to_llm_message();
620 assert!(msg.contains("budget_exceeded"));
621 assert!(msg.contains("agent-1"));
622 }
623
624 #[test]
625 fn test_workflow_quota_creation() {
626 let quota = WorkflowQuota::hourly(10, 100);
627 assert_eq!(quota.remaining(), 10);
628 }
629
630 #[test]
631 fn test_workflow_quota_begin() {
632 let quota = WorkflowQuota::hourly(10, 100);
633
634 let token = quota.begin_workflow("wf-1").unwrap();
635 assert_eq!(token.remaining_steps(), 100);
636 assert_eq!(quota.remaining(), 9);
637 }
638
639 #[test]
640 fn test_workflow_quota_exceeded() {
641 let quota = WorkflowQuota::hourly(2, 100);
642
643 assert!(quota.begin_workflow("wf-1").is_ok());
644 assert!(quota.begin_workflow("wf-2").is_ok());
645
646 let result = quota.begin_workflow("wf-3");
647 assert!(result.is_err());
648 }
649
650 #[test]
651 fn test_workflow_token_steps() {
652 let quota = WorkflowQuota::hourly(10, 5);
653 let token = quota.begin_workflow("wf-1").unwrap();
654
655 for _ in 0..5 {
656 assert!(token.execute_step().is_ok());
657 }
658
659 let result = token.execute_step();
660 assert!(result.is_err());
661 }
662
663 #[test]
664 fn test_workflow_token_can_continue() {
665 let quota = WorkflowQuota::hourly(10, 2);
666 let token = quota.begin_workflow("wf-1").unwrap();
667
668 assert!(token.can_continue());
669
670 assert!(token.execute_step().is_ok());
671 assert!(token.can_continue());
672
673 assert!(token.execute_step().is_ok());
674 assert!(!token.can_continue());
675 }
676
677 #[test]
678 fn test_quota_exceeded_llm_message() {
679 let err = QuotaExceeded::HourlyLimit {
680 current: 10,
681 limit: 10,
682 resets_in: Duration::from_secs(1800),
683 };
684
685 let msg = err.to_llm_message();
686 assert!(msg.contains("workflow_quota_exceeded"));
687 assert!(msg.contains("hourly_limit"));
688
689 let err2 = QuotaExceeded::StepLimit {
690 workflow_id: "wf-1".to_string(),
691 steps_executed: 100,
692 max_steps: 100,
693 };
694
695 let msg2 = err2.to_llm_message();
696 assert!(msg2.contains("step_limit"));
697 }
698
699 #[test]
700 fn test_workflow_end() {
701 let quota = WorkflowQuota::hourly(10, 100);
702
703 let _token = quota.begin_workflow("wf-1").unwrap();
704 assert_eq!(quota.active_count(), 1);
705
706 quota.end_workflow("wf-1");
707 assert_eq!(quota.active_count(), 0);
708 }
709}