use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use crate::error::Result;
#[async_trait]
pub trait StateStore: Send + Sync {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()>;
async fn retrieve(&self, state: &str) -> Result<(String, u64)>;
}
#[derive(Debug)]
pub struct InMemoryStateStore {
states: Arc<DashMap<String, (String, u64)>>,
max_states: usize,
}
impl InMemoryStateStore {
const MAX_STATES: usize = 10_000;
pub fn new() -> Self {
Self {
states: Arc::new(DashMap::new()),
max_states: Self::MAX_STATES,
}
}
pub fn with_max_states(max_states: usize) -> Self {
Self {
states: Arc::new(DashMap::new()),
max_states: max_states.max(1), }
}
fn cleanup_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.states.retain(|_key, (_provider, expiry)| *expiry > now);
self.states.len() >= self.max_states
}
}
impl Default for InMemoryStateStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
if self.cleanup_expired() {
return Err(crate::error::AuthError::ConfigError {
message: "State store at capacity, cannot store new state".to_string(),
});
}
self.states.insert(state, (provider, expiry_secs));
Ok(())
}
async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
let (_key, value) =
self.states.remove(state).ok_or_else(|| crate::error::AuthError::InvalidState)?;
Ok(value)
}
}
#[cfg(feature = "redis-rate-limiting")]
#[derive(Clone)]
pub struct RedisStateStore {
client: redis::aio::ConnectionManager,
}
#[cfg(feature = "redis-rate-limiting")]
impl RedisStateStore {
pub async fn new(redis_url: &str) -> Result<Self> {
let client =
redis::Client::open(redis_url).map_err(|e| crate::error::AuthError::ConfigError {
message: e.to_string(),
})?;
let connection_manager = client.get_connection_manager().await.map_err(|e| {
crate::error::AuthError::ConfigError {
message: e.to_string(),
}
})?;
Ok(Self {
client: connection_manager,
})
}
fn state_key(state: &str) -> String {
format!("oauth:state:{}", state)
}
}
#[cfg(feature = "redis-rate-limiting")]
#[async_trait]
impl StateStore for RedisStateStore {
async fn store(&self, state: String, provider: String, expiry_secs: u64) -> Result<()> {
use redis::AsyncCommands;
let key = Self::state_key(&state);
let ttl = expiry_secs
.saturating_sub(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)
.max(1);
let mut conn = self.client.clone();
let _: () = conn.set_ex(&key, &provider, ttl).await.map_err(|e| {
crate::error::AuthError::ConfigError {
message: e.to_string(),
}
})?;
Ok(())
}
async fn retrieve(&self, state: &str) -> Result<(String, u64)> {
use redis::AsyncCommands;
let key = Self::state_key(state);
let mut conn = self.client.clone();
let provider: Option<String> =
conn.get(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
message: e.to_string(),
})?;
let provider = provider.ok_or(crate::error::AuthError::InvalidState)?;
let _: () = conn.del(&key).await.map_err(|e| crate::error::AuthError::ConfigError {
message: e.to_string(),
})?;
let expiry_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Ok((provider, expiry_secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_state_store() {
let store = InMemoryStateStore::new();
store
.store(
"state123".to_string(),
"google".to_string(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600,
)
.await
.unwrap();
let (provider, _expiry) = store.retrieve("state123").await.unwrap();
assert_eq!(provider, "google");
let result = store.retrieve("state123").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_state_not_found() {
let store = InMemoryStateStore::new();
let result = store.retrieve("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_in_memory_state_replay_prevention() {
let store = InMemoryStateStore::new();
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store.store("state_abc".to_string(), "auth0".to_string(), expiry).await.unwrap();
let result1 = store.retrieve("state_abc").await;
assert!(result1.is_ok());
let result2 = store.retrieve("state_abc").await;
assert!(result2.is_err());
}
#[tokio::test]
async fn test_in_memory_multiple_states() {
let store = InMemoryStateStore::new();
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store.store("state1".to_string(), "google".to_string(), expiry).await.unwrap();
store.store("state2".to_string(), "auth0".to_string(), expiry).await.unwrap();
store.store("state3".to_string(), "okta".to_string(), expiry).await.unwrap();
let (p1, _) = store.retrieve("state1").await.unwrap();
assert_eq!(p1, "google");
let (p2, _) = store.retrieve("state2").await.unwrap();
assert_eq!(p2, "auth0");
let (p3, _) = store.retrieve("state3").await.unwrap();
assert_eq!(p3, "okta");
}
#[tokio::test]
async fn test_in_memory_state_store_trait_object() {
let store: Arc<dyn StateStore> = Arc::new(InMemoryStateStore::new());
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("state_trait".to_string(), "test_provider".to_string(), expiry)
.await
.unwrap();
let (provider, _) = store.retrieve("state_trait").await.unwrap();
assert_eq!(provider, "test_provider");
}
#[tokio::test]
async fn test_in_memory_state_store_bounded() {
let store = InMemoryStateStore::with_max_states(5);
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
for i in 0..5 {
let state = format!("state_{}", i);
store.store(state, "google".to_string(), expiry).await.unwrap();
}
let result = store.store("state_5".to_string(), "google".to_string(), expiry).await;
assert!(result.is_err(), "Should reject insertion when at capacity");
}
#[tokio::test]
async fn test_in_memory_state_store_cleanup_expired() {
let store = InMemoryStateStore::with_max_states(3);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
for i in 0..3 {
let state = format!("expired_{}", i);
store.store(state, "google".to_string(), now - 100).await.unwrap();
}
let expiry = now + 600;
let result = store.store("valid_state".to_string(), "auth0".to_string(), expiry).await;
assert!(result.is_ok(), "Should succeed after cleaning up expired states");
store
.store("valid_state_2".to_string(), "google".to_string(), expiry)
.await
.unwrap();
store
.store("valid_state_3".to_string(), "okta".to_string(), expiry)
.await
.unwrap();
let result = store.store("valid_state_4".to_string(), "auth0".to_string(), expiry).await;
assert!(result.is_err(), "Should be at capacity now");
}
#[tokio::test]
async fn test_in_memory_state_store_custom_max_size() {
let store_small = InMemoryStateStore::with_max_states(1);
let store_large = InMemoryStateStore::with_max_states(100);
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store_small.store("s1".to_string(), "p1".to_string(), expiry).await.unwrap();
let result = store_small.store("s2".to_string(), "p2".to_string(), expiry).await;
assert!(result.is_err());
for i in 0..50 {
let state = format!("state_{}", i);
store_large.store(state, "provider".to_string(), expiry).await.unwrap();
}
assert_eq!(store_large.states.len(), 50);
}
#[tokio::test]
async fn test_in_memory_state_store_zero_max_enforced() {
let store = InMemoryStateStore::with_max_states(0);
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
let result = store.store("state1".to_string(), "google".to_string(), expiry).await;
assert!(result.is_ok(), "Should allow at least 1 state minimum");
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_state_store_basic() {
let redis_url = "redis://localhost:6379";
match RedisStateStore::new(redis_url).await {
Ok(store) => {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_state_1".to_string(), "google".to_string(), expiry)
.await
.unwrap();
let (provider, _) = store.retrieve("redis_state_1").await.unwrap();
assert_eq!(provider, "google");
let result = store.retrieve("redis_state_1").await;
assert!(result.is_err());
},
Err(_) => {
eprintln!("Skipping Redis tests - Redis server not available");
},
}
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_state_replay_prevention() {
let redis_url = "redis://localhost:6379";
if let Ok(store) = RedisStateStore::new(redis_url).await {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_replay_test".to_string(), "auth0".to_string(), expiry)
.await
.unwrap();
let result1 = store.retrieve("redis_replay_test").await;
assert!(result1.is_ok());
let result2 = store.retrieve("redis_replay_test").await;
assert!(result2.is_err());
}
}
#[cfg(feature = "redis-rate-limiting")]
#[tokio::test]
async fn test_redis_multiple_states() {
let redis_url = "redis://localhost:6379";
if let Ok(store) = RedisStateStore::new(redis_url).await {
let expiry = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600;
store
.store("redis_state_a".to_string(), "google".to_string(), expiry)
.await
.unwrap();
store
.store("redis_state_b".to_string(), "okta".to_string(), expiry)
.await
.unwrap();
let (p1, _) = store.retrieve("redis_state_a").await.unwrap();
assert_eq!(p1, "google");
let (p2, _) = store.retrieve("redis_state_b").await.unwrap();
assert_eq!(p2, "okta");
}
}
}