Skip to main content

ferrex_model/
rate_limit.rs

1use std::time::Duration;
2
3use uuid::Uuid;
4
5#[cfg(feature = "serde")]
6use serde::{Deserialize, Serialize};
7
8/// Rate limiting algorithm type.
9#[derive(Debug, Clone, Copy, PartialEq)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub enum RateLimitAlgorithm {
12    /// Sliding window log algorithm (most accurate).
13    SlidingWindowLog,
14    /// Token bucket algorithm (allows bursts).
15    TokenBucket,
16    /// Fixed window counter (simplest).
17    FixedWindow,
18    /// Leaky bucket algorithm (smooth rate).
19    LeakyBucket,
20}
21
22/// Identifier for rate limiting (IP, user_id, device_id, etc.).
23#[derive(Debug, Clone, Hash, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
25pub enum RateLimitKey {
26    /// IP address based limiting.
27    IpAddress(String),
28    /// User ID based limiting.
29    UserId(Uuid),
30    /// Device ID based limiting.
31    DeviceId(Uuid),
32    /// Combined key for more granular control.
33    Combined {
34        ip: Option<String>,
35        user_id: Option<Uuid>,
36        device_id: Option<Uuid>,
37    },
38    /// Custom key for flexibility.
39    Custom(String),
40}
41
42impl RateLimitKey {
43    /// Create a cache key for Redis.
44    pub fn to_cache_key(&self, namespace: &str) -> String {
45        match self {
46            Self::IpAddress(ip) => format!("{}:ip:{}", namespace, ip),
47            Self::UserId(id) => format!("{}:user:{}", namespace, id),
48            Self::DeviceId(id) => format!("{}:device:{}", namespace, id),
49            Self::Combined {
50                ip,
51                user_id,
52                device_id,
53            } => {
54                let parts: Vec<String> = vec![
55                    ip.as_ref().map(|i| format!("ip:{}", i)),
56                    user_id.map(|u| format!("user:{}", u)),
57                    device_id.map(|d| format!("device:{}", d)),
58                ]
59                .into_iter()
60                .flatten()
61                .collect();
62                format!("{}:combined:{}", namespace, parts.join(":"))
63            }
64            Self::Custom(key) => format!("{}:custom:{}", namespace, key),
65        }
66    }
67}
68
69/// Configuration for a single rate limiting rule.
70#[derive(Debug, Clone)]
71#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
72#[cfg_attr(feature = "serde", serde(default))]
73pub struct RateLimitRule {
74    /// Name of the rule for identification.
75    pub name: String,
76    /// Algorithm to use.
77    pub algorithm: RateLimitAlgorithm,
78    /// Maximum number of requests allowed.
79    pub limit: u32,
80    /// Time window for the limit.
81    pub window: Duration,
82    /// Whether to apply exponential backoff on violations.
83    pub exponential_backoff: bool,
84    /// Base duration for backoff calculation.
85    pub backoff_base: Duration,
86    /// Maximum backoff duration.
87    pub max_backoff: Duration,
88    /// Number of violations before applying stricter measures.
89    pub violation_threshold: u32,
90}
91
92impl Default for RateLimitRule {
93    fn default() -> Self {
94        Self {
95            name: "default".to_string(),
96            algorithm: RateLimitAlgorithm::SlidingWindowLog,
97            limit: 10,
98            window: Duration::from_secs(60),
99            exponential_backoff: true,
100            backoff_base: Duration::from_secs(60),
101            max_backoff: Duration::from_secs(3600),
102            violation_threshold: 3,
103        }
104    }
105}
106
107/// Endpoint-specific rate limit configuration.
108#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
110#[cfg_attr(feature = "serde", serde(default))]
111pub struct EndpointLimits {
112    /// Login endpoint limits.
113    pub login: RateLimitRule,
114    /// Registration endpoint limits.
115    pub register: RateLimitRule,
116    /// Password reset limits.
117    pub password_reset: RateLimitRule,
118    /// PIN authentication limits.
119    pub pin_auth: RateLimitRule,
120    /// Device registration limits.
121    pub device_register: RateLimitRule,
122    /// Token refresh limits.
123    pub token_refresh: RateLimitRule,
124    /// Setup claim start limits (LAN-only endpoint).
125    pub setup_start: RateLimitRule,
126    /// Setup claim confirm limits (LAN-only endpoint).
127    pub setup_confirm: RateLimitRule,
128    /// Setup create admin limits.
129    pub setup_create_admin: RateLimitRule,
130}
131
132impl Default for EndpointLimits {
133    fn default() -> Self {
134        Self {
135            login: RateLimitRule {
136                name: "login".to_string(),
137                limit: 5,
138                window: Duration::from_secs(300),
139                violation_threshold: 3,
140                ..Default::default()
141            },
142            register: RateLimitRule {
143                name: "register".to_string(),
144                limit: 3,
145                window: Duration::from_secs(3600),
146                violation_threshold: 2,
147                ..Default::default()
148            },
149            password_reset: RateLimitRule {
150                name: "password_reset".to_string(),
151                limit: 3,
152                window: Duration::from_secs(3600),
153                violation_threshold: 2,
154                ..Default::default()
155            },
156            pin_auth: RateLimitRule {
157                name: "pin_auth".to_string(),
158                limit: 10,
159                window: Duration::from_secs(300),
160                violation_threshold: 5,
161                ..Default::default()
162            },
163            device_register: RateLimitRule {
164                name: "device_register".to_string(),
165                limit: 5,
166                window: Duration::from_secs(86400),
167                violation_threshold: 2,
168                ..Default::default()
169            },
170            token_refresh: RateLimitRule {
171                name: "token_refresh".to_string(),
172                limit: 100,
173                window: Duration::from_secs(3600),
174                exponential_backoff: false,
175                ..Default::default()
176            },
177            setup_start: RateLimitRule {
178                name: "setup_start".to_string(),
179                limit: 5,
180                window: Duration::from_secs(120),
181                violation_threshold: 3,
182                ..Default::default()
183            },
184            setup_confirm: RateLimitRule {
185                name: "setup_confirm".to_string(),
186                limit: 5,
187                window: Duration::from_secs(120),
188                violation_threshold: 3,
189                ..Default::default()
190            },
191            setup_create_admin: RateLimitRule {
192                name: "setup_create_admin".to_string(),
193                limit: 2,
194                window: Duration::from_secs(3600),
195                violation_threshold: 1,
196                ..Default::default()
197            },
198        }
199    }
200}
201
202/// Configuration for trusted sources that bypass rate limiting.
203#[derive(Debug, Clone)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct TrustedSources {
206    /// Trusted IP addresses or CIDR blocks.
207    pub ip_addresses: Vec<String>,
208    /// Trusted user IDs.
209    pub user_ids: Vec<Uuid>,
210    /// Trusted device IDs.
211    pub device_ids: Vec<Uuid>,
212}
213
214impl TrustedSources {
215    /// Check if a key is from a trusted source.
216    pub fn is_trusted(&self, key: &RateLimitKey) -> bool {
217        match key {
218            RateLimitKey::IpAddress(ip) => self.ip_addresses.contains(ip),
219            RateLimitKey::UserId(id) => self.user_ids.contains(id),
220            RateLimitKey::DeviceId(id) => self.device_ids.contains(id),
221            RateLimitKey::Combined {
222                ip,
223                user_id,
224                device_id,
225            } => {
226                ip.as_ref()
227                    .map(|i| self.ip_addresses.contains(i))
228                    .unwrap_or(false)
229                    || user_id
230                        .map(|u| self.user_ids.contains(&u))
231                        .unwrap_or(false)
232                    || device_id
233                        .map(|d| self.device_ids.contains(&d))
234                        .unwrap_or(false)
235            }
236            RateLimitKey::Custom(_) => false,
237        }
238    }
239}