1use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{
7 RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitRecord, RateLimitRule,
8 RateLimitStorage, RateLimitStorageOption, RateLimitStore,
9};
10use crate::utils::ip::{
11 create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
12};
13use crate::utils::url::normalize_pathname;
14use http::Request;
15use std::collections::HashMap;
16use std::net::IpAddr;
17use std::sync::{Arc, Mutex};
18use std::time::{Duration, Instant};
19
20pub type Body = Vec<u8>;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct RateLimitRejection {
24 pub retry_after: u64,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub struct RequestClientIp(pub IpAddr);
30
31#[derive(Default)]
32pub struct GovernorMemoryRateLimitStore {
33 records: Mutex<HashMap<String, MemoryRateLimitRecord>>,
34 cleanup_interval: Option<Duration>,
35 last_cleanup: Mutex<Option<Instant>>,
36}
37
38#[derive(Debug, Clone)]
39struct MemoryRateLimitRecord {
40 count: u64,
41 last_request: i64,
42 window_ms: i64,
43}
44
45impl GovernorMemoryRateLimitStore {
46 pub fn new() -> Self {
47 Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
48 }
49
50 pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
51 Self {
52 records: Mutex::new(HashMap::new()),
53 cleanup_interval,
54 last_cleanup: Mutex::new(None),
55 }
56 }
57
58 fn cleanup_if_due(&self, now_ms: i64) -> Result<(), OpenAuthError> {
59 let Some(interval) = self.cleanup_interval else {
60 return Ok(());
61 };
62
63 let mut last_cleanup =
64 self.last_cleanup
65 .lock()
66 .map_err(|_| OpenAuthError::LockPoisoned {
67 context: "rate limit cleanup",
68 })?;
69 let now = Instant::now();
70 if last_cleanup
71 .as_ref()
72 .is_some_and(|last| last.elapsed() < interval)
73 {
74 return Ok(());
75 }
76 *last_cleanup = Some(now);
77 drop(last_cleanup);
78
79 self.records
80 .lock()
81 .map_err(|_| OpenAuthError::LockPoisoned {
82 context: "rate limit store",
83 })?
84 .retain(|_, record| now_ms.saturating_sub(record.last_request) <= record.window_ms);
85 Ok(())
86 }
87}
88
89impl RateLimitStore for GovernorMemoryRateLimitStore {
90 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
91 Box::pin(async move {
92 validate_rule(&input.rule)?;
93 self.cleanup_if_due(input.now_ms)?;
94 let window_ms = rule_window_ms(&input.rule)?;
95 let mut records = self
96 .records
97 .lock()
98 .map_err(|_| OpenAuthError::LockPoisoned {
99 context: "rate limit store",
100 })?;
101 let decision = match records.get_mut(&input.key) {
102 Some(record)
103 if input.now_ms.saturating_sub(record.last_request) <= window_ms
104 && record.count >= input.rule.max =>
105 {
106 denied_decision(&input, record.last_request)
107 }
108 Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
109 record.count = record.count.saturating_add(1);
110 record.last_request = input.now_ms;
111 record.window_ms = window_ms;
112 allowed_decision(&input, record.count)
113 }
114 _ => {
115 records.insert(
116 input.key.clone(),
117 MemoryRateLimitRecord {
118 count: 1,
119 last_request: input.now_ms,
120 window_ms,
121 },
122 );
123 allowed_decision(&input, 1)
124 }
125 };
126 Ok(decision)
127 })
128 }
129}
130
131pub struct LegacyRateLimitStorageAdapter {
132 storage: Arc<dyn RateLimitStorage>,
133}
134
135impl LegacyRateLimitStorageAdapter {
136 pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
137 Self { storage }
138 }
139}
140
141impl RateLimitStore for LegacyRateLimitStorageAdapter {
142 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
143 Box::pin(async move {
144 validate_rule(&input.rule)?;
145 let window_ms = rule_window_ms(&input.rule)?;
146 let existing = self.storage.get(&input.key)?;
147 let decision = match existing {
148 Some(record)
149 if input.now_ms.saturating_sub(record.last_request) <= window_ms
150 && record.count >= input.rule.max =>
151 {
152 denied_decision(&input, record.last_request)
153 }
154 Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
155 let next_count = record.count.saturating_add(1);
156 self.storage.set(
157 &input.key,
158 RateLimitRecord {
159 key: input.key.clone(),
160 count: next_count,
161 last_request: input.now_ms,
162 },
163 input.rule.window,
164 true,
165 )?;
166 allowed_decision(&input, next_count)
167 }
168 _ => {
169 self.storage.set(
170 &input.key,
171 RateLimitRecord {
172 key: input.key.clone(),
173 count: 1,
174 last_request: input.now_ms,
175 },
176 input.rule.window,
177 existing.is_some(),
178 )?;
179 allowed_decision(&input, 1)
180 }
181 };
182 Ok(decision)
183 })
184 }
185}
186
187pub struct HybridRateLimitStore {
188 local: Arc<GovernorMemoryRateLimitStore>,
189 global: Arc<dyn RateLimitStore>,
190 local_multiplier: u64,
191}
192
193impl HybridRateLimitStore {
194 pub fn new(
195 local: Arc<GovernorMemoryRateLimitStore>,
196 global: Arc<dyn RateLimitStore>,
197 local_multiplier: u64,
198 ) -> Self {
199 Self {
200 local,
201 global,
202 local_multiplier: local_multiplier.max(1),
203 }
204 }
205}
206
207impl RateLimitStore for HybridRateLimitStore {
208 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
209 Box::pin(async move {
210 let local_input = RateLimitConsumeInput {
211 key: input.key.clone(),
212 rule: RateLimitRule {
213 window: input.rule.window,
214 max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
215 },
216 now_ms: input.now_ms,
217 };
218 let local = self.local.consume(local_input).await?;
219 if !local.permitted {
220 return Ok(local);
221 }
222 self.global.consume(input).await
223 })
224 }
225}
226
227pub async fn consume_rate_limit(
228 context: &AuthContext,
229 request: &Request<Body>,
230) -> Result<Option<RateLimitRejection>, OpenAuthError> {
231 if !context.rate_limit.enabled {
232 return Ok(None);
233 }
234 let Some(config) = resolve_config(context, request)? else {
235 return Ok(None);
236 };
237 let store = store(context)?;
238 let decision = store
239 .consume(RateLimitConsumeInput {
240 key: config.key,
241 rule: config.rule,
242 now_ms: now_millis(),
243 })
244 .await?;
245 if decision.permitted {
246 return Ok(None);
247 }
248 Ok(Some(RateLimitRejection {
249 retry_after: decision.retry_after,
250 }))
251}
252
253pub fn on_request_rate_limit(
254 context: &AuthContext,
255 request: &Request<Body>,
256) -> Result<Option<RateLimitRejection>, OpenAuthError> {
257 if !context.rate_limit.enabled {
258 return Ok(None);
259 }
260 if resolve_config(context, request)?.is_none() {
261 return Ok(None);
262 }
263 Err(OpenAuthError::Api(
264 "async rate limit storage requires AuthRouter::handle_async".to_owned(),
265 ))
266}
267
268pub fn on_response_rate_limit(
269 _context: &AuthContext,
270 _request: &Request<Body>,
271) -> Result<(), OpenAuthError> {
272 Ok(())
273}
274
275#[derive(Debug)]
276struct ResolvedRateLimit {
277 key: String,
278 rule: RateLimitRule,
279}
280
281fn resolve_config(
282 context: &AuthContext,
283 request: &Request<Body>,
284) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
285 let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
286 let Some(ip) = request_ip(context, request) else {
287 return Ok(None);
288 };
289 let Some(rule) = resolve_rule(context, request, &path)? else {
290 return Ok(None);
291 };
292 Ok(Some(ResolvedRateLimit {
293 key: create_rate_limit_key(&ip, &path),
294 rule,
295 }))
296}
297
298fn resolve_rule(
299 context: &AuthContext,
300 request: &Request<Body>,
301 path: &str,
302) -> Result<Option<RateLimitRule>, OpenAuthError> {
303 let mut rule = default_rule(context);
304 if let Some(special_rule) = default_special_rule(path) {
305 rule = special_rule;
306 }
307 for plugin_rule in &context.rate_limit.plugin_rules {
308 if path_matches(&plugin_rule.path, path) {
309 rule = plugin_rule.rule.clone();
310 break;
311 }
312 }
313 for custom_rule in &context.rate_limit.custom_rules {
314 if path_matches(&custom_rule.path, path) {
315 return Ok(custom_rule.rule.clone());
316 }
317 }
318 for dynamic_rule in &context.rate_limit.dynamic_rules {
319 if path_matches(&dynamic_rule.path, path) {
320 return dynamic_rule.provider.resolve(request, &rule);
321 }
322 }
323 Ok(Some(rule))
324}
325
326fn default_rule(context: &AuthContext) -> RateLimitRule {
327 RateLimitRule {
328 window: context.rate_limit.window,
329 max: context.rate_limit.max,
330 }
331}
332
333fn default_special_rule(path: &str) -> Option<RateLimitRule> {
334 if path.starts_with("/sign-in")
335 || path.starts_with("/sign-up")
336 || path.starts_with("/change-password")
337 || path.starts_with("/change-email")
338 {
339 return Some(RateLimitRule { window: 10, max: 3 });
340 }
341 if path == "/request-password-reset"
342 || path == "/send-verification-email"
343 || path.starts_with("/forget-password")
344 || path == "/email-otp/send-verification-otp"
345 || path == "/email-otp/request-password-reset"
346 {
347 return Some(RateLimitRule { window: 60, max: 3 });
348 }
349 None
350}
351
352fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
353 if context.options.advanced.ip_address.disable_ip_tracking {
354 return None;
355 }
356
357 for header_name in &context.options.advanced.ip_address.headers {
358 if let Some(value) = request
359 .headers()
360 .get(header_name)
361 .and_then(|value| value.to_str().ok())
362 {
363 let Some(candidate) = value.split(',').next().map(str::trim) else {
364 continue;
365 };
366 if is_valid_ip(candidate) {
367 return Some(normalize_ip_with_options(
368 candidate,
369 NormalizeIpOptions {
370 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
371 },
372 ));
373 }
374 }
375 }
376
377 if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
378 return Some(normalize_ip_with_options(
379 &client_ip.0.to_string(),
380 NormalizeIpOptions {
381 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
382 },
383 ));
384 }
385
386 if !context.options.production && !is_production() {
387 return Some("127.0.0.1".to_owned());
388 }
389
390 None
391}
392
393fn store(context: &AuthContext) -> Result<Arc<dyn RateLimitStore>, OpenAuthError> {
394 if let Some(store) = &context.rate_limit.custom_store {
395 if context.rate_limit.hybrid.enabled {
396 return Ok(Arc::new(HybridRateLimitStore::new(
397 Arc::clone(&context.rate_limit.memory_store),
398 Arc::clone(store),
399 context.rate_limit.hybrid.local_multiplier,
400 )));
401 }
402 return Ok(Arc::clone(store));
403 }
404 match context.rate_limit.storage {
405 RateLimitStorageOption::Memory => Ok(context.rate_limit.memory_store.clone()),
406 RateLimitStorageOption::Database => Err(OpenAuthError::InvalidConfig(
407 "database rate limit storage requires a concrete RateLimitStore".to_owned(),
408 )),
409 RateLimitStorageOption::SecondaryStorage => Err(OpenAuthError::InvalidConfig(
410 "secondary-storage rate limit storage requires a concrete RateLimitStore".to_owned(),
411 )),
412 }
413}
414
415fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
416 RateLimitDecision {
417 permitted: true,
418 retry_after: 0,
419 limit: input.rule.max,
420 remaining: input.rule.max.saturating_sub(count),
421 reset_after: input.rule.window,
422 }
423}
424
425fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
426 let window_ms = i64::try_from(input.rule.window.saturating_mul(1000)).unwrap_or(i64::MAX);
427 let retry_after = last_request
428 .saturating_add(window_ms)
429 .saturating_sub(input.now_ms)
430 .max(0);
431 RateLimitDecision {
432 permitted: false,
433 retry_after: ceil_millis_to_seconds(retry_after),
434 limit: input.rule.max,
435 remaining: 0,
436 reset_after: ceil_millis_to_seconds(retry_after),
437 }
438}
439
440fn validate_rule(rule: &RateLimitRule) -> Result<(), OpenAuthError> {
441 if rule.window == 0 {
442 return Err(OpenAuthError::InvalidConfig(
443 "rate limit window must be greater than zero".to_owned(),
444 ));
445 }
446 if rule.max == 0 {
447 return Err(OpenAuthError::InvalidConfig(
448 "rate limit max must be greater than zero".to_owned(),
449 ));
450 }
451 Ok(())
452}
453
454fn rule_window_ms(rule: &RateLimitRule) -> Result<i64, OpenAuthError> {
455 let milliseconds = rule
456 .window
457 .checked_mul(1000)
458 .ok_or_else(|| OpenAuthError::InvalidConfig("rate limit window is too large".to_owned()))?;
459 i64::try_from(milliseconds)
460 .map_err(|_| OpenAuthError::InvalidConfig("rate limit window is too large".to_owned()))
461}
462
463fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
464 if milliseconds <= 0 {
465 return 0;
466 }
467 ((milliseconds as u64).saturating_add(999)) / 1000
468}
469
470fn path_matches(pattern: &str, path: &str) -> bool {
471 if let Some((prefix, suffix)) = pattern.split_once('*') {
472 return path.starts_with(prefix) && path.ends_with(suffix);
473 }
474 pattern == path
475}
476
477fn now_millis() -> i64 {
478 time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
479}