mockforge_registry_server/
redis.rs1use anyhow::Result;
4use chrono::Datelike;
5use std::sync::Arc;
6
7use redis::{aio::ConnectionManager, AsyncCommands, Client};
9
10#[derive(Clone)]
12pub struct RedisPool {
13 manager: Arc<ConnectionManager>,
14}
15
16impl RedisPool {
17 pub async fn connect(redis_url: &str) -> Result<Self> {
19 let client = Client::open(redis_url)?;
20 let manager = ConnectionManager::new(client).await?;
21
22 Ok(Self {
23 manager: Arc::new(manager),
24 })
25 }
26
27 pub fn get_connection(&self) -> Arc<ConnectionManager> {
30 self.manager.clone()
31 }
32
33 pub async fn increment_with_expiry(&self, key: &str, expiry_seconds: u64) -> Result<i64> {
36 let mut conn = (*self.manager).clone();
38
39 let count: i64 = conn.incr(key, 1).await?;
41
42 if count == 1 {
44 conn.expire::<_, ()>(key, expiry_seconds as i64).await?;
45 }
46
47 Ok(count)
48 }
49
50 pub async fn get_counter(&self, key: &str) -> Result<i64> {
52 let mut conn = (*self.manager).clone();
53 let count: i64 = conn.get(key).await.unwrap_or(0);
54 Ok(count)
55 }
56
57 pub async fn set_with_expiry(&self, key: &str, value: &str, expiry_seconds: u64) -> Result<()> {
59 let mut conn = (*self.manager).clone();
60 conn.set_ex::<_, _, ()>(key, value, expiry_seconds).await?;
61 Ok(())
62 }
63
64 pub async fn get(&self, key: &str) -> Result<Option<String>> {
66 let mut conn = (*self.manager).clone();
67 let value: Option<String> = conn.get(key).await?;
68 Ok(value)
69 }
70
71 pub async fn delete(&self, key: &str) -> Result<()> {
73 let mut conn = (*self.manager).clone();
74 conn.del::<_, ()>(key).await?;
75 Ok(())
76 }
77
78 pub async fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
80 let mut conn = (*self.manager).clone();
81 let mut cursor: u64 = 0;
82 let mut keys = Vec::new();
83
84 loop {
85 let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
86 .arg(cursor)
87 .arg("MATCH")
88 .arg(pattern)
89 .arg("COUNT")
90 .arg(100)
91 .query_async(&mut conn)
92 .await?;
93
94 keys.extend(batch);
95 cursor = next_cursor;
96
97 if cursor == 0 {
98 break;
99 }
100 }
101
102 Ok(keys)
103 }
104
105 pub async fn ping(&self) -> Result<()> {
107 let mut conn = (*self.manager).clone();
108 let _: String = conn.get("__ping_test__").await.unwrap_or_else(|_| "PONG".to_string());
110 Ok(())
111 }
112}
113
114pub fn org_usage_key(org_id: &uuid::Uuid, period: &str) -> String {
116 format!("usage:{}:{}", org_id, period)
117}
118
119pub fn org_usage_key_by_type(org_id: &uuid::Uuid, period: &str, usage_type: &str) -> String {
121 format!("usage:{}:{}:{}", org_id, period, usage_type)
122}
123
124pub fn org_rate_limit_key(org_id: &uuid::Uuid) -> String {
126 format!("ratelimit:{}", org_id)
127}
128
129pub fn current_month_period() -> String {
131 let now = chrono::Utc::now();
132 format!("{}-{:02}", now.year(), now.month())
133}
134
135pub fn two_factor_setup_key(user_id: &uuid::Uuid) -> String {
137 format!("2fa_setup:{}", user_id)
138}
139
140pub fn two_factor_backup_codes_key(user_id: &uuid::Uuid) -> String {
142 format!("2fa_backup_codes:{}", user_id)
143}
144
145pub const TWO_FACTOR_SETUP_TTL_SECONDS: u64 = 300;
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_org_usage_key() {
154 let org_id = uuid::Uuid::new_v4();
155 let period = "2025-01";
156
157 let key = org_usage_key(&org_id, period);
158
159 assert!(key.starts_with("usage:"));
160 assert!(key.contains(&org_id.to_string()));
161 assert!(key.contains(period));
162 assert_eq!(key, format!("usage:{}:{}", org_id, period));
163 }
164
165 #[test]
166 fn test_org_usage_key_by_type() {
167 let org_id = uuid::Uuid::new_v4();
168 let period = "2025-01";
169 let usage_type = "api_calls";
170
171 let key = org_usage_key_by_type(&org_id, period, usage_type);
172
173 assert!(key.starts_with("usage:"));
174 assert!(key.contains(&org_id.to_string()));
175 assert!(key.contains(period));
176 assert!(key.contains(usage_type));
177 assert_eq!(key, format!("usage:{}:{}:{}", org_id, period, usage_type));
178 }
179
180 #[test]
181 fn test_org_rate_limit_key() {
182 let org_id = uuid::Uuid::new_v4();
183
184 let key = org_rate_limit_key(&org_id);
185
186 assert!(key.starts_with("ratelimit:"));
187 assert!(key.contains(&org_id.to_string()));
188 assert_eq!(key, format!("ratelimit:{}", org_id));
189 }
190
191 #[test]
192 fn test_current_month_period_format() {
193 let period = current_month_period();
194
195 assert_eq!(period.len(), 7); assert!(period.contains('-'));
198
199 let parts: Vec<&str> = period.split('-').collect();
201 assert_eq!(parts.len(), 2);
202
203 assert_eq!(parts[0].len(), 4);
205 let year: i32 = parts[0].parse().expect("Year should be numeric");
206 assert!(year >= 2025); assert_eq!(parts[1].len(), 2);
210 let month: u32 = parts[1].parse().expect("Month should be numeric");
211 assert!((1..=12).contains(&month));
212 }
213
214 #[test]
215 fn test_current_month_period_consistency() {
216 let period1 = current_month_period();
218 let period2 = current_month_period();
219
220 assert_eq!(period1, period2);
221 }
222
223 #[test]
224 fn test_org_usage_key_different_periods() {
225 let org_id = uuid::Uuid::new_v4();
226
227 let key1 = org_usage_key(&org_id, "2025-01");
228 let key2 = org_usage_key(&org_id, "2025-02");
229
230 assert_ne!(key1, key2);
231 assert!(key1.contains("2025-01"));
232 assert!(key2.contains("2025-02"));
233 }
234
235 #[test]
236 fn test_org_usage_key_different_orgs() {
237 let org_id1 = uuid::Uuid::new_v4();
238 let org_id2 = uuid::Uuid::new_v4();
239 let period = "2025-01";
240
241 let key1 = org_usage_key(&org_id1, period);
242 let key2 = org_usage_key(&org_id2, period);
243
244 assert_ne!(key1, key2);
245 assert!(key1.contains(&org_id1.to_string()));
246 assert!(key2.contains(&org_id2.to_string()));
247 }
248
249 #[test]
250 fn test_org_usage_key_by_type_different_types() {
251 let org_id = uuid::Uuid::new_v4();
252 let period = "2025-01";
253
254 let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
255 let key2 = org_usage_key_by_type(&org_id, period, "storage");
256 let key3 = org_usage_key_by_type(&org_id, period, "bandwidth");
257
258 assert_ne!(key1, key2);
259 assert_ne!(key2, key3);
260 assert!(key1.contains("api_calls"));
261 assert!(key2.contains("storage"));
262 assert!(key3.contains("bandwidth"));
263 }
264
265 #[test]
266 fn test_org_rate_limit_key_different_orgs() {
267 let org_id1 = uuid::Uuid::new_v4();
268 let org_id2 = uuid::Uuid::new_v4();
269
270 let key1 = org_rate_limit_key(&org_id1);
271 let key2 = org_rate_limit_key(&org_id2);
272
273 assert_ne!(key1, key2);
274 }
275
276 #[test]
277 fn test_key_format_no_spaces() {
278 let org_id = uuid::Uuid::new_v4();
279
280 let key1 = org_usage_key(&org_id, "2025-01");
281 let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
282 let key3 = org_rate_limit_key(&org_id);
283
284 assert!(!key1.contains(' '));
285 assert!(!key2.contains(' '));
286 assert!(!key3.contains(' '));
287 }
288
289 #[test]
290 fn test_key_format_no_special_chars() {
291 let org_id = uuid::Uuid::new_v4();
292
293 let key1 = org_usage_key(&org_id, "2025-01");
294 let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
295 let key3 = org_rate_limit_key(&org_id);
296
297 let valid_chars =
299 |s: &str| s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == ':');
300
301 assert!(valid_chars(&key1));
302 assert!(valid_chars(&key2));
303 assert!(valid_chars(&key3));
304 }
305
306 #[test]
307 fn test_usage_key_with_special_period_formats() {
308 let org_id = uuid::Uuid::new_v4();
309
310 let key1 = org_usage_key(&org_id, "2025-01");
312 let key2 = org_usage_key(&org_id, "2025-12");
313 let key3 = org_usage_key(&org_id, "2024-06");
314
315 assert!(key1.contains("2025-01"));
316 assert!(key2.contains("2025-12"));
317 assert!(key3.contains("2024-06"));
318 }
319
320 #[test]
321 fn test_usage_key_by_type_with_special_types() {
322 let org_id = uuid::Uuid::new_v4();
323 let period = "2025-01";
324
325 let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
327 let key2 = org_usage_key_by_type(&org_id, period, "storage_gb");
328 let key3 = org_usage_key_by_type(&org_id, period, "bandwidth_mb");
329
330 assert!(key1.ends_with("api_calls"));
331 assert!(key2.ends_with("storage_gb"));
332 assert!(key3.ends_with("bandwidth_mb"));
333 }
334
335 #[test]
336 fn test_redis_pool_clone() {
337 fn requires_clone<T: Clone>() {}
341 requires_clone::<RedisPool>();
342 }
343
344 #[test]
345 fn test_current_month_period_matches_chrono() {
346 let period = current_month_period();
347 let now = chrono::Utc::now();
348 let expected = format!("{}-{:02}", now.year(), now.month());
349
350 assert_eq!(period, expected);
351 }
352
353 }