1use crate::context::AuthContext;
4use crate::env::is_production;
5use crate::error::OpenAuthError;
6use crate::options::{RateLimitRecord, RateLimitRule, RateLimitStorage, RateLimitStorageOption};
7use crate::utils::ip::{
8 create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
9};
10use crate::utils::url::normalize_pathname;
11use http::Request;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15pub type Body = Vec<u8>;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct RateLimitRejection {
19 pub retry_after: u64,
20}
21
22#[derive(Debug, Default)]
23pub struct MemoryRateLimitStorage {
24 entries: Mutex<HashMap<String, MemoryRateLimitEntry>>,
25}
26
27#[derive(Debug, Clone)]
28struct MemoryRateLimitEntry {
29 record: RateLimitRecord,
30 expires_at: i64,
31}
32
33impl MemoryRateLimitStorage {
34 pub fn new() -> Self {
35 Self::default()
36 }
37}
38
39impl RateLimitStorage for MemoryRateLimitStorage {
40 fn get(&self, key: &str) -> Result<Option<RateLimitRecord>, OpenAuthError> {
41 let now = now_seconds();
42 let mut entries = self
43 .entries
44 .lock()
45 .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
46 let Some(entry) = entries.get(key) else {
47 return Ok(None);
48 };
49 if now >= entry.expires_at {
50 entries.remove(key);
51 return Ok(None);
52 }
53 Ok(Some(entry.record.clone()))
54 }
55
56 fn set(
57 &self,
58 key: &str,
59 value: RateLimitRecord,
60 ttl_seconds: u64,
61 _update: bool,
62 ) -> Result<(), OpenAuthError> {
63 let expires_at = now_seconds().saturating_add(ttl_seconds as i64);
64 let mut entries = self
65 .entries
66 .lock()
67 .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
68 entries.insert(
69 key.to_owned(),
70 MemoryRateLimitEntry {
71 record: value,
72 expires_at,
73 },
74 );
75 Ok(())
76 }
77}
78
79pub fn on_request_rate_limit(
80 context: &AuthContext,
81 request: &Request<Body>,
82) -> Result<Option<RateLimitRejection>, OpenAuthError> {
83 if !context.rate_limit.enabled {
84 return Ok(None);
85 }
86 let Some(config) = resolve_config(context, request)? else {
87 return Ok(None);
88 };
89 let storage = storage(context);
90 let Some(record) = storage.get(&config.key)? else {
91 return Ok(None);
92 };
93
94 if should_rate_limit(config.rule.max, config.rule.window, &record) {
95 return Ok(Some(RateLimitRejection {
96 retry_after: retry_after(record.last_request, config.rule.window),
97 }));
98 }
99
100 Ok(None)
101}
102
103pub fn on_response_rate_limit(
104 context: &AuthContext,
105 request: &Request<Body>,
106) -> Result<(), OpenAuthError> {
107 if !context.rate_limit.enabled {
108 return Ok(());
109 }
110 let Some(config) = resolve_config(context, request)? else {
111 return Ok(());
112 };
113 let storage = storage(context);
114 let now = now_seconds();
115 let next_record = match storage.get(&config.key)? {
116 Some(record) if now.saturating_sub(record.last_request) <= config.rule.window as i64 => {
117 RateLimitRecord {
118 key: config.key.clone(),
119 count: record.count.saturating_add(1),
120 last_request: now,
121 }
122 }
123 _ => RateLimitRecord {
124 key: config.key.clone(),
125 count: 1,
126 last_request: now,
127 },
128 };
129
130 let update = next_record.count > 1;
131 storage.set(&config.key, next_record, config.rule.window, update)
132}
133
134#[derive(Debug)]
135struct ResolvedRateLimit {
136 key: String,
137 rule: RateLimitRule,
138}
139
140fn resolve_config(
141 context: &AuthContext,
142 request: &Request<Body>,
143) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
144 let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
145 let Some(ip) = request_ip(context, request) else {
146 return Ok(None);
147 };
148 let Some(rule) = resolve_rule(context, request, &path)? else {
149 return Ok(None);
150 };
151 Ok(Some(ResolvedRateLimit {
152 key: create_rate_limit_key(&ip, &path),
153 rule,
154 }))
155}
156
157fn resolve_rule(
158 context: &AuthContext,
159 request: &Request<Body>,
160 path: &str,
161) -> Result<Option<RateLimitRule>, OpenAuthError> {
162 let mut rule = default_rule(context);
163 if let Some(special_rule) = default_special_rule(path) {
164 rule = special_rule;
165 }
166 for plugin_rule in &context.rate_limit.plugin_rules {
167 if path_matches(&plugin_rule.path, path) {
168 rule = plugin_rule.rule.clone();
169 break;
170 }
171 }
172 for custom_rule in &context.rate_limit.custom_rules {
173 if path_matches(&custom_rule.path, path) {
174 return Ok(custom_rule.rule.clone());
175 }
176 }
177 for dynamic_rule in &context.rate_limit.dynamic_rules {
178 if path_matches(&dynamic_rule.path, path) {
179 return dynamic_rule.provider.resolve(request, &rule);
180 }
181 }
182 Ok(Some(rule))
183}
184
185fn default_rule(context: &AuthContext) -> RateLimitRule {
186 RateLimitRule {
187 window: context.rate_limit.window,
188 max: context.rate_limit.max,
189 }
190}
191
192fn default_special_rule(path: &str) -> Option<RateLimitRule> {
193 if path.starts_with("/sign-in")
194 || path.starts_with("/sign-up")
195 || path.starts_with("/change-password")
196 || path.starts_with("/change-email")
197 {
198 return Some(RateLimitRule { window: 10, max: 3 });
199 }
200 if path == "/request-password-reset"
201 || path == "/send-verification-email"
202 || path.starts_with("/forget-password")
203 || path == "/email-otp/send-verification-otp"
204 || path == "/email-otp/request-password-reset"
205 {
206 return Some(RateLimitRule { window: 60, max: 3 });
207 }
208 None
209}
210
211fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
212 if context.options.advanced.ip_address.disable_ip_tracking {
213 return None;
214 }
215
216 for header_name in &context.options.advanced.ip_address.headers {
217 if let Some(value) = request
218 .headers()
219 .get(header_name)
220 .and_then(|value| value.to_str().ok())
221 {
222 let Some(candidate) = value.split(',').next().map(str::trim) else {
223 continue;
224 };
225 if is_valid_ip(candidate) {
226 return Some(normalize_ip_with_options(
227 candidate,
228 NormalizeIpOptions {
229 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
230 },
231 ));
232 }
233 }
234 }
235
236 if !context.options.production && !is_production() {
237 return Some("127.0.0.1".to_owned());
238 }
239
240 None
241}
242
243fn storage(context: &AuthContext) -> &dyn RateLimitStorage {
244 if let Some(storage) = &context.rate_limit.custom_storage {
245 return storage.as_ref();
246 }
247 match context.rate_limit.storage {
248 RateLimitStorageOption::Memory
249 | RateLimitStorageOption::Database
250 | RateLimitStorageOption::SecondaryStorage => context.rate_limit.memory_storage.as_ref(),
251 }
252}
253
254fn should_rate_limit(max: u64, window: u64, record: &RateLimitRecord) -> bool {
255 let time_since_last_request = now_seconds().saturating_sub(record.last_request);
256 time_since_last_request < window as i64 && record.count >= max
257}
258
259fn retry_after(last_request: i64, window: u64) -> u64 {
260 let retry_after = last_request
261 .saturating_add(window as i64)
262 .saturating_sub(now_seconds());
263 retry_after.max(0) as u64
264}
265
266fn path_matches(pattern: &str, path: &str) -> bool {
267 if let Some((prefix, suffix)) = pattern.split_once('*') {
268 return path.starts_with(prefix) && path.ends_with(suffix);
269 }
270 pattern == path
271}
272
273fn now_seconds() -> i64 {
274 time::OffsetDateTime::now_utc().unix_timestamp()
275}