use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::RwLock;
use crate::sessions::backends::cache::{SessionBackend, SessionError};
use crate::social::core::SocialAuthError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateData {
pub state: String,
pub nonce: Option<String>,
pub code_verifier: Option<String>,
pub expires_at: DateTime<Utc>,
}
impl StateData {
pub fn new(state: String, nonce: Option<String>, code_verifier: Option<String>) -> Self {
Self {
state,
nonce,
code_verifier,
expires_at: Utc::now() + Duration::minutes(10),
}
}
pub fn with_ttl(
state: String,
nonce: Option<String>,
code_verifier: Option<String>,
ttl: Duration,
) -> Self {
Self {
state,
nonce,
code_verifier,
expires_at: Utc::now() + ttl,
}
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn store(&self, data: StateData) -> Result<(), SocialAuthError>;
async fn retrieve(&self, state: &str) -> Result<StateData, SocialAuthError>;
async fn remove(&self, state: &str) -> Result<(), SocialAuthError>;
}
#[derive(Debug, Default)]
pub struct InMemoryStateStore {
store: RwLock<HashMap<String, StateData>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self {
store: RwLock::new(HashMap::new()),
}
}
async fn cleanup_expired(&self) {
let mut store = self.store.write().await;
store.retain(|_, data| !data.is_expired());
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn store(&self, data: StateData) -> Result<(), SocialAuthError> {
self.cleanup_expired().await;
let mut store = self.store.write().await;
store.insert(data.state.clone(), data);
Ok(())
}
async fn retrieve(&self, state: &str) -> Result<StateData, SocialAuthError> {
let store = self.store.read().await;
let data = store
.get(state)
.ok_or(SocialAuthError::InvalidState)?
.clone();
if data.is_expired() {
return Err(SocialAuthError::InvalidState);
}
Ok(data)
}
async fn remove(&self, state: &str) -> Result<(), SocialAuthError> {
let mut store = self.store.write().await;
store.remove(state).ok_or(SocialAuthError::InvalidState)?;
Ok(())
}
}
pub struct SessionStateStore<B: SessionBackend> {
backend: B,
key_prefix: String,
}
const DEFAULT_KEY_PREFIX: &str = "_social_auth_state:";
impl<B: SessionBackend> SessionStateStore<B> {
pub fn new(backend: B) -> Self {
Self {
backend,
key_prefix: DEFAULT_KEY_PREFIX.to_string(),
}
}
pub fn with_prefix(backend: B, prefix: impl Into<String>) -> Self {
Self {
backend,
key_prefix: prefix.into(),
}
}
fn session_key(&self, state: &str) -> String {
format!("{}{}", self.key_prefix, state)
}
fn compute_ttl(data: &StateData) -> Option<u64> {
let remaining = data.expires_at - Utc::now();
let seconds = remaining.num_seconds();
if seconds > 0 {
Some(seconds as u64)
} else {
None
}
}
}
fn map_session_error(err: SessionError) -> SocialAuthError {
SocialAuthError::Storage(err.to_string())
}
#[async_trait]
impl<B: SessionBackend + 'static> StateStore for SessionStateStore<B> {
async fn store(&self, data: StateData) -> Result<(), SocialAuthError> {
let key = self.session_key(&data.state);
let ttl = Self::compute_ttl(&data);
self.backend
.save(&key, &data, ttl)
.await
.map_err(map_session_error)
}
async fn retrieve(&self, state: &str) -> Result<StateData, SocialAuthError> {
let key = self.session_key(state);
let data: Option<StateData> = self.backend.load(&key).await.map_err(map_session_error)?;
let data = data.ok_or(SocialAuthError::InvalidState)?;
if data.is_expired() {
let _ = self.backend.delete(&key).await;
return Err(SocialAuthError::InvalidState);
}
Ok(data)
}
async fn remove(&self, state: &str) -> Result<(), SocialAuthError> {
let key = self.session_key(state);
self.backend.delete(&key).await.map_err(map_session_error)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sessions::backends::InMemorySessionBackend;
use rstest::rstest;
#[rstest]
#[tokio::test]
async fn test_state_data_expiration() {
let data = StateData::new("test_state".to_string(), None, None);
let expired_data = StateData::with_ttl(
"expired_state".to_string(),
None,
None,
Duration::seconds(-1),
);
assert!(!data.is_expired());
assert!(expired_data.is_expired());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_store_retrieve() {
let store = InMemoryStateStore::new();
let data = StateData::new(
"test_state".to_string(),
Some("test_nonce".to_string()),
Some("test_verifier".to_string()),
);
store.store(data.clone()).await.unwrap();
let retrieved = store.retrieve("test_state").await.unwrap();
assert_eq!(retrieved.state, "test_state");
assert_eq!(retrieved.nonce, Some("test_nonce".to_string()));
assert_eq!(retrieved.code_verifier, Some("test_verifier".to_string()));
}
#[rstest]
#[tokio::test]
async fn test_in_memory_store_remove() {
let store = InMemoryStateStore::new();
let data = StateData::new("test_state".to_string(), None, None);
store.store(data).await.unwrap();
store.remove("test_state").await.unwrap();
let result = store.retrieve("test_state").await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_store_nonexistent() {
let store = InMemoryStateStore::new();
let result = store.retrieve("nonexistent").await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_in_memory_store_expired() {
let store = InMemoryStateStore::new();
let expired_data = StateData::with_ttl(
"expired_state".to_string(),
None,
None,
Duration::seconds(-1),
);
store.store(expired_data).await.unwrap();
let result = store.retrieve("expired_state").await;
assert!(result.is_err());
}
#[rstest]
#[tokio::test]
async fn test_cleanup_expired() {
let store = InMemoryStateStore::new();
let valid_data = StateData::new("valid".to_string(), None, None);
let expired_data =
StateData::with_ttl("expired".to_string(), None, None, Duration::seconds(-1));
store.store(valid_data).await.unwrap();
store.store(expired_data).await.unwrap();
let new_data = StateData::new("new".to_string(), None, None);
store.store(new_data).await.unwrap();
assert!(store.retrieve("valid").await.is_ok());
assert!(store.retrieve("new").await.is_ok());
assert!(store.retrieve("expired").await.is_err());
}
#[rstest]
#[tokio::test]
async fn test_session_state_store_store_and_retrieve() {
let backend = InMemorySessionBackend::new();
let store = SessionStateStore::new(backend);
let data = StateData::new(
"oauth_state_abc".to_string(),
Some("nonce_123".to_string()),
Some("verifier_xyz".to_string()),
);
store.store(data).await.unwrap();
let retrieved = store.retrieve("oauth_state_abc").await.unwrap();
assert_eq!(retrieved.state, "oauth_state_abc");
assert_eq!(retrieved.nonce, Some("nonce_123".to_string()));
assert_eq!(retrieved.code_verifier, Some("verifier_xyz".to_string()));
}
#[rstest]
#[tokio::test]
async fn test_session_state_store_retrieve_expired_state() {
let backend = InMemorySessionBackend::new();
let store = SessionStateStore::new(backend);
let expired_data = StateData::with_ttl(
"expired_state".to_string(),
None,
None,
Duration::seconds(-1),
);
let key = format!("{}{}", DEFAULT_KEY_PREFIX, "expired_state");
store
.backend
.save(&key, &expired_data, Some(300))
.await
.unwrap();
let result = store.retrieve("expired_state").await;
assert!(matches!(result, Err(SocialAuthError::InvalidState)));
}
#[rstest]
#[tokio::test]
async fn test_session_state_store_retrieve_non_existent() {
let backend = InMemorySessionBackend::new();
let store = SessionStateStore::new(backend);
let result = store.retrieve("non_existent_state").await;
assert!(matches!(result, Err(SocialAuthError::InvalidState)));
}
#[rstest]
#[tokio::test]
async fn test_session_state_store_delete() {
let backend = InMemorySessionBackend::new();
let store = SessionStateStore::new(backend);
let data = StateData::new("state_to_delete".to_string(), None, None);
store.store(data).await.unwrap();
store.remove("state_to_delete").await.unwrap();
let result = store.retrieve("state_to_delete").await;
assert!(matches!(result, Err(SocialAuthError::InvalidState)));
}
#[rstest]
#[tokio::test]
async fn test_session_state_store_custom_key_prefix() {
let backend = InMemorySessionBackend::new();
let custom_prefix = "custom_prefix:";
let store = SessionStateStore::with_prefix(backend.clone(), custom_prefix);
let data = StateData::new("prefixed_state".to_string(), None, None);
store.store(data).await.unwrap();
let exists_with_custom_prefix: bool = backend
.exists("custom_prefix:prefixed_state")
.await
.unwrap();
let exists_with_default_prefix: bool = backend
.exists("_social_auth_state:prefixed_state")
.await
.unwrap();
assert!(exists_with_custom_prefix);
assert!(!exists_with_default_prefix);
}
}