Skip to main content

envoy/rate_limit/
mod.rs

1//! Hybrid rate limiter — L1 in-memory (AHashMap) + L2 (sqlitegraph) persistence.
2//!
3//! Per-agent token buckets with SIMD-optimized lookups (ahash) for high throughput.
4//! Falls back to sqlitegraph for L1 misses and persistence across restarts.
5//!
6//! Follows AgentRegistry pattern: L1 in-memory cache, write-through to sqlitegraph,
7//! load from DB on startup. Graph is passed as a parameter to methods that need it.
8
9pub mod config;
10pub mod state;
11pub mod store;
12
13pub use config::RateLimitConfig;
14pub use state::{RateLimitDecision, RateLimitState, TokenBucket};
15pub use store::RateLimitStore;
16
17use ahash::AHashMap;
18use parking_lot::RwLock;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::Duration;
23
24/// Statistics for the hybrid rate limiter.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HybridRateLimiterStats {
27    pub l1_size: usize,
28    pub l1_capacity: usize,
29    pub l1_hits: u64,
30    pub l1_misses: u64,
31    pub l1_evictions: u64,
32    pub banned_agents: Vec<String>,
33}
34
35/// Per-agent state with LRU tracking.
36#[derive(Debug, Clone)]
37struct L1Entry {
38    state: RateLimitState,
39    last_accessed: std::time::Instant,
40}
41
42/// Hybrid rate limiter with L1 (AHashMap) + L2 (sqlitegraph).
43///
44/// L1: In-memory cache using ahash::AHashMap for fast lookups.
45/// L2: sqlitegraph persistence for durability and L1 misses.
46///
47/// Follows AgentRegistry pattern: methods take `&sqlitegraph::SqliteGraph` parameter
48/// for L2 operations, rather than storing the graph internally.
49pub struct HybridRateLimiter {
50    store: RateLimitStore,
51    config: RateLimitConfig,
52    l1: Arc<RwLock<AHashMap<String, L1Entry>>>,
53    l1_capacity: usize,
54    banned: Arc<RwLock<HashMap<String, String>>>, // agent_id -> reason
55    stats: Arc<RwLock<HybridRateLimiterStats>>,
56}
57
58impl HybridRateLimiter {
59    /// Create a new hybrid rate limiter.
60    ///
61    /// # Arguments
62    /// * `graph` — sqlitegraph database for L2 persistence (used during initialization)
63    /// * `config` — Rate limit configuration
64    /// * `l1_capacity` — Maximum number of agents to cache in L1
65    pub fn new(
66        graph: &sqlitegraph::SqliteGraph,
67        config: RateLimitConfig,
68        l1_capacity: usize,
69    ) -> crate::error::Result<Self> {
70        let store = RateLimitStore::new();
71
72        // Load existing states from L2 into L1 (cache warming)
73        let mut l1_map = AHashMap::new();
74        if let Ok(states) = store.load_all(graph) {
75            for state in states {
76                l1_map.insert(
77                    state.agent_id.clone(),
78                    L1Entry {
79                        state,
80                        last_accessed: std::time::Instant::now(),
81                    },
82                );
83            }
84        }
85        let l1 = Arc::new(RwLock::new(l1_map));
86
87        let stats = HybridRateLimiterStats {
88            l1_size: l1.read().len(),
89            l1_capacity,
90            l1_hits: 0,
91            l1_misses: 0,
92            l1_evictions: 0,
93            banned_agents: Vec::new(),
94        };
95
96        Ok(Self {
97            store,
98            config,
99            l1,
100            l1_capacity,
101            banned: Arc::new(RwLock::new(HashMap::new())),
102            stats: Arc::new(RwLock::new(stats)),
103        })
104    }
105
106    /// Check if a request is allowed for the given agent.
107    ///
108    /// Takes graph as parameter for L2 fallback.
109    pub fn check_rate_limit(
110        &self,
111        graph: &sqlitegraph::SqliteGraph,
112        agent_id: &str,
113    ) -> RateLimitDecision {
114        // Check ban list first
115        if self.banned.read().contains_key(agent_id) {
116            return RateLimitDecision {
117                allowed: false,
118                retry_after: Some(Duration::from_secs(3600)), // 1 hour
119            };
120        }
121
122        // Try L1 first
123        {
124            let mut l1 = self.l1.write();
125            if let Some(entry) = l1.get_mut(agent_id) {
126                entry.last_accessed = std::time::Instant::now();
127                self.stats.write().l1_hits += 1;
128                return entry.state.check(1);
129            }
130        }
131
132        // L1 miss — try L2
133        self.stats.write().l1_misses += 1;
134        if let Ok(Some(state)) = self.store.load(graph, agent_id) {
135            // Promote to L1
136            let mut l1 = self.l1.write();
137            if l1.len() >= self.l1_capacity {
138                self.evict_lru(graph, &mut l1);
139            }
140            l1.insert(
141                agent_id.to_string(),
142                L1Entry {
143                    state,
144                    last_accessed: std::time::Instant::now(),
145                },
146            );
147            self.stats.write().l1_size = l1.len();
148
149            // Check the newly loaded state
150            if let Some(entry) = l1.get_mut(agent_id) {
151                return entry.state.check(1);
152            }
153        }
154
155        // Create new state
156        let mut state =
157            RateLimitState::new(agent_id, self.config.max_tokens, self.config.replenish_rate);
158        let decision = state.check(1);
159
160        let mut l1 = self.l1.write();
161        if l1.len() >= self.l1_capacity {
162            self.evict_lru(graph, &mut l1);
163        }
164        l1.insert(
165            agent_id.to_string(),
166            L1Entry {
167                state,
168                last_accessed: std::time::Instant::now(),
169            },
170        );
171        self.stats.write().l1_size = l1.len();
172
173        decision
174    }
175
176    /// Evict the least-recently-used entry from L1, flushing to L2.
177    fn evict_lru(&self, graph: &sqlitegraph::SqliteGraph, l1: &mut AHashMap<String, L1Entry>) {
178        let lru_key = l1
179            .iter()
180            .min_by_key(|(_, v)| v.last_accessed)
181            .map(|(k, _)| k.clone());
182
183        if let Some(key) = lru_key {
184            if let Some(entry) = l1.remove(&key) {
185                // Flush to L2
186                let _ = self.store.persist(graph, &entry.state);
187                self.stats.write().l1_evictions += 1;
188            }
189        }
190    }
191
192    /// Replenish tokens for all cached agents.
193    pub fn replenish_all(&self, elapsed: Duration) {
194        let mut l1 = self.l1.write();
195        for entry in l1.values_mut() {
196            entry.state.replenish(elapsed);
197        }
198    }
199
200    /// Ban an agent (permanent until unbanned).
201    pub fn ban_agent(
202        &self,
203        graph: &sqlitegraph::SqliteGraph,
204        agent_id: &str,
205        reason: &str,
206    ) -> crate::error::Result<()> {
207        self.banned
208            .write()
209            .insert(agent_id.to_string(), reason.to_string());
210
211        // Remove from L1 and persist ban to L2
212        self.l1.write().remove(agent_id);
213        let _ = self.store.persist_ban(graph, agent_id, reason);
214
215        // Update stats
216        let mut stats = self.stats.write();
217        stats.banned_agents = self.banned.read().keys().cloned().collect();
218        stats.l1_size = self.l1.read().len();
219
220        Ok(())
221    }
222
223    /// Unban an agent.
224    pub fn unban_agent(
225        &self,
226        graph: &sqlitegraph::SqliteGraph,
227        agent_id: &str,
228    ) -> crate::error::Result<()> {
229        self.banned.write().remove(agent_id);
230        let _ = self.store.remove_ban(graph, agent_id);
231
232        // Update stats
233        let mut stats = self.stats.write();
234        stats.banned_agents = self.banned.read().keys().cloned().collect();
235
236        Ok(())
237    }
238
239    /// Get current statistics.
240    pub fn stats(&self) -> HybridRateLimiterStats {
241        self.stats.read().clone()
242    }
243
244    /// Access the underlying store (for testing).
245    pub fn store(&self) -> &RateLimitStore {
246        &self.store
247    }
248
249    /// Clear L1 cache (for testing only).
250    #[doc(hidden)]
251    pub fn clear_l1_for_testing(&self) {
252        self.l1.write().clear();
253        self.stats.write().l1_size = 0;
254    }
255}