Skip to main content

aether_auth/
fake.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use async_trait::async_trait;
5
6use crate::credential::{OAuthCredential, OAuthCredentialStorage};
7use crate::error::OAuthError;
8
9#[derive(Default)]
10pub struct FakeOAuthCredentialStore {
11    credentials: Mutex<HashMap<String, OAuthCredential>>,
12}
13
14impl FakeOAuthCredentialStore {
15    pub fn new() -> Self {
16        Self { credentials: Mutex::new(HashMap::new()) }
17    }
18
19    pub fn with_credential(self, key: &str, credential: OAuthCredential) -> Self {
20        self.credentials.lock().unwrap().insert(key.to_string(), credential);
21        self
22    }
23}
24
25#[async_trait]
26impl OAuthCredentialStorage for FakeOAuthCredentialStore {
27    async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
28        Ok(self.credentials.lock().unwrap().get(server_id).cloned())
29    }
30
31    async fn save_credential(&self, key: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
32        self.credentials.lock().unwrap().insert(key.to_string(), credential);
33        Ok(())
34    }
35
36    async fn delete_credential(&self, key: &str) -> Result<(), OAuthError> {
37        self.credentials.lock().unwrap().remove(key);
38        Ok(())
39    }
40
41    fn has_credential(&self, key: &str) -> bool {
42        self.credentials.lock().unwrap().contains_key(key)
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    #[tokio::test]
51    async fn load_returns_none_when_empty() {
52        let store = FakeOAuthCredentialStore::new();
53        let result = store.load_credential("unknown").await;
54        assert!(result.unwrap().is_none());
55    }
56
57    #[tokio::test]
58    async fn save_then_load_round_trips() {
59        let store = FakeOAuthCredentialStore::new();
60        let cred = OAuthCredential {
61            client_id: "client_1".to_string(),
62            access_token: "tok_abc".to_string(),
63            refresh_token: Some("ref_xyz".to_string()),
64            expires_at: Some(9_999_999_999_999),
65            granted_scopes: Vec::new(),
66        };
67
68        store.save_credential("my-server", cred.clone()).await.unwrap();
69
70        let loaded = store.load_credential("my-server").await.unwrap().expect("should find saved credential");
71        assert_eq!(loaded.client_id, "client_1");
72        assert_eq!(loaded.access_token, "tok_abc");
73        assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
74    }
75
76    #[tokio::test]
77    async fn delete_removes_credential() {
78        let store = FakeOAuthCredentialStore::new();
79        let cred = OAuthCredential {
80            client_id: "c".to_string(),
81            access_token: "t".to_string(),
82            refresh_token: None,
83            expires_at: None,
84            granted_scopes: Vec::new(),
85        };
86        store.save_credential("x", cred).await.unwrap();
87        assert!(store.has_credential("x"));
88
89        store.delete_credential("x").await.unwrap();
90        assert!(!store.has_credential("x"));
91    }
92
93    #[test]
94    fn has_credential_reflects_state() {
95        let store = FakeOAuthCredentialStore::new().with_credential(
96            "present",
97            OAuthCredential {
98                client_id: "c".to_string(),
99                access_token: "t".to_string(),
100                refresh_token: None,
101                expires_at: None,
102                granted_scopes: Vec::new(),
103            },
104        );
105
106        assert!(store.has_credential("present"));
107        assert!(!store.has_credential("absent"));
108    }
109}