Skip to main content

cloudillo_core/rate_limit/
limiter.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Rate Limit Manager
5//!
6//! Core rate limiting implementation using the governor crate's GCRA algorithm.
7//! Supports hierarchical address levels with dual-tier (short + long term) limits.
8
9use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::time::{Duration, Instant};
15
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use lru::LruCache;
20use parking_lot::RwLock;
21use std::num::NonZeroUsize;
22use tracing::{debug, warn};
23
24use super::api::{
25	BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
26};
27use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
28use super::error::{PowError, RateLimitError};
29use super::extractors::AddressKey;
30use super::pow::PowCounterStore;
31use crate::prelude::*;
32
33/// Type alias for a keyed rate limiter
34type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
35
36/// Holds both short-term and long-term limiters for an address level
37struct TierLimiters {
38	short_term: Arc<KeyedLimiter>,
39	long_term: Arc<KeyedLimiter>,
40}
41
42impl TierLimiters {
43	// SAFETY: 1 is non-zero
44	const ONE: NonZeroU32 = match NonZeroU32::new(1) {
45		Some(v) => v,
46		None => unreachable!(),
47	};
48
49	fn new(config: &RateLimitTierConfig) -> Self {
50		// Short-term: per-second with burst
51		let short_quota =
52			Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
53		let short_term = Arc::new(RateLimiter::keyed(short_quota));
54
55		// Long-term: per-hour with burst
56		// Convert RPH to nanosecond period using integer math:
57		// period_nanos = 3_600_000_000_000 / rph
58		let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
59		let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
60			.unwrap_or_else(|| Quota::per_second(Self::ONE))
61			.allow_burst(config.long_term_burst);
62		let long_term = Arc::new(RateLimiter::keyed(long_quota));
63
64		Self { short_term, long_term }
65	}
66
67	/// Check if both short and long term limits allow the request
68	fn check(&self, key: &AddressKey) -> Result<(), Duration> {
69		// Check short-term first
70		if let Err(not_until) = self.short_term.check_key(key) {
71			return Err(not_until.wait_time_from(DefaultClock::default().now()));
72		}
73
74		// Check long-term
75		if let Err(not_until) = self.long_term.check_key(key) {
76			return Err(not_until.wait_time_from(DefaultClock::default().now()));
77		}
78
79		Ok(())
80	}
81}
82
83/// Per-category rate limiters (one for each hierarchical level)
84struct CategoryLimiters {
85	ipv4_individual: TierLimiters,
86	ipv4_network: TierLimiters,
87	ipv6_subnet: TierLimiters,
88	ipv6_provider: TierLimiters,
89}
90
91impl CategoryLimiters {
92	fn new(config: &EndpointCategoryConfig) -> Self {
93		Self {
94			ipv4_individual: TierLimiters::new(&config.ipv4_individual),
95			ipv4_network: TierLimiters::new(&config.ipv4_network),
96			ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
97			ipv6_provider: TierLimiters::new(&config.ipv6_provider),
98		}
99	}
100
101	/// Check all applicable limits for an address
102	fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
103		let keys = AddressKey::extract_all(addr);
104
105		for key in keys {
106			let limiter = self.get_limiter_for_key(&key);
107			if let Err(wait_time) = limiter.check(&key) {
108				return Err(RateLimitError::RateLimited {
109					level: key.level_name(),
110					retry_after: wait_time,
111				});
112			}
113		}
114
115		Ok(())
116	}
117
118	fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
119		match key {
120			AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
121			AddressKey::Ipv4Network(_) => &self.ipv4_network,
122			AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
123			AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
124		}
125	}
126}
127
128/// Penalty tracking for an address
129#[derive(Debug, Clone, Default)]
130struct PenaltyEntry {
131	count: u32,
132	last_penalty: Option<Instant>,
133	reason: Option<PenaltyReason>,
134}
135
136/// Main rate limit manager
137pub struct RateLimitManager {
138	/// Per-category limiters
139	categories: HashMap<String, CategoryLimiters>,
140	/// Global ban list
141	bans: RwLock<LruCache<AddressKey, BanEntry>>,
142	/// Penalty tracking per address
143	penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
144	/// Proof-of-work counter store
145	pow_store: PowCounterStore,
146	/// Statistics
147	total_limited: AtomicU64,
148	total_bans: AtomicU64,
149}
150
151impl RateLimitManager {
152	// SAFETY: These are non-zero constants
153	const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
154		Some(v) => v,
155		None => unreachable!(),
156	};
157	const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
158		Some(v) => v,
159		None => unreachable!(),
160	};
161
162	/// Create a new rate limit manager
163	pub fn new(config: &RateLimitConfig) -> Self {
164		let mut categories = HashMap::new();
165
166		// Initialize category limiters
167		categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
168		categories.insert("dav".to_string(), CategoryLimiters::new(&config.dav));
169		categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
170		categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
171		categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
172
173		let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
174		let penalty_cap =
175			NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
176
177		Self {
178			categories,
179			bans: RwLock::new(LruCache::new(ban_cap)),
180			penalties: RwLock::new(LruCache::new(penalty_cap)),
181			pow_store: PowCounterStore::new(PowConfig::default()),
182			total_limited: AtomicU64::new(0),
183			total_bans: AtomicU64::new(0),
184		}
185	}
186
187	/// Create with custom PoW config
188	pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
189		let mut manager = Self::new(config);
190		manager.pow_store = PowCounterStore::new(pow_config);
191		manager
192	}
193
194	/// Check if a request should be rate limited
195	pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
196		// Check ban list first
197		if let Some(ban) = self.check_ban(addr) {
198			return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
199		}
200
201		// Check rate limits
202		let cat_limiters = self
203			.categories
204			.get(category)
205			.ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
206
207		if let Err(e) = cat_limiters.check(addr) {
208			self.total_limited.fetch_add(1, Ordering::Relaxed);
209			return Err(e);
210		}
211
212		Ok(())
213	}
214
215	/// Check if address is banned
216	fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
217		let keys = AddressKey::extract_all(addr);
218		let mut bans = self.bans.write();
219
220		for key in keys {
221			if let Some(ban) = bans.get(&key) {
222				if ban.is_expired() {
223					bans.pop(&key);
224				} else {
225					return Some(ban.clone());
226				}
227			}
228		}
229
230		None
231	}
232
233	/// Record a penalty for an address
234	fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
235		let key = AddressKey::from_ip_individual(addr);
236		let mut penalties = self.penalties.write();
237
238		let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
239		entry.count = entry.count.saturating_add(amount);
240		entry.last_penalty = Some(Instant::now());
241		entry.reason = Some(reason);
242
243		// Check for auto-ban
244		if entry.count >= reason.failures_to_ban() {
245			drop(penalties);
246			if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
247				warn!("Failed to auto-ban address: {}", e);
248			}
249		}
250	}
251}
252
253impl Default for RateLimitManager {
254	fn default() -> Self {
255		Self::new(&RateLimitConfig::default())
256	}
257}
258
259impl RateLimitApi for RateLimitManager {
260	fn get_status(
261		&self,
262		addr: &IpAddr,
263		category: &str,
264	) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
265		let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
266
267		let keys = AddressKey::extract_all(addr);
268		let bans = self.bans.read();
269
270		let statuses = keys
271			.into_iter()
272			.map(|key| {
273				let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
274				let ban_expires = bans.peek(&key).and_then(|b| {
275					if b.is_expired() {
276						None
277					} else {
278						Some(
279							b.expires_at
280								.unwrap_or_else(|| Instant::now() + Duration::from_hours(24 * 365)),
281						)
282					}
283				});
284
285				let status = RateLimitStatus {
286					is_limited: false, // Would need to check governor state
287					remaining: None,
288					reset_at: None,
289					quota: 0,
290					is_banned,
291					ban_expires_at: ban_expires,
292				};
293
294				(key, status)
295			})
296			.collect();
297
298		Ok(statuses)
299	}
300
301	fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
302		debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
303		self.record_penalty(addr, reason, amount);
304		Ok(())
305	}
306
307	fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
308		let key = AddressKey::from_ip_individual(addr);
309		let mut penalties = self.penalties.write();
310
311		if let Some(entry) = penalties.get_mut(&key) {
312			entry.count = entry.count.saturating_sub(amount);
313			if entry.count == 0 {
314				penalties.pop(&key);
315			}
316		}
317
318		Ok(())
319	}
320
321	fn reset(&self, addr: &IpAddr) -> ClResult<()> {
322		let keys = AddressKey::extract_all(addr);
323
324		// Clear penalties
325		let mut penalties = self.penalties.write();
326		for key in &keys {
327			penalties.pop(key);
328		}
329		drop(penalties);
330
331		// Clear bans
332		let mut bans = self.bans.write();
333		for key in &keys {
334			bans.pop(key);
335		}
336
337		// Clear PoW counters
338		self.pow_store.decrement(addr, u32::MAX);
339
340		Ok(())
341	}
342
343	fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
344		let keys = AddressKey::extract_all(addr);
345		let now = Instant::now();
346		let expires_at = Some(now + duration);
347
348		let mut bans = self.bans.write();
349		for key in keys {
350			let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
351			bans.put(key, entry);
352		}
353
354		self.total_bans.fetch_add(1, Ordering::Relaxed);
355		debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
356
357		Ok(())
358	}
359
360	fn unban(&self, addr: &IpAddr) -> ClResult<()> {
361		let keys = AddressKey::extract_all(addr);
362		let mut bans = self.bans.write();
363
364		for key in keys {
365			bans.pop(&key);
366		}
367
368		Ok(())
369	}
370
371	fn is_banned(&self, addr: &IpAddr) -> bool {
372		self.check_ban(addr).is_some()
373	}
374
375	fn list_bans(&self) -> Vec<BanEntry> {
376		self.bans
377			.read()
378			.iter()
379			.filter(|(_, b)| !b.is_expired())
380			.map(|(_, b)| b.clone())
381			.collect()
382	}
383
384	fn stats(&self) -> RateLimiterStats {
385		// Count tracked addresses across all categories
386		let tracked = self
387			.categories
388			.values()
389			.map(|c| {
390				c.ipv4_individual.short_term.len()
391					+ c.ipv4_network.short_term.len()
392					+ c.ipv6_subnet.short_term.len()
393					+ c.ipv6_provider.short_term.len()
394			})
395			.sum();
396
397		RateLimiterStats {
398			tracked_addresses: tracked,
399			active_bans: self.bans.read().len(),
400			total_requests_limited: self.total_limited.load(Ordering::Relaxed),
401			total_bans_issued: self.total_bans.load(Ordering::Relaxed),
402			pow_individual_entries: self.pow_store.individual_count(),
403			pow_network_entries: self.pow_store.network_count(),
404		}
405	}
406
407	fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
408		self.pow_store.get_requirement(addr)
409	}
410
411	fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
412		self.pow_store.increment(addr, reason);
413		Ok(())
414	}
415
416	fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
417		self.pow_store.decrement(addr, amount);
418		Ok(())
419	}
420
421	fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
422		self.pow_store.verify(addr, token)
423	}
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used, clippy::expect_used)]
428mod tests {
429	use super::*;
430	use std::net::Ipv4Addr;
431
432	#[test]
433	fn test_rate_limit_manager_creation() {
434		let manager = RateLimitManager::default();
435		assert!(manager.categories.contains_key("auth"));
436		assert!(manager.categories.contains_key("federation"));
437		assert!(manager.categories.contains_key("general"));
438		assert!(manager.categories.contains_key("websocket"));
439	}
440
441	#[test]
442	fn test_rate_limit_check() {
443		let manager = RateLimitManager::default();
444		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
445
446		// First few requests should pass
447		for _ in 0..5 {
448			assert!(manager.check(&ip, "general").is_ok());
449		}
450	}
451
452	#[test]
453	fn test_unknown_category() {
454		let manager = RateLimitManager::default();
455		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
456
457		let result = manager.check(&ip, "nonexistent");
458		assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
459	}
460
461	#[test]
462	fn test_ban_functionality() {
463		let manager = RateLimitManager::default();
464		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
465
466		assert!(!manager.is_banned(&ip));
467
468		manager.ban(&ip, Duration::from_mins(1), PenaltyReason::AuthFailure).unwrap();
469		assert!(manager.is_banned(&ip));
470
471		let result = manager.check(&ip, "general");
472		assert!(matches!(result, Err(RateLimitError::Banned { .. })));
473
474		manager.unban(&ip).unwrap();
475		assert!(!manager.is_banned(&ip));
476	}
477
478	#[test]
479	fn test_penalty_auto_ban() {
480		let manager = RateLimitManager::default();
481		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
482
483		// AuthFailure requires 5 failures for auto-ban
484		for _ in 0..4 {
485			manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
486			assert!(!manager.is_banned(&ip));
487		}
488
489		// 5th failure should trigger auto-ban
490		manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
491		assert!(manager.is_banned(&ip));
492	}
493
494	#[test]
495	fn test_pow_integration() {
496		let manager = RateLimitManager::default();
497		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
498
499		// Initially no PoW required
500		assert_eq!(manager.get_pow_requirement(&ip), 0);
501		assert!(manager.verify_pow(&ip, "any_token").is_ok());
502
503		// Increment counter
504		manager
505			.increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
506			.unwrap();
507		assert_eq!(manager.get_pow_requirement(&ip), 1);
508
509		// Now need PoW
510		assert!(manager.verify_pow(&ip, "any_token").is_err());
511		assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
512	}
513
514	#[test]
515	fn test_stats() {
516		let manager = RateLimitManager::default();
517		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
518
519		let stats = manager.stats();
520		assert_eq!(stats.active_bans, 0);
521		assert_eq!(stats.total_bans_issued, 0);
522
523		manager.ban(&ip, Duration::from_mins(1), PenaltyReason::AuthFailure).unwrap();
524
525		let stats = manager.stats();
526		assert!(stats.active_bans > 0);
527		assert_eq!(stats.total_bans_issued, 1);
528	}
529}
530
531// vim: ts=4