1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use async_trait::async_trait;
5
6use crate::credential::{OAuthCredential, OAuthCredentialStorage};
7use crate::error::OAuthError;
8
9#[derive(Default)]
10pub struct FakeOAuthCredentialStore {
11 credentials: Mutex<HashMap<String, OAuthCredential>>,
12}
13
14impl FakeOAuthCredentialStore {
15 pub fn new() -> Self {
16 Self { credentials: Mutex::new(HashMap::new()) }
17 }
18
19 pub fn with_credential(self, key: &str, credential: OAuthCredential) -> Self {
20 self.credentials.lock().unwrap().insert(key.to_string(), credential);
21 self
22 }
23}
24
25#[async_trait]
26impl OAuthCredentialStorage for FakeOAuthCredentialStore {
27 async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
28 Ok(self.credentials.lock().unwrap().get(server_id).cloned())
29 }
30
31 async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
32 self.credentials.lock().unwrap().insert(key.to_string(), credential);
33 Ok(())
34 }
35
36 async fn delete_credential(&self, key: &str) -> Result<(), OAuthError> {
37 self.credentials.lock().unwrap().remove(key);
38 Ok(())
39 }
40
41 fn has_credential(&self, key: &str) -> bool {
42 self.credentials.lock().unwrap().contains_key(key)
43 }
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49
50 #[tokio::test]
51 async fn load_returns_none_when_empty() {
52 let store = FakeOAuthCredentialStore::new();
53 let result = store.load_credential("unknown").await;
54 assert!(result.unwrap().is_none());
55 }
56
57 #[tokio::test]
58 async fn save_then_load_round_trips() {
59 let store = FakeOAuthCredentialStore::new();
60 let cred = OAuthCredential {
61 client_id: "client_1".to_string(),
62 access_token: "tok_abc".to_string(),
63 refresh_token: Some("ref_xyz".to_string()),
64 expires_at: Some(9_999_999_999_999),
65 granted_scopes: Vec::new(),
66 };
67
68 store.save_credential("my-server", cred.clone()).await.unwrap();
69
70 let loaded = store.load_credential("my-server").await.unwrap().expect("should find saved credential");
71 assert_eq!(loaded.client_id, "client_1");
72 assert_eq!(loaded.access_token, "tok_abc");
73 assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
74 }
75
76 #[tokio::test]
77 async fn delete_removes_credential() {
78 let store = FakeOAuthCredentialStore::new();
79 let cred = OAuthCredential {
80 client_id: "c".to_string(),
81 access_token: "t".to_string(),
82 refresh_token: None,
83 expires_at: None,
84 granted_scopes: Vec::new(),
85 };
86 store.save_credential("x", cred).await.unwrap();
87 assert!(store.has_credential("x"));
88
89 store.delete_credential("x").await.unwrap();
90 assert!(!store.has_credential("x"));
91 }
92
93 #[test]
94 fn has_credential_reflects_state() {
95 let store = FakeOAuthCredentialStore::new().with_credential(
96 "present",
97 OAuthCredential {
98 client_id: "c".to_string(),
99 access_token: "t".to_string(),
100 refresh_token: None,
101 expires_at: None,
102 granted_scopes: Vec::new(),
103 },
104 );
105
106 assert!(store.has_credential("present"));
107 assert!(!store.has_credential("absent"));
108 }
109}