Skip to main content

heliosdb_proxy/rate_limit/
agent.rs

1//! Agent Token Budget & Workflow Quotas
2//!
3//! AI/Agent-specific rate limiting features:
4//! - Token budgets (daily/hourly allocations)
5//! - Workflow quotas (limit multi-step operations)
6//! - LLM-friendly error messages
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use parking_lot::RwLock;
14
15/// Agent token budget
16///
17/// AI agents get token budgets instead of simple rate limits.
18/// This allows for burst operations while enforcing daily/hourly limits.
19#[derive(Debug)]
20pub struct AgentTokenBudget {
21    /// Budget identifier
22    agent_id: String,
23
24    /// Total token allocation for the period
25    total_tokens: u64,
26
27    /// Used tokens in current period
28    used_tokens: AtomicU64,
29
30    /// Token cost per operation type
31    operation_costs: HashMap<String, u64>,
32
33    /// Budget period (reset interval)
34    period: Duration,
35
36    /// Last reset time
37    last_reset: RwLock<Instant>,
38
39    /// Warning threshold (percentage)
40    warning_threshold: f64,
41
42    /// Hard limit enabled
43    hard_limit: bool,
44}
45
46impl AgentTokenBudget {
47    /// Create a new daily token budget
48    pub fn daily(agent_id: impl Into<String>, tokens: u64) -> Self {
49        Self::new(agent_id, tokens, Duration::from_secs(86400))
50    }
51
52    /// Create a new hourly token budget
53    pub fn hourly(agent_id: impl Into<String>, tokens: u64) -> Self {
54        Self::new(agent_id, tokens, Duration::from_secs(3600))
55    }
56
57    /// Create a new token budget with custom period
58    pub fn new(agent_id: impl Into<String>, tokens: u64, period: Duration) -> Self {
59        Self {
60            agent_id: agent_id.into(),
61            total_tokens: tokens,
62            used_tokens: AtomicU64::new(0),
63            operation_costs: Self::default_operation_costs(),
64            period,
65            last_reset: RwLock::new(Instant::now()),
66            warning_threshold: 0.8,
67            hard_limit: true,
68        }
69    }
70
71    /// Set warning threshold (0.0 - 1.0)
72    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
73        self.warning_threshold = threshold.clamp(0.0, 1.0);
74        self
75    }
76
77    /// Set hard limit behavior
78    pub fn with_hard_limit(mut self, hard: bool) -> Self {
79        self.hard_limit = hard;
80        self
81    }
82
83    /// Set operation costs
84    pub fn with_operation_costs(mut self, costs: HashMap<String, u64>) -> Self {
85        self.operation_costs = costs;
86        self
87    }
88
89    /// Add operation cost
90    pub fn add_operation_cost(&mut self, operation: impl Into<String>, cost: u64) {
91        self.operation_costs.insert(operation.into(), cost);
92    }
93
94    /// Consume tokens for an operation
95    pub fn consume(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
96        self.maybe_reset();
97
98        let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
99        let total_cost = cost.saturating_mul(estimated_tokens);
100
101        let used = self.used_tokens.fetch_add(total_cost, Ordering::SeqCst);
102
103        if self.hard_limit && used + total_cost > self.total_tokens {
104            // Rollback
105            self.used_tokens.fetch_sub(total_cost, Ordering::SeqCst);
106
107            return Err(BudgetExceeded {
108                agent_id: self.agent_id.clone(),
109                requested: total_cost,
110                remaining: self.total_tokens.saturating_sub(used),
111                total: self.total_tokens,
112                resets_in: self.time_until_reset(),
113            });
114        }
115
116        Ok(())
117    }
118
119    /// Check if budget is available (without consuming)
120    pub fn check(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
121        self.maybe_reset();
122
123        let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
124        let total_cost = cost.saturating_mul(estimated_tokens);
125        let used = self.used_tokens.load(Ordering::SeqCst);
126
127        if used + total_cost > self.total_tokens {
128            return Err(BudgetExceeded {
129                agent_id: self.agent_id.clone(),
130                requested: total_cost,
131                remaining: self.total_tokens.saturating_sub(used),
132                total: self.total_tokens,
133                resets_in: self.time_until_reset(),
134            });
135        }
136
137        Ok(())
138    }
139
140    /// Get remaining tokens
141    pub fn remaining(&self) -> u64 {
142        self.maybe_reset();
143        let used = self.used_tokens.load(Ordering::SeqCst);
144        self.total_tokens.saturating_sub(used)
145    }
146
147    /// Get used tokens
148    pub fn used(&self) -> u64 {
149        self.maybe_reset();
150        self.used_tokens.load(Ordering::SeqCst)
151    }
152
153    /// Get usage percentage (0.0 - 1.0)
154    pub fn usage_percentage(&self) -> f64 {
155        self.maybe_reset();
156        let used = self.used_tokens.load(Ordering::SeqCst);
157        used as f64 / self.total_tokens as f64
158    }
159
160    /// Check if over warning threshold
161    pub fn is_warning(&self) -> bool {
162        self.usage_percentage() >= self.warning_threshold
163    }
164
165    /// Get time until reset
166    pub fn time_until_reset(&self) -> Duration {
167        let last = *self.last_reset.read();
168        let elapsed = last.elapsed();
169
170        if elapsed >= self.period {
171            Duration::ZERO
172        } else {
173            self.period - elapsed
174        }
175    }
176
177    /// Force reset
178    pub fn reset(&self) {
179        self.used_tokens.store(0, Ordering::SeqCst);
180        *self.last_reset.write() = Instant::now();
181    }
182
183    /// Maybe reset if period elapsed
184    fn maybe_reset(&self) {
185        let last = *self.last_reset.read();
186        if last.elapsed() >= self.period {
187            self.reset();
188        }
189    }
190
191    fn default_operation_costs() -> HashMap<String, u64> {
192        let mut costs = HashMap::new();
193        costs.insert("query".to_string(), 1);
194        costs.insert("embedding".to_string(), 5);
195        costs.insert("vector_search".to_string(), 10);
196        costs.insert("write".to_string(), 2);
197        costs.insert("transaction".to_string(), 3);
198        costs
199    }
200}
201
202impl Clone for AgentTokenBudget {
203    fn clone(&self) -> Self {
204        Self {
205            agent_id: self.agent_id.clone(),
206            total_tokens: self.total_tokens,
207            used_tokens: AtomicU64::new(self.used_tokens.load(Ordering::Relaxed)),
208            operation_costs: self.operation_costs.clone(),
209            period: self.period,
210            last_reset: RwLock::new(*self.last_reset.read()),
211            warning_threshold: self.warning_threshold,
212            hard_limit: self.hard_limit,
213        }
214    }
215}
216
217/// Budget exceeded error
218#[derive(Debug, Clone)]
219pub struct BudgetExceeded {
220    /// Agent ID
221    pub agent_id: String,
222
223    /// Tokens requested
224    pub requested: u64,
225
226    /// Tokens remaining
227    pub remaining: u64,
228
229    /// Total budget
230    pub total: u64,
231
232    /// Time until budget resets
233    pub resets_in: Duration,
234}
235
236impl std::fmt::Display for BudgetExceeded {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        write!(
239            f,
240            "Token budget exceeded for agent '{}': requested {} tokens, {} remaining of {} total, resets in {}s",
241            self.agent_id,
242            self.requested,
243            self.remaining,
244            self.total,
245            self.resets_in.as_secs()
246        )
247    }
248}
249
250impl std::error::Error for BudgetExceeded {}
251
252impl BudgetExceeded {
253    /// Get LLM-friendly error message
254    pub fn to_llm_message(&self) -> String {
255        format!(
256            "{{\"error\": \"budget_exceeded\", \"message\": \"Token budget exceeded\", \
257             \"details\": {{\"agent_id\": \"{}\", \"requested\": {}, \"remaining\": {}, \
258             \"total\": {}, \"resets_in_seconds\": {}}}, \
259             \"suggestion\": \"Wait for budget reset or request a higher allocation\"}}",
260            self.agent_id,
261            self.requested,
262            self.remaining,
263            self.total,
264            self.resets_in.as_secs()
265        )
266    }
267}
268
269/// Workflow quota
270///
271/// Tracks and limits agent workflow executions (multi-step operations).
272#[derive(Debug)]
273pub struct WorkflowQuota {
274    /// Maximum workflows per period
275    max_workflows: u32,
276
277    /// Maximum steps per workflow
278    max_steps: u32,
279
280    /// Current period's workflow count
281    workflow_count: AtomicU32,
282
283    /// Quota period
284    period: Duration,
285
286    /// Last reset
287    last_reset: RwLock<Instant>,
288
289    /// Active workflows
290    active_workflows: DashMap<String, WorkflowToken>,
291}
292
293impl WorkflowQuota {
294    /// Create a new hourly workflow quota
295    pub fn hourly(max_workflows: u32, max_steps: u32) -> Self {
296        Self::new(max_workflows, max_steps, Duration::from_secs(3600))
297    }
298
299    /// Create a new workflow quota
300    pub fn new(max_workflows: u32, max_steps: u32, period: Duration) -> Self {
301        Self {
302            max_workflows,
303            max_steps,
304            workflow_count: AtomicU32::new(0),
305            period,
306            last_reset: RwLock::new(Instant::now()),
307            active_workflows: DashMap::new(),
308        }
309    }
310
311    /// Begin a new workflow
312    pub fn begin_workflow(&self, workflow_id: impl Into<String>) -> Result<WorkflowToken, QuotaExceeded> {
313        self.maybe_reset();
314
315        let count = self.workflow_count.fetch_add(1, Ordering::SeqCst);
316        if count >= self.max_workflows {
317            self.workflow_count.fetch_sub(1, Ordering::SeqCst);
318            return Err(QuotaExceeded::HourlyLimit {
319                current: count,
320                limit: self.max_workflows,
321                resets_in: self.time_until_reset(),
322            });
323        }
324
325        let id = workflow_id.into();
326        let token = WorkflowToken::new(id.clone(), self.max_steps);
327        self.active_workflows.insert(id, token.clone());
328
329        Ok(token)
330    }
331
332    /// End a workflow
333    pub fn end_workflow(&self, workflow_id: &str) {
334        self.active_workflows.remove(workflow_id);
335    }
336
337    /// Get active workflow count
338    pub fn active_count(&self) -> usize {
339        self.active_workflows.len()
340    }
341
342    /// Get workflow count in period
343    pub fn period_count(&self) -> u32 {
344        self.maybe_reset();
345        self.workflow_count.load(Ordering::SeqCst)
346    }
347
348    /// Get remaining workflows
349    pub fn remaining(&self) -> u32 {
350        self.maybe_reset();
351        let count = self.workflow_count.load(Ordering::SeqCst);
352        self.max_workflows.saturating_sub(count)
353    }
354
355    /// Get time until reset
356    pub fn time_until_reset(&self) -> Duration {
357        let last = *self.last_reset.read();
358        let elapsed = last.elapsed();
359
360        if elapsed >= self.period {
361            Duration::ZERO
362        } else {
363            self.period - elapsed
364        }
365    }
366
367    /// Force reset
368    pub fn reset(&self) {
369        self.workflow_count.store(0, Ordering::SeqCst);
370        *self.last_reset.write() = Instant::now();
371    }
372
373    fn maybe_reset(&self) {
374        let last = *self.last_reset.read();
375        if last.elapsed() >= self.period {
376            self.reset();
377        }
378    }
379}
380
381impl Clone for WorkflowQuota {
382    fn clone(&self) -> Self {
383        Self {
384            max_workflows: self.max_workflows,
385            max_steps: self.max_steps,
386            workflow_count: AtomicU32::new(self.workflow_count.load(Ordering::Relaxed)),
387            period: self.period,
388            last_reset: RwLock::new(*self.last_reset.read()),
389            active_workflows: DashMap::new(),
390        }
391    }
392}
393
394/// Workflow token
395///
396/// Tracks a single workflow's step usage.
397#[derive(Debug)]
398pub struct WorkflowToken {
399    /// Workflow ID
400    pub id: String,
401
402    /// Remaining steps
403    remaining_steps: AtomicU32,
404
405    /// Total steps allowed
406    max_steps: u32,
407
408    /// Steps executed
409    steps_executed: AtomicU32,
410
411    /// Created at
412    created_at: Instant,
413}
414
415impl Clone for WorkflowToken {
416    fn clone(&self) -> Self {
417        Self {
418            id: self.id.clone(),
419            remaining_steps: AtomicU32::new(self.remaining_steps.load(Ordering::Relaxed)),
420            max_steps: self.max_steps,
421            steps_executed: AtomicU32::new(self.steps_executed.load(Ordering::Relaxed)),
422            created_at: self.created_at,
423        }
424    }
425}
426
427impl WorkflowToken {
428    fn new(id: String, max_steps: u32) -> Self {
429        Self {
430            id,
431            remaining_steps: AtomicU32::new(max_steps),
432            max_steps,
433            steps_executed: AtomicU32::new(0),
434            created_at: Instant::now(),
435        }
436    }
437
438    /// Execute a step
439    pub fn execute_step(&self) -> Result<(), QuotaExceeded> {
440        let remaining = self.remaining_steps.fetch_sub(1, Ordering::SeqCst);
441
442        if remaining == 0 {
443            self.remaining_steps.fetch_add(1, Ordering::SeqCst); // Rollback
444            return Err(QuotaExceeded::StepLimit {
445                workflow_id: self.id.clone(),
446                steps_executed: self.steps_executed.load(Ordering::SeqCst),
447                max_steps: self.max_steps,
448            });
449        }
450
451        self.steps_executed.fetch_add(1, Ordering::SeqCst);
452        Ok(())
453    }
454
455    /// Get remaining steps
456    pub fn remaining_steps(&self) -> u32 {
457        self.remaining_steps.load(Ordering::SeqCst)
458    }
459
460    /// Get executed steps
461    pub fn steps_executed(&self) -> u32 {
462        self.steps_executed.load(Ordering::SeqCst)
463    }
464
465    /// Get workflow duration
466    pub fn duration(&self) -> Duration {
467        self.created_at.elapsed()
468    }
469
470    /// Check if can execute more steps
471    pub fn can_continue(&self) -> bool {
472        self.remaining_steps.load(Ordering::SeqCst) > 0
473    }
474}
475
476/// Quota exceeded error
477#[derive(Debug, Clone)]
478pub enum QuotaExceeded {
479    /// Hourly workflow limit reached
480    HourlyLimit {
481        current: u32,
482        limit: u32,
483        resets_in: Duration,
484    },
485
486    /// Step limit for workflow reached
487    StepLimit {
488        workflow_id: String,
489        steps_executed: u32,
490        max_steps: u32,
491    },
492}
493
494impl std::fmt::Display for QuotaExceeded {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        match self {
497            QuotaExceeded::HourlyLimit { current, limit, resets_in } => {
498                write!(
499                    f,
500                    "Hourly workflow limit exceeded: {}/{} workflows, resets in {}s",
501                    current, limit, resets_in.as_secs()
502                )
503            }
504            QuotaExceeded::StepLimit { workflow_id, steps_executed, max_steps } => {
505                write!(
506                    f,
507                    "Workflow '{}' step limit exceeded: {}/{} steps",
508                    workflow_id, steps_executed, max_steps
509                )
510            }
511        }
512    }
513}
514
515impl std::error::Error for QuotaExceeded {}
516
517impl QuotaExceeded {
518    /// Get LLM-friendly error message
519    pub fn to_llm_message(&self) -> String {
520        match self {
521            QuotaExceeded::HourlyLimit { current, limit, resets_in } => {
522                format!(
523                    "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"hourly_limit\", \
524                     \"current\": {}, \"limit\": {}, \"resets_in_seconds\": {}, \
525                     \"suggestion\": \"Wait for quota reset or optimize workflow count\"}}",
526                    current, limit, resets_in.as_secs()
527                )
528            }
529            QuotaExceeded::StepLimit { workflow_id, steps_executed, max_steps } => {
530                format!(
531                    "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"step_limit\", \
532                     \"workflow_id\": \"{}\", \"steps_executed\": {}, \"max_steps\": {}, \
533                     \"suggestion\": \"Complete current workflow before starting more steps\"}}",
534                    workflow_id, steps_executed, max_steps
535                )
536            }
537        }
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    #[test]
546    fn test_token_budget_creation() {
547        let budget = AgentTokenBudget::daily("agent-1", 10000);
548        assert_eq!(budget.remaining(), 10000);
549        assert_eq!(budget.used(), 0);
550    }
551
552    #[test]
553    fn test_token_budget_consume() {
554        let budget = AgentTokenBudget::daily("agent-1", 100);
555
556        assert!(budget.consume("query", 10).is_ok());
557        assert_eq!(budget.used(), 10);
558        assert_eq!(budget.remaining(), 90);
559    }
560
561    #[test]
562    fn test_token_budget_exceeded() {
563        let budget = AgentTokenBudget::daily("agent-1", 10);
564
565        assert!(budget.consume("query", 5).is_ok());
566        assert!(budget.consume("query", 5).is_ok());
567
568        let result = budget.consume("query", 1);
569        assert!(result.is_err());
570
571        let err = result.unwrap_err();
572        assert_eq!(err.agent_id, "agent-1");
573        assert_eq!(err.remaining, 0);
574    }
575
576    #[test]
577    fn test_token_budget_operation_costs() {
578        let budget = AgentTokenBudget::daily("agent-1", 1000);
579
580        // Default: embedding costs 5x
581        assert!(budget.consume("embedding", 10).is_ok());
582        assert_eq!(budget.used(), 50); // 5 * 10
583    }
584
585    #[test]
586    fn test_token_budget_warning() {
587        let budget = AgentTokenBudget::daily("agent-1", 100)
588            .with_warning_threshold(0.8);
589
590        assert!(!budget.is_warning());
591
592        assert!(budget.consume("query", 85).is_ok());
593        assert!(budget.is_warning());
594    }
595
596    #[test]
597    fn test_token_budget_reset() {
598        let budget = AgentTokenBudget::new("agent-1", 100, Duration::from_millis(50));
599
600        assert!(budget.consume("query", 100).is_ok());
601        assert_eq!(budget.remaining(), 0);
602
603        std::thread::sleep(Duration::from_millis(60));
604
605        // Should auto-reset
606        assert_eq!(budget.remaining(), 100);
607    }
608
609    #[test]
610    fn test_budget_exceeded_llm_message() {
611        let err = BudgetExceeded {
612            agent_id: "agent-1".to_string(),
613            requested: 100,
614            remaining: 50,
615            total: 1000,
616            resets_in: Duration::from_secs(3600),
617        };
618
619        let msg = err.to_llm_message();
620        assert!(msg.contains("budget_exceeded"));
621        assert!(msg.contains("agent-1"));
622    }
623
624    #[test]
625    fn test_workflow_quota_creation() {
626        let quota = WorkflowQuota::hourly(10, 100);
627        assert_eq!(quota.remaining(), 10);
628    }
629
630    #[test]
631    fn test_workflow_quota_begin() {
632        let quota = WorkflowQuota::hourly(10, 100);
633
634        let token = quota.begin_workflow("wf-1").unwrap();
635        assert_eq!(token.remaining_steps(), 100);
636        assert_eq!(quota.remaining(), 9);
637    }
638
639    #[test]
640    fn test_workflow_quota_exceeded() {
641        let quota = WorkflowQuota::hourly(2, 100);
642
643        assert!(quota.begin_workflow("wf-1").is_ok());
644        assert!(quota.begin_workflow("wf-2").is_ok());
645
646        let result = quota.begin_workflow("wf-3");
647        assert!(result.is_err());
648    }
649
650    #[test]
651    fn test_workflow_token_steps() {
652        let quota = WorkflowQuota::hourly(10, 5);
653        let token = quota.begin_workflow("wf-1").unwrap();
654
655        for _ in 0..5 {
656            assert!(token.execute_step().is_ok());
657        }
658
659        let result = token.execute_step();
660        assert!(result.is_err());
661    }
662
663    #[test]
664    fn test_workflow_token_can_continue() {
665        let quota = WorkflowQuota::hourly(10, 2);
666        let token = quota.begin_workflow("wf-1").unwrap();
667
668        assert!(token.can_continue());
669
670        assert!(token.execute_step().is_ok());
671        assert!(token.can_continue());
672
673        assert!(token.execute_step().is_ok());
674        assert!(!token.can_continue());
675    }
676
677    #[test]
678    fn test_quota_exceeded_llm_message() {
679        let err = QuotaExceeded::HourlyLimit {
680            current: 10,
681            limit: 10,
682            resets_in: Duration::from_secs(1800),
683        };
684
685        let msg = err.to_llm_message();
686        assert!(msg.contains("workflow_quota_exceeded"));
687        assert!(msg.contains("hourly_limit"));
688
689        let err2 = QuotaExceeded::StepLimit {
690            workflow_id: "wf-1".to_string(),
691            steps_executed: 100,
692            max_steps: 100,
693        };
694
695        let msg2 = err2.to_llm_message();
696        assert!(msg2.contains("step_limit"));
697    }
698
699    #[test]
700    fn test_workflow_end() {
701        let quota = WorkflowQuota::hourly(10, 100);
702
703        let _token = quota.begin_workflow("wf-1").unwrap();
704        assert_eq!(quota.active_count(), 1);
705
706        quota.end_workflow("wf-1");
707        assert_eq!(quota.active_count(), 0);
708    }
709}