1use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::time::{Duration, Instant};
15
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use lru::LruCache;
20use parking_lot::RwLock;
21use std::num::NonZeroUsize;
22use tracing::{debug, warn};
23
24use super::api::{
25 BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
26};
27use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
28use super::error::{PowError, RateLimitError};
29use super::extractors::AddressKey;
30use super::pow::PowCounterStore;
31use crate::prelude::*;
32
33type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
35
36struct TierLimiters {
38 short_term: Arc<KeyedLimiter>,
39 long_term: Arc<KeyedLimiter>,
40}
41
42impl TierLimiters {
43 const ONE: NonZeroU32 = match NonZeroU32::new(1) {
45 Some(v) => v,
46 None => unreachable!(),
47 };
48
49 fn new(config: &RateLimitTierConfig) -> Self {
50 let short_quota =
52 Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
53 let short_term = Arc::new(RateLimiter::keyed(short_quota));
54
55 let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
59 let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
60 .unwrap_or_else(|| Quota::per_second(Self::ONE))
61 .allow_burst(config.long_term_burst);
62 let long_term = Arc::new(RateLimiter::keyed(long_quota));
63
64 Self { short_term, long_term }
65 }
66
67 fn check(&self, key: &AddressKey) -> Result<(), Duration> {
69 if let Err(not_until) = self.short_term.check_key(key) {
71 return Err(not_until.wait_time_from(DefaultClock::default().now()));
72 }
73
74 if let Err(not_until) = self.long_term.check_key(key) {
76 return Err(not_until.wait_time_from(DefaultClock::default().now()));
77 }
78
79 Ok(())
80 }
81}
82
83struct CategoryLimiters {
85 ipv4_individual: TierLimiters,
86 ipv4_network: TierLimiters,
87 ipv6_subnet: TierLimiters,
88 ipv6_provider: TierLimiters,
89}
90
91impl CategoryLimiters {
92 fn new(config: &EndpointCategoryConfig) -> Self {
93 Self {
94 ipv4_individual: TierLimiters::new(&config.ipv4_individual),
95 ipv4_network: TierLimiters::new(&config.ipv4_network),
96 ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
97 ipv6_provider: TierLimiters::new(&config.ipv6_provider),
98 }
99 }
100
101 fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
103 let keys = AddressKey::extract_all(addr);
104
105 for key in keys {
106 let limiter = self.get_limiter_for_key(&key);
107 if let Err(wait_time) = limiter.check(&key) {
108 return Err(RateLimitError::RateLimited {
109 level: key.level_name(),
110 retry_after: wait_time,
111 });
112 }
113 }
114
115 Ok(())
116 }
117
118 fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
119 match key {
120 AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
121 AddressKey::Ipv4Network(_) => &self.ipv4_network,
122 AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
123 AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
124 }
125 }
126}
127
128#[derive(Debug, Clone, Default)]
130struct PenaltyEntry {
131 count: u32,
132 last_penalty: Option<Instant>,
133 reason: Option<PenaltyReason>,
134}
135
136pub struct RateLimitManager {
138 categories: HashMap<String, CategoryLimiters>,
140 bans: RwLock<LruCache<AddressKey, BanEntry>>,
142 penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
144 pow_store: PowCounterStore,
146 total_limited: AtomicU64,
148 total_bans: AtomicU64,
149}
150
151impl RateLimitManager {
152 const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
154 Some(v) => v,
155 None => unreachable!(),
156 };
157 const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
158 Some(v) => v,
159 None => unreachable!(),
160 };
161
162 pub fn new(config: &RateLimitConfig) -> Self {
164 let mut categories = HashMap::new();
165
166 categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
168 categories.insert("dav".to_string(), CategoryLimiters::new(&config.dav));
169 categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
170 categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
171 categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
172
173 let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
174 let penalty_cap =
175 NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
176
177 Self {
178 categories,
179 bans: RwLock::new(LruCache::new(ban_cap)),
180 penalties: RwLock::new(LruCache::new(penalty_cap)),
181 pow_store: PowCounterStore::new(PowConfig::default()),
182 total_limited: AtomicU64::new(0),
183 total_bans: AtomicU64::new(0),
184 }
185 }
186
187 pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
189 let mut manager = Self::new(config);
190 manager.pow_store = PowCounterStore::new(pow_config);
191 manager
192 }
193
194 pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
196 if let Some(ban) = self.check_ban(addr) {
198 return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
199 }
200
201 let cat_limiters = self
203 .categories
204 .get(category)
205 .ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
206
207 if let Err(e) = cat_limiters.check(addr) {
208 self.total_limited.fetch_add(1, Ordering::Relaxed);
209 return Err(e);
210 }
211
212 Ok(())
213 }
214
215 fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
217 let keys = AddressKey::extract_all(addr);
218 let mut bans = self.bans.write();
219
220 for key in keys {
221 if let Some(ban) = bans.get(&key) {
222 if ban.is_expired() {
223 bans.pop(&key);
224 } else {
225 return Some(ban.clone());
226 }
227 }
228 }
229
230 None
231 }
232
233 fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
235 let key = AddressKey::from_ip_individual(addr);
236 let mut penalties = self.penalties.write();
237
238 let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
239 entry.count = entry.count.saturating_add(amount);
240 entry.last_penalty = Some(Instant::now());
241 entry.reason = Some(reason);
242
243 if entry.count >= reason.failures_to_ban() {
245 drop(penalties);
246 if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
247 warn!("Failed to auto-ban address: {}", e);
248 }
249 }
250 }
251}
252
253impl Default for RateLimitManager {
254 fn default() -> Self {
255 Self::new(&RateLimitConfig::default())
256 }
257}
258
259impl RateLimitApi for RateLimitManager {
260 fn get_status(
261 &self,
262 addr: &IpAddr,
263 category: &str,
264 ) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
265 let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
266
267 let keys = AddressKey::extract_all(addr);
268 let bans = self.bans.read();
269
270 let statuses = keys
271 .into_iter()
272 .map(|key| {
273 let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
274 let ban_expires = bans.peek(&key).and_then(|b| {
275 if b.is_expired() {
276 None
277 } else {
278 Some(
279 b.expires_at
280 .unwrap_or_else(|| Instant::now() + Duration::from_hours(24 * 365)),
281 )
282 }
283 });
284
285 let status = RateLimitStatus {
286 is_limited: false, remaining: None,
288 reset_at: None,
289 quota: 0,
290 is_banned,
291 ban_expires_at: ban_expires,
292 };
293
294 (key, status)
295 })
296 .collect();
297
298 Ok(statuses)
299 }
300
301 fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
302 debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
303 self.record_penalty(addr, reason, amount);
304 Ok(())
305 }
306
307 fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
308 let key = AddressKey::from_ip_individual(addr);
309 let mut penalties = self.penalties.write();
310
311 if let Some(entry) = penalties.get_mut(&key) {
312 entry.count = entry.count.saturating_sub(amount);
313 if entry.count == 0 {
314 penalties.pop(&key);
315 }
316 }
317
318 Ok(())
319 }
320
321 fn reset(&self, addr: &IpAddr) -> ClResult<()> {
322 let keys = AddressKey::extract_all(addr);
323
324 let mut penalties = self.penalties.write();
326 for key in &keys {
327 penalties.pop(key);
328 }
329 drop(penalties);
330
331 let mut bans = self.bans.write();
333 for key in &keys {
334 bans.pop(key);
335 }
336
337 self.pow_store.decrement(addr, u32::MAX);
339
340 Ok(())
341 }
342
343 fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
344 let keys = AddressKey::extract_all(addr);
345 let now = Instant::now();
346 let expires_at = Some(now + duration);
347
348 let mut bans = self.bans.write();
349 for key in keys {
350 let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
351 bans.put(key, entry);
352 }
353
354 self.total_bans.fetch_add(1, Ordering::Relaxed);
355 debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
356
357 Ok(())
358 }
359
360 fn unban(&self, addr: &IpAddr) -> ClResult<()> {
361 let keys = AddressKey::extract_all(addr);
362 let mut bans = self.bans.write();
363
364 for key in keys {
365 bans.pop(&key);
366 }
367
368 Ok(())
369 }
370
371 fn is_banned(&self, addr: &IpAddr) -> bool {
372 self.check_ban(addr).is_some()
373 }
374
375 fn list_bans(&self) -> Vec<BanEntry> {
376 self.bans
377 .read()
378 .iter()
379 .filter(|(_, b)| !b.is_expired())
380 .map(|(_, b)| b.clone())
381 .collect()
382 }
383
384 fn stats(&self) -> RateLimiterStats {
385 let tracked = self
387 .categories
388 .values()
389 .map(|c| {
390 c.ipv4_individual.short_term.len()
391 + c.ipv4_network.short_term.len()
392 + c.ipv6_subnet.short_term.len()
393 + c.ipv6_provider.short_term.len()
394 })
395 .sum();
396
397 RateLimiterStats {
398 tracked_addresses: tracked,
399 active_bans: self.bans.read().len(),
400 total_requests_limited: self.total_limited.load(Ordering::Relaxed),
401 total_bans_issued: self.total_bans.load(Ordering::Relaxed),
402 pow_individual_entries: self.pow_store.individual_count(),
403 pow_network_entries: self.pow_store.network_count(),
404 }
405 }
406
407 fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
408 self.pow_store.get_requirement(addr)
409 }
410
411 fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
412 self.pow_store.increment(addr, reason);
413 Ok(())
414 }
415
416 fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
417 self.pow_store.decrement(addr, amount);
418 Ok(())
419 }
420
421 fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
422 self.pow_store.verify(addr, token)
423 }
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used, clippy::expect_used)]
428mod tests {
429 use super::*;
430 use std::net::Ipv4Addr;
431
432 #[test]
433 fn test_rate_limit_manager_creation() {
434 let manager = RateLimitManager::default();
435 assert!(manager.categories.contains_key("auth"));
436 assert!(manager.categories.contains_key("federation"));
437 assert!(manager.categories.contains_key("general"));
438 assert!(manager.categories.contains_key("websocket"));
439 }
440
441 #[test]
442 fn test_rate_limit_check() {
443 let manager = RateLimitManager::default();
444 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
445
446 for _ in 0..5 {
448 assert!(manager.check(&ip, "general").is_ok());
449 }
450 }
451
452 #[test]
453 fn test_unknown_category() {
454 let manager = RateLimitManager::default();
455 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
456
457 let result = manager.check(&ip, "nonexistent");
458 assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
459 }
460
461 #[test]
462 fn test_ban_functionality() {
463 let manager = RateLimitManager::default();
464 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
465
466 assert!(!manager.is_banned(&ip));
467
468 manager.ban(&ip, Duration::from_mins(1), PenaltyReason::AuthFailure).unwrap();
469 assert!(manager.is_banned(&ip));
470
471 let result = manager.check(&ip, "general");
472 assert!(matches!(result, Err(RateLimitError::Banned { .. })));
473
474 manager.unban(&ip).unwrap();
475 assert!(!manager.is_banned(&ip));
476 }
477
478 #[test]
479 fn test_penalty_auto_ban() {
480 let manager = RateLimitManager::default();
481 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
482
483 for _ in 0..4 {
485 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
486 assert!(!manager.is_banned(&ip));
487 }
488
489 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
491 assert!(manager.is_banned(&ip));
492 }
493
494 #[test]
495 fn test_pow_integration() {
496 let manager = RateLimitManager::default();
497 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
498
499 assert_eq!(manager.get_pow_requirement(&ip), 0);
501 assert!(manager.verify_pow(&ip, "any_token").is_ok());
502
503 manager
505 .increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
506 .unwrap();
507 assert_eq!(manager.get_pow_requirement(&ip), 1);
508
509 assert!(manager.verify_pow(&ip, "any_token").is_err());
511 assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
512 }
513
514 #[test]
515 fn test_stats() {
516 let manager = RateLimitManager::default();
517 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
518
519 let stats = manager.stats();
520 assert_eq!(stats.active_bans, 0);
521 assert_eq!(stats.total_bans_issued, 0);
522
523 manager.ban(&ip, Duration::from_mins(1), PenaltyReason::AuthFailure).unwrap();
524
525 let stats = manager.stats();
526 assert!(stats.active_bans > 0);
527 assert_eq!(stats.total_bans_issued, 1);
528 }
529}
530
531