Skip to main content

openauth_core/
rate_limit.rs

1//! Router-level rate limiting.
2
3use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{
7    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitRule,
8    RateLimitStorage, RateLimitStorageOption, RateLimitStore,
9};
10use crate::utils::ip::{
11    create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
12};
13use crate::utils::url::normalize_pathname;
14use http::Request;
15use std::collections::HashMap;
16use std::net::IpAddr;
17use std::sync::{Arc, Mutex};
18use std::time::{Duration, Instant};
19
20pub type Body = Vec<u8>;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct RateLimitRejection {
24    pub retry_after: u64,
25}
26
27/// Framework-neutral client IP resolved by an HTTP adapter.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct RequestClientIp(pub IpAddr);
30
31#[derive(Default)]
32pub struct GovernorMemoryRateLimitStore {
33    records: Mutex<HashMap<String, MemoryRateLimitRecord>>,
34    cleanup_interval: Option<Duration>,
35    last_cleanup: Mutex<Option<Instant>>,
36}
37
38#[derive(Debug, Clone)]
39struct MemoryRateLimitRecord {
40    count: u64,
41    last_request: i64,
42    window_ms: i64,
43}
44
45impl GovernorMemoryRateLimitStore {
46    pub fn new() -> Self {
47        Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
48    }
49
50    pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
51        Self {
52            records: Mutex::new(HashMap::new()),
53            cleanup_interval,
54            last_cleanup: Mutex::new(None),
55        }
56    }
57
58    fn cleanup_if_due(&self, now_ms: i64) -> Result<(), OpenAuthError> {
59        let Some(interval) = self.cleanup_interval else {
60            return Ok(());
61        };
62
63        let mut last_cleanup =
64            self.last_cleanup
65                .lock()
66                .map_err(|_| OpenAuthError::LockPoisoned {
67                    context: "rate limit cleanup",
68                })?;
69        let now = Instant::now();
70        if last_cleanup
71            .as_ref()
72            .is_some_and(|last| last.elapsed() < interval)
73        {
74            return Ok(());
75        }
76        *last_cleanup = Some(now);
77        drop(last_cleanup);
78
79        self.records
80            .lock()
81            .map_err(|_| OpenAuthError::LockPoisoned {
82                context: "rate limit store",
83            })?
84            .retain(|_, record| now_ms.saturating_sub(record.last_request) <= record.window_ms);
85        Ok(())
86    }
87}
88
89impl RateLimitStore for GovernorMemoryRateLimitStore {
90    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
91        Box::pin(async move {
92            validate_rule(&input.rule)?;
93            self.cleanup_if_due(input.now_ms)?;
94            let window_ms = rule_window_ms(&input.rule)?;
95            let mut records = self
96                .records
97                .lock()
98                .map_err(|_| OpenAuthError::LockPoisoned {
99                    context: "rate limit store",
100                })?;
101            let decision = match records.get_mut(&input.key) {
102                Some(record)
103                    if input.now_ms.saturating_sub(record.last_request) <= window_ms
104                        && record.count >= input.rule.max =>
105                {
106                    denied_decision(&input, record.last_request)
107                }
108                Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
109                    record.count = record.count.saturating_add(1);
110                    record.last_request = input.now_ms;
111                    record.window_ms = window_ms;
112                    allowed_decision(&input, record.count)
113                }
114                _ => {
115                    records.insert(
116                        input.key.clone(),
117                        MemoryRateLimitRecord {
118                            count: 1,
119                            last_request: input.now_ms,
120                            window_ms,
121                        },
122                    );
123                    allowed_decision(&input, 1)
124                }
125            };
126            Ok(decision)
127        })
128    }
129}
130
131pub struct LegacyRateLimitStorageAdapter {
132    storage: Arc<dyn RateLimitStorage>,
133}
134
135impl LegacyRateLimitStorageAdapter {
136    pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
137        Self { storage }
138    }
139}
140
141impl RateLimitStore for LegacyRateLimitStorageAdapter {
142    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
143        Box::pin(async move {
144            validate_rule(&input.rule)?;
145            let window_ms = rule_window_ms(&input.rule)?;
146            let existing = self.storage.get(&input.key)?;
147            let decision = match existing {
148                Some(record)
149                    if input.now_ms.saturating_sub(record.last_request) <= window_ms
150                        && record.count >= input.rule.max =>
151                {
152                    denied_decision(&input, record.last_request)
153                }
154                Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
155                    let next_count = record.count.saturating_add(1);
156                    self.storage.set(
157                        &input.key,
158                        RateLimitRecord {
159                            key: input.key.clone(),
160                            count: next_count,
161                            last_request: input.now_ms,
162                        },
163                        input.rule.window,
164                        true,
165                    )?;
166                    allowed_decision(&input, next_count)
167                }
168                _ => {
169                    self.storage.set(
170                        &input.key,
171                        RateLimitRecord {
172                            key: input.key.clone(),
173                            count: 1,
174                            last_request: input.now_ms,
175                        },
176                        input.rule.window,
177                        existing.is_some(),
178                    )?;
179                    allowed_decision(&input, 1)
180                }
181            };
182            Ok(decision)
183        })
184    }
185}
186
187pub struct HybridRateLimitStore {
188    local: Arc<GovernorMemoryRateLimitStore>,
189    global: Arc<dyn RateLimitStore>,
190    local_multiplier: u64,
191}
192
193impl HybridRateLimitStore {
194    pub fn new(
195        local: Arc<GovernorMemoryRateLimitStore>,
196        global: Arc<dyn RateLimitStore>,
197        local_multiplier: u64,
198    ) -> Self {
199        Self {
200            local,
201            global,
202            local_multiplier: local_multiplier.max(1),
203        }
204    }
205}
206
207impl RateLimitStore for HybridRateLimitStore {
208    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
209        Box::pin(async move {
210            let local_input = RateLimitConsumeInput {
211                key: input.key.clone(),
212                rule: RateLimitRule {
213                    window: input.rule.window,
214                    max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
215                },
216                now_ms: input.now_ms,
217            };
218            let local = self.local.consume(local_input).await?;
219            if !local.permitted {
220                return Ok(local);
221            }
222            self.global.consume(input).await
223        })
224    }
225}
226
227pub async fn consume_rate_limit(
228    context: &AuthContext,
229    request: &Request<Body>,
230) -> Result<Option<RateLimitRejection>, OpenAuthError> {
231    if !context.rate_limit.enabled {
232        return Ok(None);
233    }
234    let Some(config) = resolve_config(context, request)? else {
235        return Ok(None);
236    };
237    let store = store(context)?;
238    let decision = store
239        .consume(RateLimitConsumeInput {
240            key: config.key,
241            rule: config.rule,
242            now_ms: now_millis(),
243        })
244        .await?;
245    if decision.permitted {
246        return Ok(None);
247    }
248    Ok(Some(RateLimitRejection {
249        retry_after: decision.retry_after,
250    }))
251}
252
253pub fn on_request_rate_limit(
254    context: &AuthContext,
255    request: &Request<Body>,
256) -> Result<Option<RateLimitRejection>, OpenAuthError> {
257    if !context.rate_limit.enabled {
258        return Ok(None);
259    }
260    if resolve_config(context, request)?.is_none() {
261        return Ok(None);
262    }
263    Err(OpenAuthError::Api(
264        "async rate limit storage requires AuthRouter::handle_async".to_owned(),
265    ))
266}
267
268pub fn on_response_rate_limit(
269    _context: &AuthContext,
270    _request: &Request<Body>,
271) -> Result<(), OpenAuthError> {
272    Ok(())
273}
274
275#[derive(Debug)]
276struct ResolvedRateLimit {
277    key: String,
278    rule: RateLimitRule,
279}
280
281fn resolve_config(
282    context: &AuthContext,
283    request: &Request<Body>,
284) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
285    let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
286    let Some(ip) = request_ip(context, request) else {
287        return Ok(None);
288    };
289    let Some(rule) = resolve_rule(context, request, &path)? else {
290        return Ok(None);
291    };
292    Ok(Some(ResolvedRateLimit {
293        key: create_rate_limit_key(&ip, &path),
294        rule,
295    }))
296}
297
298fn resolve_rule(
299    context: &AuthContext,
300    request: &Request<Body>,
301    path: &str,
302) -> Result<Option<RateLimitRule>, OpenAuthError> {
303    let mut rule = default_rule(context);
304    if let Some(special_rule) = default_special_rule(path) {
305        rule = special_rule;
306    }
307    for plugin_rule in &context.rate_limit.plugin_rules {
308        if path_matches(&plugin_rule.path, path) {
309            rule = plugin_rule.rule.clone();
310            break;
311        }
312    }
313    for custom_rule in &context.rate_limit.custom_rules {
314        if path_matches(&custom_rule.path, path) {
315            return Ok(custom_rule.rule.clone());
316        }
317    }
318    for dynamic_rule in &context.rate_limit.dynamic_rules {
319        if path_matches(&dynamic_rule.path, path) {
320            return dynamic_rule.provider.resolve(request, &rule);
321        }
322    }
323    Ok(Some(rule))
324}
325
326fn default_rule(context: &AuthContext) -> RateLimitRule {
327    RateLimitRule {
328        window: context.rate_limit.window,
329        max: context.rate_limit.max,
330    }
331}
332
333fn default_special_rule(path: &str) -> Option<RateLimitRule> {
334    if path.starts_with("/sign-in")
335        || path.starts_with("/sign-up")
336        || path.starts_with("/change-password")
337        || path.starts_with("/change-email")
338    {
339        return Some(RateLimitRule { window: 10, max: 3 });
340    }
341    if path == "/request-password-reset"
342        || path == "/send-verification-email"
343        || path.starts_with("/forget-password")
344        || path == "/email-otp/send-verification-otp"
345        || path == "/email-otp/request-password-reset"
346    {
347        return Some(RateLimitRule { window: 60, max: 3 });
348    }
349    None
350}
351
352fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
353    if context.options.advanced.ip_address.disable_ip_tracking {
354        return None;
355    }
356
357    for header_name in &context.options.advanced.ip_address.headers {
358        if let Some(value) = request
359            .headers()
360            .get(header_name)
361            .and_then(|value| value.to_str().ok())
362        {
363            let Some(candidate) = value.split(',').next().map(str::trim) else {
364                continue;
365            };
366            if is_valid_ip(candidate) {
367                return Some(normalize_ip_with_options(
368                    candidate,
369                    NormalizeIpOptions {
370                        ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
371                    },
372                ));
373            }
374        }
375    }
376
377    if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
378        return Some(normalize_ip_with_options(
379            &client_ip.0.to_string(),
380            NormalizeIpOptions {
381                ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
382            },
383        ));
384    }
385
386    if !context.options.production && !is_production() {
387        return Some("127.0.0.1".to_owned());
388    }
389
390    None
391}
392
393fn store(context: &AuthContext) -> Result<Arc<dyn RateLimitStore>, OpenAuthError> {
394    if let Some(store) = &context.rate_limit.custom_store {
395        if context.rate_limit.hybrid.enabled {
396            return Ok(Arc::new(HybridRateLimitStore::new(
397                Arc::clone(&context.rate_limit.memory_store),
398                Arc::clone(store),
399                context.rate_limit.hybrid.local_multiplier,
400            )));
401        }
402        return Ok(Arc::clone(store));
403    }
404    match context.rate_limit.storage {
405        RateLimitStorageOption::Memory => Ok(context.rate_limit.memory_store.clone()),
406        RateLimitStorageOption::Database => Err(OpenAuthError::InvalidConfig(
407            "database rate limit storage requires a concrete RateLimitStore".to_owned(),
408        )),
409        RateLimitStorageOption::SecondaryStorage => Err(OpenAuthError::InvalidConfig(
410            "secondary-storage rate limit storage requires a concrete RateLimitStore".to_owned(),
411        )),
412    }
413}
414
415fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
416    RateLimitDecision {
417        permitted: true,
418        retry_after: 0,
419        limit: input.rule.max,
420        remaining: input.rule.max.saturating_sub(count),
421        reset_after: input.rule.window,
422    }
423}
424
425fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
426    let window_ms = i64::try_from(input.rule.window.saturating_mul(1000)).unwrap_or(i64::MAX);
427    let retry_after = last_request
428        .saturating_add(window_ms)
429        .saturating_sub(input.now_ms)
430        .max(0);
431    RateLimitDecision {
432        permitted: false,
433        retry_after: ceil_millis_to_seconds(retry_after),
434        limit: input.rule.max,
435        remaining: 0,
436        reset_after: ceil_millis_to_seconds(retry_after),
437    }
438}
439
440fn validate_rule(rule: &RateLimitRule) -> Result<(), OpenAuthError> {
441    if rule.window == 0 {
442        return Err(OpenAuthError::InvalidConfig(
443            "rate limit window must be greater than zero".to_owned(),
444        ));
445    }
446    if rule.max == 0 {
447        return Err(OpenAuthError::InvalidConfig(
448            "rate limit max must be greater than zero".to_owned(),
449        ));
450    }
451    Ok(())
452}
453
454fn rule_window_ms(rule: &RateLimitRule) -> Result<i64, OpenAuthError> {
455    let milliseconds = rule
456        .window
457        .checked_mul(1000)
458        .ok_or_else(|| OpenAuthError::InvalidConfig("rate limit window is too large".to_owned()))?;
459    i64::try_from(milliseconds)
460        .map_err(|_| OpenAuthError::InvalidConfig("rate limit window is too large".to_owned()))
461}
462
463fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
464    if milliseconds <= 0 {
465        return 0;
466    }
467    ((milliseconds as u64).saturating_add(999)) / 1000
468}
469
470fn path_matches(pattern: &str, path: &str) -> bool {
471    if let Some((prefix, suffix)) = pattern.split_once('*') {
472        return path.starts_with(prefix) && path.ends_with(suffix);
473    }
474    pattern == path
475}
476
477fn now_millis() -> i64 {
478    time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
479}