use async_trait::async_trait;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthCredential {
pub credential_type: String,
pub token: Option<String>,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
pub metadata: serde_json::Value,
}
#[derive(Debug, thiserror::Error)]
pub enum CredentialError {
#[error("Credential not found")]
NotFound,
#[error("{0}")]
Storage(String),
}
#[async_trait]
pub trait CredentialService: Send + Sync {
async fn load_credential(&self, key: &str) -> Result<Option<AuthCredential>, CredentialError>;
async fn save_credential(
&self,
key: &str,
credential: AuthCredential,
) -> Result<(), CredentialError>;
async fn delete_credential(&self, key: &str) -> Result<(), CredentialError>;
}
pub struct InMemoryCredentialService {
inner: DashMap<String, AuthCredential>,
}
impl InMemoryCredentialService {
pub fn new() -> Self {
Self {
inner: DashMap::new(),
}
}
}
impl Default for InMemoryCredentialService {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl CredentialService for InMemoryCredentialService {
async fn load_credential(&self, key: &str) -> Result<Option<AuthCredential>, CredentialError> {
Ok(self.inner.get(key).map(|entry| entry.value().clone()))
}
async fn save_credential(
&self,
key: &str,
credential: AuthCredential,
) -> Result<(), CredentialError> {
self.inner.insert(key.to_string(), credential);
Ok(())
}
async fn delete_credential(&self, key: &str) -> Result<(), CredentialError> {
self.inner.remove(key);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_credential() -> AuthCredential {
AuthCredential {
credential_type: "oauth2".to_string(),
token: Some("access-token-123".to_string()),
refresh_token: Some("refresh-456".to_string()),
expires_at: Some(1700000000),
metadata: serde_json::json!({"scope": "read write"}),
}
}
#[tokio::test]
async fn save_and_load() {
let svc = InMemoryCredentialService::new();
let cred = sample_credential();
svc.save_credential("my-key", cred.clone()).await.unwrap();
let loaded = svc.load_credential("my-key").await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.credential_type, "oauth2");
assert_eq!(loaded.token, Some("access-token-123".to_string()));
}
#[tokio::test]
async fn load_nonexistent_returns_none() {
let svc = InMemoryCredentialService::new();
let loaded = svc.load_credential("missing").await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn delete_credential() {
let svc = InMemoryCredentialService::new();
svc.save_credential("key", sample_credential())
.await
.unwrap();
svc.delete_credential("key").await.unwrap();
let loaded = svc.load_credential("key").await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn overwrite_credential() {
let svc = InMemoryCredentialService::new();
svc.save_credential("key", sample_credential())
.await
.unwrap();
let updated = AuthCredential {
credential_type: "api_key".to_string(),
token: Some("new-token".to_string()),
refresh_token: None,
expires_at: None,
metadata: serde_json::json!({}),
};
svc.save_credential("key", updated).await.unwrap();
let loaded = svc.load_credential("key").await.unwrap().unwrap();
assert_eq!(loaded.credential_type, "api_key");
assert_eq!(loaded.token, Some("new-token".to_string()));
}
#[test]
fn credential_service_is_object_safe() {
fn _assert(_: &dyn CredentialService) {}
}
#[test]
fn auth_credential_serde_roundtrip() {
let cred = sample_credential();
let json = serde_json::to_string(&cred).unwrap();
let parsed: AuthCredential = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.credential_type, "oauth2");
assert_eq!(parsed.token, Some("access-token-123".to_string()));
}
}