Skip to main content

heliosdb_proxy/rate_limit/
agent.rs

1//! Agent Token Budget & Workflow Quotas
2//!
3//! AI/Agent-specific rate limiting features:
4//! - Token budgets (daily/hourly allocations)
5//! - Workflow quotas (limit multi-step operations)
6//! - LLM-friendly error messages
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use dashmap::DashMap;
13use parking_lot::RwLock;
14
15/// Agent token budget
16///
17/// AI agents get token budgets instead of simple rate limits.
18/// This allows for burst operations while enforcing daily/hourly limits.
19#[derive(Debug)]
20pub struct AgentTokenBudget {
21    /// Budget identifier
22    agent_id: String,
23
24    /// Total token allocation for the period
25    total_tokens: u64,
26
27    /// Used tokens in current period
28    used_tokens: AtomicU64,
29
30    /// Token cost per operation type
31    operation_costs: HashMap<String, u64>,
32
33    /// Budget period (reset interval)
34    period: Duration,
35
36    /// Last reset time
37    last_reset: RwLock<Instant>,
38
39    /// Warning threshold (percentage)
40    warning_threshold: f64,
41
42    /// Hard limit enabled
43    hard_limit: bool,
44}
45
46impl AgentTokenBudget {
47    /// Create a new daily token budget
48    pub fn daily(agent_id: impl Into<String>, tokens: u64) -> Self {
49        Self::new(agent_id, tokens, Duration::from_secs(86400))
50    }
51
52    /// Create a new hourly token budget
53    pub fn hourly(agent_id: impl Into<String>, tokens: u64) -> Self {
54        Self::new(agent_id, tokens, Duration::from_secs(3600))
55    }
56
57    /// Create a new token budget with custom period
58    pub fn new(agent_id: impl Into<String>, tokens: u64, period: Duration) -> Self {
59        Self {
60            agent_id: agent_id.into(),
61            total_tokens: tokens,
62            used_tokens: AtomicU64::new(0),
63            operation_costs: Self::default_operation_costs(),
64            period,
65            last_reset: RwLock::new(Instant::now()),
66            warning_threshold: 0.8,
67            hard_limit: true,
68        }
69    }
70
71    /// Set warning threshold (0.0 - 1.0)
72    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
73        self.warning_threshold = threshold.clamp(0.0, 1.0);
74        self
75    }
76
77    /// Set hard limit behavior
78    pub fn with_hard_limit(mut self, hard: bool) -> Self {
79        self.hard_limit = hard;
80        self
81    }
82
83    /// Set operation costs
84    pub fn with_operation_costs(mut self, costs: HashMap<String, u64>) -> Self {
85        self.operation_costs = costs;
86        self
87    }
88
89    /// Add operation cost
90    pub fn add_operation_cost(&mut self, operation: impl Into<String>, cost: u64) {
91        self.operation_costs.insert(operation.into(), cost);
92    }
93
94    /// Consume tokens for an operation
95    pub fn consume(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
96        self.maybe_reset();
97
98        let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
99        let total_cost = cost.saturating_mul(estimated_tokens);
100
101        let used = self.used_tokens.fetch_add(total_cost, Ordering::SeqCst);
102
103        if self.hard_limit && used + total_cost > self.total_tokens {
104            // Rollback
105            self.used_tokens.fetch_sub(total_cost, Ordering::SeqCst);
106
107            return Err(BudgetExceeded {
108                agent_id: self.agent_id.clone(),
109                requested: total_cost,
110                remaining: self.total_tokens.saturating_sub(used),
111                total: self.total_tokens,
112                resets_in: self.time_until_reset(),
113            });
114        }
115
116        Ok(())
117    }
118
119    /// Check if budget is available (without consuming)
120    pub fn check(&self, operation: &str, estimated_tokens: u64) -> Result<(), BudgetExceeded> {
121        self.maybe_reset();
122
123        let cost = self.operation_costs.get(operation).copied().unwrap_or(1);
124        let total_cost = cost.saturating_mul(estimated_tokens);
125        let used = self.used_tokens.load(Ordering::SeqCst);
126
127        if used + total_cost > self.total_tokens {
128            return Err(BudgetExceeded {
129                agent_id: self.agent_id.clone(),
130                requested: total_cost,
131                remaining: self.total_tokens.saturating_sub(used),
132                total: self.total_tokens,
133                resets_in: self.time_until_reset(),
134            });
135        }
136
137        Ok(())
138    }
139
140    /// Get remaining tokens
141    pub fn remaining(&self) -> u64 {
142        self.maybe_reset();
143        let used = self.used_tokens.load(Ordering::SeqCst);
144        self.total_tokens.saturating_sub(used)
145    }
146
147    /// Get used tokens
148    pub fn used(&self) -> u64 {
149        self.maybe_reset();
150        self.used_tokens.load(Ordering::SeqCst)
151    }
152
153    /// Get usage percentage (0.0 - 1.0)
154    pub fn usage_percentage(&self) -> f64 {
155        self.maybe_reset();
156        let used = self.used_tokens.load(Ordering::SeqCst);
157        used as f64 / self.total_tokens as f64
158    }
159
160    /// Check if over warning threshold
161    pub fn is_warning(&self) -> bool {
162        self.usage_percentage() >= self.warning_threshold
163    }
164
165    /// Get time until reset
166    pub fn time_until_reset(&self) -> Duration {
167        let last = *self.last_reset.read();
168        let elapsed = last.elapsed();
169
170        if elapsed >= self.period {
171            Duration::ZERO
172        } else {
173            self.period - elapsed
174        }
175    }
176
177    /// Force reset
178    pub fn reset(&self) {
179        self.used_tokens.store(0, Ordering::SeqCst);
180        *self.last_reset.write() = Instant::now();
181    }
182
183    /// Maybe reset if period elapsed
184    fn maybe_reset(&self) {
185        let last = *self.last_reset.read();
186        if last.elapsed() >= self.period {
187            self.reset();
188        }
189    }
190
191    fn default_operation_costs() -> HashMap<String, u64> {
192        let mut costs = HashMap::new();
193        costs.insert("query".to_string(), 1);
194        costs.insert("embedding".to_string(), 5);
195        costs.insert("vector_search".to_string(), 10);
196        costs.insert("write".to_string(), 2);
197        costs.insert("transaction".to_string(), 3);
198        costs
199    }
200}
201
202impl Clone for AgentTokenBudget {
203    fn clone(&self) -> Self {
204        Self {
205            agent_id: self.agent_id.clone(),
206            total_tokens: self.total_tokens,
207            used_tokens: AtomicU64::new(self.used_tokens.load(Ordering::Relaxed)),
208            operation_costs: self.operation_costs.clone(),
209            period: self.period,
210            last_reset: RwLock::new(*self.last_reset.read()),
211            warning_threshold: self.warning_threshold,
212            hard_limit: self.hard_limit,
213        }
214    }
215}
216
217/// Budget exceeded error
218#[derive(Debug, Clone)]
219pub struct BudgetExceeded {
220    /// Agent ID
221    pub agent_id: String,
222
223    /// Tokens requested
224    pub requested: u64,
225
226    /// Tokens remaining
227    pub remaining: u64,
228
229    /// Total budget
230    pub total: u64,
231
232    /// Time until budget resets
233    pub resets_in: Duration,
234}
235
236impl std::fmt::Display for BudgetExceeded {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        write!(
239            f,
240            "Token budget exceeded for agent '{}': requested {} tokens, {} remaining of {} total, resets in {}s",
241            self.agent_id,
242            self.requested,
243            self.remaining,
244            self.total,
245            self.resets_in.as_secs()
246        )
247    }
248}
249
250impl std::error::Error for BudgetExceeded {}
251
252impl BudgetExceeded {
253    /// Get LLM-friendly error message
254    pub fn to_llm_message(&self) -> String {
255        format!(
256            "{{\"error\": \"budget_exceeded\", \"message\": \"Token budget exceeded\", \
257             \"details\": {{\"agent_id\": \"{}\", \"requested\": {}, \"remaining\": {}, \
258             \"total\": {}, \"resets_in_seconds\": {}}}, \
259             \"suggestion\": \"Wait for budget reset or request a higher allocation\"}}",
260            self.agent_id,
261            self.requested,
262            self.remaining,
263            self.total,
264            self.resets_in.as_secs()
265        )
266    }
267}
268
269/// Workflow quota
270///
271/// Tracks and limits agent workflow executions (multi-step operations).
272#[derive(Debug)]
273pub struct WorkflowQuota {
274    /// Maximum workflows per period
275    max_workflows: u32,
276
277    /// Maximum steps per workflow
278    max_steps: u32,
279
280    /// Current period's workflow count
281    workflow_count: AtomicU32,
282
283    /// Quota period
284    period: Duration,
285
286    /// Last reset
287    last_reset: RwLock<Instant>,
288
289    /// Active workflows
290    active_workflows: DashMap<String, WorkflowToken>,
291}
292
293impl WorkflowQuota {
294    /// Create a new hourly workflow quota
295    pub fn hourly(max_workflows: u32, max_steps: u32) -> Self {
296        Self::new(max_workflows, max_steps, Duration::from_secs(3600))
297    }
298
299    /// Create a new workflow quota
300    pub fn new(max_workflows: u32, max_steps: u32, period: Duration) -> Self {
301        Self {
302            max_workflows,
303            max_steps,
304            workflow_count: AtomicU32::new(0),
305            period,
306            last_reset: RwLock::new(Instant::now()),
307            active_workflows: DashMap::new(),
308        }
309    }
310
311    /// Begin a new workflow
312    pub fn begin_workflow(
313        &self,
314        workflow_id: impl Into<String>,
315    ) -> Result<WorkflowToken, QuotaExceeded> {
316        self.maybe_reset();
317
318        let count = self.workflow_count.fetch_add(1, Ordering::SeqCst);
319        if count >= self.max_workflows {
320            self.workflow_count.fetch_sub(1, Ordering::SeqCst);
321            return Err(QuotaExceeded::HourlyLimit {
322                current: count,
323                limit: self.max_workflows,
324                resets_in: self.time_until_reset(),
325            });
326        }
327
328        let id = workflow_id.into();
329        let token = WorkflowToken::new(id.clone(), self.max_steps);
330        self.active_workflows.insert(id, token.clone());
331
332        Ok(token)
333    }
334
335    /// End a workflow
336    pub fn end_workflow(&self, workflow_id: &str) {
337        self.active_workflows.remove(workflow_id);
338    }
339
340    /// Get active workflow count
341    pub fn active_count(&self) -> usize {
342        self.active_workflows.len()
343    }
344
345    /// Get workflow count in period
346    pub fn period_count(&self) -> u32 {
347        self.maybe_reset();
348        self.workflow_count.load(Ordering::SeqCst)
349    }
350
351    /// Get remaining workflows
352    pub fn remaining(&self) -> u32 {
353        self.maybe_reset();
354        let count = self.workflow_count.load(Ordering::SeqCst);
355        self.max_workflows.saturating_sub(count)
356    }
357
358    /// Get time until reset
359    pub fn time_until_reset(&self) -> Duration {
360        let last = *self.last_reset.read();
361        let elapsed = last.elapsed();
362
363        if elapsed >= self.period {
364            Duration::ZERO
365        } else {
366            self.period - elapsed
367        }
368    }
369
370    /// Force reset
371    pub fn reset(&self) {
372        self.workflow_count.store(0, Ordering::SeqCst);
373        *self.last_reset.write() = Instant::now();
374    }
375
376    fn maybe_reset(&self) {
377        let last = *self.last_reset.read();
378        if last.elapsed() >= self.period {
379            self.reset();
380        }
381    }
382}
383
384impl Clone for WorkflowQuota {
385    fn clone(&self) -> Self {
386        Self {
387            max_workflows: self.max_workflows,
388            max_steps: self.max_steps,
389            workflow_count: AtomicU32::new(self.workflow_count.load(Ordering::Relaxed)),
390            period: self.period,
391            last_reset: RwLock::new(*self.last_reset.read()),
392            active_workflows: DashMap::new(),
393        }
394    }
395}
396
397/// Workflow token
398///
399/// Tracks a single workflow's step usage.
400#[derive(Debug)]
401pub struct WorkflowToken {
402    /// Workflow ID
403    pub id: String,
404
405    /// Remaining steps
406    remaining_steps: AtomicU32,
407
408    /// Total steps allowed
409    max_steps: u32,
410
411    /// Steps executed
412    steps_executed: AtomicU32,
413
414    /// Created at
415    created_at: Instant,
416}
417
418impl Clone for WorkflowToken {
419    fn clone(&self) -> Self {
420        Self {
421            id: self.id.clone(),
422            remaining_steps: AtomicU32::new(self.remaining_steps.load(Ordering::Relaxed)),
423            max_steps: self.max_steps,
424            steps_executed: AtomicU32::new(self.steps_executed.load(Ordering::Relaxed)),
425            created_at: self.created_at,
426        }
427    }
428}
429
430impl WorkflowToken {
431    fn new(id: String, max_steps: u32) -> Self {
432        Self {
433            id,
434            remaining_steps: AtomicU32::new(max_steps),
435            max_steps,
436            steps_executed: AtomicU32::new(0),
437            created_at: Instant::now(),
438        }
439    }
440
441    /// Execute a step
442    pub fn execute_step(&self) -> Result<(), QuotaExceeded> {
443        let remaining = self.remaining_steps.fetch_sub(1, Ordering::SeqCst);
444
445        if remaining == 0 {
446            self.remaining_steps.fetch_add(1, Ordering::SeqCst); // Rollback
447            return Err(QuotaExceeded::StepLimit {
448                workflow_id: self.id.clone(),
449                steps_executed: self.steps_executed.load(Ordering::SeqCst),
450                max_steps: self.max_steps,
451            });
452        }
453
454        self.steps_executed.fetch_add(1, Ordering::SeqCst);
455        Ok(())
456    }
457
458    /// Get remaining steps
459    pub fn remaining_steps(&self) -> u32 {
460        self.remaining_steps.load(Ordering::SeqCst)
461    }
462
463    /// Get executed steps
464    pub fn steps_executed(&self) -> u32 {
465        self.steps_executed.load(Ordering::SeqCst)
466    }
467
468    /// Get workflow duration
469    pub fn duration(&self) -> Duration {
470        self.created_at.elapsed()
471    }
472
473    /// Check if can execute more steps
474    pub fn can_continue(&self) -> bool {
475        self.remaining_steps.load(Ordering::SeqCst) > 0
476    }
477}
478
479/// Quota exceeded error
480#[derive(Debug, Clone)]
481pub enum QuotaExceeded {
482    /// Hourly workflow limit reached
483    HourlyLimit {
484        current: u32,
485        limit: u32,
486        resets_in: Duration,
487    },
488
489    /// Step limit for workflow reached
490    StepLimit {
491        workflow_id: String,
492        steps_executed: u32,
493        max_steps: u32,
494    },
495}
496
497impl std::fmt::Display for QuotaExceeded {
498    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499        match self {
500            QuotaExceeded::HourlyLimit {
501                current,
502                limit,
503                resets_in,
504            } => {
505                write!(
506                    f,
507                    "Hourly workflow limit exceeded: {}/{} workflows, resets in {}s",
508                    current,
509                    limit,
510                    resets_in.as_secs()
511                )
512            }
513            QuotaExceeded::StepLimit {
514                workflow_id,
515                steps_executed,
516                max_steps,
517            } => {
518                write!(
519                    f,
520                    "Workflow '{}' step limit exceeded: {}/{} steps",
521                    workflow_id, steps_executed, max_steps
522                )
523            }
524        }
525    }
526}
527
528impl std::error::Error for QuotaExceeded {}
529
530impl QuotaExceeded {
531    /// Get LLM-friendly error message
532    pub fn to_llm_message(&self) -> String {
533        match self {
534            QuotaExceeded::HourlyLimit {
535                current,
536                limit,
537                resets_in,
538            } => {
539                format!(
540                    "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"hourly_limit\", \
541                     \"current\": {}, \"limit\": {}, \"resets_in_seconds\": {}, \
542                     \"suggestion\": \"Wait for quota reset or optimize workflow count\"}}",
543                    current,
544                    limit,
545                    resets_in.as_secs()
546                )
547            }
548            QuotaExceeded::StepLimit {
549                workflow_id,
550                steps_executed,
551                max_steps,
552            } => {
553                format!(
554                    "{{\"error\": \"workflow_quota_exceeded\", \"type\": \"step_limit\", \
555                     \"workflow_id\": \"{}\", \"steps_executed\": {}, \"max_steps\": {}, \
556                     \"suggestion\": \"Complete current workflow before starting more steps\"}}",
557                    workflow_id, steps_executed, max_steps
558                )
559            }
560        }
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    #[test]
569    fn test_token_budget_creation() {
570        let budget = AgentTokenBudget::daily("agent-1", 10000);
571        assert_eq!(budget.remaining(), 10000);
572        assert_eq!(budget.used(), 0);
573    }
574
575    #[test]
576    fn test_token_budget_consume() {
577        let budget = AgentTokenBudget::daily("agent-1", 100);
578
579        assert!(budget.consume("query", 10).is_ok());
580        assert_eq!(budget.used(), 10);
581        assert_eq!(budget.remaining(), 90);
582    }
583
584    #[test]
585    fn test_token_budget_exceeded() {
586        let budget = AgentTokenBudget::daily("agent-1", 10);
587
588        assert!(budget.consume("query", 5).is_ok());
589        assert!(budget.consume("query", 5).is_ok());
590
591        let result = budget.consume("query", 1);
592        assert!(result.is_err());
593
594        let err = result.unwrap_err();
595        assert_eq!(err.agent_id, "agent-1");
596        assert_eq!(err.remaining, 0);
597    }
598
599    #[test]
600    fn test_token_budget_operation_costs() {
601        let budget = AgentTokenBudget::daily("agent-1", 1000);
602
603        // Default: embedding costs 5x
604        assert!(budget.consume("embedding", 10).is_ok());
605        assert_eq!(budget.used(), 50); // 5 * 10
606    }
607
608    #[test]
609    fn test_token_budget_warning() {
610        let budget = AgentTokenBudget::daily("agent-1", 100).with_warning_threshold(0.8);
611
612        assert!(!budget.is_warning());
613
614        assert!(budget.consume("query", 85).is_ok());
615        assert!(budget.is_warning());
616    }
617
618    #[test]
619    fn test_token_budget_reset() {
620        let budget = AgentTokenBudget::new("agent-1", 100, Duration::from_millis(50));
621
622        assert!(budget.consume("query", 100).is_ok());
623        assert_eq!(budget.remaining(), 0);
624
625        std::thread::sleep(Duration::from_millis(60));
626
627        // Should auto-reset
628        assert_eq!(budget.remaining(), 100);
629    }
630
631    #[test]
632    fn test_budget_exceeded_llm_message() {
633        let err = BudgetExceeded {
634            agent_id: "agent-1".to_string(),
635            requested: 100,
636            remaining: 50,
637            total: 1000,
638            resets_in: Duration::from_secs(3600),
639        };
640
641        let msg = err.to_llm_message();
642        assert!(msg.contains("budget_exceeded"));
643        assert!(msg.contains("agent-1"));
644    }
645
646    #[test]
647    fn test_workflow_quota_creation() {
648        let quota = WorkflowQuota::hourly(10, 100);
649        assert_eq!(quota.remaining(), 10);
650    }
651
652    #[test]
653    fn test_workflow_quota_begin() {
654        let quota = WorkflowQuota::hourly(10, 100);
655
656        let token = quota.begin_workflow("wf-1").unwrap();
657        assert_eq!(token.remaining_steps(), 100);
658        assert_eq!(quota.remaining(), 9);
659    }
660
661    #[test]
662    fn test_workflow_quota_exceeded() {
663        let quota = WorkflowQuota::hourly(2, 100);
664
665        assert!(quota.begin_workflow("wf-1").is_ok());
666        assert!(quota.begin_workflow("wf-2").is_ok());
667
668        let result = quota.begin_workflow("wf-3");
669        assert!(result.is_err());
670    }
671
672    #[test]
673    fn test_workflow_token_steps() {
674        let quota = WorkflowQuota::hourly(10, 5);
675        let token = quota.begin_workflow("wf-1").unwrap();
676
677        for _ in 0..5 {
678            assert!(token.execute_step().is_ok());
679        }
680
681        let result = token.execute_step();
682        assert!(result.is_err());
683    }
684
685    #[test]
686    fn test_workflow_token_can_continue() {
687        let quota = WorkflowQuota::hourly(10, 2);
688        let token = quota.begin_workflow("wf-1").unwrap();
689
690        assert!(token.can_continue());
691
692        assert!(token.execute_step().is_ok());
693        assert!(token.can_continue());
694
695        assert!(token.execute_step().is_ok());
696        assert!(!token.can_continue());
697    }
698
699    #[test]
700    fn test_quota_exceeded_llm_message() {
701        let err = QuotaExceeded::HourlyLimit {
702            current: 10,
703            limit: 10,
704            resets_in: Duration::from_secs(1800),
705        };
706
707        let msg = err.to_llm_message();
708        assert!(msg.contains("workflow_quota_exceeded"));
709        assert!(msg.contains("hourly_limit"));
710
711        let err2 = QuotaExceeded::StepLimit {
712            workflow_id: "wf-1".to_string(),
713            steps_executed: 100,
714            max_steps: 100,
715        };
716
717        let msg2 = err2.to_llm_message();
718        assert!(msg2.contains("step_limit"));
719    }
720
721    #[test]
722    fn test_workflow_end() {
723        let quota = WorkflowQuota::hourly(10, 100);
724
725        let _token = quota.begin_workflow("wf-1").unwrap();
726        assert_eq!(quota.active_count(), 1);
727
728        quota.end_workflow("wf-1");
729        assert_eq!(quota.active_count(), 0);
730    }
731}