use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use serde::{Deserialize, Serialize};
pub type StorageResult<T> = Result<T, StorageError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StorageError {
NotFound(String),
Expired(String),
Backend(String),
Serialization(String),
Crypto(String),
}
impl std::fmt::Display for StorageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound(key) => write!(f, "Not found: {}", key),
Self::Expired(key) => write!(f, "Expired: {}", key),
Self::Backend(msg) => write!(f, "Storage error: {}", msg),
Self::Serialization(msg) => write!(f, "Serialization error: {}", msg),
Self::Crypto(msg) => write!(f, "Crypto error: {}", msg),
}
}
}
impl std::error::Error for StorageError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationCodeGrant {
pub client_id: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub subject: String,
pub expires_at: u64,
pub nonce: Option<String>,
pub state: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessTokenData {
pub subject: String,
pub client_id: String,
pub scopes: Vec<String>,
pub expires_at: u64,
pub issued_at: u64,
pub refresh_token_hash: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshTokenData {
pub subject: String,
pub client_id: String,
pub scopes: Vec<String>,
pub expires_at: u64,
pub issued_at: u64,
pub generation: u32,
pub family_id: String,
pub used: bool,
}
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub trait TokenStore: Send + Sync + 'static {
fn store_authorization_code(
&self,
code_hash: &str,
grant: &AuthorizationCodeGrant,
) -> BoxFuture<'_, StorageResult<()>>;
fn consume_authorization_code(
&self,
code_hash: &str,
) -> BoxFuture<'_, StorageResult<AuthorizationCodeGrant>>;
fn store_access_token(
&self,
token_hash: &str,
data: &AccessTokenData,
) -> BoxFuture<'_, StorageResult<()>>;
fn get_access_token(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<AccessTokenData>>;
fn revoke_access_token(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<()>>;
fn store_refresh_token(
&self,
token_hash: &str,
data: &RefreshTokenData,
) -> BoxFuture<'_, StorageResult<()>>;
fn get_refresh_token(&self, token_hash: &str)
-> BoxFuture<'_, StorageResult<RefreshTokenData>>;
fn mark_refresh_token_used(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<()>>;
fn revoke_refresh_token_family(&self, family_id: &str) -> BoxFuture<'_, StorageResult<()>>;
fn cleanup_expired(&self) -> BoxFuture<'_, StorageResult<u64>> {
Box::pin(async { Ok(0) })
}
}
#[deprecated(
note = "Use DurableObjectTokenStore for production. MemoryTokenStore loses all tokens \
on Worker restart (~15-30 minutes). Only use for testing or development."
)]
#[derive(Debug, Default)]
pub struct MemoryTokenStore {
authorization_codes: RwLock<HashMap<String, AuthorizationCodeGrant>>,
access_tokens: RwLock<HashMap<String, AccessTokenData>>,
refresh_tokens: RwLock<HashMap<String, RefreshTokenData>>,
}
impl MemoryTokenStore {
#[allow(deprecated)]
pub fn new() -> Self {
#[cfg(target_arch = "wasm32")]
web_sys::console::warn_1(
&"⚠️ Using MemoryTokenStore - tokens will be lost on Worker restart (~15-30 minutes). \
Use DurableObjectTokenStore for production deployments."
.into(),
);
Self::default()
}
fn now_secs() -> u64 {
(js_sys::Date::now() / 1000.0) as u64
}
}
impl TokenStore for MemoryTokenStore {
fn store_authorization_code(
&self,
code_hash: &str,
grant: &AuthorizationCodeGrant,
) -> BoxFuture<'_, StorageResult<()>> {
let code_hash = code_hash.to_string();
let grant = grant.clone();
Box::pin(async move {
let mut codes = self
.authorization_codes
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
codes.insert(code_hash, grant);
Ok(())
})
}
fn consume_authorization_code(
&self,
code_hash: &str,
) -> BoxFuture<'_, StorageResult<AuthorizationCodeGrant>> {
let code_hash = code_hash.to_string();
Box::pin(async move {
let mut codes = self
.authorization_codes
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let grant = codes
.remove(&code_hash)
.ok_or_else(|| StorageError::NotFound(code_hash.clone()))?;
if Self::now_secs() > grant.expires_at {
return Err(StorageError::Expired(code_hash));
}
Ok(grant)
})
}
fn store_access_token(
&self,
token_hash: &str,
data: &AccessTokenData,
) -> BoxFuture<'_, StorageResult<()>> {
let token_hash = token_hash.to_string();
let data = data.clone();
Box::pin(async move {
let mut tokens = self
.access_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
tokens.insert(token_hash, data);
Ok(())
})
}
fn get_access_token(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<AccessTokenData>> {
let token_hash = token_hash.to_string();
Box::pin(async move {
let tokens = self
.access_tokens
.read()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let data = tokens
.get(&token_hash)
.ok_or_else(|| StorageError::NotFound(token_hash.clone()))?
.clone();
if Self::now_secs() > data.expires_at {
return Err(StorageError::Expired(token_hash));
}
Ok(data)
})
}
fn revoke_access_token(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<()>> {
let token_hash = token_hash.to_string();
Box::pin(async move {
let mut tokens = self
.access_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
tokens.remove(&token_hash);
Ok(())
})
}
fn store_refresh_token(
&self,
token_hash: &str,
data: &RefreshTokenData,
) -> BoxFuture<'_, StorageResult<()>> {
let token_hash = token_hash.to_string();
let data = data.clone();
Box::pin(async move {
let mut tokens = self
.refresh_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
tokens.insert(token_hash, data);
Ok(())
})
}
fn get_refresh_token(
&self,
token_hash: &str,
) -> BoxFuture<'_, StorageResult<RefreshTokenData>> {
let token_hash = token_hash.to_string();
Box::pin(async move {
let tokens = self
.refresh_tokens
.read()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let data = tokens
.get(&token_hash)
.ok_or_else(|| StorageError::NotFound(token_hash.clone()))?
.clone();
if Self::now_secs() > data.expires_at {
return Err(StorageError::Expired(token_hash));
}
Ok(data)
})
}
fn mark_refresh_token_used(&self, token_hash: &str) -> BoxFuture<'_, StorageResult<()>> {
let token_hash = token_hash.to_string();
Box::pin(async move {
let mut tokens = self
.refresh_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
if let Some(data) = tokens.get_mut(&token_hash) {
data.used = true;
}
Ok(())
})
}
fn revoke_refresh_token_family(&self, family_id: &str) -> BoxFuture<'_, StorageResult<()>> {
let family_id = family_id.to_string();
Box::pin(async move {
let mut tokens = self
.refresh_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
tokens.retain(|_, v| v.family_id != family_id);
Ok(())
})
}
fn cleanup_expired(&self) -> BoxFuture<'_, StorageResult<u64>> {
Box::pin(async move {
let now = Self::now_secs();
let mut count = 0u64;
{
let mut codes = self
.authorization_codes
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let before = codes.len();
codes.retain(|_, v| v.expires_at > now);
count += (before - codes.len()) as u64;
}
{
let mut tokens = self
.access_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let before = tokens.len();
tokens.retain(|_, v| v.expires_at > now);
count += (before - tokens.len()) as u64;
}
{
let mut tokens = self
.refresh_tokens
.write()
.map_err(|e| StorageError::Backend(format!("Lock error: {}", e)))?;
let before = tokens.len();
tokens.retain(|_, v| v.expires_at > now);
count += (before - tokens.len()) as u64;
}
Ok(count)
})
}
}
pub type SharedTokenStore = Arc<dyn TokenStore>;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_store_authorization_code() {
let store = MemoryTokenStore::new();
let now = (js_sys::Date::now() / 1000.0) as u64;
let grant = AuthorizationCodeGrant {
client_id: "test-client".to_string(),
redirect_uri: "https://example.com/callback".to_string(),
scopes: vec!["read".to_string()],
code_challenge: Some("challenge".to_string()),
code_challenge_method: Some("S256".to_string()),
subject: "user123".to_string(),
expires_at: now + 300, nonce: None,
state: Some("state123".to_string()),
};
store
.store_authorization_code("code_hash_123", &grant)
.await
.unwrap();
let retrieved = store
.consume_authorization_code("code_hash_123")
.await
.unwrap();
assert_eq!(retrieved.client_id, "test-client");
assert_eq!(retrieved.subject, "user123");
let result = store.consume_authorization_code("code_hash_123").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_memory_store_refresh_token_family_revocation() {
let store = MemoryTokenStore::new();
let now = (js_sys::Date::now() / 1000.0) as u64;
for i in 0..3 {
let data = RefreshTokenData {
subject: "user123".to_string(),
client_id: "test-client".to_string(),
scopes: vec!["read".to_string()],
expires_at: now + 3600,
issued_at: now,
generation: i,
family_id: "family-abc".to_string(),
used: false,
};
store
.store_refresh_token(&format!("token_{}", i), &data)
.await
.unwrap();
}
let other_data = RefreshTokenData {
subject: "user456".to_string(),
client_id: "test-client".to_string(),
scopes: vec!["read".to_string()],
expires_at: now + 3600,
issued_at: now,
generation: 0,
family_id: "family-xyz".to_string(),
used: false,
};
store
.store_refresh_token("token_other", &other_data)
.await
.unwrap();
store
.revoke_refresh_token_family("family-abc")
.await
.unwrap();
assert!(store.get_refresh_token("token_0").await.is_err());
assert!(store.get_refresh_token("token_1").await.is_err());
assert!(store.get_refresh_token("token_2").await.is_err());
assert!(store.get_refresh_token("token_other").await.is_ok());
}
}