Skip to main content

oxidite_middleware/
rate_limit.rs

1use oxidite_db::Database;
2use std::sync::Arc;
3use std::collections::HashMap;
4use tokio::sync::Mutex;
5
6/// Rate limit configuration
7#[derive(Clone, Debug)]
8pub struct RateLimitConfig {
9    pub requests_per_minute: u32,
10    pub requests_per_hour: Option<u32>,
11}
12
13impl Default for RateLimitConfig {
14    fn default() -> Self {
15        Self {
16            requests_per_minute: 60,
17            requests_per_hour: Some(1000),
18        }
19    }
20}
21
22/// In-memory rate limiter with sliding window
23pub struct RateLimiter {
24    db: Option<Arc<dyn Database>>,
25    config: RateLimitConfig,
26    // In-memory cache: identifier -> (timestamp, count)
27    cache: Arc<Mutex<HashMap<String, Vec<i64>>>>,
28}
29
30impl RateLimiter {
31    pub fn new(config: RateLimitConfig) -> Self {
32        Self {
33            db: None,
34            config,
35            cache: Arc::new(Mutex::new(HashMap::new())),
36        }
37    }
38    
39    pub fn with_db(config: RateLimitConfig, db: Arc<dyn Database>) -> Self {
40        Self {
41            db: Some(db),
42            config,
43            cache: Arc::new(Mutex::new(HashMap::new())),
44        }
45    }
46    
47    /// Check if request is allowed (returns true if allowed)
48    pub async fn check(&self, identifier: &str, endpoint: &str) -> bool {
49        let now = chrono::Utc::now().timestamp();
50        let minute_ago = now - 60;
51        let hour_ago = now - 3600;
52        
53        // Use in-memory cache for performance
54        let mut cache = self.cache.lock().await;
55        let key = format!("{}:{}", identifier, endpoint);
56        
57        // Get timestamps for this identifier+endpoint
58        let timestamps = cache.entry(key.clone()).or_insert_with(Vec::new);
59        
60        // Remove timestamps older than 1 hour
61        timestamps.retain(|&ts| ts > hour_ago);
62        
63        // Count requests in last minute
64        let minute_count = timestamps.iter().filter(|&&ts| ts > minute_ago).count() as u32;
65        
66        // Check minute limit
67        if minute_count >= self.config.requests_per_minute {
68            return false;
69        }
70        
71        // Check hour limit if configured
72        if let Some(hour_limit) = self.config.requests_per_hour {
73            let hour_count = timestamps.len() as u32;
74            if hour_count >= hour_limit {
75                return false;
76            }
77        }
78        
79        // Request allowed - add timestamp
80        timestamps.push(now);
81        
82        // Persist to database if configured (async, don't wait)
83        if let Some(db) = &self.db {
84            let db_clone = db.clone();
85            let ident = identifier.to_string();
86            let ep = endpoint.to_string();
87            tokio::spawn(async move {
88                let _ = Self::record_request(&*db_clone, &ident, &ep).await;
89            });
90        }
91        
92        true
93    }
94    
95    /// Record request in database
96    async fn record_request(db: &dyn Database, identifier: &str, endpoint: &str) -> oxidite_db::Result<()> {
97        let now = chrono::Utc::now().timestamp();
98        let window_start = (now / 60) * 60; // Round to minute
99        
100        // Try to increment existing record
101        let update_query = format!(
102            "UPDATE rate_limits 
103             SET request_count = request_count + 1, updated_at = {}
104             WHERE identifier = '{}' AND endpoint = '{}' AND window_start = {}",
105            now, identifier, endpoint, window_start
106        );
107        
108        let rows = db.execute(&update_query).await?;
109        
110        // If no existing record, insert new one
111        if rows == 0 {
112            let insert_query = format!(
113                "INSERT INTO rate_limits (identifier, endpoint, request_count, window_start, created_at, updated_at)
114                 VALUES ('{}', '{}', 1, {}, {}, {})",
115                identifier, endpoint, window_start, now, now
116            );
117            db.execute(&insert_query).await?;
118        }
119        
120        Ok(())
121    }
122    
123    /// Get remaining requests for identifier
124    pub async fn get_remaining(&self, identifier: &str, endpoint: &str) -> u32 {
125        let now = chrono::Utc::now().timestamp();
126        let minute_ago = now - 60;
127        
128        let cache = self.cache.lock().await;
129        let key = format!("{}:{}", identifier, endpoint);
130        
131        if let Some(timestamps) = cache.get(&key) {
132            let minute_count = timestamps.iter().filter(|&&ts| ts > minute_ago).count() as u32;
133            self.config.requests_per_minute.saturating_sub(minute_count)
134        } else {
135            self.config.requests_per_minute
136        }
137    }
138    
139    /// Clean up old entries from cache (call periodically)
140    pub async fn cleanup(&self) {
141        let now = chrono::Utc::now().timestamp();
142        let hour_ago = now - 3600;
143        
144        let mut cache = self.cache.lock().await;
145        cache.retain(|_, timestamps| {
146            timestamps.retain(|&ts| ts > hour_ago);
147            !timestamps.is_empty()
148        });
149    }
150}