1use std::time::Duration;
2
3use uuid::Uuid;
4
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub enum RateLimitAlgorithm {
12 SlidingWindowLog,
14 TokenBucket,
16 FixedWindow,
18 LeakyBucket,
20}
21
22#[derive(Debug, Clone, Hash, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25pub enum RateLimitKey {
26 IpAddress(String),
28 UserId(Uuid),
30 DeviceId(Uuid),
32 Combined {
34 ip: Option<String>,
35 user_id: Option<Uuid>,
36 device_id: Option<Uuid>,
37 },
38 Custom(String),
40}
41
42impl RateLimitKey {
43 pub fn to_cache_key(&self, namespace: &str) -> String {
45 match self {
46 Self::IpAddress(ip) => format!("{}:ip:{}", namespace, ip),
47 Self::UserId(id) => format!("{}:user:{}", namespace, id),
48 Self::DeviceId(id) => format!("{}:device:{}", namespace, id),
49 Self::Combined {
50 ip,
51 user_id,
52 device_id,
53 } => {
54 let parts: Vec<String> = vec![
55 ip.as_ref().map(|i| format!("ip:{}", i)),
56 user_id.map(|u| format!("user:{}", u)),
57 device_id.map(|d| format!("device:{}", d)),
58 ]
59 .into_iter()
60 .flatten()
61 .collect();
62 format!("{}:combined:{}", namespace, parts.join(":"))
63 }
64 Self::Custom(key) => format!("{}:custom:{}", namespace, key),
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72#[cfg_attr(feature = "serde", serde(default))]
73pub struct RateLimitRule {
74 pub name: String,
76 pub algorithm: RateLimitAlgorithm,
78 pub limit: u32,
80 pub window: Duration,
82 pub exponential_backoff: bool,
84 pub backoff_base: Duration,
86 pub max_backoff: Duration,
88 pub violation_threshold: u32,
90}
91
92impl Default for RateLimitRule {
93 fn default() -> Self {
94 Self {
95 name: "default".to_string(),
96 algorithm: RateLimitAlgorithm::SlidingWindowLog,
97 limit: 10,
98 window: Duration::from_secs(60),
99 exponential_backoff: true,
100 backoff_base: Duration::from_secs(60),
101 max_backoff: Duration::from_secs(3600),
102 violation_threshold: 3,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
110#[cfg_attr(feature = "serde", serde(default))]
111pub struct EndpointLimits {
112 pub login: RateLimitRule,
114 pub register: RateLimitRule,
116 pub password_reset: RateLimitRule,
118 pub pin_auth: RateLimitRule,
120 pub device_register: RateLimitRule,
122 pub token_refresh: RateLimitRule,
124 pub setup_start: RateLimitRule,
126 pub setup_confirm: RateLimitRule,
128 pub setup_create_admin: RateLimitRule,
130}
131
132impl Default for EndpointLimits {
133 fn default() -> Self {
134 Self {
135 login: RateLimitRule {
136 name: "login".to_string(),
137 limit: 5,
138 window: Duration::from_secs(300),
139 violation_threshold: 3,
140 ..Default::default()
141 },
142 register: RateLimitRule {
143 name: "register".to_string(),
144 limit: 3,
145 window: Duration::from_secs(3600),
146 violation_threshold: 2,
147 ..Default::default()
148 },
149 password_reset: RateLimitRule {
150 name: "password_reset".to_string(),
151 limit: 3,
152 window: Duration::from_secs(3600),
153 violation_threshold: 2,
154 ..Default::default()
155 },
156 pin_auth: RateLimitRule {
157 name: "pin_auth".to_string(),
158 limit: 10,
159 window: Duration::from_secs(300),
160 violation_threshold: 5,
161 ..Default::default()
162 },
163 device_register: RateLimitRule {
164 name: "device_register".to_string(),
165 limit: 5,
166 window: Duration::from_secs(86400),
167 violation_threshold: 2,
168 ..Default::default()
169 },
170 token_refresh: RateLimitRule {
171 name: "token_refresh".to_string(),
172 limit: 100,
173 window: Duration::from_secs(3600),
174 exponential_backoff: false,
175 ..Default::default()
176 },
177 setup_start: RateLimitRule {
178 name: "setup_start".to_string(),
179 limit: 5,
180 window: Duration::from_secs(120),
181 violation_threshold: 3,
182 ..Default::default()
183 },
184 setup_confirm: RateLimitRule {
185 name: "setup_confirm".to_string(),
186 limit: 5,
187 window: Duration::from_secs(120),
188 violation_threshold: 3,
189 ..Default::default()
190 },
191 setup_create_admin: RateLimitRule {
192 name: "setup_create_admin".to_string(),
193 limit: 2,
194 window: Duration::from_secs(3600),
195 violation_threshold: 1,
196 ..Default::default()
197 },
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct TrustedSources {
206 pub ip_addresses: Vec<String>,
208 pub user_ids: Vec<Uuid>,
210 pub device_ids: Vec<Uuid>,
212}
213
214impl TrustedSources {
215 pub fn is_trusted(&self, key: &RateLimitKey) -> bool {
217 match key {
218 RateLimitKey::IpAddress(ip) => self.ip_addresses.contains(ip),
219 RateLimitKey::UserId(id) => self.user_ids.contains(id),
220 RateLimitKey::DeviceId(id) => self.device_ids.contains(id),
221 RateLimitKey::Combined {
222 ip,
223 user_id,
224 device_id,
225 } => {
226 ip.as_ref()
227 .map(|i| self.ip_addresses.contains(i))
228 .unwrap_or(false)
229 || user_id
230 .map(|u| self.user_ids.contains(&u))
231 .unwrap_or(false)
232 || device_id
233 .map(|d| self.device_ids.contains(&d))
234 .unwrap_or(false)
235 }
236 RateLimitKey::Custom(_) => false,
237 }
238 }
239}