1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::num::NonZeroU32;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use governor::clock::{Clock, DefaultClock};
14use governor::state::keyed::DashMapStateStore;
15use governor::{Quota, RateLimiter};
16use lru::LruCache;
17use parking_lot::RwLock;
18use std::num::NonZeroUsize;
19use tracing::{debug, warn};
20
21use super::api::{
22 BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
23};
24use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
25use super::error::{PowError, RateLimitError};
26use super::extractors::AddressKey;
27use super::pow::PowCounterStore;
28use crate::prelude::*;
29
30type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
32
33struct TierLimiters {
35 short_term: Arc<KeyedLimiter>,
36 long_term: Arc<KeyedLimiter>,
37}
38
39impl TierLimiters {
40 fn new(config: &RateLimitTierConfig) -> Self {
41 let short_quota =
43 Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
44 let short_term = Arc::new(RateLimiter::keyed(short_quota));
45
46 let rps = config.long_term_rph.get() as f64 / 3600.0;
49 let period_nanos = (1_000_000_000.0 / rps) as u64;
50 const ONE: NonZeroU32 = match NonZeroU32::new(1) {
52 Some(v) => v,
53 None => unreachable!(),
54 };
55 let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
56 .unwrap_or_else(|| Quota::per_second(ONE))
57 .allow_burst(config.long_term_burst);
58 let long_term = Arc::new(RateLimiter::keyed(long_quota));
59
60 Self { short_term, long_term }
61 }
62
63 fn check(&self, key: &AddressKey) -> Result<(), Duration> {
65 if let Err(not_until) = self.short_term.check_key(key) {
67 return Err(not_until.wait_time_from(DefaultClock::default().now()));
68 }
69
70 if let Err(not_until) = self.long_term.check_key(key) {
72 return Err(not_until.wait_time_from(DefaultClock::default().now()));
73 }
74
75 Ok(())
76 }
77}
78
79struct CategoryLimiters {
81 ipv4_individual: TierLimiters,
82 ipv4_network: TierLimiters,
83 ipv6_subnet: TierLimiters,
84 ipv6_provider: TierLimiters,
85}
86
87impl CategoryLimiters {
88 fn new(config: &EndpointCategoryConfig) -> Self {
89 Self {
90 ipv4_individual: TierLimiters::new(&config.ipv4_individual),
91 ipv4_network: TierLimiters::new(&config.ipv4_network),
92 ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
93 ipv6_provider: TierLimiters::new(&config.ipv6_provider),
94 }
95 }
96
97 fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
99 let keys = AddressKey::extract_all(addr);
100
101 for key in keys {
102 let limiter = self.get_limiter_for_key(&key);
103 if let Err(wait_time) = limiter.check(&key) {
104 return Err(RateLimitError::RateLimited {
105 level: key.level_name(),
106 retry_after: wait_time,
107 });
108 }
109 }
110
111 Ok(())
112 }
113
114 fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
115 match key {
116 AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
117 AddressKey::Ipv4Network(_) => &self.ipv4_network,
118 AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
119 AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Default)]
126struct PenaltyEntry {
127 count: u32,
128 last_penalty: Option<Instant>,
129 reason: Option<PenaltyReason>,
130}
131
132pub struct RateLimitManager {
134 categories: HashMap<String, CategoryLimiters>,
136 bans: RwLock<LruCache<AddressKey, BanEntry>>,
138 penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
140 pow_store: PowCounterStore,
142 total_limited: AtomicU64,
144 total_bans: AtomicU64,
145}
146
147impl RateLimitManager {
148 pub fn new(config: RateLimitConfig) -> Self {
150 let mut categories = HashMap::new();
151
152 categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
154 categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
155 categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
156 categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
157
158 const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
160 Some(v) => v,
161 None => unreachable!(),
162 };
163 const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
164 Some(v) => v,
165 None => unreachable!(),
166 };
167 let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(TEN_THOUSAND);
168 let penalty_cap = NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(TWENTY_THOUSAND);
169
170 Self {
171 categories,
172 bans: RwLock::new(LruCache::new(ban_cap)),
173 penalties: RwLock::new(LruCache::new(penalty_cap)),
174 pow_store: PowCounterStore::new(PowConfig::default()),
175 total_limited: AtomicU64::new(0),
176 total_bans: AtomicU64::new(0),
177 }
178 }
179
180 pub fn with_pow_config(config: RateLimitConfig, pow_config: PowConfig) -> Self {
182 let mut manager = Self::new(config);
183 manager.pow_store = PowCounterStore::new(pow_config);
184 manager
185 }
186
187 pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
189 if let Some(ban) = self.check_ban(addr) {
191 return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
192 }
193
194 let cat_limiters = self
196 .categories
197 .get(category)
198 .ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
199
200 if let Err(e) = cat_limiters.check(addr) {
201 self.total_limited.fetch_add(1, Ordering::Relaxed);
202 return Err(e);
203 }
204
205 Ok(())
206 }
207
208 fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
210 let keys = AddressKey::extract_all(addr);
211 let mut bans = self.bans.write();
212
213 for key in keys {
214 if let Some(ban) = bans.get(&key) {
215 if ban.is_expired() {
216 bans.pop(&key);
217 } else {
218 return Some(ban.clone());
219 }
220 }
221 }
222
223 None
224 }
225
226 fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
228 let key = AddressKey::from_ip_individual(addr);
229 let mut penalties = self.penalties.write();
230
231 let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
232 entry.count = entry.count.saturating_add(amount);
233 entry.last_penalty = Some(Instant::now());
234 entry.reason = Some(reason);
235
236 if entry.count >= reason.failures_to_ban() {
238 drop(penalties);
239 if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
240 warn!("Failed to auto-ban address: {}", e);
241 }
242 }
243 }
244}
245
246impl Default for RateLimitManager {
247 fn default() -> Self {
248 Self::new(RateLimitConfig::default())
249 }
250}
251
252impl RateLimitApi for RateLimitManager {
253 fn get_status(
254 &self,
255 addr: &IpAddr,
256 category: &str,
257 ) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
258 let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
259
260 let keys = AddressKey::extract_all(addr);
261 let bans = self.bans.read();
262
263 let statuses =
264 keys.into_iter()
265 .map(|key| {
266 let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
267 let ban_expires = bans.peek(&key).and_then(|b| {
268 if b.is_expired() {
269 None
270 } else {
271 Some(b.expires_at.unwrap_or_else(|| {
272 Instant::now() + Duration::from_secs(86400 * 365)
273 }))
274 }
275 });
276
277 let status = RateLimitStatus {
278 is_limited: false, remaining: None,
280 reset_at: None,
281 quota: 0,
282 is_banned,
283 ban_expires_at: ban_expires,
284 };
285
286 (key, status)
287 })
288 .collect();
289
290 Ok(statuses)
291 }
292
293 fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
294 debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
295 self.record_penalty(addr, reason, amount);
296 Ok(())
297 }
298
299 fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
300 let key = AddressKey::from_ip_individual(addr);
301 let mut penalties = self.penalties.write();
302
303 if let Some(entry) = penalties.get_mut(&key) {
304 entry.count = entry.count.saturating_sub(amount);
305 if entry.count == 0 {
306 penalties.pop(&key);
307 }
308 }
309
310 Ok(())
311 }
312
313 fn reset(&self, addr: &IpAddr) -> ClResult<()> {
314 let keys = AddressKey::extract_all(addr);
315
316 let mut penalties = self.penalties.write();
318 for key in &keys {
319 penalties.pop(key);
320 }
321 drop(penalties);
322
323 let mut bans = self.bans.write();
325 for key in &keys {
326 bans.pop(key);
327 }
328
329 self.pow_store.decrement(addr, u32::MAX);
331
332 Ok(())
333 }
334
335 fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
336 let keys = AddressKey::extract_all(addr);
337 let now = Instant::now();
338 let expires_at = Some(now + duration);
339
340 let mut bans = self.bans.write();
341 for key in keys {
342 let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
343 bans.put(key, entry);
344 }
345
346 self.total_bans.fetch_add(1, Ordering::Relaxed);
347 debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
348
349 Ok(())
350 }
351
352 fn unban(&self, addr: &IpAddr) -> ClResult<()> {
353 let keys = AddressKey::extract_all(addr);
354 let mut bans = self.bans.write();
355
356 for key in keys {
357 bans.pop(&key);
358 }
359
360 Ok(())
361 }
362
363 fn is_banned(&self, addr: &IpAddr) -> bool {
364 self.check_ban(addr).is_some()
365 }
366
367 fn list_bans(&self) -> Vec<BanEntry> {
368 self.bans
369 .read()
370 .iter()
371 .filter(|(_, b)| !b.is_expired())
372 .map(|(_, b)| b.clone())
373 .collect()
374 }
375
376 fn stats(&self) -> RateLimiterStats {
377 let tracked = self
379 .categories
380 .values()
381 .map(|c| {
382 c.ipv4_individual.short_term.len()
383 + c.ipv4_network.short_term.len()
384 + c.ipv6_subnet.short_term.len()
385 + c.ipv6_provider.short_term.len()
386 })
387 .sum();
388
389 RateLimiterStats {
390 tracked_addresses: tracked,
391 active_bans: self.bans.read().len(),
392 total_requests_limited: self.total_limited.load(Ordering::Relaxed),
393 total_bans_issued: self.total_bans.load(Ordering::Relaxed),
394 pow_individual_entries: self.pow_store.individual_count(),
395 pow_network_entries: self.pow_store.network_count(),
396 }
397 }
398
399 fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
400 self.pow_store.get_requirement(addr)
401 }
402
403 fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
404 self.pow_store.increment(addr, reason);
405 Ok(())
406 }
407
408 fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
409 self.pow_store.decrement(addr, amount);
410 Ok(())
411 }
412
413 fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
414 self.pow_store.verify(addr, token)
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use std::net::Ipv4Addr;
422
423 #[test]
424 fn test_rate_limit_manager_creation() {
425 let manager = RateLimitManager::default();
426 assert!(manager.categories.contains_key("auth"));
427 assert!(manager.categories.contains_key("federation"));
428 assert!(manager.categories.contains_key("general"));
429 assert!(manager.categories.contains_key("websocket"));
430 }
431
432 #[test]
433 fn test_rate_limit_check() {
434 let manager = RateLimitManager::default();
435 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
436
437 for _ in 0..5 {
439 assert!(manager.check(&ip, "general").is_ok());
440 }
441 }
442
443 #[test]
444 fn test_unknown_category() {
445 let manager = RateLimitManager::default();
446 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
447
448 let result = manager.check(&ip, "nonexistent");
449 assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
450 }
451
452 #[test]
453 fn test_ban_functionality() {
454 let manager = RateLimitManager::default();
455 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
456
457 assert!(!manager.is_banned(&ip));
458
459 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
460 assert!(manager.is_banned(&ip));
461
462 let result = manager.check(&ip, "general");
463 assert!(matches!(result, Err(RateLimitError::Banned { .. })));
464
465 manager.unban(&ip).unwrap();
466 assert!(!manager.is_banned(&ip));
467 }
468
469 #[test]
470 fn test_penalty_auto_ban() {
471 let manager = RateLimitManager::default();
472 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
473
474 for _ in 0..4 {
476 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
477 assert!(!manager.is_banned(&ip));
478 }
479
480 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
482 assert!(manager.is_banned(&ip));
483 }
484
485 #[test]
486 fn test_pow_integration() {
487 let manager = RateLimitManager::default();
488 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
489
490 assert_eq!(manager.get_pow_requirement(&ip), 0);
492 assert!(manager.verify_pow(&ip, "any_token").is_ok());
493
494 manager
496 .increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
497 .unwrap();
498 assert_eq!(manager.get_pow_requirement(&ip), 1);
499
500 assert!(manager.verify_pow(&ip, "any_token").is_err());
502 assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
503 }
504
505 #[test]
506 fn test_stats() {
507 let manager = RateLimitManager::default();
508 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
509
510 let stats = manager.stats();
511 assert_eq!(stats.active_bans, 0);
512 assert_eq!(stats.total_bans_issued, 0);
513
514 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
515
516 let stats = manager.stats();
517 assert!(stats.active_bans > 0);
518 assert_eq!(stats.total_bans_issued, 1);
519 }
520}
521
522