1use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{
7 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitRule,
8 RateLimitStorage, RateLimitStorageOption, RateLimitStore,
9};
10use crate::utils::ip::{
11 create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
12};
13use crate::utils::url::normalize_pathname;
14use governor::clock::Clock;
15use governor::middleware::StateInformationMiddleware;
16use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
17use http::Request;
18use std::collections::HashMap;
19use std::net::IpAddr;
20use std::num::NonZeroU32;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24pub type Body = Vec<u8>;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct RateLimitRejection {
28 pub retry_after: u64,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct RequestClientIp(pub IpAddr);
34
35type GovernorLimiter = DefaultKeyedRateLimiter<String, StateInformationMiddleware>;
36
37#[derive(Default)]
38pub struct GovernorMemoryRateLimitStore {
39 limiters: Mutex<HashMap<(u64, u64), Arc<GovernorLimiter>>>,
40 cleanup_interval: Option<Duration>,
41 last_cleanup: Mutex<Option<Instant>>,
42}
43
44impl GovernorMemoryRateLimitStore {
45 pub fn new() -> Self {
46 Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
47 }
48
49 pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
50 Self {
51 limiters: Mutex::new(HashMap::new()),
52 cleanup_interval,
53 last_cleanup: Mutex::new(None),
54 }
55 }
56
57 fn limiter(&self, rule: &RateLimitRule) -> Result<Arc<GovernorLimiter>, OpenAuthError> {
58 self.cleanup_if_due()?;
59 let quota = quota(rule)?;
60 let mut limiters = self
61 .limiters
62 .lock()
63 .map_err(|_| OpenAuthError::Api("rate limit store lock poisoned".to_owned()))?;
64 let key = (rule.window, rule.max);
65 if let Some(limiter) = limiters.get(&key) {
66 return Ok(Arc::clone(limiter));
67 }
68 let limiter =
69 Arc::new(RateLimiter::keyed(quota).with_middleware::<StateInformationMiddleware>());
70 limiters.insert(key, Arc::clone(&limiter));
71 Ok(limiter)
72 }
73
74 fn cleanup_if_due(&self) -> Result<(), OpenAuthError> {
75 let Some(interval) = self.cleanup_interval else {
76 return Ok(());
77 };
78
79 let mut last_cleanup = self
80 .last_cleanup
81 .lock()
82 .map_err(|_| OpenAuthError::Api("rate limit cleanup lock poisoned".to_owned()))?;
83 let now = Instant::now();
84 if last_cleanup
85 .as_ref()
86 .is_some_and(|last| last.elapsed() < interval)
87 {
88 return Ok(());
89 }
90 *last_cleanup = Some(now);
91 drop(last_cleanup);
92
93 let limiters = self
94 .limiters
95 .lock()
96 .map_err(|_| OpenAuthError::Api("rate limit store lock poisoned".to_owned()))?
97 .values()
98 .cloned()
99 .collect::<Vec<_>>();
100 for limiter in limiters {
101 limiter.retain_recent();
102 limiter.shrink_to_fit();
103 }
104 Ok(())
105 }
106}
107
108impl RateLimitStore for GovernorMemoryRateLimitStore {
109 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
110 Box::pin(async move {
111 let limiter = self.limiter(&input.rule)?;
112 match limiter.check_key(&input.key) {
113 Ok(snapshot) => Ok(RateLimitDecision {
114 permitted: true,
115 retry_after: 0,
116 limit: input.rule.max,
117 remaining: u64::from(snapshot.remaining_burst_capacity()),
118 reset_after: input.rule.window,
119 }),
120 Err(not_until) => {
121 let retry_after =
122 ceil_duration_to_seconds(not_until.wait_time_from(limiter.clock().now()));
123 Ok(RateLimitDecision {
124 permitted: false,
125 retry_after,
126 limit: input.rule.max,
127 remaining: 0,
128 reset_after: retry_after,
129 })
130 }
131 }
132 })
133 }
134}
135
136pub struct LegacyRateLimitStorageAdapter {
137 storage: Arc<dyn RateLimitStorage>,
138}
139
140impl LegacyRateLimitStorageAdapter {
141 pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
142 Self { storage }
143 }
144}
145
146impl RateLimitStore for LegacyRateLimitStorageAdapter {
147 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
148 Box::pin(async move {
149 let window_ms = input.rule.window.saturating_mul(1000) as i64;
150 let existing = self.storage.get(&input.key)?;
151 let decision = match existing {
152 Some(record)
153 if input.now_ms.saturating_sub(record.last_request) <= window_ms
154 && record.count >= input.rule.max =>
155 {
156 denied_decision(&input, record.last_request)
157 }
158 Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
159 let next_count = record.count.saturating_add(1);
160 self.storage.set(
161 &input.key,
162 RateLimitRecord {
163 key: input.key.clone(),
164 count: next_count,
165 last_request: input.now_ms,
166 },
167 input.rule.window,
168 true,
169 )?;
170 allowed_decision(&input, next_count)
171 }
172 _ => {
173 self.storage.set(
174 &input.key,
175 RateLimitRecord {
176 key: input.key.clone(),
177 count: 1,
178 last_request: input.now_ms,
179 },
180 input.rule.window,
181 existing.is_some(),
182 )?;
183 allowed_decision(&input, 1)
184 }
185 };
186 Ok(decision)
187 })
188 }
189}
190
191pub struct HybridRateLimitStore {
192 local: Arc<GovernorMemoryRateLimitStore>,
193 global: Arc<dyn RateLimitStore>,
194 local_multiplier: u64,
195}
196
197impl HybridRateLimitStore {
198 pub fn new(
199 local: Arc<GovernorMemoryRateLimitStore>,
200 global: Arc<dyn RateLimitStore>,
201 local_multiplier: u64,
202 ) -> Self {
203 Self {
204 local,
205 global,
206 local_multiplier: local_multiplier.max(1),
207 }
208 }
209}
210
211impl RateLimitStore for HybridRateLimitStore {
212 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
213 Box::pin(async move {
214 let local_input = RateLimitConsumeInput {
215 key: input.key.clone(),
216 rule: RateLimitRule {
217 window: input.rule.window,
218 max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
219 },
220 now_ms: input.now_ms,
221 };
222 let local = self.local.consume(local_input).await?;
223 if !local.permitted {
224 return Ok(local);
225 }
226 self.global.consume(input).await
227 })
228 }
229}
230
231pub async fn consume_rate_limit(
232 context: &AuthContext,
233 request: &Request<Body>,
234) -> Result<Option<RateLimitRejection>, OpenAuthError> {
235 if !context.rate_limit.enabled {
236 return Ok(None);
237 }
238 let Some(config) = resolve_config(context, request)? else {
239 return Ok(None);
240 };
241 let store = store(context);
242 let decision = store
243 .consume(RateLimitConsumeInput {
244 key: config.key,
245 rule: config.rule,
246 now_ms: now_millis(),
247 })
248 .await?;
249 if decision.permitted {
250 return Ok(None);
251 }
252 Ok(Some(RateLimitRejection {
253 retry_after: decision.retry_after,
254 }))
255}
256
257pub fn on_request_rate_limit(
258 context: &AuthContext,
259 request: &Request<Body>,
260) -> Result<Option<RateLimitRejection>, OpenAuthError> {
261 if !context.rate_limit.enabled {
262 return Ok(None);
263 }
264 if resolve_config(context, request)?.is_none() {
265 return Ok(None);
266 }
267 Err(OpenAuthError::Api(
268 "async rate limit storage requires AuthRouter::handle_async".to_owned(),
269 ))
270}
271
272pub fn on_response_rate_limit(
273 _context: &AuthContext,
274 _request: &Request<Body>,
275) -> Result<(), OpenAuthError> {
276 Ok(())
277}
278
279#[derive(Debug)]
280struct ResolvedRateLimit {
281 key: String,
282 rule: RateLimitRule,
283}
284
285fn resolve_config(
286 context: &AuthContext,
287 request: &Request<Body>,
288) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
289 let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
290 let Some(ip) = request_ip(context, request) else {
291 return Ok(None);
292 };
293 let Some(rule) = resolve_rule(context, request, &path)? else {
294 return Ok(None);
295 };
296 Ok(Some(ResolvedRateLimit {
297 key: create_rate_limit_key(&ip, &path),
298 rule,
299 }))
300}
301
302fn resolve_rule(
303 context: &AuthContext,
304 request: &Request<Body>,
305 path: &str,
306) -> Result<Option<RateLimitRule>, OpenAuthError> {
307 let mut rule = default_rule(context);
308 if let Some(special_rule) = default_special_rule(path) {
309 rule = special_rule;
310 }
311 for plugin_rule in &context.rate_limit.plugin_rules {
312 if path_matches(&plugin_rule.path, path) {
313 rule = plugin_rule.rule.clone();
314 break;
315 }
316 }
317 for custom_rule in &context.rate_limit.custom_rules {
318 if path_matches(&custom_rule.path, path) {
319 return Ok(custom_rule.rule.clone());
320 }
321 }
322 for dynamic_rule in &context.rate_limit.dynamic_rules {
323 if path_matches(&dynamic_rule.path, path) {
324 return dynamic_rule.provider.resolve(request, &rule);
325 }
326 }
327 Ok(Some(rule))
328}
329
330fn default_rule(context: &AuthContext) -> RateLimitRule {
331 RateLimitRule {
332 window: context.rate_limit.window,
333 max: context.rate_limit.max,
334 }
335}
336
337fn default_special_rule(path: &str) -> Option<RateLimitRule> {
338 if path.starts_with("/sign-in")
339 || path.starts_with("/sign-up")
340 || path.starts_with("/change-password")
341 || path.starts_with("/change-email")
342 {
343 return Some(RateLimitRule { window: 10, max: 3 });
344 }
345 if path == "/request-password-reset"
346 || path == "/send-verification-email"
347 || path.starts_with("/forget-password")
348 || path == "/email-otp/send-verification-otp"
349 || path == "/email-otp/request-password-reset"
350 {
351 return Some(RateLimitRule { window: 60, max: 3 });
352 }
353 None
354}
355
356fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
357 if context.options.advanced.ip_address.disable_ip_tracking {
358 return None;
359 }
360
361 for header_name in &context.options.advanced.ip_address.headers {
362 if let Some(value) = request
363 .headers()
364 .get(header_name)
365 .and_then(|value| value.to_str().ok())
366 {
367 let Some(candidate) = value.split(',').next().map(str::trim) else {
368 continue;
369 };
370 if is_valid_ip(candidate) {
371 return Some(normalize_ip_with_options(
372 candidate,
373 NormalizeIpOptions {
374 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
375 },
376 ));
377 }
378 }
379 }
380
381 if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
382 return Some(normalize_ip_with_options(
383 &client_ip.0.to_string(),
384 NormalizeIpOptions {
385 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
386 },
387 ));
388 }
389
390 if !context.options.production && !is_production() {
391 return Some("127.0.0.1".to_owned());
392 }
393
394 None
395}
396
397fn store(context: &AuthContext) -> Arc<dyn RateLimitStore> {
398 if let Some(store) = &context.rate_limit.custom_store {
399 if context.rate_limit.hybrid.enabled {
400 return Arc::new(HybridRateLimitStore::new(
401 Arc::clone(&context.rate_limit.memory_store),
402 Arc::clone(store),
403 context.rate_limit.hybrid.local_multiplier,
404 ));
405 }
406 return Arc::clone(store);
407 }
408 match context.rate_limit.storage {
409 RateLimitStorageOption::Memory
410 | RateLimitStorageOption::Database
411 | RateLimitStorageOption::SecondaryStorage => context.rate_limit.memory_store.clone(),
412 }
413}
414
415fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
416 RateLimitDecision {
417 permitted: true,
418 retry_after: 0,
419 limit: input.rule.max,
420 remaining: input.rule.max.saturating_sub(count),
421 reset_after: input.rule.window,
422 }
423}
424
425fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
426 let retry_after = last_request
427 .saturating_add(input.rule.window.saturating_mul(1000) as i64)
428 .saturating_sub(input.now_ms)
429 .max(0);
430 RateLimitDecision {
431 permitted: false,
432 retry_after: ceil_millis_to_seconds(retry_after),
433 limit: input.rule.max,
434 remaining: 0,
435 reset_after: ceil_millis_to_seconds(retry_after),
436 }
437}
438
439fn ceil_duration_to_seconds(duration: std::time::Duration) -> u64 {
440 if duration.is_zero() {
441 return 0;
442 }
443 duration
444 .as_secs()
445 .saturating_add(u64::from(duration.subsec_nanos() > 0))
446}
447
448fn quota(rule: &RateLimitRule) -> Result<Quota, OpenAuthError> {
449 if rule.window == 0 {
450 return Err(OpenAuthError::InvalidConfig(
451 "rate limit window must be greater than zero".to_owned(),
452 ));
453 }
454 let max =
455 NonZeroU32::new(u32::try_from(rule.max).map_err(|_| {
456 OpenAuthError::InvalidConfig("rate limit max must fit in u32".to_owned())
457 })?)
458 .ok_or_else(|| {
459 OpenAuthError::InvalidConfig("rate limit max must be greater than zero".to_owned())
460 })?;
461 let mut replenish_interval = Duration::from_secs(rule.window) / max.get();
462 if replenish_interval.is_zero() {
463 replenish_interval = Duration::from_nanos(1);
464 }
465 Quota::with_period(replenish_interval)
466 .map(|quota| quota.allow_burst(max))
467 .ok_or_else(|| {
468 OpenAuthError::InvalidConfig(
469 "rate limit replenish interval must be greater than zero".to_owned(),
470 )
471 })
472}
473
474fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
475 if milliseconds <= 0 {
476 return 0;
477 }
478 ((milliseconds as u64).saturating_add(999)) / 1000
479}
480
481fn path_matches(pattern: &str, path: &str) -> bool {
482 if let Some((prefix, suffix)) = pattern.split_once('*') {
483 return path.starts_with(prefix) && path.ends_with(suffix);
484 }
485 pattern == path
486}
487
488fn now_millis() -> i64 {
489 time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
490}