use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use crate::auth::error::Result;
#[async_trait]
pub trait StateStore: Send + Sync {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
}
#[derive(Debug)]
pub struct InMemoryStateStore {
states: Arc<DashMap<String, (String, u64)>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self {
states: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemoryStateStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
self.states.insert(state, (provider, expiry_secs));
Ok(())
}
async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
let (_key, value) = self
.states
.remove(state)
.ok_or_else(|| crate::auth::error::AuthError::InvalidState)?;
Ok(value)
}
}
#[cfg(feature = "redis-rate-limiting")]
#[derive(Clone)]
pub struct RedisStateStore {
client: redis::aio::ConnectionManager,
}
#[cfg(feature = "redis-rate-limiting")]
impl RedisStateStore {
pub async fn new(redis_url: &str) -> Result<Self> {
let client = redis::Client::open(redis_url).map_err(|e| {
crate::auth::error::AuthError::ConfigError {
message: e.to_string(),
}
})?;
let connection_manager = client.get_connection_manager().await.map_err(|e| {
crate::auth::error::AuthError::ConfigError {
message: e.to_string(),
}
})?;
Ok(Self {
client: connection_manager,
})
}
fn state_key(state: &str) -> String {
format!("oauth:state:{}", state)
}
}
#[cfg(feature = "redis-rate-limiting")]
#[async_trait]
impl StateStore for RedisStateStore {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
use redis::AsyncCommands;
let key = Self::state_key(&state);
let ttl = expiry_secs
.saturating_sub(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)
.max(1);
let mut conn = self.client.clone();
let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
crate::auth::error::AuthError::ConfigError {
message: e.to_string(),
}
})?;
Ok(())
}
async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
use redis::AsyncCommands;
let key = Self::state_key(state);
let mut conn = self.client.clone();
let provider: Option<String> =
conn.get(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
message: e.to_string(),
})?;
let provider = provider.ok_or(crate::auth::error::AuthError::InvalidState)?;
let _: () =
conn.del(&key).await.map_err(|e| crate::auth::error::AuthError::ConfigError {
message: e.to_string(),
})?;
let expiry_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok((provider, expiry_secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_state_store() {
let store = InMemoryStateStore::new();
store
.store(
"state123".to_string(),
"google".to_string(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600,
)
.await
.unwrap();
let (provider, _expiry) = store.retrieve("state123").await.unwrap();
assert_eq!(provider, "google");
let result = store.retrieve("state123").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_state_not_found() {
let store = InMemoryStateStore::new();
let result = store.retrieve("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_in_memory_state_replay_prevention() {
let store = InMemoryStateStore::new();
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
let result1 = store.retrieve("state_abc").await;
assert!(result1.is_ok());
let result2 = store.retrieve("state_abc").await;
assert!(result2.is_err());
}
#[tokio::test]
async fn test_in_memory_multiple_states() {
let store = InMemoryStateStore::new();
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
let (p1, _) = store.retrieve("state1").await.unwrap();
assert_eq!(p1, "google");
let (p2, _) = store.retrieve("state2").await.unwrap();
assert_eq!(p2, "auth0");
let (p3, _) = store.retrieve("state3").await.unwrap();
assert_eq!(p3, "okta");
}
#[tokio::test]
async fn test_in_memory_state_store_trait_object() {
let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("state_trait".to_string(), "test_provider".to_string(), expiry)
.await
.unwrap();
let (provider, _) = store.retrieve("state_trait").await.unwrap();
assert_eq!(provider, "test_provider");
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_state_store_basic() {
let redis_url = "redis://localhost:6379";
match RedisStateStore::new(redis_url).await {
Ok(store) => {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_state_1".to_string(), "google".to_string(), expiry)
.await
.unwrap();
let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
assert_eq!(provider, "google");
let result = store.retrieve("redis_state_1").await;
assert!(result.is_err());
},
Err(_) => {
eprintln!("Skipping Redis tests - Redis server not available");
},
}
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_state_replay_prevention() {
let redis_url = "redis://localhost:6379";
if let Ok(store) = RedisStateStore::new(redis_url).await {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
.await
.unwrap();
let result1 = store.retrieve("redis_replay_test").await;
assert!(result1.is_ok());
let result2 = store.retrieve("redis_replay_test").await;
assert!(result2.is_err());
}
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_multiple_states() {
let redis_url = "redis://localhost:6379";
if let Ok(store) = RedisStateStore::new(redis_url).await {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_state_a".to_string(), "google".to_string(), expiry)
.await
.unwrap();
store
.store("redis_state_b".to_string(), "okta".to_string(), expiry)
.await
.unwrap();
let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
assert_eq!(p1, "google");
let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
assert_eq!(p2, "okta");
}
}
}