Skip to main content

cloudillo_core/rate_limit/
limiter.rs

1//! Rate Limit Manager
2//!
3//! Core rate limiting implementation using the governor crate's GCRA algorithm.
4//! Supports hierarchical address levels with dual-tier (short + long term) limits.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::num::NonZeroU32;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use governor::clock::{Clock, DefaultClock};
14use governor::state::keyed::DashMapStateStore;
15use governor::{Quota, RateLimiter};
16use lru::LruCache;
17use parking_lot::RwLock;
18use std::num::NonZeroUsize;
19use tracing::{debug, warn};
20
21use super::api::{
22	BanEntry, PenaltyReason, PowPenaltyReason, RateLimitApi, RateLimitStatus, RateLimiterStats,
23};
24use super::config::{EndpointCategoryConfig, PowConfig, RateLimitConfig, RateLimitTierConfig};
25use super::error::{PowError, RateLimitError};
26use super::extractors::AddressKey;
27use super::pow::PowCounterStore;
28use crate::prelude::*;
29
30/// Type alias for a keyed rate limiter
31type KeyedLimiter = RateLimiter<AddressKey, DashMapStateStore<AddressKey>, DefaultClock>;
32
33/// Holds both short-term and long-term limiters for an address level
34struct TierLimiters {
35	short_term: Arc<KeyedLimiter>,
36	long_term: Arc<KeyedLimiter>,
37}
38
39impl TierLimiters {
40	fn new(config: &RateLimitTierConfig) -> Self {
41		// Short-term: per-second with burst
42		let short_quota =
43			Quota::per_second(config.short_term_rps).allow_burst(config.short_term_burst);
44		let short_term = Arc::new(RateLimiter::keyed(short_quota));
45
46		// Long-term: per-hour with burst
47		// Convert RPH to rate per second for governor
48		let rps = config.long_term_rph.get() as f64 / 3600.0;
49		let period_nanos = (1_000_000_000.0 / rps) as u64;
50		// SAFETY: 1 is non-zero
51		const ONE: NonZeroU32 = match NonZeroU32::new(1) {
52			Some(v) => v,
53			None => unreachable!(),
54		};
55		let long_quota = Quota::with_period(Duration::from_nanos(period_nanos))
56			.unwrap_or_else(|| Quota::per_second(ONE))
57			.allow_burst(config.long_term_burst);
58		let long_term = Arc::new(RateLimiter::keyed(long_quota));
59
60		Self { short_term, long_term }
61	}
62
63	/// Check if both short and long term limits allow the request
64	fn check(&self, key: &AddressKey) -> Result<(), Duration> {
65		// Check short-term first
66		if let Err(not_until) = self.short_term.check_key(key) {
67			return Err(not_until.wait_time_from(DefaultClock::default().now()));
68		}
69
70		// Check long-term
71		if let Err(not_until) = self.long_term.check_key(key) {
72			return Err(not_until.wait_time_from(DefaultClock::default().now()));
73		}
74
75		Ok(())
76	}
77}
78
79/// Per-category rate limiters (one for each hierarchical level)
80struct CategoryLimiters {
81	ipv4_individual: TierLimiters,
82	ipv4_network: TierLimiters,
83	ipv6_subnet: TierLimiters,
84	ipv6_provider: TierLimiters,
85}
86
87impl CategoryLimiters {
88	fn new(config: &EndpointCategoryConfig) -> Self {
89		Self {
90			ipv4_individual: TierLimiters::new(&config.ipv4_individual),
91			ipv4_network: TierLimiters::new(&config.ipv4_network),
92			ipv6_subnet: TierLimiters::new(&config.ipv6_subnet),
93			ipv6_provider: TierLimiters::new(&config.ipv6_provider),
94		}
95	}
96
97	/// Check all applicable limits for an address
98	fn check(&self, addr: &IpAddr) -> Result<(), RateLimitError> {
99		let keys = AddressKey::extract_all(addr);
100
101		for key in keys {
102			let limiter = self.get_limiter_for_key(&key);
103			if let Err(wait_time) = limiter.check(&key) {
104				return Err(RateLimitError::RateLimited {
105					level: key.level_name(),
106					retry_after: wait_time,
107				});
108			}
109		}
110
111		Ok(())
112	}
113
114	fn get_limiter_for_key(&self, key: &AddressKey) -> &TierLimiters {
115		match key {
116			AddressKey::Ipv4Individual(_) => &self.ipv4_individual,
117			AddressKey::Ipv4Network(_) => &self.ipv4_network,
118			AddressKey::Ipv6Subnet(_) => &self.ipv6_subnet,
119			AddressKey::Ipv6Provider(_) => &self.ipv6_provider,
120		}
121	}
122}
123
124/// Penalty tracking for an address
125#[derive(Debug, Clone, Default)]
126struct PenaltyEntry {
127	count: u32,
128	last_penalty: Option<Instant>,
129	reason: Option<PenaltyReason>,
130}
131
132/// Main rate limit manager
133pub struct RateLimitManager {
134	/// Per-category limiters
135	categories: HashMap<String, CategoryLimiters>,
136	/// Global ban list
137	bans: RwLock<LruCache<AddressKey, BanEntry>>,
138	/// Penalty tracking per address
139	penalties: RwLock<LruCache<AddressKey, PenaltyEntry>>,
140	/// Proof-of-work counter store
141	pow_store: PowCounterStore,
142	/// Statistics
143	total_limited: AtomicU64,
144	total_bans: AtomicU64,
145}
146
147impl RateLimitManager {
148	/// Create a new rate limit manager
149	pub fn new(config: RateLimitConfig) -> Self {
150		let mut categories = HashMap::new();
151
152		// Initialize category limiters
153		categories.insert("auth".to_string(), CategoryLimiters::new(&config.auth));
154		categories.insert("federation".to_string(), CategoryLimiters::new(&config.federation));
155		categories.insert("general".to_string(), CategoryLimiters::new(&config.general));
156		categories.insert("websocket".to_string(), CategoryLimiters::new(&config.websocket));
157
158		// SAFETY: These are non-zero constants
159		const TEN_THOUSAND: NonZeroUsize = match NonZeroUsize::new(10_000) {
160			Some(v) => v,
161			None => unreachable!(),
162		};
163		const TWENTY_THOUSAND: NonZeroUsize = match NonZeroUsize::new(20_000) {
164			Some(v) => v,
165			None => unreachable!(),
166		};
167		let ban_cap = NonZeroUsize::new(config.max_tracked_ips / 10).unwrap_or(TEN_THOUSAND);
168		let penalty_cap = NonZeroUsize::new(config.max_tracked_ips / 5).unwrap_or(TWENTY_THOUSAND);
169
170		Self {
171			categories,
172			bans: RwLock::new(LruCache::new(ban_cap)),
173			penalties: RwLock::new(LruCache::new(penalty_cap)),
174			pow_store: PowCounterStore::new(PowConfig::default()),
175			total_limited: AtomicU64::new(0),
176			total_bans: AtomicU64::new(0),
177		}
178	}
179
180	/// Create with custom PoW config
181	pub fn with_pow_config(config: RateLimitConfig, pow_config: PowConfig) -> Self {
182		let mut manager = Self::new(config);
183		manager.pow_store = PowCounterStore::new(pow_config);
184		manager
185	}
186
187	/// Check if a request should be rate limited
188	pub fn check(&self, addr: &IpAddr, category: &str) -> Result<(), RateLimitError> {
189		// Check ban list first
190		if let Some(ban) = self.check_ban(addr) {
191			return Err(RateLimitError::Banned { remaining: ban.remaining_duration() });
192		}
193
194		// Check rate limits
195		let cat_limiters = self
196			.categories
197			.get(category)
198			.ok_or_else(|| RateLimitError::UnknownCategory(category.to_string()))?;
199
200		if let Err(e) = cat_limiters.check(addr) {
201			self.total_limited.fetch_add(1, Ordering::Relaxed);
202			return Err(e);
203		}
204
205		Ok(())
206	}
207
208	/// Check if address is banned
209	fn check_ban(&self, addr: &IpAddr) -> Option<BanEntry> {
210		let keys = AddressKey::extract_all(addr);
211		let mut bans = self.bans.write();
212
213		for key in keys {
214			if let Some(ban) = bans.get(&key) {
215				if ban.is_expired() {
216					bans.pop(&key);
217				} else {
218					return Some(ban.clone());
219				}
220			}
221		}
222
223		None
224	}
225
226	/// Record a penalty for an address
227	fn record_penalty(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) {
228		let key = AddressKey::from_ip_individual(addr);
229		let mut penalties = self.penalties.write();
230
231		let entry = penalties.get_or_insert_mut(key.clone(), PenaltyEntry::default);
232		entry.count = entry.count.saturating_add(amount);
233		entry.last_penalty = Some(Instant::now());
234		entry.reason = Some(reason);
235
236		// Check for auto-ban
237		if entry.count >= reason.failures_to_ban() {
238			drop(penalties);
239			if let Err(e) = self.ban(addr, reason.ban_duration(), reason) {
240				warn!("Failed to auto-ban address: {}", e);
241			}
242		}
243	}
244}
245
246impl Default for RateLimitManager {
247	fn default() -> Self {
248		Self::new(RateLimitConfig::default())
249	}
250}
251
252impl RateLimitApi for RateLimitManager {
253	fn get_status(
254		&self,
255		addr: &IpAddr,
256		category: &str,
257	) -> ClResult<Vec<(AddressKey, RateLimitStatus)>> {
258		let _cat_limiters = self.categories.get(category).ok_or(Error::NotFound)?;
259
260		let keys = AddressKey::extract_all(addr);
261		let bans = self.bans.read();
262
263		let statuses =
264			keys.into_iter()
265				.map(|key| {
266					let is_banned = bans.peek(&key).is_some_and(|b| !b.is_expired());
267					let ban_expires = bans.peek(&key).and_then(|b| {
268						if b.is_expired() {
269							None
270						} else {
271							Some(b.expires_at.unwrap_or_else(|| {
272								Instant::now() + Duration::from_secs(86400 * 365)
273							}))
274						}
275					});
276
277					let status = RateLimitStatus {
278						is_limited: false, // Would need to check governor state
279						remaining: None,
280						reset_at: None,
281						quota: 0,
282						is_banned,
283						ban_expires_at: ban_expires,
284					};
285
286					(key, status)
287				})
288				.collect();
289
290		Ok(statuses)
291	}
292
293	fn penalize(&self, addr: &IpAddr, reason: PenaltyReason, amount: u32) -> ClResult<()> {
294		debug!("Penalizing {:?} for {:?} (amount: {})", addr, reason, amount);
295		self.record_penalty(addr, reason, amount);
296		Ok(())
297	}
298
299	fn grant(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
300		let key = AddressKey::from_ip_individual(addr);
301		let mut penalties = self.penalties.write();
302
303		if let Some(entry) = penalties.get_mut(&key) {
304			entry.count = entry.count.saturating_sub(amount);
305			if entry.count == 0 {
306				penalties.pop(&key);
307			}
308		}
309
310		Ok(())
311	}
312
313	fn reset(&self, addr: &IpAddr) -> ClResult<()> {
314		let keys = AddressKey::extract_all(addr);
315
316		// Clear penalties
317		let mut penalties = self.penalties.write();
318		for key in &keys {
319			penalties.pop(key);
320		}
321		drop(penalties);
322
323		// Clear bans
324		let mut bans = self.bans.write();
325		for key in &keys {
326			bans.pop(key);
327		}
328
329		// Clear PoW counters
330		self.pow_store.decrement(addr, u32::MAX);
331
332		Ok(())
333	}
334
335	fn ban(&self, addr: &IpAddr, duration: Duration, reason: PenaltyReason) -> ClResult<()> {
336		let keys = AddressKey::extract_all(addr);
337		let now = Instant::now();
338		let expires_at = Some(now + duration);
339
340		let mut bans = self.bans.write();
341		for key in keys {
342			let entry = BanEntry { key: key.clone(), reason, created_at: now, expires_at };
343			bans.put(key, entry);
344		}
345
346		self.total_bans.fetch_add(1, Ordering::Relaxed);
347		debug!("Banned {:?} for {:?} due to {:?}", addr, duration, reason);
348
349		Ok(())
350	}
351
352	fn unban(&self, addr: &IpAddr) -> ClResult<()> {
353		let keys = AddressKey::extract_all(addr);
354		let mut bans = self.bans.write();
355
356		for key in keys {
357			bans.pop(&key);
358		}
359
360		Ok(())
361	}
362
363	fn is_banned(&self, addr: &IpAddr) -> bool {
364		self.check_ban(addr).is_some()
365	}
366
367	fn list_bans(&self) -> Vec<BanEntry> {
368		self.bans
369			.read()
370			.iter()
371			.filter(|(_, b)| !b.is_expired())
372			.map(|(_, b)| b.clone())
373			.collect()
374	}
375
376	fn stats(&self) -> RateLimiterStats {
377		// Count tracked addresses across all categories
378		let tracked = self
379			.categories
380			.values()
381			.map(|c| {
382				c.ipv4_individual.short_term.len()
383					+ c.ipv4_network.short_term.len()
384					+ c.ipv6_subnet.short_term.len()
385					+ c.ipv6_provider.short_term.len()
386			})
387			.sum();
388
389		RateLimiterStats {
390			tracked_addresses: tracked,
391			active_bans: self.bans.read().len(),
392			total_requests_limited: self.total_limited.load(Ordering::Relaxed),
393			total_bans_issued: self.total_bans.load(Ordering::Relaxed),
394			pow_individual_entries: self.pow_store.individual_count(),
395			pow_network_entries: self.pow_store.network_count(),
396		}
397	}
398
399	fn get_pow_requirement(&self, addr: &IpAddr) -> u32 {
400		self.pow_store.get_requirement(addr)
401	}
402
403	fn increment_pow_counter(&self, addr: &IpAddr, reason: PowPenaltyReason) -> ClResult<()> {
404		self.pow_store.increment(addr, reason);
405		Ok(())
406	}
407
408	fn decrement_pow_counter(&self, addr: &IpAddr, amount: u32) -> ClResult<()> {
409		self.pow_store.decrement(addr, amount);
410		Ok(())
411	}
412
413	fn verify_pow(&self, addr: &IpAddr, token: &str) -> Result<(), PowError> {
414		self.pow_store.verify(addr, token)
415	}
416}
417
418#[cfg(test)]
419mod tests {
420	use super::*;
421	use std::net::Ipv4Addr;
422
423	#[test]
424	fn test_rate_limit_manager_creation() {
425		let manager = RateLimitManager::default();
426		assert!(manager.categories.contains_key("auth"));
427		assert!(manager.categories.contains_key("federation"));
428		assert!(manager.categories.contains_key("general"));
429		assert!(manager.categories.contains_key("websocket"));
430	}
431
432	#[test]
433	fn test_rate_limit_check() {
434		let manager = RateLimitManager::default();
435		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
436
437		// First few requests should pass
438		for _ in 0..5 {
439			assert!(manager.check(&ip, "general").is_ok());
440		}
441	}
442
443	#[test]
444	fn test_unknown_category() {
445		let manager = RateLimitManager::default();
446		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
447
448		let result = manager.check(&ip, "nonexistent");
449		assert!(matches!(result, Err(RateLimitError::UnknownCategory(_))));
450	}
451
452	#[test]
453	fn test_ban_functionality() {
454		let manager = RateLimitManager::default();
455		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
456
457		assert!(!manager.is_banned(&ip));
458
459		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
460		assert!(manager.is_banned(&ip));
461
462		let result = manager.check(&ip, "general");
463		assert!(matches!(result, Err(RateLimitError::Banned { .. })));
464
465		manager.unban(&ip).unwrap();
466		assert!(!manager.is_banned(&ip));
467	}
468
469	#[test]
470	fn test_penalty_auto_ban() {
471		let manager = RateLimitManager::default();
472		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
473
474		// AuthFailure requires 5 failures for auto-ban
475		for _ in 0..4 {
476			manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
477			assert!(!manager.is_banned(&ip));
478		}
479
480		// 5th failure should trigger auto-ban
481		manager.penalize(&ip, PenaltyReason::AuthFailure, 1).unwrap();
482		assert!(manager.is_banned(&ip));
483	}
484
485	#[test]
486	fn test_pow_integration() {
487		let manager = RateLimitManager::default();
488		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
489
490		// Initially no PoW required
491		assert_eq!(manager.get_pow_requirement(&ip), 0);
492		assert!(manager.verify_pow(&ip, "any_token").is_ok());
493
494		// Increment counter
495		manager
496			.increment_pow_counter(&ip, PowPenaltyReason::ConnSignatureFailure)
497			.unwrap();
498		assert_eq!(manager.get_pow_requirement(&ip), 1);
499
500		// Now need PoW
501		assert!(manager.verify_pow(&ip, "any_token").is_err());
502		assert!(manager.verify_pow(&ip, "any_tokenA").is_ok());
503	}
504
505	#[test]
506	fn test_stats() {
507		let manager = RateLimitManager::default();
508		let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
509
510		let stats = manager.stats();
511		assert_eq!(stats.active_bans, 0);
512		assert_eq!(stats.total_bans_issued, 0);
513
514		manager.ban(&ip, Duration::from_secs(60), PenaltyReason::AuthFailure).unwrap();
515
516		let stats = manager.stats();
517		assert!(stats.active_bans > 0);
518		assert_eq!(stats.total_bans_issued, 1);
519	}
520}
521
522// vim: ts=4