1pub mod config;
10pub mod state;
11pub mod store;
12
13pub use config::RateLimitConfig;
14pub use state::{RateLimitDecision, RateLimitState, TokenBucket};
15pub use store::RateLimitStore;
16
17use ahash::AHashMap;
18use parking_lot::RwLock;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::Duration;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HybridRateLimiterStats {
27 pub l1_size: usize,
28 pub l1_capacity: usize,
29 pub l1_hits: u64,
30 pub l1_misses: u64,
31 pub l1_evictions: u64,
32 pub banned_agents: Vec<String>,
33}
34
35#[derive(Debug, Clone)]
37struct L1Entry {
38 state: RateLimitState,
39 last_accessed: std::time::Instant,
40}
41
42pub struct HybridRateLimiter {
50 store: RateLimitStore,
51 config: RateLimitConfig,
52 l1: Arc<RwLock<AHashMap<String, L1Entry>>>,
53 l1_capacity: usize,
54 banned: Arc<RwLock<HashMap<String, String>>>, stats: Arc<RwLock<HybridRateLimiterStats>>,
56}
57
58impl HybridRateLimiter {
59 pub fn new(
66 graph: &sqlitegraph::SqliteGraph,
67 config: RateLimitConfig,
68 l1_capacity: usize,
69 ) -> crate::error::Result<Self> {
70 let store = RateLimitStore::new();
71
72 let mut l1_map = AHashMap::new();
74 if let Ok(states) = store.load_all(graph) {
75 for state in states {
76 l1_map.insert(
77 state.agent_id.clone(),
78 L1Entry {
79 state,
80 last_accessed: std::time::Instant::now(),
81 },
82 );
83 }
84 }
85 let l1 = Arc::new(RwLock::new(l1_map));
86
87 let stats = HybridRateLimiterStats {
88 l1_size: l1.read().len(),
89 l1_capacity,
90 l1_hits: 0,
91 l1_misses: 0,
92 l1_evictions: 0,
93 banned_agents: Vec::new(),
94 };
95
96 Ok(Self {
97 store,
98 config,
99 l1,
100 l1_capacity,
101 banned: Arc::new(RwLock::new(HashMap::new())),
102 stats: Arc::new(RwLock::new(stats)),
103 })
104 }
105
106 pub fn check_rate_limit(
110 &self,
111 graph: &sqlitegraph::SqliteGraph,
112 agent_id: &str,
113 ) -> RateLimitDecision {
114 if self.banned.read().contains_key(agent_id) {
116 return RateLimitDecision {
117 allowed: false,
118 retry_after: Some(Duration::from_secs(3600)), };
120 }
121
122 {
124 let mut l1 = self.l1.write();
125 if let Some(entry) = l1.get_mut(agent_id) {
126 entry.last_accessed = std::time::Instant::now();
127 self.stats.write().l1_hits += 1;
128 return entry.state.check(1);
129 }
130 }
131
132 self.stats.write().l1_misses += 1;
134 if let Ok(Some(state)) = self.store.load(graph, agent_id) {
135 let mut l1 = self.l1.write();
137 if l1.len() >= self.l1_capacity {
138 self.evict_lru(graph, &mut l1);
139 }
140 l1.insert(
141 agent_id.to_string(),
142 L1Entry {
143 state,
144 last_accessed: std::time::Instant::now(),
145 },
146 );
147 self.stats.write().l1_size = l1.len();
148
149 if let Some(entry) = l1.get_mut(agent_id) {
151 return entry.state.check(1);
152 }
153 }
154
155 let mut state =
157 RateLimitState::new(agent_id, self.config.max_tokens, self.config.replenish_rate);
158 let decision = state.check(1);
159
160 let mut l1 = self.l1.write();
161 if l1.len() >= self.l1_capacity {
162 self.evict_lru(graph, &mut l1);
163 }
164 l1.insert(
165 agent_id.to_string(),
166 L1Entry {
167 state,
168 last_accessed: std::time::Instant::now(),
169 },
170 );
171 self.stats.write().l1_size = l1.len();
172
173 decision
174 }
175
176 fn evict_lru(&self, graph: &sqlitegraph::SqliteGraph, l1: &mut AHashMap<String, L1Entry>) {
178 let lru_key = l1
179 .iter()
180 .min_by_key(|(_, v)| v.last_accessed)
181 .map(|(k, _)| k.clone());
182
183 if let Some(key) = lru_key {
184 if let Some(entry) = l1.remove(&key) {
185 let _ = self.store.persist(graph, &entry.state);
187 self.stats.write().l1_evictions += 1;
188 }
189 }
190 }
191
192 pub fn replenish_all(&self, elapsed: Duration) {
194 let mut l1 = self.l1.write();
195 for entry in l1.values_mut() {
196 entry.state.replenish(elapsed);
197 }
198 }
199
200 pub fn ban_agent(
202 &self,
203 graph: &sqlitegraph::SqliteGraph,
204 agent_id: &str,
205 reason: &str,
206 ) -> crate::error::Result<()> {
207 self.banned
208 .write()
209 .insert(agent_id.to_string(), reason.to_string());
210
211 self.l1.write().remove(agent_id);
213 let _ = self.store.persist_ban(graph, agent_id, reason);
214
215 let mut stats = self.stats.write();
217 stats.banned_agents = self.banned.read().keys().cloned().collect();
218 stats.l1_size = self.l1.read().len();
219
220 Ok(())
221 }
222
223 pub fn unban_agent(
225 &self,
226 graph: &sqlitegraph::SqliteGraph,
227 agent_id: &str,
228 ) -> crate::error::Result<()> {
229 self.banned.write().remove(agent_id);
230 let _ = self.store.remove_ban(graph, agent_id);
231
232 let mut stats = self.stats.write();
234 stats.banned_agents = self.banned.read().keys().cloned().collect();
235
236 Ok(())
237 }
238
239 pub fn stats(&self) -> HybridRateLimiterStats {
241 self.stats.read().clone()
242 }
243
244 pub fn store(&self) -> &RateLimitStore {
246 &self.store
247 }
248
249 #[doc(hidden)]
251 pub fn clear_l1_for_testing(&self) {
252 self.l1.write().clear();
253 self.stats.write().l1_size = 0;
254 }
255}