skp_ratelimit/
manager.rs

1//! Rate limit manager for per-route configuration.
2//!
3//! The `RateLimitManager` allows you to configure different rate limits
4//! for different routes or patterns, with optional default fallback.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use oc_ratelimit_advanced::{RateLimitManager, Quota, GCRA, MemoryStorage};
10//!
11//! let storage = MemoryStorage::new();
12//! let manager = RateLimitManager::builder()
13//!     .default_quota(Quota::per_second(10))
14//!     .route("/api/search", Quota::per_minute(30))
15//!     .route("/api/auth/login", Quota::per_minute(5))
16//!     .route_pattern("/api/users/*", Quota::per_second(20))
17//!     .build(GCRA::new(), storage);
18//! ```
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use crate::algorithm::Algorithm;
24use crate::decision::Decision;
25use crate::error::Result;
26use crate::key::Key;
27use crate::quota::Quota;
28use crate::storage::Storage;
29
30/// A rate limit configuration for a specific route.
31#[derive(Debug, Clone)]
32pub struct RouteConfig {
33    /// The quota for this route.
34    pub quota: Quota,
35    /// Optional custom key suffix.
36    pub key_suffix: Option<String>,
37}
38
39impl RouteConfig {
40    /// Create a new route config with the given quota.
41    pub fn new(quota: Quota) -> Self {
42        Self {
43            quota,
44            key_suffix: None,
45        }
46    }
47
48    /// Add a custom key suffix.
49    pub fn with_key_suffix(mut self, suffix: impl Into<String>) -> Self {
50        self.key_suffix = Some(suffix.into());
51        self
52    }
53}
54
55impl From<Quota> for RouteConfig {
56    fn from(quota: Quota) -> Self {
57        Self::new(quota)
58    }
59}
60
61/// Manager for per-route rate limiting.
62///
63/// This provides a centralized way to configure different rate limits
64/// for different routes or patterns.
65pub struct RateLimitManager<A, S, K> {
66    algorithm: A,
67    storage: Arc<S>,
68    key_extractor: K,
69    default_quota: Option<Quota>,
70    routes: HashMap<String, RouteConfig>,
71    patterns: Vec<(String, RouteConfig)>,
72}
73
74impl<A, S, K> RateLimitManager<A, S, K>
75where
76    A: Algorithm,
77    S: Storage,
78{
79    /// Create a new rate limit manager builder.
80    pub fn builder() -> RateLimitManagerBuilder<K> {
81        RateLimitManagerBuilder::new()
82    }
83
84    /// Check and record a request.
85    pub async fn check_and_record<R>(&self, path: &str, request: &R) -> Result<Decision>
86    where
87        K: Key<R>,
88    {
89        let config = self.get_config(path);
90
91        let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
92            // No quota configured, allow the request
93            return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
94                u64::MAX,
95                u64::MAX,
96                std::time::Instant::now() + std::time::Duration::from_secs(3600),
97                std::time::Instant::now(),
98            )));
99        };
100
101        // Build the key
102        let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
103        let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
104            format!("{}:{}", base_key, suffix)
105        } else {
106            format!("{}:{}", base_key, path)
107        };
108
109        self.algorithm
110            .check_and_record(&*self.storage, &key, quota)
111            .await
112    }
113
114    /// Check without recording.
115    pub async fn check<R>(&self, path: &str, request: &R) -> Result<Decision>
116    where
117        K: Key<R>,
118    {
119        let config = self.get_config(path);
120
121        let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
122            return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
123                u64::MAX,
124                u64::MAX,
125                std::time::Instant::now() + std::time::Duration::from_secs(3600),
126                std::time::Instant::now(),
127            )));
128        };
129
130        let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
131        let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
132            format!("{}:{}", base_key, suffix)
133        } else {
134            format!("{}:{}", base_key, path)
135        };
136
137        self.algorithm.check(&*self.storage, &key, quota).await
138    }
139
140    /// Get the configuration for a path.
141    fn get_config(&self, path: &str) -> Option<&RouteConfig> {
142        // Exact match first
143        if let Some(config) = self.routes.get(path) {
144            return Some(config);
145        }
146
147        // Pattern matching
148        for (pattern, config) in &self.patterns {
149            if pattern_matches(pattern, path) {
150                return Some(config);
151            }
152        }
153
154        None
155    }
156
157    /// Reset rate limit for a specific key.
158    pub async fn reset(&self, key: &str) -> Result<()> {
159        self.algorithm.reset(&*self.storage, key).await
160    }
161}
162
163/// Check if a pattern matches a path.
164///
165/// Simple glob-style matching:
166/// - `*` matches any single path segment
167/// - `**` matches any number of segments
168fn pattern_matches(pattern: &str, path: &str) -> bool {
169    let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
170    let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
171
172    let mut pi = 0; // pattern index
173    let mut pa = 0; // path index
174
175    while pi < pattern_parts.len() && pa < path_parts.len() {
176        let p = pattern_parts[pi];
177
178        if p == "**" {
179            // ** matches rest of path
180            return true;
181        } else if p == "*" {
182            // * matches single segment
183            pi += 1;
184            pa += 1;
185        } else if p == path_parts[pa] {
186            // Exact match
187            pi += 1;
188            pa += 1;
189        } else {
190            return false;
191        }
192    }
193
194    // Pattern exhausted - check if path is also exhausted
195    pi == pattern_parts.len() && pa == path_parts.len()
196}
197
198/// Builder for RateLimitManager.
199pub struct RateLimitManagerBuilder<K> {
200    default_quota: Option<Quota>,
201    routes: HashMap<String, RouteConfig>,
202    patterns: Vec<(String, RouteConfig)>,
203    key_extractor: Option<K>,
204}
205
206impl<K> Default for RateLimitManagerBuilder<K> {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl<K> RateLimitManagerBuilder<K> {
213    /// Create a new builder.
214    pub fn new() -> Self {
215        Self {
216            default_quota: None,
217            routes: HashMap::new(),
218            patterns: Vec::new(),
219            key_extractor: None,
220        }
221    }
222
223    /// Set the default quota for routes without specific configuration.
224    pub fn default_quota(mut self, quota: Quota) -> Self {
225        self.default_quota = Some(quota);
226        self
227    }
228
229    /// Add a rate limit for a specific route.
230    pub fn route(mut self, path: impl Into<String>, config: impl Into<RouteConfig>) -> Self {
231        self.routes.insert(path.into(), config.into());
232        self
233    }
234
235    /// Add a rate limit for a route pattern.
236    ///
237    /// Patterns support `*` for single segment and `**` for multiple segments.
238    pub fn route_pattern(
239        mut self,
240        pattern: impl Into<String>,
241        config: impl Into<RouteConfig>,
242    ) -> Self {
243        self.patterns.push((pattern.into(), config.into()));
244        self
245    }
246
247    /// Set the key extractor.
248    pub fn key_extractor(mut self, extractor: K) -> Self {
249        self.key_extractor = Some(extractor);
250        self
251    }
252
253    /// Build the manager with the given algorithm and storage.
254    pub fn build<A, S>(self, algorithm: A, storage: S) -> RateLimitManager<A, S, K>
255    where
256        K: Default,
257    {
258        RateLimitManager {
259            algorithm,
260            storage: Arc::new(storage),
261            key_extractor: self.key_extractor.unwrap_or_default(),
262            default_quota: self.default_quota,
263            routes: self.routes,
264            patterns: self.patterns,
265        }
266    }
267
268    /// Build the manager with a specific key extractor.
269    pub fn build_with_key<A, S>(
270        self,
271        algorithm: A,
272        storage: S,
273        key_extractor: K,
274    ) -> RateLimitManager<A, S, K> {
275        RateLimitManager {
276            algorithm,
277            storage: Arc::new(storage),
278            key_extractor,
279            default_quota: self.default_quota,
280            routes: self.routes,
281            patterns: self.patterns,
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_pattern_matches_exact() {
292        assert!(pattern_matches("/api/users", "/api/users"));
293        assert!(!pattern_matches("/api/users", "/api/posts"));
294    }
295
296    #[test]
297    fn test_pattern_matches_single_wildcard() {
298        assert!(pattern_matches("/api/*/posts", "/api/users/posts"));
299        assert!(pattern_matches("/api/*/posts", "/api/admins/posts"));
300        assert!(!pattern_matches("/api/*/posts", "/api/users/comments"));
301    }
302
303    #[test]
304    fn test_pattern_matches_double_wildcard() {
305        assert!(pattern_matches("/api/**", "/api/users"));
306        assert!(pattern_matches("/api/**", "/api/users/123/posts"));
307        assert!(!pattern_matches("/api/**", "/v2/api/users"));
308    }
309
310    #[test]
311    fn test_route_config_from_quota() {
312        let config: RouteConfig = Quota::per_minute(60).into();
313        assert_eq!(config.quota.max_requests(), 60);
314        assert!(config.key_suffix.is_none());
315    }
316}