Skip to main content

openauth_core/
rate_limit.rs

1//! Router-level rate limiting.
2
3use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{
7    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitRule,
8    RateLimitStorage, RateLimitStorageOption, RateLimitStore,
9};
10use crate::utils::ip::{
11    create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
12};
13use crate::utils::url::normalize_pathname;
14use governor::clock::Clock;
15use governor::middleware::StateInformationMiddleware;
16use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
17use http::Request;
18use std::collections::HashMap;
19use std::net::IpAddr;
20use std::num::NonZeroU32;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24pub type Body = Vec<u8>;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct RateLimitRejection {
28    pub retry_after: u64,
29}
30
31/// Framework-neutral client IP resolved by an HTTP adapter.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct RequestClientIp(pub IpAddr);
34
35type GovernorLimiter = DefaultKeyedRateLimiter<String, StateInformationMiddleware>;
36
37#[derive(Default)]
38pub struct GovernorMemoryRateLimitStore {
39    limiters: Mutex<HashMap<(u64, u64), Arc<GovernorLimiter>>>,
40    cleanup_interval: Option<Duration>,
41    last_cleanup: Mutex<Option<Instant>>,
42}
43
44impl GovernorMemoryRateLimitStore {
45    pub fn new() -> Self {
46        Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
47    }
48
49    pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
50        Self {
51            limiters: Mutex::new(HashMap::new()),
52            cleanup_interval,
53            last_cleanup: Mutex::new(None),
54        }
55    }
56
57    fn limiter(&self, rule: &RateLimitRule) -> Result<Arc<GovernorLimiter>, OpenAuthError> {
58        self.cleanup_if_due()?;
59        let quota = quota(rule)?;
60        let mut limiters = self
61            .limiters
62            .lock()
63            .map_err(|_| OpenAuthError::Api("rate limit store lock poisoned".to_owned()))?;
64        let key = (rule.window, rule.max);
65        if let Some(limiter) = limiters.get(&key) {
66            return Ok(Arc::clone(limiter));
67        }
68        let limiter =
69            Arc::new(RateLimiter::keyed(quota).with_middleware::<StateInformationMiddleware>());
70        limiters.insert(key, Arc::clone(&limiter));
71        Ok(limiter)
72    }
73
74    fn cleanup_if_due(&self) -> Result<(), OpenAuthError> {
75        let Some(interval) = self.cleanup_interval else {
76            return Ok(());
77        };
78
79        let mut last_cleanup = self
80            .last_cleanup
81            .lock()
82            .map_err(|_| OpenAuthError::Api("rate limit cleanup lock poisoned".to_owned()))?;
83        let now = Instant::now();
84        if last_cleanup
85            .as_ref()
86            .is_some_and(|last| last.elapsed() < interval)
87        {
88            return Ok(());
89        }
90        *last_cleanup = Some(now);
91        drop(last_cleanup);
92
93        let limiters = self
94            .limiters
95            .lock()
96            .map_err(|_| OpenAuthError::Api("rate limit store lock poisoned".to_owned()))?
97            .values()
98            .cloned()
99            .collect::<Vec<_>>();
100        for limiter in limiters {
101            limiter.retain_recent();
102            limiter.shrink_to_fit();
103        }
104        Ok(())
105    }
106}
107
108impl RateLimitStore for GovernorMemoryRateLimitStore {
109    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
110        Box::pin(async move {
111            let limiter = self.limiter(&input.rule)?;
112            match limiter.check_key(&input.key) {
113                Ok(snapshot) => Ok(RateLimitDecision {
114                    permitted: true,
115                    retry_after: 0,
116                    limit: input.rule.max,
117                    remaining: u64::from(snapshot.remaining_burst_capacity()),
118                    reset_after: input.rule.window,
119                }),
120                Err(not_until) => {
121                    let retry_after =
122                        ceil_duration_to_seconds(not_until.wait_time_from(limiter.clock().now()));
123                    Ok(RateLimitDecision {
124                        permitted: false,
125                        retry_after,
126                        limit: input.rule.max,
127                        remaining: 0,
128                        reset_after: retry_after,
129                    })
130                }
131            }
132        })
133    }
134}
135
136pub struct LegacyRateLimitStorageAdapter {
137    storage: Arc<dyn RateLimitStorage>,
138}
139
140impl LegacyRateLimitStorageAdapter {
141    pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
142        Self { storage }
143    }
144}
145
146impl RateLimitStore for LegacyRateLimitStorageAdapter {
147    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
148        Box::pin(async move {
149            let window_ms = input.rule.window.saturating_mul(1000) as i64;
150            let existing = self.storage.get(&input.key)?;
151            let decision = match existing {
152                Some(record)
153                    if input.now_ms.saturating_sub(record.last_request) <= window_ms
154                        && record.count >= input.rule.max =>
155                {
156                    denied_decision(&input, record.last_request)
157                }
158                Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
159                    let next_count = record.count.saturating_add(1);
160                    self.storage.set(
161                        &input.key,
162                        RateLimitRecord {
163                            key: input.key.clone(),
164                            count: next_count,
165                            last_request: input.now_ms,
166                        },
167                        input.rule.window,
168                        true,
169                    )?;
170                    allowed_decision(&input, next_count)
171                }
172                _ => {
173                    self.storage.set(
174                        &input.key,
175                        RateLimitRecord {
176                            key: input.key.clone(),
177                            count: 1,
178                            last_request: input.now_ms,
179                        },
180                        input.rule.window,
181                        existing.is_some(),
182                    )?;
183                    allowed_decision(&input, 1)
184                }
185            };
186            Ok(decision)
187        })
188    }
189}
190
191pub struct HybridRateLimitStore {
192    local: Arc<GovernorMemoryRateLimitStore>,
193    global: Arc<dyn RateLimitStore>,
194    local_multiplier: u64,
195}
196
197impl HybridRateLimitStore {
198    pub fn new(
199        local: Arc<GovernorMemoryRateLimitStore>,
200        global: Arc<dyn RateLimitStore>,
201        local_multiplier: u64,
202    ) -> Self {
203        Self {
204            local,
205            global,
206            local_multiplier: local_multiplier.max(1),
207        }
208    }
209}
210
211impl RateLimitStore for HybridRateLimitStore {
212    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
213        Box::pin(async move {
214            let local_input = RateLimitConsumeInput {
215                key: input.key.clone(),
216                rule: RateLimitRule {
217                    window: input.rule.window,
218                    max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
219                },
220                now_ms: input.now_ms,
221            };
222            let local = self.local.consume(local_input).await?;
223            if !local.permitted {
224                return Ok(local);
225            }
226            self.global.consume(input).await
227        })
228    }
229}
230
231pub async fn consume_rate_limit(
232    context: &AuthContext,
233    request: &Request<Body>,
234) -> Result<Option<RateLimitRejection>, OpenAuthError> {
235    if !context.rate_limit.enabled {
236        return Ok(None);
237    }
238    let Some(config) = resolve_config(context, request)? else {
239        return Ok(None);
240    };
241    let store = store(context);
242    let decision = store
243        .consume(RateLimitConsumeInput {
244            key: config.key,
245            rule: config.rule,
246            now_ms: now_millis(),
247        })
248        .await?;
249    if decision.permitted {
250        return Ok(None);
251    }
252    Ok(Some(RateLimitRejection {
253        retry_after: decision.retry_after,
254    }))
255}
256
257pub fn on_request_rate_limit(
258    context: &AuthContext,
259    request: &Request<Body>,
260) -> Result<Option<RateLimitRejection>, OpenAuthError> {
261    if !context.rate_limit.enabled {
262        return Ok(None);
263    }
264    if resolve_config(context, request)?.is_none() {
265        return Ok(None);
266    }
267    Err(OpenAuthError::Api(
268        "async rate limit storage requires AuthRouter::handle_async".to_owned(),
269    ))
270}
271
272pub fn on_response_rate_limit(
273    _context: &AuthContext,
274    _request: &Request<Body>,
275) -> Result<(), OpenAuthError> {
276    Ok(())
277}
278
279#[derive(Debug)]
280struct ResolvedRateLimit {
281    key: String,
282    rule: RateLimitRule,
283}
284
285fn resolve_config(
286    context: &AuthContext,
287    request: &Request<Body>,
288) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
289    let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
290    let Some(ip) = request_ip(context, request) else {
291        return Ok(None);
292    };
293    let Some(rule) = resolve_rule(context, request, &path)? else {
294        return Ok(None);
295    };
296    Ok(Some(ResolvedRateLimit {
297        key: create_rate_limit_key(&ip, &path),
298        rule,
299    }))
300}
301
302fn resolve_rule(
303    context: &AuthContext,
304    request: &Request<Body>,
305    path: &str,
306) -> Result<Option<RateLimitRule>, OpenAuthError> {
307    let mut rule = default_rule(context);
308    if let Some(special_rule) = default_special_rule(path) {
309        rule = special_rule;
310    }
311    for plugin_rule in &context.rate_limit.plugin_rules {
312        if path_matches(&plugin_rule.path, path) {
313            rule = plugin_rule.rule.clone();
314            break;
315        }
316    }
317    for custom_rule in &context.rate_limit.custom_rules {
318        if path_matches(&custom_rule.path, path) {
319            return Ok(custom_rule.rule.clone());
320        }
321    }
322    for dynamic_rule in &context.rate_limit.dynamic_rules {
323        if path_matches(&dynamic_rule.path, path) {
324            return dynamic_rule.provider.resolve(request, &rule);
325        }
326    }
327    Ok(Some(rule))
328}
329
330fn default_rule(context: &AuthContext) -> RateLimitRule {
331    RateLimitRule {
332        window: context.rate_limit.window,
333        max: context.rate_limit.max,
334    }
335}
336
337fn default_special_rule(path: &str) -> Option<RateLimitRule> {
338    if path.starts_with("/sign-in")
339        || path.starts_with("/sign-up")
340        || path.starts_with("/change-password")
341        || path.starts_with("/change-email")
342    {
343        return Some(RateLimitRule { window: 10, max: 3 });
344    }
345    if path == "/request-password-reset"
346        || path == "/send-verification-email"
347        || path.starts_with("/forget-password")
348        || path == "/email-otp/send-verification-otp"
349        || path == "/email-otp/request-password-reset"
350    {
351        return Some(RateLimitRule { window: 60, max: 3 });
352    }
353    None
354}
355
356fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
357    if context.options.advanced.ip_address.disable_ip_tracking {
358        return None;
359    }
360
361    for header_name in &context.options.advanced.ip_address.headers {
362        if let Some(value) = request
363            .headers()
364            .get(header_name)
365            .and_then(|value| value.to_str().ok())
366        {
367            let Some(candidate) = value.split(',').next().map(str::trim) else {
368                continue;
369            };
370            if is_valid_ip(candidate) {
371                return Some(normalize_ip_with_options(
372                    candidate,
373                    NormalizeIpOptions {
374                        ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
375                    },
376                ));
377            }
378        }
379    }
380
381    if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
382        return Some(normalize_ip_with_options(
383            &client_ip.0.to_string(),
384            NormalizeIpOptions {
385                ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
386            },
387        ));
388    }
389
390    if !context.options.production && !is_production() {
391        return Some("127.0.0.1".to_owned());
392    }
393
394    None
395}
396
397fn store(context: &AuthContext) -> Arc<dyn RateLimitStore> {
398    if let Some(store) = &context.rate_limit.custom_store {
399        if context.rate_limit.hybrid.enabled {
400            return Arc::new(HybridRateLimitStore::new(
401                Arc::clone(&context.rate_limit.memory_store),
402                Arc::clone(store),
403                context.rate_limit.hybrid.local_multiplier,
404            ));
405        }
406        return Arc::clone(store);
407    }
408    match context.rate_limit.storage {
409        RateLimitStorageOption::Memory
410        | RateLimitStorageOption::Database
411        | RateLimitStorageOption::SecondaryStorage => context.rate_limit.memory_store.clone(),
412    }
413}
414
415fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
416    RateLimitDecision {
417        permitted: true,
418        retry_after: 0,
419        limit: input.rule.max,
420        remaining: input.rule.max.saturating_sub(count),
421        reset_after: input.rule.window,
422    }
423}
424
425fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
426    let retry_after = last_request
427        .saturating_add(input.rule.window.saturating_mul(1000) as i64)
428        .saturating_sub(input.now_ms)
429        .max(0);
430    RateLimitDecision {
431        permitted: false,
432        retry_after: ceil_millis_to_seconds(retry_after),
433        limit: input.rule.max,
434        remaining: 0,
435        reset_after: ceil_millis_to_seconds(retry_after),
436    }
437}
438
439fn ceil_duration_to_seconds(duration: std::time::Duration) -> u64 {
440    if duration.is_zero() {
441        return 0;
442    }
443    duration
444        .as_secs()
445        .saturating_add(u64::from(duration.subsec_nanos() > 0))
446}
447
448fn quota(rule: &RateLimitRule) -> Result<Quota, OpenAuthError> {
449    if rule.window == 0 {
450        return Err(OpenAuthError::InvalidConfig(
451            "rate limit window must be greater than zero".to_owned(),
452        ));
453    }
454    let max =
455        NonZeroU32::new(u32::try_from(rule.max).map_err(|_| {
456            OpenAuthError::InvalidConfig("rate limit max must fit in u32".to_owned())
457        })?)
458        .ok_or_else(|| {
459            OpenAuthError::InvalidConfig("rate limit max must be greater than zero".to_owned())
460        })?;
461    let mut replenish_interval = Duration::from_secs(rule.window) / max.get();
462    if replenish_interval.is_zero() {
463        replenish_interval = Duration::from_nanos(1);
464    }
465    Quota::with_period(replenish_interval)
466        .map(|quota| quota.allow_burst(max))
467        .ok_or_else(|| {
468            OpenAuthError::InvalidConfig(
469                "rate limit replenish interval must be greater than zero".to_owned(),
470            )
471        })
472}
473
474fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
475    if milliseconds <= 0 {
476        return 0;
477    }
478    ((milliseconds as u64).saturating_add(999)) / 1000
479}
480
481fn path_matches(pattern: &str, path: &str) -> bool {
482    if let Some((prefix, suffix)) = pattern.split_once('*') {
483        return path.starts_with(prefix) && path.ends_with(suffix);
484    }
485    pattern == path
486}
487
488fn now_millis() -> i64 {
489    time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
490}