Skip to main content

cloudillo_core/rate_limit/
limiter.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Rate Limit Manager
5//!
6//! Core rate limiting implementation using the governor crate's GCRA algorithm.
7//! Supports hierarchical address levels with dual-tier (short + long term) limits.
8
9use std::collections::HashMap;
10use std::net::IpAddr;
11use std::num::NonZeroU32;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use governor::clock::{Clock, DefaultClock};
17use governor::state::keyed::DashMapStateStore;
18use governor::{Quota, RateLimiter};
19use lru::LruCache;
20use parking_lot::RwLock;
21use std::num::NonZeroUsize;
22use tracing::{debug, warn};
23
24use super::api::{
25	BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
26};
27use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
28use super::error::{PowError, RateLimitError};
29use super::extractors::AddressKey;
30use super::pow::PowCounterStore;
31use crate::prelude::*;
32
33/// Type alias for a keyed rate limiter
34type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
35
36/// Holds both short-term and long-term limiters for an address level
37struct TierLimiters {
38	short_term: Arc<KeyedLimiter>,
39	long_term: Arc<KeyedLimiter>,
40}
41
42impl TierLimiters {
43	// SAFETY: 1 is non-zero
44	const ONE: NonZeroU32 = match NonZeroU32::new(1) {
45		Some(v) => v,
46		None => unreachable!(),
47	};
48
49	fn new(config: &RateLimitTierConfig) -> Self {
50		// Short-term: per-second with burst
51		let short_quota =
52			Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
53		let short_term = Arc::new(RateLimiter::keyed(short_quota));
54
55		// Long-term: per-hour with burst
56		// Convert RPH to nanosecond period using integer math:
57		// period_nanos = 3_600_000_000_000 / rph
58		let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
59		let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
60			.unwrap_or_else(|| Quota::per_second(Self::ONE))
61			.allow_burst(config.long_term_burst);
62		let long_term = Arc::new(RateLimiter::keyed(long_quota));
63
64		Self { short_term, long_term }
65	}
66
67	/// Check if both short and long term limits allow the request
68	fn check(&self, key: &AddressKey) -> Result<(), Duration> {
69		// Check short-term first
70		if let Err(not_until) = self.short_term.check_key(key) {
71			return Err(not_until.wait_time_from(DefaultClock::default().now()));
72		}
73
74		// Check long-term
75		if let Err(not_until) = self.long_term.check_key(key) {
76			return Err(not_until.wait_time_from(DefaultClock::default().now()));
77		}
78
79		Ok(())
80	}
81}
82
83/// Per-category rate limiters (one for each hierarchical level)
84struct CategoryLimiters {
85	ipv4_individual: TierLimiters,
86	ipv4_network: TierLimiters,
87	ipv6_subnet: TierLimiters,
88	ipv6_provider: TierLimiters,
89}
90
91impl CategoryLimiters {
92	fn new(config: &EndpointCategoryConfig) -> Self {
93		Self {
94			ipv4_individual: TierLimiters::new(&config.ipv4_individual),
95			ipv4_network: TierLimiters::new(&config.ipv4_network),
96			ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
97			ipv6_provider: TierLimiters::new(&config.ipv6_provider),
98		}
99	}
100
101	/// Check all applicable limits for an address
102	fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
103		let keys = AddressKey::extract_all(addr);
104
105		for key in keys {
106			let limiter = self.get_limiter_for_key(&key);
107			if let Err(wait_time) = limiter.check(&key) {
108				return Err(RateLimitError::RateLimited {
109					level: key.level_name(),
110					retry_after: wait_time,
111				});
112			}
113		}
114
115		Ok(())
116	}
117
118	fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
119		match key {
120			AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
121			AddressKey::Ipv4Network(_) => &self.ipv4_network,
122			AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
123			AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
124		}
125	}
126}
127
128/// Penalty tracking for an address
129#[derive(Debug, Clone, Default)]
130struct PenaltyEntry {
131	count: u32,
132	last_penalty: Option<Instant>,
133	reason: Option<PenaltyReason>,
134}
135
136/// Main rate limit manager
137pub struct RateLimitManager {
138	/// Per-category limiters
139	categories: HashMap<String, CategoryLimiters>,
140	/// Global ban list
141	bans: RwLock<LruCache<AddressKey, BanEntry>>,
142	/// Penalty tracking per address
143	penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
144	/// Proof-of-work counter store
145	pow_store: PowCounterStore,
146	/// Statistics
147	total_limited: AtomicU64,
148	total_bans: AtomicU64,
149}
150
151impl RateLimitManager {
152	// SAFETY: These are non-zero constants
153	const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
154		Some(v) => v,
155		None => unreachable!(),
156	};
157	const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
158		Some(v) => v,
159		None => unreachable!(),
160	};
161
162	/// Create a new rate limit manager
163	pub fn new(config: &RateLimitConfig) -> Self {
164		let mut categories = HashMap::new();
165
166		// Initialize category limiters
167		categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
168		categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
169		categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
170		categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
171
172		let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
173		let penalty_cap =
174			NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
175
176		Self {
177			categories,
178			bans: RwLock::new(LruCache::new(ban_cap)),
179			penalties: RwLock::new(LruCache::new(penalty_cap)),
180			pow_store: PowCounterStore::new(PowConfig::default()),
181			total_limited: AtomicU64::new(0),
182			total_bans: AtomicU64::new(0),
183		}
184	}
185
186	/// Create with custom PoW config
187	pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
188		let mut manager = Self::new(config);
189		manager.pow_store = PowCounterStore::new(pow_config);
190		manager
191	}
192
193	/// Check if a request should be rate limited
194	pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
195		// Check ban list first
196		if let Some(ban) = self.check_ban(addr) {
197			return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
198		}
199
200		// Check rate limits
201		let cat_limiters = self
202			.categories
203			.get(category)
204			.ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
205
206		if let Err(e) = cat_limiters.check(addr) {
207			self.total_limited.fetch_add(1, Ordering::Relaxed);
208			return Err(e);
209		}
210
211		Ok(())
212	}
213
214	/// Check if address is banned
215	fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
216		let keys = AddressKey::extract_all(addr);
217		let mut bans = self.bans.write();
218
219		for key in keys {
220			if let Some(ban) = bans.get(&key) {
221				if ban.is_expired() {
222					bans.pop(&key);
223				} else {
224					return Some(ban.clone());
225				}
226			}
227		}
228
229		None
230	}
231
232	/// Record a penalty for an address
233	fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
234		let key = AddressKey::from_ip_individual(addr);
235		let mut penalties = self.penalties.write();
236
237		let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
238		entry.count = entry.count.saturating_add(amount);
239		entry.last_penalty = Some(Instant::now());
240		entry.reason = Some(reason);
241
242		// Check for auto-ban
243		if entry.count >= reason.failures_to_ban() {
244			drop(penalties);
245			if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
246				warn!("Failed to auto-ban address: {}", e);
247			}
248		}
249	}
250}
251
252impl Default for RateLimitManager {
253	fn default() -> Self {
254		Self::new(&RateLimitConfig::default())
255	}
256}
257
258impl RateLimitApi for RateLimitManager {
259	fn get_status(
260		&self,
261		addr: &IpAddr,
262		category: &str,
263	) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
264		let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
265
266		let keys = AddressKey::extract_all(addr);
267		let bans = self.bans.read();
268
269		let statuses =
270			keys.into_iter()
271				.map(|key| {
272					let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
273					let ban_expires = bans.peek(&key).and_then(|b| {
274						if b.is_expired() {
275							None
276						} else {
277							Some(b.expires_at.unwrap_or_else(|| {
278								Instant::now() + Duration::from_secs(86400 * 365)
279							}))
280						}
281					});
282
283					let status = RateLimitStatus {
284						is_limited: false, // Would need to check governor state
285						remaining: None,
286						reset_at: None,
287						quota: 0,
288						is_banned,
289						ban_expires_at: ban_expires,
290					};
291
292					(key, status)
293				})
294				.collect();
295
296		Ok(statuses)
297	}
298
299	fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
300		debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
301		self.record_penalty(addr, reason, amount);
302		Ok(())
303	}
304
305	fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
306		let key = AddressKey::from_ip_individual(addr);
307		let mut penalties = self.penalties.write();
308
309		if let Some(entry) = penalties.get_mut(&key) {
310			entry.count = entry.count.saturating_sub(amount);
311			if entry.count == 0 {
312				penalties.pop(&key);
313			}
314		}
315
316		Ok(())
317	}
318
319	fn reset(&self, addr: &IpAddr) -> ClResult<()> {
320		let keys = AddressKey::extract_all(addr);
321
322		// Clear penalties
323		let mut penalties = self.penalties.write();
324		for key in &keys {
325			penalties.pop(key);
326		}
327		drop(penalties);
328
329		// Clear bans
330		let mut bans = self.bans.write();
331		for key in &keys {
332			bans.pop(key);
333		}
334
335		// Clear PoW counters
336		self.pow_store.decrement(addr, u32::MAX);
337
338		Ok(())
339	}
340
341	fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
342		let keys = AddressKey::extract_all(addr);
343		let now = Instant::now();
344		let expires_at = Some(now + duration);
345
346		let mut bans = self.bans.write();
347		for key in keys {
348			let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
349			bans.put(key, entry);
350		}
351
352		self.total_bans.fetch_add(1, Ordering::Relaxed);
353		debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
354
355		Ok(())
356	}
357
358	fn unban(&self, addr: &IpAddr) -> ClResult<()> {
359		let keys = AddressKey::extract_all(addr);
360		let mut bans = self.bans.write();
361
362		for key in keys {
363			bans.pop(&key);
364		}
365
366		Ok(())
367	}
368
369	fn is_banned(&self, addr: &IpAddr) -> bool {
370		self.check_ban(addr).is_some()
371	}
372
373	fn list_bans(&self) -> Vec<BanEntry> {
374		self.bans
375			.read()
376			.iter()
377			.filter(|(_, b)| !b.is_expired())
378			.map(|(_, b)| b.clone())
379			.collect()
380	}
381
382	fn stats(&self) -> RateLimiterStats {
383		// Count tracked addresses across all categories
384		let tracked = self
385			.categories
386			.values()
387			.map(|c| {
388				c.ipv4_individual.short_term.len()
389					+ c.ipv4_network.short_term.len()
390					+ c.ipv6_subnet.short_term.len()
391					+ c.ipv6_provider.short_term.len()
392			})
393			.sum();
394
395		RateLimiterStats {
396			tracked_addresses: tracked,
397			active_bans: self.bans.read().len(),
398			total_requests_limited: self.total_limited.load(Ordering::Relaxed),
399			total_bans_issued: self.total_bans.load(Ordering::Relaxed),
400			pow_individual_entries: self.pow_store.individual_count(),
401			pow_network_entries: self.pow_store.network_count(),
402		}
403	}
404
405	fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
406		self.pow_store.get_requirement(addr)
407	}
408
409	fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
410		self.pow_store.increment(addr, reason);
411		Ok(())
412	}
413
414	fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
415		self.pow_store.decrement(addr, amount);
416		Ok(())
417	}
418
419	fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
420		self.pow_store.verify(addr, token)
421	}
422}
423
424#[cfg(test)]
425mod tests {
426	use super::*;
427	use std::net::Ipv4Addr;
428
429	#[test]
430	fn test_rate_limit_manager_creation() {
431		let manager = RateLimitManager::default();
432		assert!(manager.categories.contains_key("auth"));
433		assert!(manager.categories.contains_key("federation"));
434		assert!(manager.categories.contains_key("general"));
435		assert!(manager.categories.contains_key("websocket"));
436	}
437
438	#[test]
439	fn test_rate_limit_check() {
440		let manager = RateLimitManager::default();
441		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
442
443		// First few requests should pass
444		for _ in 0..5 {
445			assert!(manager.check(&ip, "general").is_ok());
446		}
447	}
448
449	#[test]
450	fn test_unknown_category() {
451		let manager = RateLimitManager::default();
452		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
453
454		let result = manager.check(&ip, "nonexistent");
455		assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
456	}
457
458	#[test]
459	fn test_ban_functionality() {
460		let manager = RateLimitManager::default();
461		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
462
463		assert!(!manager.is_banned(&ip));
464
465		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
466		assert!(manager.is_banned(&ip));
467
468		let result = manager.check(&ip, "general");
469		assert!(matches!(result, Err(RateLimitError::Banned { .. })));
470
471		manager.unban(&ip).unwrap();
472		assert!(!manager.is_banned(&ip));
473	}
474
475	#[test]
476	fn test_penalty_auto_ban() {
477		let manager = RateLimitManager::default();
478		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
479
480		// AuthFailure requires 5 failures for auto-ban
481		for _ in 0..4 {
482			manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
483			assert!(!manager.is_banned(&ip));
484		}
485
486		// 5th failure should trigger auto-ban
487		manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
488		assert!(manager.is_banned(&ip));
489	}
490
491	#[test]
492	fn test_pow_integration() {
493		let manager = RateLimitManager::default();
494		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
495
496		// Initially no PoW required
497		assert_eq!(manager.get_pow_requirement(&ip), 0);
498		assert!(manager.verify_pow(&ip, "any_token").is_ok());
499
500		// Increment counter
501		manager
502			.increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
503			.unwrap();
504		assert_eq!(manager.get_pow_requirement(&ip), 1);
505
506		// Now need PoW
507		assert!(manager.verify_pow(&ip, "any_token").is_err());
508		assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
509	}
510
511	#[test]
512	fn test_stats() {
513		let manager = RateLimitManager::default();
514		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
515
516		let stats = manager.stats();
517		assert_eq!(stats.active_bans, 0);
518		assert_eq!(stats.total_bans_issued, 0);
519
520		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
521
522		let stats = manager.stats();
523		assert!(stats.active_bans > 0);
524		assert_eq!(stats.total_bans_issued, 1);
525	}
526}
527
528// vim: ts=4