1use std::collections::HashMap;
11use std::net::{IpAddr, SocketAddr};
12use std::sync::{Arc, Mutex};
13use std::time::Instant;
14
15use axum::body::Body;
16use axum::extract::ConnectInfo;
17use axum::http::{HeaderValue, Method, Request};
18use axum::response::{IntoResponse, Response};
19
20use riley_auth_core::config::RateLimitTiersConfig;
21use riley_auth_core::error::Error;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum RateLimitTier {
28 Auth,
30 Standard,
32 Public,
34}
35
36impl RateLimitTier {
37 pub fn as_str(self) -> &'static str {
38 match self {
39 Self::Auth => "auth",
40 Self::Standard => "standard",
41 Self::Public => "public",
42 }
43 }
44}
45
46pub fn classify_path(path: &str) -> RateLimitTier {
51 let path = if path.len() > 1 && path.ends_with('/') {
53 &path[..path.len() - 1]
54 } else {
55 path
56 };
57
58 if matches!(
60 path,
61 "/oauth/token"
62 | "/oauth/authorize"
63 | "/oauth/revoke"
64 | "/oauth/introspect"
65 | "/auth/setup"
66 | "/auth/refresh"
67 ) {
68 return RateLimitTier::Auth;
69 }
70
71 if let Some(rest) = path.strip_prefix("/auth/link/") {
73 if !rest.is_empty() {
74 return RateLimitTier::Auth;
75 }
76 }
77
78 if let Some(rest) = path.strip_prefix("/auth/") {
81 let segment = rest.split('/').next().unwrap_or("");
82 if !matches!(
83 segment,
84 "me" | "logout" | "logout-all" | "sessions" | "link"
85 | "setup" | "refresh" | ""
86 ) {
87 return RateLimitTier::Auth;
88 }
89 }
90
91 if path == "/health" || path.starts_with("/.well-known/") || path == "/.well-known" {
93 return RateLimitTier::Public;
94 }
95
96 RateLimitTier::Standard
98}
99
100pub fn extract_ip<B>(req: &Request<B>, behind_proxy: bool) -> Option<IpAddr> {
106 let peer_ip = req.extensions()
107 .get::<ConnectInfo<SocketAddr>>()
108 .map(|ci| ci.0.ip());
109 crate::routes::extract_client_ip(req.headers(), peer_ip, behind_proxy)
110}
111
112struct WindowEntry {
116 count: u64,
117 window_start: Instant,
118}
119
120struct TierState {
122 windows: HashMap<IpAddr, WindowEntry>,
123 last_prune: Instant,
124}
125
126const MAX_ENTRIES_PER_TIER: usize = 100_000;
129
130struct InMemoryTierLimiter {
132 state: Mutex<TierState>,
133 burst_size: u32,
134 window_secs: u64,
135}
136
137impl InMemoryTierLimiter {
138 fn new(burst_size: u32, window_secs: u64) -> Self {
139 Self {
140 state: Mutex::new(TierState {
141 windows: HashMap::new(),
142 last_prune: Instant::now(),
143 }),
144 burst_size,
145 window_secs,
146 }
147 }
148
149 fn check(&self, ip: &IpAddr) -> (bool, u64, u64) {
151 let mut state = self.state.lock().expect("rate limit lock poisoned");
152 let now = Instant::now();
153
154 if now.duration_since(state.last_prune).as_secs() >= self.window_secs {
156 let window_secs = self.window_secs;
157 state.windows.retain(|_, entry| {
158 now.duration_since(entry.window_start).as_secs() < window_secs
159 });
160 state.last_prune = now;
161 }
162
163 if state.windows.len() >= MAX_ENTRIES_PER_TIER && !state.windows.contains_key(ip) {
165 let retry_after = self.window_secs.max(1);
166 return (false, 0, retry_after);
167 }
168
169 let entry = state.windows.entry(*ip).or_insert(WindowEntry {
170 count: 0,
171 window_start: now,
172 });
173
174 if now.duration_since(entry.window_start).as_secs() >= self.window_secs {
176 entry.count = 0;
177 entry.window_start = now;
178 }
179
180 entry.count += 1;
181 let allowed = entry.count <= self.burst_size as u64;
182 let remaining = if allowed {
183 self.burst_size as u64 - entry.count
184 } else {
185 0
186 };
187
188 let elapsed = now.duration_since(entry.window_start).as_secs();
190 let retry_after = self.window_secs.saturating_sub(elapsed).max(1);
191
192 (allowed, remaining, retry_after)
193 }
194}
195
196pub struct InMemoryRateLimiter {
198 auth: InMemoryTierLimiter,
199 standard: InMemoryTierLimiter,
200 public: InMemoryTierLimiter,
201}
202
203impl InMemoryRateLimiter {
204 pub fn new(tiers: &RateLimitTiersConfig) -> Self {
205 Self {
206 auth: InMemoryTierLimiter::new(tiers.auth.requests, tiers.auth.window_secs),
207 standard: InMemoryTierLimiter::new(tiers.standard.requests, tiers.standard.window_secs),
208 public: InMemoryTierLimiter::new(tiers.public.requests, tiers.public.window_secs),
209 }
210 }
211
212 fn tier_limiter(&self, tier: RateLimitTier) -> &InMemoryTierLimiter {
213 match tier {
214 RateLimitTier::Auth => &self.auth,
215 RateLimitTier::Standard => &self.standard,
216 RateLimitTier::Public => &self.public,
217 }
218 }
219
220 fn burst_size(&self, tier: RateLimitTier) -> u32 {
221 self.tier_limiter(tier).burst_size
222 }
223
224 fn check(&self, tier: RateLimitTier, ip: &IpAddr) -> (bool, u64, u64) {
225 self.tier_limiter(tier).check(ip)
226 }
227}
228
229pub async fn memory_rate_limit_middleware(
231 limiter: Arc<InMemoryRateLimiter>,
232 behind_proxy: bool,
233 req: Request<Body>,
234 next: axum::middleware::Next,
235) -> Response {
236 if req.method() == Method::OPTIONS {
238 return next.run(req).await;
239 }
240
241 let ip = match extract_ip(&req, behind_proxy) {
242 Some(ip) => ip,
243 None => {
244 tracing::warn!("rate limiter: could not extract client IP, bypassing rate limit");
245 return next.run(req).await;
246 }
247 };
248
249 let tier = classify_path(req.uri().path());
250 let (allowed, remaining, retry_after) = limiter.check(tier, &ip);
251
252 if allowed {
253 let mut response = next.run(req).await;
254 let headers = response.headers_mut();
255 headers.insert("x-ratelimit-remaining", HeaderValue::from(remaining));
256 headers.insert("x-ratelimit-limit", HeaderValue::from(limiter.burst_size(tier)));
257 headers_insert_reset(headers, retry_after);
258 response
259 } else {
260 metrics::counter!("riley_auth_rate_limit_hits_total", "tier" => tier.as_str()).increment(1);
261 let mut response = Error::RateLimited.into_response();
262 let headers = response.headers_mut();
263 headers_insert_retry_after(headers, retry_after);
264 headers.insert("x-ratelimit-remaining", HeaderValue::from(0u64));
265 headers.insert("x-ratelimit-limit", HeaderValue::from(limiter.burst_size(tier)));
266 headers_insert_reset(headers, retry_after);
267 response
268 }
269}
270
271fn headers_insert_retry_after(headers: &mut axum::http::HeaderMap, retry_after: u64) {
273 headers.insert("retry-after", HeaderValue::from(retry_after));
274}
275
276fn headers_insert_reset(headers: &mut axum::http::HeaderMap, retry_after: u64) {
278 let reset = std::time::SystemTime::now()
279 .duration_since(std::time::UNIX_EPOCH)
280 .unwrap_or_default()
281 .as_secs()
282 + retry_after;
283 headers.insert("x-ratelimit-reset", HeaderValue::from(reset));
284}
285
286#[cfg(feature = "redis")]
289mod redis_impl {
290 use super::*;
291 use redis::aio::ConnectionManager;
292 use redis::Script;
293
294 #[derive(Clone)]
296 pub struct RedisRateLimiter {
297 conn: ConnectionManager,
298 burst_size: u32,
299 window_secs: u64,
300 key_prefix: String,
301 }
302
303 impl RedisRateLimiter {
304 pub async fn new(
306 redis_url: &str,
307 burst_size: u32,
308 window_secs: u64,
309 ) -> Result<Self, redis::RedisError> {
310 let client = redis::Client::open(redis_url)?;
311 let conn = ConnectionManager::new(client).await?;
312 Ok(Self {
313 conn,
314 burst_size,
315 window_secs,
316 key_prefix: "rate_limit".to_string(),
317 })
318 }
319
320 pub fn with_prefix(
322 conn: ConnectionManager,
323 burst_size: u32,
324 window_secs: u64,
325 key_prefix: String,
326 ) -> Self {
327 Self {
328 conn,
329 burst_size,
330 window_secs,
331 key_prefix,
332 }
333 }
334
335 pub async fn check(&self, ip: &IpAddr) -> Result<(u64, u64), redis::RedisError> {
341 let script = Script::new(
342 r"
343 local current = redis.call('INCR', KEYS[1])
344 if current == 1 then
345 redis.call('EXPIRE', KEYS[1], ARGV[1])
346 end
347 local ttl = redis.call('TTL', KEYS[1])
348 if ttl < 0 then ttl = tonumber(ARGV[1]) end
349 return {current, ttl}
350 ",
351 );
352
353 let key = format!("{}:{}", self.key_prefix, ip);
354 let mut conn = self.conn.clone();
355 let (count, ttl): (u64, u64) = script
356 .key(&key)
357 .arg(self.window_secs)
358 .invoke_async(&mut conn)
359 .await?;
360 Ok((count, ttl))
361 }
362
363 pub async fn is_allowed(&self, ip: &IpAddr) -> bool {
365 match self.check(ip).await {
366 Ok((count, _)) => count <= self.burst_size as u64,
367 Err(e) => {
368 tracing::warn!(error = %e, "Redis rate limiter unavailable, allowing request");
369 true
370 }
371 }
372 }
373
374 pub async fn check_with_headers(
376 &self,
377 ip: &IpAddr,
378 ) -> (bool, Option<u64>, Option<u64>) {
379 match self.check(ip).await {
380 Ok((count, ttl)) => {
381 let allowed = count <= self.burst_size as u64;
382 let remaining = if allowed {
383 Some(self.burst_size as u64 - count)
384 } else {
385 Some(0)
386 };
387 let retry_after = if allowed { None } else { Some(ttl.max(1)) };
388 (allowed, remaining, retry_after)
389 }
390 Err(e) => {
391 tracing::warn!(error = %e, "Redis rate limiter unavailable, allowing request");
392 (true, None, None)
393 }
394 }
395 }
396 }
397
398 #[derive(Clone)]
400 pub struct TieredRedisRateLimiter {
401 auth: RedisRateLimiter,
402 standard: RedisRateLimiter,
403 public: RedisRateLimiter,
404 }
405
406 impl TieredRedisRateLimiter {
407 pub async fn new(
409 redis_url: &str,
410 tiers: &RateLimitTiersConfig,
411 ) -> Result<Self, redis::RedisError> {
412 let client = redis::Client::open(redis_url)?;
413 let conn = ConnectionManager::new(client).await?;
414 Ok(Self {
415 auth: RedisRateLimiter::with_prefix(
416 conn.clone(),
417 tiers.auth.requests,
418 tiers.auth.window_secs,
419 "rate:auth".to_string(),
420 ),
421 standard: RedisRateLimiter::with_prefix(
422 conn.clone(),
423 tiers.standard.requests,
424 tiers.standard.window_secs,
425 "rate:standard".to_string(),
426 ),
427 public: RedisRateLimiter::with_prefix(
428 conn,
429 tiers.public.requests,
430 tiers.public.window_secs,
431 "rate:public".to_string(),
432 ),
433 })
434 }
435
436 fn tier_limiter(&self, tier: RateLimitTier) -> &RedisRateLimiter {
437 match tier {
438 RateLimitTier::Auth => &self.auth,
439 RateLimitTier::Standard => &self.standard,
440 RateLimitTier::Public => &self.public,
441 }
442 }
443
444 pub fn auth(&self) -> &RedisRateLimiter { &self.auth }
445 pub fn standard(&self) -> &RedisRateLimiter { &self.standard }
446 pub fn public(&self) -> &RedisRateLimiter { &self.public }
447 }
448
449 pub async fn redis_rate_limit_middleware(
451 limiter: Arc<TieredRedisRateLimiter>,
452 behind_proxy: bool,
453 req: Request<Body>,
454 next: axum::middleware::Next,
455 ) -> Response {
456 if req.method() == Method::OPTIONS {
458 return next.run(req).await;
459 }
460
461 let ip = match extract_ip(&req, behind_proxy) {
462 Some(ip) => ip,
463 None => {
464 tracing::warn!("rate limiter: could not extract client IP, bypassing rate limit");
465 return next.run(req).await;
466 }
467 };
468
469 let tier = classify_path(req.uri().path());
470 let tier_limiter = limiter.tier_limiter(tier);
471 let (allowed, remaining, retry_after) = tier_limiter.check_with_headers(&ip).await;
472
473 if allowed {
474 let mut response = next.run(req).await;
475 if let Some(remaining) = remaining {
476 let headers = response.headers_mut();
477 headers.insert("x-ratelimit-remaining", HeaderValue::from(remaining));
478 headers.insert("x-ratelimit-limit", HeaderValue::from(tier_limiter.burst_size));
479 if let Some(wait) = retry_after {
480 headers_insert_reset(headers, wait);
481 }
482 }
483 response
484 } else {
485 metrics::counter!("riley_auth_rate_limit_hits_total", "tier" => tier.as_str()).increment(1);
486 let mut response = Error::RateLimited.into_response();
487 let headers = response.headers_mut();
488 if let Some(wait) = retry_after {
489 headers_insert_retry_after(headers, wait);
490 headers_insert_reset(headers, wait);
491 }
492 headers.insert("x-ratelimit-remaining", HeaderValue::from(0u32));
493 headers.insert("x-ratelimit-limit", HeaderValue::from(tier_limiter.burst_size));
494 response
495 }
496 }
497}
498
499#[cfg(feature = "redis")]
500pub use redis_impl::{RedisRateLimiter, TieredRedisRateLimiter, redis_rate_limit_middleware};
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
509 fn classify_auth_exact_matches() {
510 assert_eq!(classify_path("/oauth/token"), RateLimitTier::Auth);
511 assert_eq!(classify_path("/oauth/authorize"), RateLimitTier::Auth);
512 assert_eq!(classify_path("/oauth/revoke"), RateLimitTier::Auth);
513 assert_eq!(classify_path("/oauth/introspect"), RateLimitTier::Auth);
514 assert_eq!(classify_path("/auth/setup"), RateLimitTier::Auth);
515 assert_eq!(classify_path("/auth/refresh"), RateLimitTier::Auth);
516 }
517
518 #[test]
519 fn classify_auth_provider_routes() {
520 assert_eq!(classify_path("/auth/google"), RateLimitTier::Auth);
522 assert_eq!(classify_path("/auth/github"), RateLimitTier::Auth);
523
524 assert_eq!(classify_path("/auth/google/callback"), RateLimitTier::Auth);
526 assert_eq!(classify_path("/auth/github/callback"), RateLimitTier::Auth);
527 }
528
529 #[test]
530 fn classify_auth_link_routes() {
531 assert_eq!(classify_path("/auth/link/google"), RateLimitTier::Auth);
533 assert_eq!(classify_path("/auth/link/github"), RateLimitTier::Auth);
534
535 assert_eq!(classify_path("/auth/link/google/callback"), RateLimitTier::Auth);
537 assert_eq!(classify_path("/auth/link/github/callback"), RateLimitTier::Auth);
538 }
539
540 #[test]
541 fn classify_auth_trailing_slash() {
542 assert_eq!(classify_path("/oauth/token/"), RateLimitTier::Auth);
543 assert_eq!(classify_path("/oauth/introspect/"), RateLimitTier::Auth);
544 assert_eq!(classify_path("/auth/setup/"), RateLimitTier::Auth);
545 assert_eq!(classify_path("/auth/google/"), RateLimitTier::Auth);
546 assert_eq!(classify_path("/auth/link/google/"), RateLimitTier::Auth);
547 }
548
549 #[test]
550 fn classify_public_endpoints() {
551 assert_eq!(classify_path("/health"), RateLimitTier::Public);
552 assert_eq!(classify_path("/.well-known/jwks.json"), RateLimitTier::Public);
553 assert_eq!(classify_path("/.well-known/openid-configuration"), RateLimitTier::Public);
554 }
555
556 #[test]
557 fn classify_standard_endpoints() {
558 assert_eq!(classify_path("/auth/me"), RateLimitTier::Standard);
559 assert_eq!(classify_path("/auth/logout"), RateLimitTier::Standard);
560 assert_eq!(classify_path("/auth/logout-all"), RateLimitTier::Standard);
561 assert_eq!(classify_path("/auth/sessions"), RateLimitTier::Standard);
562 assert_eq!(classify_path("/auth/sessions/some-id"), RateLimitTier::Standard);
563 assert_eq!(classify_path("/auth/me/links"), RateLimitTier::Standard);
564 assert_eq!(classify_path("/auth/me/username"), RateLimitTier::Standard);
565 assert_eq!(classify_path("/admin/users"), RateLimitTier::Standard);
566 assert_eq!(classify_path("/admin/webhooks"), RateLimitTier::Standard);
567 assert_eq!(classify_path("/oauth/consent"), RateLimitTier::Standard);
568 }
569
570 #[test]
573 fn extract_ip_direct() {
574 let req = Request::builder()
575 .uri("/test")
576 .body(())
577 .unwrap();
578 assert!(extract_ip(&req, false).is_none());
579 }
580
581 #[test]
582 fn extract_ip_from_xff_header() {
583 let req = Request::builder()
584 .uri("/test")
585 .header("x-forwarded-for", "203.0.113.50, 70.41.3.18, 150.172.238.178")
586 .body(())
587 .unwrap();
588 let ip = extract_ip(&req, true).unwrap();
589 assert_eq!(ip, "203.0.113.50".parse::<IpAddr>().unwrap());
590 }
591
592 #[test]
593 fn extract_ip_from_real_ip_header() {
594 let req = Request::builder()
595 .uri("/test")
596 .header("x-real-ip", "10.0.0.1")
597 .body(())
598 .unwrap();
599 let ip = extract_ip(&req, true).unwrap();
600 assert_eq!(ip, "10.0.0.1".parse::<IpAddr>().unwrap());
601 }
602
603 #[test]
604 fn extract_ip_ignores_proxy_headers_when_not_behind_proxy() {
605 let req = Request::builder()
606 .uri("/test")
607 .header("x-forwarded-for", "203.0.113.50")
608 .body(())
609 .unwrap();
610 assert!(extract_ip(&req, false).is_none());
611 }
612
613 #[test]
616 fn in_memory_limiter_allows_within_burst() {
617 let tiers = RateLimitTiersConfig::default();
618 let limiter = InMemoryRateLimiter::new(&tiers);
619 let ip: IpAddr = "10.0.0.1".parse().unwrap();
620
621 for i in 1..=15 {
623 let (allowed, _, _) = limiter.check(RateLimitTier::Auth, &ip);
624 assert!(allowed, "request {i} should be allowed");
625 }
626 let (allowed, remaining, _) = limiter.check(RateLimitTier::Auth, &ip);
627 assert!(!allowed, "request 16 should be rate-limited");
628 assert_eq!(remaining, 0);
629 }
630
631 #[test]
632 fn in_memory_limiter_tiers_are_independent() {
633 let tiers = RateLimitTiersConfig::default();
634 let limiter = InMemoryRateLimiter::new(&tiers);
635 let ip: IpAddr = "10.0.0.1".parse().unwrap();
636
637 for _ in 0..15 {
639 limiter.check(RateLimitTier::Auth, &ip);
640 }
641 let (allowed, _, _) = limiter.check(RateLimitTier::Auth, &ip);
642 assert!(!allowed, "auth should be exhausted");
643
644 let (allowed, _, _) = limiter.check(RateLimitTier::Standard, &ip);
646 assert!(allowed, "standard should still be available");
647 }
648
649 #[test]
650 fn in_memory_limiter_different_ips_independent() {
651 let tiers = RateLimitTiersConfig::default();
652 let limiter = InMemoryRateLimiter::new(&tiers);
653 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
654 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
655
656 for _ in 0..15 {
658 limiter.check(RateLimitTier::Auth, &ip1);
659 }
660 let (allowed, _, _) = limiter.check(RateLimitTier::Auth, &ip1);
661 assert!(!allowed);
662
663 let (allowed, _, _) = limiter.check(RateLimitTier::Auth, &ip2);
665 assert!(allowed);
666 }
667
668 #[test]
669 fn in_memory_retry_after_less_than_window() {
670 let tiers = RateLimitTiersConfig::default();
671 let limiter = InMemoryRateLimiter::new(&tiers);
672 let ip: IpAddr = "10.0.0.1".parse().unwrap();
673
674 for _ in 0..15 {
676 limiter.check(RateLimitTier::Auth, &ip);
677 }
678 let (allowed, _, retry_after) = limiter.check(RateLimitTier::Auth, &ip);
679 assert!(!allowed);
680 assert!(retry_after <= 60, "retry_after {retry_after} should be <= 60");
682 assert!(retry_after >= 1, "retry_after should be at least 1");
683 }
684}