use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use anyhow::Result;
use chrono::Utc;
use lru::LruCache;
use crate::errors::OAuthStorageError;
use crate::storage::OAuthRequestStorage;
use crate::workflow::OAuthRequest;
#[derive(Clone)]
pub struct LruOAuthRequestStorage {
cache: Arc<Mutex<LruCache<String, OAuthRequest>>>,
}
impl LruOAuthRequestStorage {
pub fn new(capacity: NonZeroUsize) -> Self {
Self {
cache: Arc::new(Mutex::new(LruCache::new(capacity))),
}
}
pub fn len(&self) -> usize {
self.cache.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.cache.lock().unwrap().is_empty()
}
pub fn capacity(&self) -> NonZeroUsize {
self.cache.lock().unwrap().cap()
}
pub fn clear(&self) {
self.cache.lock().unwrap().clear();
}
}
#[async_trait::async_trait]
impl OAuthRequestStorage for LruOAuthRequestStorage {
async fn get_oauth_request_by_state(&self, state: &str) -> Result<Option<OAuthRequest>> {
let mut cache = self
.cache
.lock()
.map_err(|e| OAuthStorageError::CacheLockFailedGet {
details: e.to_string(),
})?;
if let Some(request) = cache.get(state) {
let now = Utc::now();
if request.expires_at > now {
Ok(Some(request.clone()))
} else {
cache.pop(state);
Ok(None)
}
} else {
Ok(None)
}
}
async fn delete_oauth_request_by_state(&self, state: &str) -> Result<()> {
let mut cache =
self.cache
.lock()
.map_err(|e| OAuthStorageError::CacheLockFailedDelete {
details: e.to_string(),
})?;
cache.pop(state);
Ok(())
}
async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
let mut cache =
self.cache
.lock()
.map_err(|e| OAuthStorageError::CacheLockFailedInsert {
details: e.to_string(),
})?;
cache.put(request.oauth_state.clone(), request);
Ok(())
}
async fn clear_expired_oauth_requests(&self) -> Result<u64> {
let mut cache =
self.cache
.lock()
.map_err(|e| OAuthStorageError::CacheLockFailedCleanup {
details: e.to_string(),
})?;
let now = Utc::now();
let expired_keys: Vec<String> = cache
.iter()
.filter_map(|(key, request)| {
if request.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
for key in &expired_keys {
cache.pop(key);
}
Ok(expired_keys.len() as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{Duration, Utc};
use std::num::NonZeroUsize;
fn create_test_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
OAuthRequest {
oauth_state: state.to_string(),
issuer: issuer.to_string(),
authorization_server: issuer.to_string(),
nonce: format!("nonce-{}", state),
pkce_verifier: format!("verifier-{}", state),
signing_public_key: format!("pubkey-{}", state),
dpop_private_key: format!("privkey-{}", state),
created_at: Utc::now(),
expires_at: Utc::now() + Duration::minutes(10),
}
}
fn create_expired_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
OAuthRequest {
oauth_state: state.to_string(),
issuer: issuer.to_string(),
authorization_server: issuer.to_string(),
nonce: format!("nonce-{}", state),
pkce_verifier: format!("verifier-{}", state),
signing_public_key: format!("pubkey-{}", state),
dpop_private_key: format!("privkey-{}", state),
created_at: Utc::now() - Duration::minutes(20),
expires_at: Utc::now() - Duration::minutes(10), }
}
#[tokio::test]
async fn test_new_storage() {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap());
assert_eq!(storage.len(), 0);
assert!(storage.is_empty());
assert_eq!(storage.capacity().get(), 100);
}
#[tokio::test]
async fn test_basic_operations() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
let result = storage.get_oauth_request_by_state("unknown-state").await?;
assert_eq!(result, None);
let request =
create_test_oauth_request("test-state", "https://pds.example.com", "did:plc:test");
storage.insert_oauth_request(request.clone()).await?;
let result = storage.get_oauth_request_by_state("test-state").await?;
assert!(result.is_some());
assert_eq!(result.as_ref().unwrap().oauth_state, request.oauth_state);
assert_eq!(storage.len(), 1);
let updated_request = create_test_oauth_request(
"test-state",
"https://updated.example.com",
"did:plc:updated",
);
storage
.insert_oauth_request(updated_request.clone())
.await?;
let result = storage.get_oauth_request_by_state("test-state").await?;
assert!(result.is_some());
assert_eq!(result.as_ref().unwrap().issuer, updated_request.issuer);
assert_eq!(storage.len(), 1);
storage.delete_oauth_request_by_state("test-state").await?;
let result = storage.get_oauth_request_by_state("test-state").await?;
assert_eq!(result, None);
assert_eq!(storage.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_expiration_handling() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
let expired_request = create_expired_oauth_request(
"expired-state",
"https://pds.example.com",
"did:plc:expired",
);
storage.insert_oauth_request(expired_request).await?;
assert_eq!(storage.len(), 1);
let result = storage.get_oauth_request_by_state("expired-state").await?;
assert_eq!(result, None);
assert_eq!(storage.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_lru_eviction() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(2).unwrap());
let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
storage.insert_oauth_request(req1.clone()).await?;
storage.insert_oauth_request(req2).await?;
assert_eq!(storage.len(), 2);
let _ = storage.get_oauth_request_by_state("state1").await?;
let req3 = create_test_oauth_request("state3", "https://pds.example.com", "did:plc:user3");
storage.insert_oauth_request(req3.clone()).await?;
assert_eq!(storage.len(), 2);
let result1 = storage.get_oauth_request_by_state("state1").await?;
assert!(result1.is_some());
assert_eq!(result1.unwrap().oauth_state, req1.oauth_state);
let result3 = storage.get_oauth_request_by_state("state3").await?;
assert!(result3.is_some());
assert_eq!(result3.unwrap().oauth_state, req3.oauth_state);
assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
Ok(())
}
#[tokio::test]
async fn test_clear() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
let req1 = create_test_oauth_request("state1", "https://pds.example.com", "did:plc:user1");
let req2 = create_test_oauth_request("state2", "https://pds.example.com", "did:plc:user2");
storage.insert_oauth_request(req1).await?;
storage.insert_oauth_request(req2).await?;
assert_eq!(storage.len(), 2);
storage.clear();
assert_eq!(storage.len(), 0);
assert!(storage.is_empty());
assert_eq!(storage.get_oauth_request_by_state("state1").await?, None);
assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
Ok(())
}
#[tokio::test]
async fn test_clear_expired_requests() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
let expired1 =
create_expired_oauth_request("expired1", "https://pds.example.com", "did:plc:expired1");
let expired2 =
create_expired_oauth_request("expired2", "https://pds.example.com", "did:plc:expired2");
let valid1 =
create_test_oauth_request("valid1", "https://pds.example.com", "did:plc:valid1");
let valid2 =
create_test_oauth_request("valid2", "https://pds.example.com", "did:plc:valid2");
storage.insert_oauth_request(expired1).await?;
storage.insert_oauth_request(valid1).await?;
storage.insert_oauth_request(expired2).await?;
storage.insert_oauth_request(valid2).await?;
assert_eq!(storage.len(), 4);
let removed_count = storage.clear_expired_oauth_requests().await?;
assert_eq!(removed_count, 2); assert_eq!(storage.len(), 2);
assert!(
storage
.get_oauth_request_by_state("valid1")
.await?
.is_some()
);
assert!(
storage
.get_oauth_request_by_state("valid2")
.await?
.is_some()
);
assert_eq!(storage.get_oauth_request_by_state("expired1").await?, None);
assert_eq!(storage.get_oauth_request_by_state("expired2").await?, None);
Ok(())
}
#[tokio::test]
async fn test_delete_nonexistent() -> Result<()> {
let storage = LruOAuthRequestStorage::new(NonZeroUsize::new(10).unwrap());
storage
.delete_oauth_request_by_state("non-existent-state")
.await?;
assert_eq!(storage.len(), 0);
Ok(())
}
#[tokio::test]
async fn test_thread_safety() -> Result<()> {
let storage = Arc::new(LruOAuthRequestStorage::new(NonZeroUsize::new(100).unwrap()));
let mut handles = Vec::new();
for i in 0..10 {
let storage_clone = Arc::clone(&storage);
let handle = tokio::spawn(async move {
let state = format!("state{}", i);
let issuer = format!("https://pds{}.example.com", i);
let did = format!("did:plc:user{}", i);
let request = create_test_oauth_request(&state, &issuer, &did);
storage_clone.insert_oauth_request(request.clone()).await?;
let result = storage_clone.get_oauth_request_by_state(&state).await?;
assert!(result.is_some());
assert_eq!(result.unwrap().oauth_state, request.oauth_state);
storage_clone.delete_oauth_request_by_state(&state).await?;
let result = storage_clone.get_oauth_request_by_state(&state).await?;
assert_eq!(result, None);
Ok::<(), anyhow::Error>(())
});
handles.push(handle);
}
for handle in handles {
handle.await??;
}
assert_eq!(storage.len(), 0);
Ok(())
}
}