1use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use lru::LruCache;
20use parking_lot::RwLock;
21use std::num::NonZeroUsize;
22use tracing::{debug, warn};
23
24use super::api::{
25 BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
26};
27use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
28use super::error::{PowError, RateLimitError};
29use super::extractors::AddressKey;
30use super::pow::PowCounterStore;
31use crate::prelude::*;
32
33type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
35
36struct TierLimiters {
38 short_term: Arc<KeyedLimiter>,
39 long_term: Arc<KeyedLimiter>,
40}
41
42impl TierLimiters {
43 const ONE: NonZeroU32 = match NonZeroU32::new(1) {
45 Some(v) => v,
46 None => unreachable!(),
47 };
48
49 fn new(config: &RateLimitTierConfig) -> Self {
50 let short_quota =
52 Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
53 let short_term = Arc::new(RateLimiter::keyed(short_quota));
54
55 let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
59 let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
60 .unwrap_or_else(|| Quota::per_second(Self::ONE))
61 .allow_burst(config.long_term_burst);
62 let long_term = Arc::new(RateLimiter::keyed(long_quota));
63
64 Self { short_term, long_term }
65 }
66
67 fn check(&self, key: &AddressKey) -> Result<(), Duration> {
69 if let Err(not_until) = self.short_term.check_key(key) {
71 return Err(not_until.wait_time_from(DefaultClock::default().now()));
72 }
73
74 if let Err(not_until) = self.long_term.check_key(key) {
76 return Err(not_until.wait_time_from(DefaultClock::default().now()));
77 }
78
79 Ok(())
80 }
81}
82
83struct CategoryLimiters {
85 ipv4_individual: TierLimiters,
86 ipv4_network: TierLimiters,
87 ipv6_subnet: TierLimiters,
88 ipv6_provider: TierLimiters,
89}
90
91impl CategoryLimiters {
92 fn new(config: &EndpointCategoryConfig) -> Self {
93 Self {
94 ipv4_individual: TierLimiters::new(&config.ipv4_individual),
95 ipv4_network: TierLimiters::new(&config.ipv4_network),
96 ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
97 ipv6_provider: TierLimiters::new(&config.ipv6_provider),
98 }
99 }
100
101 fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
103 let keys = AddressKey::extract_all(addr);
104
105 for key in keys {
106 let limiter = self.get_limiter_for_key(&key);
107 if let Err(wait_time) = limiter.check(&key) {
108 return Err(RateLimitError::RateLimited {
109 level: key.level_name(),
110 retry_after: wait_time,
111 });
112 }
113 }
114
115 Ok(())
116 }
117
118 fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
119 match key {
120 AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
121 AddressKey::Ipv4Network(_) => &self.ipv4_network,
122 AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
123 AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Default)]
130struct PenaltyEntry {
131 count: u32,
132 last_penalty: Option<Instant>,
133 reason: Option<PenaltyReason>,
134}
135
136pub struct RateLimitManager {
138 categories: HashMap<String, CategoryLimiters>,
140 bans: RwLock<LruCache<AddressKey, BanEntry>>,
142 penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
144 pow_store: PowCounterStore,
146 total_limited: AtomicU64,
148 total_bans: AtomicU64,
149}
150
151impl RateLimitManager {
152 const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
154 Some(v) => v,
155 None => unreachable!(),
156 };
157 const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
158 Some(v) => v,
159 None => unreachable!(),
160 };
161
162 pub fn new(config: &RateLimitConfig) -> Self {
164 let mut categories = HashMap::new();
165
166 categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
168 categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
169 categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
170 categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
171
172 let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
173 let penalty_cap =
174 NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
175
176 Self {
177 categories,
178 bans: RwLock::new(LruCache::new(ban_cap)),
179 penalties: RwLock::new(LruCache::new(penalty_cap)),
180 pow_store: PowCounterStore::new(PowConfig::default()),
181 total_limited: AtomicU64::new(0),
182 total_bans: AtomicU64::new(0),
183 }
184 }
185
186 pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
188 let mut manager = Self::new(config);
189 manager.pow_store = PowCounterStore::new(pow_config);
190 manager
191 }
192
193 pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
195 if let Some(ban) = self.check_ban(addr) {
197 return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
198 }
199
200 let cat_limiters = self
202 .categories
203 .get(category)
204 .ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
205
206 if let Err(e) = cat_limiters.check(addr) {
207 self.total_limited.fetch_add(1, Ordering::Relaxed);
208 return Err(e);
209 }
210
211 Ok(())
212 }
213
214 fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
216 let keys = AddressKey::extract_all(addr);
217 let mut bans = self.bans.write();
218
219 for key in keys {
220 if let Some(ban) = bans.get(&key) {
221 if ban.is_expired() {
222 bans.pop(&key);
223 } else {
224 return Some(ban.clone());
225 }
226 }
227 }
228
229 None
230 }
231
232 fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
234 let key = AddressKey::from_ip_individual(addr);
235 let mut penalties = self.penalties.write();
236
237 let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
238 entry.count = entry.count.saturating_add(amount);
239 entry.last_penalty = Some(Instant::now());
240 entry.reason = Some(reason);
241
242 if entry.count >= reason.failures_to_ban() {
244 drop(penalties);
245 if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
246 warn!("Failed to auto-ban address: {}", e);
247 }
248 }
249 }
250}
251
252impl Default for RateLimitManager {
253 fn default() -> Self {
254 Self::new(&RateLimitConfig::default())
255 }
256}
257
258impl RateLimitApi for RateLimitManager {
259 fn get_status(
260 &self,
261 addr: &IpAddr,
262 category: &str,
263 ) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
264 let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
265
266 let keys = AddressKey::extract_all(addr);
267 let bans = self.bans.read();
268
269 let statuses =
270 keys.into_iter()
271 .map(|key| {
272 let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
273 let ban_expires = bans.peek(&key).and_then(|b| {
274 if b.is_expired() {
275 None
276 } else {
277 Some(b.expires_at.unwrap_or_else(|| {
278 Instant::now() + Duration::from_secs(86400 * 365)
279 }))
280 }
281 });
282
283 let status = RateLimitStatus {
284 is_limited: false, remaining: None,
286 reset_at: None,
287 quota: 0,
288 is_banned,
289 ban_expires_at: ban_expires,
290 };
291
292 (key, status)
293 })
294 .collect();
295
296 Ok(statuses)
297 }
298
299 fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
300 debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
301 self.record_penalty(addr, reason, amount);
302 Ok(())
303 }
304
305 fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
306 let key = AddressKey::from_ip_individual(addr);
307 let mut penalties = self.penalties.write();
308
309 if let Some(entry) = penalties.get_mut(&key) {
310 entry.count = entry.count.saturating_sub(amount);
311 if entry.count == 0 {
312 penalties.pop(&key);
313 }
314 }
315
316 Ok(())
317 }
318
319 fn reset(&self, addr: &IpAddr) -> ClResult<()> {
320 let keys = AddressKey::extract_all(addr);
321
322 let mut penalties = self.penalties.write();
324 for key in &keys {
325 penalties.pop(key);
326 }
327 drop(penalties);
328
329 let mut bans = self.bans.write();
331 for key in &keys {
332 bans.pop(key);
333 }
334
335 self.pow_store.decrement(addr, u32::MAX);
337
338 Ok(())
339 }
340
341 fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
342 let keys = AddressKey::extract_all(addr);
343 let now = Instant::now();
344 let expires_at = Some(now + duration);
345
346 let mut bans = self.bans.write();
347 for key in keys {
348 let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
349 bans.put(key, entry);
350 }
351
352 self.total_bans.fetch_add(1, Ordering::Relaxed);
353 debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
354
355 Ok(())
356 }
357
358 fn unban(&self, addr: &IpAddr) -> ClResult<()> {
359 let keys = AddressKey::extract_all(addr);
360 let mut bans = self.bans.write();
361
362 for key in keys {
363 bans.pop(&key);
364 }
365
366 Ok(())
367 }
368
369 fn is_banned(&self, addr: &IpAddr) -> bool {
370 self.check_ban(addr).is_some()
371 }
372
373 fn list_bans(&self) -> Vec<BanEntry> {
374 self.bans
375 .read()
376 .iter()
377 .filter(|(_, b)| !b.is_expired())
378 .map(|(_, b)| b.clone())
379 .collect()
380 }
381
382 fn stats(&self) -> RateLimiterStats {
383 let tracked = self
385 .categories
386 .values()
387 .map(|c| {
388 c.ipv4_individual.short_term.len()
389 + c.ipv4_network.short_term.len()
390 + c.ipv6_subnet.short_term.len()
391 + c.ipv6_provider.short_term.len()
392 })
393 .sum();
394
395 RateLimiterStats {
396 tracked_addresses: tracked,
397 active_bans: self.bans.read().len(),
398 total_requests_limited: self.total_limited.load(Ordering::Relaxed),
399 total_bans_issued: self.total_bans.load(Ordering::Relaxed),
400 pow_individual_entries: self.pow_store.individual_count(),
401 pow_network_entries: self.pow_store.network_count(),
402 }
403 }
404
405 fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
406 self.pow_store.get_requirement(addr)
407 }
408
409 fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
410 self.pow_store.increment(addr, reason);
411 Ok(())
412 }
413
414 fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
415 self.pow_store.decrement(addr, amount);
416 Ok(())
417 }
418
419 fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
420 self.pow_store.verify(addr, token)
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use std::net::Ipv4Addr;
428
429 #[test]
430 fn test_rate_limit_manager_creation() {
431 let manager = RateLimitManager::default();
432 assert!(manager.categories.contains_key("auth"));
433 assert!(manager.categories.contains_key("federation"));
434 assert!(manager.categories.contains_key("general"));
435 assert!(manager.categories.contains_key("websocket"));
436 }
437
438 #[test]
439 fn test_rate_limit_check() {
440 let manager = RateLimitManager::default();
441 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
442
443 for _ in 0..5 {
445 assert!(manager.check(&ip, "general").is_ok());
446 }
447 }
448
449 #[test]
450 fn test_unknown_category() {
451 let manager = RateLimitManager::default();
452 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
453
454 let result = manager.check(&ip, "nonexistent");
455 assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
456 }
457
458 #[test]
459 fn test_ban_functionality() {
460 let manager = RateLimitManager::default();
461 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
462
463 assert!(!manager.is_banned(&ip));
464
465 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
466 assert!(manager.is_banned(&ip));
467
468 let result = manager.check(&ip, "general");
469 assert!(matches!(result, Err(RateLimitError::Banned { .. })));
470
471 manager.unban(&ip).unwrap();
472 assert!(!manager.is_banned(&ip));
473 }
474
475 #[test]
476 fn test_penalty_auto_ban() {
477 let manager = RateLimitManager::default();
478 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
479
480 for _ in 0..4 {
482 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
483 assert!(!manager.is_banned(&ip));
484 }
485
486 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
488 assert!(manager.is_banned(&ip));
489 }
490
491 #[test]
492 fn test_pow_integration() {
493 let manager = RateLimitManager::default();
494 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
495
496 assert_eq!(manager.get_pow_requirement(&ip), 0);
498 assert!(manager.verify_pow(&ip, "any_token").is_ok());
499
500 manager
502 .increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
503 .unwrap();
504 assert_eq!(manager.get_pow_requirement(&ip), 1);
505
506 assert!(manager.verify_pow(&ip, "any_token").is_err());
508 assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
509 }
510
511 #[test]
512 fn test_stats() {
513 let manager = RateLimitManager::default();
514 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
515
516 let stats = manager.stats();
517 assert_eq!(stats.active_bans, 0);
518 assert_eq!(stats.total_bans_issued, 0);
519
520 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
521
522 let stats = manager.stats();
523 assert!(stats.active_bans > 0);
524 assert_eq!(stats.total_bans_issued, 1);
525 }
526}
527
528