Skip to main content

openauth_core/
rate_limit.rs

1//! Router-level rate limiting.
2
3use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{RateLimitRecord, RateLimitRule, RateLimitStorage, RateLimitStorageOption};
7use crate::utils::ip::{
8    create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
9};
10use crate::utils::url::normalize_pathname;
11use http::Request;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15pub type Body = Vec<u8>;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct RateLimitRejection {
19    pub retry_after: u64,
20}
21
22#[derive(Debug, Default)]
23pub struct MemoryRateLimitStorage {
24    entries: Mutex<HashMap<String, MemoryRateLimitEntry>>,
25}
26
27#[derive(Debug, Clone)]
28struct MemoryRateLimitEntry {
29    record: RateLimitRecord,
30    expires_at: i64,
31}
32
33impl MemoryRateLimitStorage {
34    pub fn new() -> Self {
35        Self::default()
36    }
37}
38
39impl RateLimitStorage for MemoryRateLimitStorage {
40    fn get(&self, key: &str) -> Result<Option<RateLimitRecord>, OpenAuthError> {
41        let now = now_seconds();
42        let mut entries = self
43            .entries
44            .lock()
45            .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
46        let Some(entry) = entries.get(key) else {
47            return Ok(None);
48        };
49        if now >= entry.expires_at {
50            entries.remove(key);
51            return Ok(None);
52        }
53        Ok(Some(entry.record.clone()))
54    }
55
56    fn set(
57        &self,
58        key: &str,
59        value: RateLimitRecord,
60        ttl_seconds: u64,
61        _update: bool,
62    ) -> Result<(), OpenAuthError> {
63        let expires_at = now_seconds().saturating_add(ttl_seconds as i64);
64        let mut entries = self
65            .entries
66            .lock()
67            .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
68        entries.insert(
69            key.to_owned(),
70            MemoryRateLimitEntry {
71                record: value,
72                expires_at,
73            },
74        );
75        Ok(())
76    }
77}
78
79pub fn on_request_rate_limit(
80    context: &AuthContext,
81    request: &Request<Body>,
82) -> Result<Option<RateLimitRejection>, OpenAuthError> {
83    if !context.rate_limit.enabled {
84        return Ok(None);
85    }
86    let Some(config) = resolve_config(context, request)? else {
87        return Ok(None);
88    };
89    let storage = storage(context);
90    let Some(record) = storage.get(&config.key)? else {
91        return Ok(None);
92    };
93
94    if should_rate_limit(config.rule.max, config.rule.window, &record) {
95        return Ok(Some(RateLimitRejection {
96            retry_after: retry_after(record.last_request, config.rule.window),
97        }));
98    }
99
100    Ok(None)
101}
102
103pub fn on_response_rate_limit(
104    context: &AuthContext,
105    request: &Request<Body>,
106) -> Result<(), OpenAuthError> {
107    if !context.rate_limit.enabled {
108        return Ok(());
109    }
110    let Some(config) = resolve_config(context, request)? else {
111        return Ok(());
112    };
113    let storage = storage(context);
114    let now = now_seconds();
115    let next_record = match storage.get(&config.key)? {
116        Some(record) if now.saturating_sub(record.last_request) <= config.rule.window as i64 => {
117            RateLimitRecord {
118                key: config.key.clone(),
119                count: record.count.saturating_add(1),
120                last_request: now,
121            }
122        }
123        _ => RateLimitRecord {
124            key: config.key.clone(),
125            count: 1,
126            last_request: now,
127        },
128    };
129
130    let update = next_record.count > 1;
131    storage.set(&config.key, next_record, config.rule.window, update)
132}
133
134#[derive(Debug)]
135struct ResolvedRateLimit {
136    key: String,
137    rule: RateLimitRule,
138}
139
140fn resolve_config(
141    context: &AuthContext,
142    request: &Request<Body>,
143) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
144    let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
145    let Some(ip) = request_ip(context, request) else {
146        return Ok(None);
147    };
148    let Some(rule) = resolve_rule(context, request, &path)? else {
149        return Ok(None);
150    };
151    Ok(Some(ResolvedRateLimit {
152        key: create_rate_limit_key(&ip, &path),
153        rule,
154    }))
155}
156
157fn resolve_rule(
158    context: &AuthContext,
159    request: &Request<Body>,
160    path: &str,
161) -> Result<Option<RateLimitRule>, OpenAuthError> {
162    let mut rule = default_rule(context);
163    if let Some(special_rule) = default_special_rule(path) {
164        rule = special_rule;
165    }
166    for plugin_rule in &context.rate_limit.plugin_rules {
167        if path_matches(&plugin_rule.path, path) {
168            rule = plugin_rule.rule.clone();
169            break;
170        }
171    }
172    for custom_rule in &context.rate_limit.custom_rules {
173        if path_matches(&custom_rule.path, path) {
174            return Ok(custom_rule.rule.clone());
175        }
176    }
177    for dynamic_rule in &context.rate_limit.dynamic_rules {
178        if path_matches(&dynamic_rule.path, path) {
179            return dynamic_rule.provider.resolve(request, &rule);
180        }
181    }
182    Ok(Some(rule))
183}
184
185fn default_rule(context: &AuthContext) -> RateLimitRule {
186    RateLimitRule {
187        window: context.rate_limit.window,
188        max: context.rate_limit.max,
189    }
190}
191
192fn default_special_rule(path: &str) -> Option<RateLimitRule> {
193    if path.starts_with("/sign-in")
194        || path.starts_with("/sign-up")
195        || path.starts_with("/change-password")
196        || path.starts_with("/change-email")
197    {
198        return Some(RateLimitRule { window: 10, max: 3 });
199    }
200    if path == "/request-password-reset"
201        || path == "/send-verification-email"
202        || path.starts_with("/forget-password")
203        || path == "/email-otp/send-verification-otp"
204        || path == "/email-otp/request-password-reset"
205    {
206        return Some(RateLimitRule { window: 60, max: 3 });
207    }
208    None
209}
210
211fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
212    if context.options.advanced.ip_address.disable_ip_tracking {
213        return None;
214    }
215
216    for header_name in &context.options.advanced.ip_address.headers {
217        if let Some(value) = request
218            .headers()
219            .get(header_name)
220            .and_then(|value| value.to_str().ok())
221        {
222            let Some(candidate) = value.split(',').next().map(str::trim) else {
223                continue;
224            };
225            if is_valid_ip(candidate) {
226                return Some(normalize_ip_with_options(
227                    candidate,
228                    NormalizeIpOptions {
229                        ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
230                    },
231                ));
232            }
233        }
234    }
235
236    if !context.options.production && !is_production() {
237        return Some("127.0.0.1".to_owned());
238    }
239
240    None
241}
242
243fn storage(context: &AuthContext) -> &dyn RateLimitStorage {
244    if let Some(storage) = &context.rate_limit.custom_storage {
245        return storage.as_ref();
246    }
247    match context.rate_limit.storage {
248        RateLimitStorageOption::Memory
249        | RateLimitStorageOption::Database
250        | RateLimitStorageOption::SecondaryStorage => context.rate_limit.memory_storage.as_ref(),
251    }
252}
253
254fn should_rate_limit(max: u64, window: u64, record: &RateLimitRecord) -> bool {
255    let time_since_last_request = now_seconds().saturating_sub(record.last_request);
256    time_since_last_request < window as i64 && record.count >= max
257}
258
259fn retry_after(last_request: i64, window: u64) -> u64 {
260    let retry_after = last_request
261        .saturating_add(window as i64)
262        .saturating_sub(now_seconds());
263    retry_after.max(0) as u64
264}
265
266fn path_matches(pattern: &str, path: &str) -> bool {
267    if let Some((prefix, suffix)) = pattern.split_once('*') {
268        return path.starts_with(prefix) && path.ends_with(suffix);
269    }
270    pattern == path
271}
272
273fn now_seconds() -> i64 {
274    time::OffsetDateTime::now_utc().unix_timestamp()
275}