Skip to main content

mockforge_registry_server/
redis.rs

1//! Redis connection and utilities for rate limiting and caching
2
3use anyhow::Result;
4use chrono::Datelike;
5use std::sync::Arc;
6
7// Import redis crate types
8use redis::{aio::ConnectionManager, AsyncCommands, Client};
9
10/// Redis connection wrapper
11#[derive(Clone)]
12pub struct RedisPool {
13    manager: Arc<ConnectionManager>,
14}
15
16impl RedisPool {
17    /// Create a new Redis connection pool
18    pub async fn connect(redis_url: &str) -> Result<Self> {
19        let client = Client::open(redis_url)?;
20        let manager = ConnectionManager::new(client).await?;
21
22        Ok(Self {
23            manager: Arc::new(manager),
24        })
25    }
26
27    /// Get a connection for async operations
28    /// Note: ConnectionManager is already cloneable, so we can use it directly
29    pub fn get_connection(&self) -> Arc<ConnectionManager> {
30        self.manager.clone()
31    }
32
33    /// Increment a counter with expiration
34    /// Returns the new count after increment
35    pub async fn increment_with_expiry(&self, key: &str, expiry_seconds: u64) -> Result<i64> {
36        // ConnectionManager is already async-safe, we can use it directly
37        let mut conn = (*self.manager).clone();
38
39        // Use Redis pipeline for atomic increment + expiry
40        let count: i64 = conn.incr(key, 1).await?;
41
42        // Set expiry on first increment (count == 1)
43        if count == 1 {
44            conn.expire::<_, ()>(key, expiry_seconds as i64).await?;
45        }
46
47        Ok(count)
48    }
49
50    /// Get a counter value
51    pub async fn get_counter(&self, key: &str) -> Result<i64> {
52        let mut conn = (*self.manager).clone();
53        let count: i64 = conn.get(key).await.unwrap_or(0);
54        Ok(count)
55    }
56
57    /// Set a key with expiration
58    pub async fn set_with_expiry(&self, key: &str, value: &str, expiry_seconds: u64) -> Result<()> {
59        let mut conn = (*self.manager).clone();
60        conn.set_ex::<_, _, ()>(key, value, expiry_seconds).await?;
61        Ok(())
62    }
63
64    /// Get a key value
65    pub async fn get(&self, key: &str) -> Result<Option<String>> {
66        let mut conn = (*self.manager).clone();
67        let value: Option<String> = conn.get(key).await?;
68        Ok(value)
69    }
70
71    /// Delete a key
72    pub async fn delete(&self, key: &str) -> Result<()> {
73        let mut conn = (*self.manager).clone();
74        conn.del::<_, ()>(key).await?;
75        Ok(())
76    }
77
78    /// Scan for keys matching a glob pattern using Redis SCAN
79    pub async fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
80        let mut conn = (*self.manager).clone();
81        let mut cursor: u64 = 0;
82        let mut keys = Vec::new();
83
84        loop {
85            let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
86                .arg(cursor)
87                .arg("MATCH")
88                .arg(pattern)
89                .arg("COUNT")
90                .arg(100)
91                .query_async(&mut conn)
92                .await?;
93
94            keys.extend(batch);
95            cursor = next_cursor;
96
97            if cursor == 0 {
98                break;
99            }
100        }
101
102        Ok(keys)
103    }
104
105    /// Health check - verify Redis connectivity
106    pub async fn ping(&self) -> Result<()> {
107        let mut conn = (*self.manager).clone();
108        // Use the AsyncCommands trait method directly
109        let _: String = conn.get("__ping_test__").await.unwrap_or_else(|_| "PONG".to_string());
110        Ok(())
111    }
112}
113
114/// Generate Redis key for org usage counter
115pub fn org_usage_key(org_id: &uuid::Uuid, period: &str) -> String {
116    format!("usage:{}:{}", org_id, period)
117}
118
119/// Generate Redis key for org usage counter by type
120pub fn org_usage_key_by_type(org_id: &uuid::Uuid, period: &str, usage_type: &str) -> String {
121    format!("usage:{}:{}:{}", org_id, period, usage_type)
122}
123
124/// Generate Redis key for org rate limit
125pub fn org_rate_limit_key(org_id: &uuid::Uuid) -> String {
126    format!("ratelimit:{}", org_id)
127}
128
129/// Get current month period string (YYYY-MM)
130pub fn current_month_period() -> String {
131    let now = chrono::Utc::now();
132    format!("{}-{:02}", now.year(), now.month())
133}
134
135/// Generate Redis key for 2FA setup secret
136pub fn two_factor_setup_key(user_id: &uuid::Uuid) -> String {
137    format!("2fa_setup:{}", user_id)
138}
139
140/// Generate Redis key for 2FA backup codes (stored during setup)
141pub fn two_factor_backup_codes_key(user_id: &uuid::Uuid) -> String {
142    format!("2fa_backup_codes:{}", user_id)
143}
144
145/// TTL for 2FA setup secrets (5 minutes)
146pub const TWO_FACTOR_SETUP_TTL_SECONDS: u64 = 300;
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_org_usage_key() {
154        let org_id = uuid::Uuid::new_v4();
155        let period = "2025-01";
156
157        let key = org_usage_key(&org_id, period);
158
159        assert!(key.starts_with("usage:"));
160        assert!(key.contains(&org_id.to_string()));
161        assert!(key.contains(period));
162        assert_eq!(key, format!("usage:{}:{}", org_id, period));
163    }
164
165    #[test]
166    fn test_org_usage_key_by_type() {
167        let org_id = uuid::Uuid::new_v4();
168        let period = "2025-01";
169        let usage_type = "api_calls";
170
171        let key = org_usage_key_by_type(&org_id, period, usage_type);
172
173        assert!(key.starts_with("usage:"));
174        assert!(key.contains(&org_id.to_string()));
175        assert!(key.contains(period));
176        assert!(key.contains(usage_type));
177        assert_eq!(key, format!("usage:{}:{}:{}", org_id, period, usage_type));
178    }
179
180    #[test]
181    fn test_org_rate_limit_key() {
182        let org_id = uuid::Uuid::new_v4();
183
184        let key = org_rate_limit_key(&org_id);
185
186        assert!(key.starts_with("ratelimit:"));
187        assert!(key.contains(&org_id.to_string()));
188        assert_eq!(key, format!("ratelimit:{}", org_id));
189    }
190
191    #[test]
192    fn test_current_month_period_format() {
193        let period = current_month_period();
194
195        // Should be in YYYY-MM format
196        assert_eq!(period.len(), 7); // "YYYY-MM" is 7 characters
197        assert!(period.contains('-'));
198
199        // Parse and validate format
200        let parts: Vec<&str> = period.split('-').collect();
201        assert_eq!(parts.len(), 2);
202
203        // Year should be 4 digits
204        assert_eq!(parts[0].len(), 4);
205        let year: i32 = parts[0].parse().expect("Year should be numeric");
206        assert!(year >= 2025); // Should be current year or later
207
208        // Month should be 2 digits
209        assert_eq!(parts[1].len(), 2);
210        let month: u32 = parts[1].parse().expect("Month should be numeric");
211        assert!((1..=12).contains(&month));
212    }
213
214    #[test]
215    fn test_current_month_period_consistency() {
216        // Call multiple times in quick succession, should return same value
217        let period1 = current_month_period();
218        let period2 = current_month_period();
219
220        assert_eq!(period1, period2);
221    }
222
223    #[test]
224    fn test_org_usage_key_different_periods() {
225        let org_id = uuid::Uuid::new_v4();
226
227        let key1 = org_usage_key(&org_id, "2025-01");
228        let key2 = org_usage_key(&org_id, "2025-02");
229
230        assert_ne!(key1, key2);
231        assert!(key1.contains("2025-01"));
232        assert!(key2.contains("2025-02"));
233    }
234
235    #[test]
236    fn test_org_usage_key_different_orgs() {
237        let org_id1 = uuid::Uuid::new_v4();
238        let org_id2 = uuid::Uuid::new_v4();
239        let period = "2025-01";
240
241        let key1 = org_usage_key(&org_id1, period);
242        let key2 = org_usage_key(&org_id2, period);
243
244        assert_ne!(key1, key2);
245        assert!(key1.contains(&org_id1.to_string()));
246        assert!(key2.contains(&org_id2.to_string()));
247    }
248
249    #[test]
250    fn test_org_usage_key_by_type_different_types() {
251        let org_id = uuid::Uuid::new_v4();
252        let period = "2025-01";
253
254        let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
255        let key2 = org_usage_key_by_type(&org_id, period, "storage");
256        let key3 = org_usage_key_by_type(&org_id, period, "bandwidth");
257
258        assert_ne!(key1, key2);
259        assert_ne!(key2, key3);
260        assert!(key1.contains("api_calls"));
261        assert!(key2.contains("storage"));
262        assert!(key3.contains("bandwidth"));
263    }
264
265    #[test]
266    fn test_org_rate_limit_key_different_orgs() {
267        let org_id1 = uuid::Uuid::new_v4();
268        let org_id2 = uuid::Uuid::new_v4();
269
270        let key1 = org_rate_limit_key(&org_id1);
271        let key2 = org_rate_limit_key(&org_id2);
272
273        assert_ne!(key1, key2);
274    }
275
276    #[test]
277    fn test_key_format_no_spaces() {
278        let org_id = uuid::Uuid::new_v4();
279
280        let key1 = org_usage_key(&org_id, "2025-01");
281        let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
282        let key3 = org_rate_limit_key(&org_id);
283
284        assert!(!key1.contains(' '));
285        assert!(!key2.contains(' '));
286        assert!(!key3.contains(' '));
287    }
288
289    #[test]
290    fn test_key_format_no_special_chars() {
291        let org_id = uuid::Uuid::new_v4();
292
293        let key1 = org_usage_key(&org_id, "2025-01");
294        let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
295        let key3 = org_rate_limit_key(&org_id);
296
297        // Keys should only contain alphanumeric, hyphens, underscores, and colons
298        let valid_chars =
299            |s: &str| s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == ':');
300
301        assert!(valid_chars(&key1));
302        assert!(valid_chars(&key2));
303        assert!(valid_chars(&key3));
304    }
305
306    #[test]
307    fn test_usage_key_with_special_period_formats() {
308        let org_id = uuid::Uuid::new_v4();
309
310        // Test various period formats
311        let key1 = org_usage_key(&org_id, "2025-01");
312        let key2 = org_usage_key(&org_id, "2025-12");
313        let key3 = org_usage_key(&org_id, "2024-06");
314
315        assert!(key1.contains("2025-01"));
316        assert!(key2.contains("2025-12"));
317        assert!(key3.contains("2024-06"));
318    }
319
320    #[test]
321    fn test_usage_key_by_type_with_special_types() {
322        let org_id = uuid::Uuid::new_v4();
323        let period = "2025-01";
324
325        // Test various usage types
326        let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
327        let key2 = org_usage_key_by_type(&org_id, period, "storage_gb");
328        let key3 = org_usage_key_by_type(&org_id, period, "bandwidth_mb");
329
330        assert!(key1.ends_with("api_calls"));
331        assert!(key2.ends_with("storage_gb"));
332        assert!(key3.ends_with("bandwidth_mb"));
333    }
334
335    #[test]
336    fn test_redis_pool_clone() {
337        // This tests the Clone trait on RedisPool
338        // We can't actually create a RedisPool without a Redis server,
339        // but we can verify the trait is implemented via compilation
340        fn requires_clone<T: Clone>() {}
341        requires_clone::<RedisPool>();
342    }
343
344    #[test]
345    fn test_current_month_period_matches_chrono() {
346        let period = current_month_period();
347        let now = chrono::Utc::now();
348        let expected = format!("{}-{:02}", now.year(), now.month());
349
350        assert_eq!(period, expected);
351    }
352
353    // Mock-based tests would require a Redis server, so we focus on
354    // testing the key generation functions which don't require external services
355}