use std::collections::HashMap;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::{timeout, Duration};
const DEFAULT_LOCK_TIMEOUT_MS: u64 = 500;
#[derive(Debug, thiserror::Error)]
pub enum TokenStoreError {
#[error("Token store lock timeout after {}ms", ms)]
LockTimeout {
ms: u64,
},
}
pub type TokenStoreResult<T> = std::result::Result<T, TokenStoreError>;
#[derive(Default)]
pub struct TokenStore {
tokens: RwLock<HashMap<String, TokenInfo>>,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct TokenInfo {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: chrono::DateTime<chrono::Utc>,
pub scopes: Vec<String>,
pub user_id: Option<String>,
pub user_email: Option<String>,
}
impl TokenStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn store_token(&self, key: String, token: TokenInfo) -> TokenStoreResult<()> {
let mut guard = self.acquire_write_lock().await?;
guard.insert(key, token);
Ok(())
}
pub async fn get_token(&self, key: &str) -> TokenStoreResult<Option<TokenInfo>> {
let guard = self.acquire_read_lock().await?;
Ok(guard.get(key).cloned())
}
pub async fn remove_token(&self, key: &str) -> TokenStoreResult<()> {
let mut guard = self.acquire_write_lock().await?;
guard.remove(key);
Ok(())
}
pub async fn cleanup_expired(&self) -> TokenStoreResult<()> {
let now = chrono::Utc::now();
let mut guard = self.acquire_write_lock().await?;
guard.retain(|_, token| token.expires_at > now);
Ok(())
}
async fn acquire_read_lock(
&self,
) -> TokenStoreResult<RwLockReadGuard<'_, HashMap<String, TokenInfo>>> {
match timeout(
Duration::from_millis(DEFAULT_LOCK_TIMEOUT_MS),
self.tokens.read(),
)
.await
{
Ok(guard) => Ok(guard),
Err(_elapsed) => Err(TokenStoreError::LockTimeout {
ms: DEFAULT_LOCK_TIMEOUT_MS,
}),
}
}
async fn acquire_write_lock(
&self,
) -> TokenStoreResult<RwLockWriteGuard<'_, HashMap<String, TokenInfo>>> {
match timeout(
Duration::from_millis(DEFAULT_LOCK_TIMEOUT_MS),
self.tokens.write(),
)
.await
{
Ok(guard) => Ok(guard),
Err(_elapsed) => Err(TokenStoreError::LockTimeout {
ms: DEFAULT_LOCK_TIMEOUT_MS,
}),
}
}
}