Skip to main content

t_ron/
rate.rs

1//! Rate limiter — per-agent, per-tool token bucket.
2
3use dashmap::DashMap;
4use std::time::Instant;
5
6pub struct RateLimiter {
7    /// "agent_id\x1ftool_name" -> bucket (unit separator avoids tuple allocation)
8    buckets: DashMap<String, TokenBucket>,
9    /// Default calls per minute
10    default_rate: u64,
11    /// Per-agent rate overrides (from policy config).
12    agent_rates: DashMap<String, u64>,
13}
14
15struct TokenBucket {
16    tokens: f64,
17    max_tokens: f64,
18    refill_rate: f64, // tokens per second
19    last_refill: Instant,
20}
21
22impl Default for RateLimiter {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl RateLimiter {
29    pub fn new() -> Self {
30        Self {
31            buckets: DashMap::new(),
32            default_rate: 60, // 60 calls/minute default
33            agent_rates: DashMap::new(),
34        }
35    }
36
37    /// Check if a call is within rate limits. Consumes a token if allowed.
38    #[inline]
39    pub fn check(&self, agent_id: &str, tool_name: &str) -> bool {
40        let key = bucket_key(agent_id, tool_name);
41        let rate = self
42            .agent_rates
43            .get(agent_id)
44            .map(|r| *r)
45            .unwrap_or(self.default_rate);
46        let mut bucket = self.buckets.entry(key).or_insert_with(|| TokenBucket {
47            tokens: rate as f64,
48            max_tokens: rate as f64,
49            refill_rate: rate as f64 / 60.0,
50            last_refill: Instant::now(),
51        });
52
53        // Refill tokens based on elapsed time
54        let now = Instant::now();
55        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
56        bucket.tokens = (bucket.tokens + elapsed * bucket.refill_rate).min(bucket.max_tokens);
57        bucket.last_refill = now;
58
59        // Try to consume a token
60        if bucket.tokens >= 1.0 {
61            bucket.tokens -= 1.0;
62            true
63        } else {
64            false
65        }
66    }
67
68    /// Set rate limit for a specific agent (calls per minute).
69    ///
70    /// Applies immediately to existing buckets and is stored so that
71    /// future buckets for this agent also use the new rate.
72    pub fn set_rate(&self, agent_id: &str, calls_per_minute: u64) {
73        let new_max = calls_per_minute as f64;
74        // Store the per-agent rate for future bucket creation
75        self.agent_rates
76            .insert(agent_id.to_string(), calls_per_minute);
77        // Update all existing buckets for this agent
78        let prefix = format!("{agent_id}\x1f");
79        for mut entry in self.buckets.iter_mut() {
80            if entry.key().starts_with(&prefix) {
81                let bucket = entry.value_mut();
82                bucket.max_tokens = new_max;
83                bucket.refill_rate = new_max / 60.0;
84                // Clamp current tokens so a lowered limit takes effect immediately
85                bucket.tokens = bucket.tokens.min(new_max);
86            }
87        }
88    }
89}
90
91/// Build a bucket key from agent + tool using ASCII unit separator.
92#[inline]
93fn bucket_key(agent_id: &str, tool_name: &str) -> String {
94    use std::fmt::Write;
95    let mut key = String::with_capacity(agent_id.len() + 1 + tool_name.len());
96    let _ = write!(key, "{agent_id}\x1f{tool_name}");
97    key
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn basic_rate_limit() {
106        let limiter = RateLimiter::new();
107        // Should allow 60 calls (default bucket)
108        for _ in 0..60 {
109            assert!(limiter.check("agent", "tool"));
110        }
111        // 61st should be denied
112        assert!(!limiter.check("agent", "tool"));
113    }
114
115    #[test]
116    fn different_agents_separate_buckets() {
117        let limiter = RateLimiter::new();
118        for _ in 0..60 {
119            limiter.check("agent-a", "tool");
120        }
121        // agent-b should still have tokens
122        assert!(limiter.check("agent-b", "tool"));
123    }
124
125    #[test]
126    fn different_tools_separate_buckets() {
127        let limiter = RateLimiter::new();
128        for _ in 0..60 {
129            limiter.check("agent", "tool_a");
130        }
131        assert!(!limiter.check("agent", "tool_a"));
132        // Same agent, different tool still has tokens
133        assert!(limiter.check("agent", "tool_b"));
134    }
135
136    #[test]
137    fn set_rate_lowers_limit() {
138        let limiter = RateLimiter::new();
139        // Prime the bucket for agent — consumes 1 token, leaving 59
140        assert!(limiter.check("agent", "tool"));
141        // Lower rate to 10/min — clamps current tokens from 59 to 10
142        limiter.set_rate("agent", 10);
143        // Should allow exactly 10 more calls (tokens clamped to 10)
144        let mut allowed = 0;
145        for _ in 0..20 {
146            if limiter.check("agent", "tool") {
147                allowed += 1;
148            } else {
149                break;
150            }
151        }
152        assert_eq!(allowed, 10);
153    }
154
155    #[test]
156    fn set_rate_does_not_affect_other_agents() {
157        let limiter = RateLimiter::new();
158        // Prime both agents
159        assert!(limiter.check("agent-a", "tool"));
160        assert!(limiter.check("agent-b", "tool"));
161
162        limiter.set_rate("agent-a", 5);
163
164        // agent-b should still have default rate
165        let mut count = 0;
166        for _ in 0..59 {
167            if limiter.check("agent-b", "tool") {
168                count += 1;
169            }
170        }
171        assert_eq!(count, 59); // 60 - 1 (initial) = 59 remaining
172    }
173
174    #[test]
175    fn set_rate_before_any_check() {
176        let limiter = RateLimiter::new();
177        // set_rate stores the rate so future buckets use it
178        limiter.set_rate("nobody", 5);
179        // First check should create a bucket with rate=5, not the default 60
180        let mut count = 0;
181        for _ in 0..10 {
182            if limiter.check("nobody", "tool") {
183                count += 1;
184            }
185        }
186        assert_eq!(count, 5);
187    }
188
189    #[test]
190    fn token_refill_over_time() {
191        let limiter = RateLimiter::new();
192        // Exhaust all tokens
193        for _ in 0..60 {
194            limiter.check("agent", "tool");
195        }
196        assert!(!limiter.check("agent", "tool"));
197
198        // Manually advance the bucket's last_refill to simulate time passing
199        // We can't easily sleep in tests, but we can verify the refill logic
200        // by checking that the bucket key exists
201        assert!(limiter.buckets.contains_key(&bucket_key("agent", "tool")));
202    }
203}