1use async_trait::async_trait;
4use dashmap::DashMap;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct AuthCredential {
10 pub credential_type: String,
12 pub token: Option<String>,
14 pub refresh_token: Option<String>,
16 pub expires_at: Option<u64>,
18 pub metadata: serde_json::Value,
20}
21
22#[derive(Debug, thiserror::Error)]
24pub enum CredentialError {
25 #[error("Credential not found")]
27 NotFound,
28 #[error("{0}")]
30 Storage(String),
31}
32
33#[async_trait]
35pub trait CredentialService: Send + Sync {
36 async fn load_credential(&self, key: &str) -> Result<Option<AuthCredential>, CredentialError>;
38
39 async fn save_credential(
41 &self,
42 key: &str,
43 credential: AuthCredential,
44 ) -> Result<(), CredentialError>;
45
46 async fn delete_credential(&self, key: &str) -> Result<(), CredentialError>;
48}
49
50pub struct InMemoryCredentialService {
52 inner: DashMap<String, AuthCredential>,
53}
54
55impl InMemoryCredentialService {
56 pub fn new() -> Self {
58 Self {
59 inner: DashMap::new(),
60 }
61 }
62}
63
64impl Default for InMemoryCredentialService {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70#[async_trait]
71impl CredentialService for InMemoryCredentialService {
72 async fn load_credential(&self, key: &str) -> Result<Option<AuthCredential>, CredentialError> {
73 Ok(self.inner.get(key).map(|entry| entry.value().clone()))
74 }
75
76 async fn save_credential(
77 &self,
78 key: &str,
79 credential: AuthCredential,
80 ) -> Result<(), CredentialError> {
81 self.inner.insert(key.to_string(), credential);
82 Ok(())
83 }
84
85 async fn delete_credential(&self, key: &str) -> Result<(), CredentialError> {
86 self.inner.remove(key);
87 Ok(())
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 fn sample_credential() -> AuthCredential {
96 AuthCredential {
97 credential_type: "oauth2".to_string(),
98 token: Some("access-token-123".to_string()),
99 refresh_token: Some("refresh-456".to_string()),
100 expires_at: Some(1700000000),
101 metadata: serde_json::json!({"scope": "read write"}),
102 }
103 }
104
105 #[tokio::test]
106 async fn save_and_load() {
107 let svc = InMemoryCredentialService::new();
108 let cred = sample_credential();
109 svc.save_credential("my-key", cred.clone()).await.unwrap();
110
111 let loaded = svc.load_credential("my-key").await.unwrap();
112 assert!(loaded.is_some());
113 let loaded = loaded.unwrap();
114 assert_eq!(loaded.credential_type, "oauth2");
115 assert_eq!(loaded.token, Some("access-token-123".to_string()));
116 }
117
118 #[tokio::test]
119 async fn load_nonexistent_returns_none() {
120 let svc = InMemoryCredentialService::new();
121 let loaded = svc.load_credential("missing").await.unwrap();
122 assert!(loaded.is_none());
123 }
124
125 #[tokio::test]
126 async fn delete_credential() {
127 let svc = InMemoryCredentialService::new();
128 svc.save_credential("key", sample_credential())
129 .await
130 .unwrap();
131 svc.delete_credential("key").await.unwrap();
132
133 let loaded = svc.load_credential("key").await.unwrap();
134 assert!(loaded.is_none());
135 }
136
137 #[tokio::test]
138 async fn overwrite_credential() {
139 let svc = InMemoryCredentialService::new();
140 svc.save_credential("key", sample_credential())
141 .await
142 .unwrap();
143
144 let updated = AuthCredential {
145 credential_type: "api_key".to_string(),
146 token: Some("new-token".to_string()),
147 refresh_token: None,
148 expires_at: None,
149 metadata: serde_json::json!({}),
150 };
151 svc.save_credential("key", updated).await.unwrap();
152
153 let loaded = svc.load_credential("key").await.unwrap().unwrap();
154 assert_eq!(loaded.credential_type, "api_key");
155 assert_eq!(loaded.token, Some("new-token".to_string()));
156 }
157
158 #[test]
159 fn credential_service_is_object_safe() {
160 fn _assert(_: &dyn CredentialService) {}
161 }
162
163 #[test]
164 fn auth_credential_serde_roundtrip() {
165 let cred = sample_credential();
166 let json = serde_json::to_string(&cred).unwrap();
167 let parsed: AuthCredential = serde_json::from_str(&json).unwrap();
168 assert_eq!(parsed.credential_type, "oauth2");
169 assert_eq!(parsed.token, Some("access-token-123".to_string()));
170 }
171}