Skip to main content

cloudillo_core/rate_limit/
limiter.rs

1//! Rate Limit Manager
2//!
3//! Core rate limiting implementation using the governor crate's GCRA algorithm.
4//! Supports hierarchical address levels with dual-tier (short + long term) limits.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::num::NonZeroU32;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use governor::clock::{Clock, DefaultClock};
14use governor::state::keyed::DashMapStateStore;
15use governor::{Quota, RateLimiter};
16use lru::LruCache;
17use parking_lot::RwLock;
18use std::num::NonZeroUsize;
19use tracing::{debug, warn};
20
21use super::api::{
22	BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
23};
24use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
25use super::error::{PowError, RateLimitError};
26use super::extractors::AddressKey;
27use super::pow::PowCounterStore;
28use crate::prelude::*;
29
30/// Type alias for a keyed rate limiter
31type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
32
33/// Holds both short-term and long-term limiters for an address level
34struct TierLimiters {
35	short_term: Arc<KeyedLimiter>,
36	long_term: Arc<KeyedLimiter>,
37}
38
39impl TierLimiters {
40	// SAFETY: 1 is non-zero
41	const ONE: NonZeroU32 = match NonZeroU32::new(1) {
42		Some(v) => v,
43		None => unreachable!(),
44	};
45
46	fn new(config: &RateLimitTierConfig) -> Self {
47		// Short-term: per-second with burst
48		let short_quota =
49			Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
50		let short_term = Arc::new(RateLimiter::keyed(short_quota));
51
52		// Long-term: per-hour with burst
53		// Convert RPH to nanosecond period using integer math:
54		// period_nanos = 3_600_000_000_000 / rph
55		let period_nanos = 3_600_000_000_000_u64 / u64::from(config.long_term_rph.get());
56		let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
57			.unwrap_or_else(|| Quota::per_second(Self::ONE))
58			.allow_burst(config.long_term_burst);
59		let long_term = Arc::new(RateLimiter::keyed(long_quota));
60
61		Self { short_term, long_term }
62	}
63
64	/// Check if both short and long term limits allow the request
65	fn check(&self, key: &AddressKey) -> Result<(), Duration> {
66		// Check short-term first
67		if let Err(not_until) = self.short_term.check_key(key) {
68			return Err(not_until.wait_time_from(DefaultClock::default().now()));
69		}
70
71		// Check long-term
72		if let Err(not_until) = self.long_term.check_key(key) {
73			return Err(not_until.wait_time_from(DefaultClock::default().now()));
74		}
75
76		Ok(())
77	}
78}
79
80/// Per-category rate limiters (one for each hierarchical level)
81struct CategoryLimiters {
82	ipv4_individual: TierLimiters,
83	ipv4_network: TierLimiters,
84	ipv6_subnet: TierLimiters,
85	ipv6_provider: TierLimiters,
86}
87
88impl CategoryLimiters {
89	fn new(config: &EndpointCategoryConfig) -> Self {
90		Self {
91			ipv4_individual: TierLimiters::new(&config.ipv4_individual),
92			ipv4_network: TierLimiters::new(&config.ipv4_network),
93			ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
94			ipv6_provider: TierLimiters::new(&config.ipv6_provider),
95		}
96	}
97
98	/// Check all applicable limits for an address
99	fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
100		let keys = AddressKey::extract_all(addr);
101
102		for key in keys {
103			let limiter = self.get_limiter_for_key(&key);
104			if let Err(wait_time) = limiter.check(&key) {
105				return Err(RateLimitError::RateLimited {
106					level: key.level_name(),
107					retry_after: wait_time,
108				});
109			}
110		}
111
112		Ok(())
113	}
114
115	fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
116		match key {
117			AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
118			AddressKey::Ipv4Network(_) => &self.ipv4_network,
119			AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
120			AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
121		}
122	}
123}
124
125/// Penalty tracking for an address
126#[derive(Debug, Clone, Default)]
127struct PenaltyEntry {
128	count: u32,
129	last_penalty: Option<Instant>,
130	reason: Option<PenaltyReason>,
131}
132
133/// Main rate limit manager
134pub struct RateLimitManager {
135	/// Per-category limiters
136	categories: HashMap<String, CategoryLimiters>,
137	/// Global ban list
138	bans: RwLock<LruCache<AddressKey, BanEntry>>,
139	/// Penalty tracking per address
140	penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
141	/// Proof-of-work counter store
142	pow_store: PowCounterStore,
143	/// Statistics
144	total_limited: AtomicU64,
145	total_bans: AtomicU64,
146}
147
148impl RateLimitManager {
149	// SAFETY: These are non-zero constants
150	const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
151		Some(v) => v,
152		None => unreachable!(),
153	};
154	const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
155		Some(v) => v,
156		None => unreachable!(),
157	};
158
159	/// Create a new rate limit manager
160	pub fn new(config: &RateLimitConfig) -> Self {
161		let mut categories = HashMap::new();
162
163		// Initialize category limiters
164		categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
165		categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
166		categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
167		categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
168
169		let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(Self::TEN_THOUSAND);
170		let penalty_cap =
171			NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(Self::TWENTY_THOUSAND);
172
173		Self {
174			categories,
175			bans: RwLock::new(LruCache::new(ban_cap)),
176			penalties: RwLock::new(LruCache::new(penalty_cap)),
177			pow_store: PowCounterStore::new(PowConfig::default()),
178			total_limited: AtomicU64::new(0),
179			total_bans: AtomicU64::new(0),
180		}
181	}
182
183	/// Create with custom PoW config
184	pub fn with_pow_config(config: &RateLimitConfig, pow_config: PowConfig) -> Self {
185		let mut manager = Self::new(config);
186		manager.pow_store = PowCounterStore::new(pow_config);
187		manager
188	}
189
190	/// Check if a request should be rate limited
191	pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
192		// Check ban list first
193		if let Some(ban) = self.check_ban(addr) {
194			return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
195		}
196
197		// Check rate limits
198		let cat_limiters = self
199			.categories
200			.get(category)
201			.ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
202
203		if let Err(e) = cat_limiters.check(addr) {
204			self.total_limited.fetch_add(1, Ordering::Relaxed);
205			return Err(e);
206		}
207
208		Ok(())
209	}
210
211	/// Check if address is banned
212	fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
213		let keys = AddressKey::extract_all(addr);
214		let mut bans = self.bans.write();
215
216		for key in keys {
217			if let Some(ban) = bans.get(&key) {
218				if ban.is_expired() {
219					bans.pop(&key);
220				} else {
221					return Some(ban.clone());
222				}
223			}
224		}
225
226		None
227	}
228
229	/// Record a penalty for an address
230	fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
231		let key = AddressKey::from_ip_individual(addr);
232		let mut penalties = self.penalties.write();
233
234		let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
235		entry.count = entry.count.saturating_add(amount);
236		entry.last_penalty = Some(Instant::now());
237		entry.reason = Some(reason);
238
239		// Check for auto-ban
240		if entry.count >= reason.failures_to_ban() {
241			drop(penalties);
242			if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
243				warn!("Failed to auto-ban address: {}", e);
244			}
245		}
246	}
247}
248
249impl Default for RateLimitManager {
250	fn default() -> Self {
251		Self::new(&RateLimitConfig::default())
252	}
253}
254
255impl RateLimitApi for RateLimitManager {
256	fn get_status(
257		&self,
258		addr: &IpAddr,
259		category: &str,
260	) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
261		let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
262
263		let keys = AddressKey::extract_all(addr);
264		let bans = self.bans.read();
265
266		let statuses =
267			keys.into_iter()
268				.map(|key| {
269					let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
270					let ban_expires = bans.peek(&key).and_then(|b| {
271						if b.is_expired() {
272							None
273						} else {
274							Some(b.expires_at.unwrap_or_else(|| {
275								Instant::now() + Duration::from_secs(86400 * 365)
276							}))
277						}
278					});
279
280					let status = RateLimitStatus {
281						is_limited: false, // Would need to check governor state
282						remaining: None,
283						reset_at: None,
284						quota: 0,
285						is_banned,
286						ban_expires_at: ban_expires,
287					};
288
289					(key, status)
290				})
291				.collect();
292
293		Ok(statuses)
294	}
295
296	fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
297		debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
298		self.record_penalty(addr, reason, amount);
299		Ok(())
300	}
301
302	fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
303		let key = AddressKey::from_ip_individual(addr);
304		let mut penalties = self.penalties.write();
305
306		if let Some(entry) = penalties.get_mut(&key) {
307			entry.count = entry.count.saturating_sub(amount);
308			if entry.count == 0 {
309				penalties.pop(&key);
310			}
311		}
312
313		Ok(())
314	}
315
316	fn reset(&self, addr: &IpAddr) -> ClResult<()> {
317		let keys = AddressKey::extract_all(addr);
318
319		// Clear penalties
320		let mut penalties = self.penalties.write();
321		for key in &keys {
322			penalties.pop(key);
323		}
324		drop(penalties);
325
326		// Clear bans
327		let mut bans = self.bans.write();
328		for key in &keys {
329			bans.pop(key);
330		}
331
332		// Clear PoW counters
333		self.pow_store.decrement(addr, u32::MAX);
334
335		Ok(())
336	}
337
338	fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
339		let keys = AddressKey::extract_all(addr);
340		let now = Instant::now();
341		let expires_at = Some(now + duration);
342
343		let mut bans = self.bans.write();
344		for key in keys {
345			let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
346			bans.put(key, entry);
347		}
348
349		self.total_bans.fetch_add(1, Ordering::Relaxed);
350		debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
351
352		Ok(())
353	}
354
355	fn unban(&self, addr: &IpAddr) -> ClResult<()> {
356		let keys = AddressKey::extract_all(addr);
357		let mut bans = self.bans.write();
358
359		for key in keys {
360			bans.pop(&key);
361		}
362
363		Ok(())
364	}
365
366	fn is_banned(&self, addr: &IpAddr) -> bool {
367		self.check_ban(addr).is_some()
368	}
369
370	fn list_bans(&self) -> Vec<BanEntry> {
371		self.bans
372			.read()
373			.iter()
374			.filter(|(_, b)| !b.is_expired())
375			.map(|(_, b)| b.clone())
376			.collect()
377	}
378
379	fn stats(&self) -> RateLimiterStats {
380		// Count tracked addresses across all categories
381		let tracked = self
382			.categories
383			.values()
384			.map(|c| {
385				c.ipv4_individual.short_term.len()
386					+ c.ipv4_network.short_term.len()
387					+ c.ipv6_subnet.short_term.len()
388					+ c.ipv6_provider.short_term.len()
389			})
390			.sum();
391
392		RateLimiterStats {
393			tracked_addresses: tracked,
394			active_bans: self.bans.read().len(),
395			total_requests_limited: self.total_limited.load(Ordering::Relaxed),
396			total_bans_issued: self.total_bans.load(Ordering::Relaxed),
397			pow_individual_entries: self.pow_store.individual_count(),
398			pow_network_entries: self.pow_store.network_count(),
399		}
400	}
401
402	fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
403		self.pow_store.get_requirement(addr)
404	}
405
406	fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
407		self.pow_store.increment(addr, reason);
408		Ok(())
409	}
410
411	fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
412		self.pow_store.decrement(addr, amount);
413		Ok(())
414	}
415
416	fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
417		self.pow_store.verify(addr, token)
418	}
419}
420
421#[cfg(test)]
422mod tests {
423	use super::*;
424	use std::net::Ipv4Addr;
425
426	#[test]
427	fn test_rate_limit_manager_creation() {
428		let manager = RateLimitManager::default();
429		assert!(manager.categories.contains_key("auth"));
430		assert!(manager.categories.contains_key("federation"));
431		assert!(manager.categories.contains_key("general"));
432		assert!(manager.categories.contains_key("websocket"));
433	}
434
435	#[test]
436	fn test_rate_limit_check() {
437		let manager = RateLimitManager::default();
438		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
439
440		// First few requests should pass
441		for _ in 0..5 {
442			assert!(manager.check(&ip, "general").is_ok());
443		}
444	}
445
446	#[test]
447	fn test_unknown_category() {
448		let manager = RateLimitManager::default();
449		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
450
451		let result = manager.check(&ip, "nonexistent");
452		assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
453	}
454
455	#[test]
456	fn test_ban_functionality() {
457		let manager = RateLimitManager::default();
458		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
459
460		assert!(!manager.is_banned(&ip));
461
462		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
463		assert!(manager.is_banned(&ip));
464
465		let result = manager.check(&ip, "general");
466		assert!(matches!(result, Err(RateLimitError::Banned { .. })));
467
468		manager.unban(&ip).unwrap();
469		assert!(!manager.is_banned(&ip));
470	}
471
472	#[test]
473	fn test_penalty_auto_ban() {
474		let manager = RateLimitManager::default();
475		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
476
477		// AuthFailure requires 5 failures for auto-ban
478		for _ in 0..4 {
479			manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
480			assert!(!manager.is_banned(&ip));
481		}
482
483		// 5th failure should trigger auto-ban
484		manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
485		assert!(manager.is_banned(&ip));
486	}
487
488	#[test]
489	fn test_pow_integration() {
490		let manager = RateLimitManager::default();
491		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
492
493		// Initially no PoW required
494		assert_eq!(manager.get_pow_requirement(&ip), 0);
495		assert!(manager.verify_pow(&ip, "any_token").is_ok());
496
497		// Increment counter
498		manager
499			.increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
500			.unwrap();
501		assert_eq!(manager.get_pow_requirement(&ip), 1);
502
503		// Now need PoW
504		assert!(manager.verify_pow(&ip, "any_token").is_err());
505		assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
506	}
507
508	#[test]
509	fn test_stats() {
510		let manager = RateLimitManager::default();
511		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
512
513		let stats = manager.stats();
514		assert_eq!(stats.active_bans, 0);
515		assert_eq!(stats.total_bans_issued, 0);
516
517		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
518
519		let stats = manager.stats();
520		assert!(stats.active_bans > 0);
521		assert_eq!(stats.total_bans_issued, 1);
522	}
523}
524
525// vim: ts=4