use anyhow::Result;
use chrono::Datelike;
use std::sync::Arc;
use redis::{aio::ConnectionManager, AsyncCommands, Client};
#[derive(Clone)]
pub struct RedisPool {
manager: Arc<ConnectionManager>,
}
impl RedisPool {
pub async fn connect(redis_url: &str) -> Result<Self> {
let client = Client::open(redis_url)?;
let manager = ConnectionManager::new(client).await?;
Ok(Self {
manager: Arc::new(manager),
})
}
pub fn get_connection(&self) -> Arc<ConnectionManager> {
self.manager.clone()
}
pub async fn increment_with_expiry(&self, key: &str, expiry_seconds: u64) -> Result<i64> {
let mut conn = (*self.manager).clone();
let count: i64 = conn.incr(key, 1).await?;
if count == 1 {
conn.expire::<_, ()>(key, expiry_seconds as i64).await?;
}
Ok(count)
}
pub async fn get_counter(&self, key: &str) -> Result<i64> {
let mut conn = (*self.manager).clone();
let count: i64 = conn.get(key).await.unwrap_or(0);
Ok(count)
}
pub async fn set_with_expiry(&self, key: &str, value: &str, expiry_seconds: u64) -> Result<()> {
let mut conn = (*self.manager).clone();
conn.set_ex::<_, _, ()>(key, value, expiry_seconds).await?;
Ok(())
}
pub async fn get(&self, key: &str) -> Result<Option<String>> {
let mut conn = (*self.manager).clone();
let value: Option<String> = conn.get(key).await?;
Ok(value)
}
pub async fn delete(&self, key: &str) -> Result<()> {
let mut conn = (*self.manager).clone();
conn.del::<_, ()>(key).await?;
Ok(())
}
pub async fn scan_keys(&self, pattern: &str) -> Result<Vec<String>> {
let mut conn = (*self.manager).clone();
let mut cursor: u64 = 0;
let mut keys = Vec::new();
loop {
let (next_cursor, batch): (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await?;
keys.extend(batch);
cursor = next_cursor;
if cursor == 0 {
break;
}
}
Ok(keys)
}
pub async fn ping(&self) -> Result<()> {
let mut conn = (*self.manager).clone();
let _: String = conn.get("__ping_test__").await.unwrap_or_else(|_| "PONG".to_string());
Ok(())
}
}
pub fn org_usage_key(org_id: &uuid::Uuid, period: &str) -> String {
format!("usage:{}:{}", org_id, period)
}
pub fn org_usage_key_by_type(org_id: &uuid::Uuid, period: &str, usage_type: &str) -> String {
format!("usage:{}:{}:{}", org_id, period, usage_type)
}
pub fn org_rate_limit_key(org_id: &uuid::Uuid) -> String {
format!("ratelimit:{}", org_id)
}
pub fn current_month_period() -> String {
let now = chrono::Utc::now();
format!("{}-{:02}", now.year(), now.month())
}
pub fn two_factor_setup_key(user_id: &uuid::Uuid) -> String {
format!("2fa_setup:{}", user_id)
}
pub fn two_factor_backup_codes_key(user_id: &uuid::Uuid) -> String {
format!("2fa_backup_codes:{}", user_id)
}
pub const TWO_FACTOR_SETUP_TTL_SECONDS: u64 = 300;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_org_usage_key() {
let org_id = uuid::Uuid::new_v4();
let period = "2025-01";
let key = org_usage_key(&org_id, period);
assert!(key.starts_with("usage:"));
assert!(key.contains(&org_id.to_string()));
assert!(key.contains(period));
assert_eq!(key, format!("usage:{}:{}", org_id, period));
}
#[test]
fn test_org_usage_key_by_type() {
let org_id = uuid::Uuid::new_v4();
let period = "2025-01";
let usage_type = "api_calls";
let key = org_usage_key_by_type(&org_id, period, usage_type);
assert!(key.starts_with("usage:"));
assert!(key.contains(&org_id.to_string()));
assert!(key.contains(period));
assert!(key.contains(usage_type));
assert_eq!(key, format!("usage:{}:{}:{}", org_id, period, usage_type));
}
#[test]
fn test_org_rate_limit_key() {
let org_id = uuid::Uuid::new_v4();
let key = org_rate_limit_key(&org_id);
assert!(key.starts_with("ratelimit:"));
assert!(key.contains(&org_id.to_string()));
assert_eq!(key, format!("ratelimit:{}", org_id));
}
#[test]
fn test_current_month_period_format() {
let period = current_month_period();
assert_eq!(period.len(), 7); assert!(period.contains('-'));
let parts: Vec<&str> = period.split('-').collect();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].len(), 4);
let year: i32 = parts[0].parse().expect("Year should be numeric");
assert!(year >= 2025);
assert_eq!(parts[1].len(), 2);
let month: u32 = parts[1].parse().expect("Month should be numeric");
assert!((1..=12).contains(&month));
}
#[test]
fn test_current_month_period_consistency() {
let period1 = current_month_period();
let period2 = current_month_period();
assert_eq!(period1, period2);
}
#[test]
fn test_org_usage_key_different_periods() {
let org_id = uuid::Uuid::new_v4();
let key1 = org_usage_key(&org_id, "2025-01");
let key2 = org_usage_key(&org_id, "2025-02");
assert_ne!(key1, key2);
assert!(key1.contains("2025-01"));
assert!(key2.contains("2025-02"));
}
#[test]
fn test_org_usage_key_different_orgs() {
let org_id1 = uuid::Uuid::new_v4();
let org_id2 = uuid::Uuid::new_v4();
let period = "2025-01";
let key1 = org_usage_key(&org_id1, period);
let key2 = org_usage_key(&org_id2, period);
assert_ne!(key1, key2);
assert!(key1.contains(&org_id1.to_string()));
assert!(key2.contains(&org_id2.to_string()));
}
#[test]
fn test_org_usage_key_by_type_different_types() {
let org_id = uuid::Uuid::new_v4();
let period = "2025-01";
let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
let key2 = org_usage_key_by_type(&org_id, period, "storage");
let key3 = org_usage_key_by_type(&org_id, period, "bandwidth");
assert_ne!(key1, key2);
assert_ne!(key2, key3);
assert!(key1.contains("api_calls"));
assert!(key2.contains("storage"));
assert!(key3.contains("bandwidth"));
}
#[test]
fn test_org_rate_limit_key_different_orgs() {
let org_id1 = uuid::Uuid::new_v4();
let org_id2 = uuid::Uuid::new_v4();
let key1 = org_rate_limit_key(&org_id1);
let key2 = org_rate_limit_key(&org_id2);
assert_ne!(key1, key2);
}
#[test]
fn test_key_format_no_spaces() {
let org_id = uuid::Uuid::new_v4();
let key1 = org_usage_key(&org_id, "2025-01");
let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
let key3 = org_rate_limit_key(&org_id);
assert!(!key1.contains(' '));
assert!(!key2.contains(' '));
assert!(!key3.contains(' '));
}
#[test]
fn test_key_format_no_special_chars() {
let org_id = uuid::Uuid::new_v4();
let key1 = org_usage_key(&org_id, "2025-01");
let key2 = org_usage_key_by_type(&org_id, "2025-01", "api_calls");
let key3 = org_rate_limit_key(&org_id);
let valid_chars =
|s: &str| s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == ':');
assert!(valid_chars(&key1));
assert!(valid_chars(&key2));
assert!(valid_chars(&key3));
}
#[test]
fn test_usage_key_with_special_period_formats() {
let org_id = uuid::Uuid::new_v4();
let key1 = org_usage_key(&org_id, "2025-01");
let key2 = org_usage_key(&org_id, "2025-12");
let key3 = org_usage_key(&org_id, "2024-06");
assert!(key1.contains("2025-01"));
assert!(key2.contains("2025-12"));
assert!(key3.contains("2024-06"));
}
#[test]
fn test_usage_key_by_type_with_special_types() {
let org_id = uuid::Uuid::new_v4();
let period = "2025-01";
let key1 = org_usage_key_by_type(&org_id, period, "api_calls");
let key2 = org_usage_key_by_type(&org_id, period, "storage_gb");
let key3 = org_usage_key_by_type(&org_id, period, "bandwidth_mb");
assert!(key1.ends_with("api_calls"));
assert!(key2.ends_with("storage_gb"));
assert!(key3.ends_with("bandwidth_mb"));
}
#[test]
fn test_redis_pool_clone() {
fn requires_clone<T: Clone>() {}
requires_clone::<RedisPool>();
}
#[test]
fn test_current_month_period_matches_chrono() {
let period = current_month_period();
let now = chrono::Utc::now();
let expected = format!("{}-{:02}", now.year(), now.month());
assert_eq!(period, expected);
}
}