1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::num::NonZeroU32;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use governor::clock::{Clock, DefaultClock};
14use governor::state::keyed::DashMapStateStore;
15use governor::{Quota, RateLimiter};
16use lru::LruCache;
17use parking_lot::RwLock;
18use std::num::NonZeroUsize;
19use tracing::{debug, warn};
20
21use super::api::{
22 BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
23};
24use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
25use super::error::{PowError, RateLimitError};
26use super::extractors::AddressKey;
27use super::pow::PowCounterStore;
28use crate::prelude::*;
29
30type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
32
33struct TierLimiters {
35 short_term: Arc<KeyedLimiter>,
36 long_term: Arc<KeyedLimiter>,
37}
38
39impl TierLimiters {
40 const ONE: NonZeroU32 = match NonZeroU32::new(1) {
42 Some(v) => v,
43 None => unreachable!(),
44 };
45
46 fn new(config: &RateLimitTierConfig) -> Self {
47 let short_quota =
49 Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
50 let short_term = Arc::new(RateLimiter::keyed(short_quota));
51
52 let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
56 let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
57 .unwrap_or_else(|| Quota::per_second(Self::ONE))
58 .allow_burst(config.long_term_burst);
59 let long_term = Arc::new(RateLimiter::keyed(long_quota));
60
61 Self { short_term, long_term }
62 }
63
64 fn check(&self, key: &AddressKey) -> Result<(), Duration> {
66 if let Err(not_until) = self.short_term.check_key(key) {
68 return Err(not_until.wait_time_from(DefaultClock::default().now()));
69 }
70
71 if let Err(not_until) = self.long_term.check_key(key) {
73 return Err(not_until.wait_time_from(DefaultClock::default().now()));
74 }
75
76 Ok(())
77 }
78}
79
80struct CategoryLimiters {
82 ipv4_individual: TierLimiters,
83 ipv4_network: TierLimiters,
84 ipv6_subnet: TierLimiters,
85 ipv6_provider: TierLimiters,
86}
87
88impl CategoryLimiters {
89 fn new(config: &EndpointCategoryConfig) -> Self {
90 Self {
91 ipv4_individual: TierLimiters::new(&config.ipv4_individual),
92 ipv4_network: TierLimiters::new(&config.ipv4_network),
93 ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
94 ipv6_provider: TierLimiters::new(&config.ipv6_provider),
95 }
96 }
97
98 fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
100 let keys = AddressKey::extract_all(addr);
101
102 for key in keys {
103 let limiter = self.get_limiter_for_key(&key);
104 if let Err(wait_time) = limiter.check(&key) {
105 return Err(RateLimitError::RateLimited {
106 level: key.level_name(),
107 retry_after: wait_time,
108 });
109 }
110 }
111
112 Ok(())
113 }
114
115 fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
116 match key {
117 AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
118 AddressKey::Ipv4Network(_) => &self.ipv4_network,
119 AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
120 AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Default)]
127struct PenaltyEntry {
128 count: u32,
129 last_penalty: Option<Instant>,
130 reason: Option<PenaltyReason>,
131}
132
133pub struct RateLimitManager {
135 categories: HashMap<String, CategoryLimiters>,
137 bans: RwLock<LruCache<AddressKey, BanEntry>>,
139 penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
141 pow_store: PowCounterStore,
143 total_limited: AtomicU64,
145 total_bans: AtomicU64,
146}
147
148impl RateLimitManager {
149 const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
151 Some(v) => v,
152 None => unreachable!(),
153 };
154 const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
155 Some(v) => v,
156 None => unreachable!(),
157 };
158
159 pub fn new(config: &RateLimitConfig) -> Self {
161 let mut categories = HashMap::new();
162
163 categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
165 categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
166 categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
167 categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
168
169 let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
170 let penalty_cap =
171 NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
172
173 Self {
174 categories,
175 bans: RwLock::new(LruCache::new(ban_cap)),
176 penalties: RwLock::new(LruCache::new(penalty_cap)),
177 pow_store: PowCounterStore::new(PowConfig::default()),
178 total_limited: AtomicU64::new(0),
179 total_bans: AtomicU64::new(0),
180 }
181 }
182
183 pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
185 let mut manager = Self::new(config);
186 manager.pow_store = PowCounterStore::new(pow_config);
187 manager
188 }
189
190 pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
192 if let Some(ban) = self.check_ban(addr) {
194 return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
195 }
196
197 let cat_limiters = self
199 .categories
200 .get(category)
201 .ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
202
203 if let Err(e) = cat_limiters.check(addr) {
204 self.total_limited.fetch_add(1, Ordering::Relaxed);
205 return Err(e);
206 }
207
208 Ok(())
209 }
210
211 fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
213 let keys = AddressKey::extract_all(addr);
214 let mut bans = self.bans.write();
215
216 for key in keys {
217 if let Some(ban) = bans.get(&key) {
218 if ban.is_expired() {
219 bans.pop(&key);
220 } else {
221 return Some(ban.clone());
222 }
223 }
224 }
225
226 None
227 }
228
229 fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
231 let key = AddressKey::from_ip_individual(addr);
232 let mut penalties = self.penalties.write();
233
234 let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
235 entry.count = entry.count.saturating_add(amount);
236 entry.last_penalty = Some(Instant::now());
237 entry.reason = Some(reason);
238
239 if entry.count >= reason.failures_to_ban() {
241 drop(penalties);
242 if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
243 warn!("Failed to auto-ban address: {}", e);
244 }
245 }
246 }
247}
248
249impl Default for RateLimitManager {
250 fn default() -> Self {
251 Self::new(&RateLimitConfig::default())
252 }
253}
254
255impl RateLimitApi for RateLimitManager {
256 fn get_status(
257 &self,
258 addr: &IpAddr,
259 category: &str,
260 ) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
261 let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
262
263 let keys = AddressKey::extract_all(addr);
264 let bans = self.bans.read();
265
266 let statuses =
267 keys.into_iter()
268 .map(|key| {
269 let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
270 let ban_expires = bans.peek(&key).and_then(|b| {
271 if b.is_expired() {
272 None
273 } else {
274 Some(b.expires_at.unwrap_or_else(|| {
275 Instant::now() + Duration::from_secs(86400 * 365)
276 }))
277 }
278 });
279
280 let status = RateLimitStatus {
281 is_limited: false, remaining: None,
283 reset_at: None,
284 quota: 0,
285 is_banned,
286 ban_expires_at: ban_expires,
287 };
288
289 (key, status)
290 })
291 .collect();
292
293 Ok(statuses)
294 }
295
296 fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
297 debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
298 self.record_penalty(addr, reason, amount);
299 Ok(())
300 }
301
302 fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
303 let key = AddressKey::from_ip_individual(addr);
304 let mut penalties = self.penalties.write();
305
306 if let Some(entry) = penalties.get_mut(&key) {
307 entry.count = entry.count.saturating_sub(amount);
308 if entry.count == 0 {
309 penalties.pop(&key);
310 }
311 }
312
313 Ok(())
314 }
315
316 fn reset(&self, addr: &IpAddr) -> ClResult<()> {
317 let keys = AddressKey::extract_all(addr);
318
319 let mut penalties = self.penalties.write();
321 for key in &keys {
322 penalties.pop(key);
323 }
324 drop(penalties);
325
326 let mut bans = self.bans.write();
328 for key in &keys {
329 bans.pop(key);
330 }
331
332 self.pow_store.decrement(addr, u32::MAX);
334
335 Ok(())
336 }
337
338 fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
339 let keys = AddressKey::extract_all(addr);
340 let now = Instant::now();
341 let expires_at = Some(now + duration);
342
343 let mut bans = self.bans.write();
344 for key in keys {
345 let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
346 bans.put(key, entry);
347 }
348
349 self.total_bans.fetch_add(1, Ordering::Relaxed);
350 debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
351
352 Ok(())
353 }
354
355 fn unban(&self, addr: &IpAddr) -> ClResult<()> {
356 let keys = AddressKey::extract_all(addr);
357 let mut bans = self.bans.write();
358
359 for key in keys {
360 bans.pop(&key);
361 }
362
363 Ok(())
364 }
365
366 fn is_banned(&self, addr: &IpAddr) -> bool {
367 self.check_ban(addr).is_some()
368 }
369
370 fn list_bans(&self) -> Vec<BanEntry> {
371 self.bans
372 .read()
373 .iter()
374 .filter(|(_, b)| !b.is_expired())
375 .map(|(_, b)| b.clone())
376 .collect()
377 }
378
379 fn stats(&self) -> RateLimiterStats {
380 let tracked = self
382 .categories
383 .values()
384 .map(|c| {
385 c.ipv4_individual.short_term.len()
386 + c.ipv4_network.short_term.len()
387 + c.ipv6_subnet.short_term.len()
388 + c.ipv6_provider.short_term.len()
389 })
390 .sum();
391
392 RateLimiterStats {
393 tracked_addresses: tracked,
394 active_bans: self.bans.read().len(),
395 total_requests_limited: self.total_limited.load(Ordering::Relaxed),
396 total_bans_issued: self.total_bans.load(Ordering::Relaxed),
397 pow_individual_entries: self.pow_store.individual_count(),
398 pow_network_entries: self.pow_store.network_count(),
399 }
400 }
401
402 fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
403 self.pow_store.get_requirement(addr)
404 }
405
406 fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
407 self.pow_store.increment(addr, reason);
408 Ok(())
409 }
410
411 fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
412 self.pow_store.decrement(addr, amount);
413 Ok(())
414 }
415
416 fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
417 self.pow_store.verify(addr, token)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use std::net::Ipv4Addr;
425
426 #[test]
427 fn test_rate_limit_manager_creation() {
428 let manager = RateLimitManager::default();
429 assert!(manager.categories.contains_key("auth"));
430 assert!(manager.categories.contains_key("federation"));
431 assert!(manager.categories.contains_key("general"));
432 assert!(manager.categories.contains_key("websocket"));
433 }
434
435 #[test]
436 fn test_rate_limit_check() {
437 let manager = RateLimitManager::default();
438 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
439
440 for _ in 0..5 {
442 assert!(manager.check(&ip, "general").is_ok());
443 }
444 }
445
446 #[test]
447 fn test_unknown_category() {
448 let manager = RateLimitManager::default();
449 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
450
451 let result = manager.check(&ip, "nonexistent");
452 assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
453 }
454
455 #[test]
456 fn test_ban_functionality() {
457 let manager = RateLimitManager::default();
458 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
459
460 assert!(!manager.is_banned(&ip));
461
462 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
463 assert!(manager.is_banned(&ip));
464
465 let result = manager.check(&ip, "general");
466 assert!(matches!(result, Err(RateLimitError::Banned { .. })));
467
468 manager.unban(&ip).unwrap();
469 assert!(!manager.is_banned(&ip));
470 }
471
472 #[test]
473 fn test_penalty_auto_ban() {
474 let manager = RateLimitManager::default();
475 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
476
477 for _ in 0..4 {
479 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
480 assert!(!manager.is_banned(&ip));
481 }
482
483 manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
485 assert!(manager.is_banned(&ip));
486 }
487
488 #[test]
489 fn test_pow_integration() {
490 let manager = RateLimitManager::default();
491 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
492
493 assert_eq!(manager.get_pow_requirement(&ip), 0);
495 assert!(manager.verify_pow(&ip, "any_token").is_ok());
496
497 manager
499 .increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
500 .unwrap();
501 assert_eq!(manager.get_pow_requirement(&ip), 1);
502
503 assert!(manager.verify_pow(&ip, "any_token").is_err());
505 assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
506 }
507
508 #[test]
509 fn test_stats() {
510 let manager = RateLimitManager::default();
511 let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
512
513 let stats = manager.stats();
514 assert_eq!(stats.active_bans, 0);
515 assert_eq!(stats.total_bans_issued, 0);
516
517 manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
518
519 let stats = manager.stats();
520 assert!(stats.active_bans > 0);
521 assert_eq!(stats.total_bans_issued, 1);
522 }
523}
524
525