Skip to main content

llm/testing/
fake_credential_store.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use crate::oauth::OAuthError;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage};
6
7#[derive(Default)]
8pub struct FakeOAuthCredentialStore {
9    credentials: Mutex<HashMap<String, OAuthCredential>>,
10}
11
12impl FakeOAuthCredentialStore {
13    pub fn new() -> Self {
14        Self { credentials: Mutex::new(HashMap::new()) }
15    }
16
17    pub fn with_credential(self, server_id: &str, credential: OAuthCredential) -> Self {
18        self.credentials.lock().unwrap().insert(server_id.to_string(), credential);
19        self
20    }
21}
22
23impl OAuthCredentialStorage for FakeOAuthCredentialStore {
24    async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
25        Ok(self.credentials.lock().unwrap().get(server_id).cloned())
26    }
27
28    async fn save_credential(&self, server_id: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
29        self.credentials.lock().unwrap().insert(server_id.to_string(), credential);
30        Ok(())
31    }
32
33    fn has_credential(&self, server_id: &str) -> bool {
34        self.credentials.lock().unwrap().contains_key(server_id)
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41
42    #[test]
43    fn load_returns_none_when_empty() {
44        let store = FakeOAuthCredentialStore::new();
45        let result = tokio::runtime::Runtime::new().unwrap().block_on(store.load_credential("unknown"));
46        assert!(result.unwrap().is_none());
47    }
48
49    #[tokio::test]
50    async fn save_then_load_round_trips() {
51        let store = FakeOAuthCredentialStore::new();
52        let cred = OAuthCredential {
53            client_id: "client_1".to_string(),
54            access_token: "tok_abc".to_string(),
55            refresh_token: Some("ref_xyz".to_string()),
56            expires_at: Some(9_999_999_999_999),
57        };
58
59        store.save_credential("my-server", cred.clone()).await.unwrap();
60
61        let loaded = store.load_credential("my-server").await.unwrap().expect("should find saved credential");
62        assert_eq!(loaded.client_id, "client_1");
63        assert_eq!(loaded.access_token, "tok_abc");
64        assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
65    }
66
67    #[test]
68    fn has_credential_reflects_state() {
69        let store = FakeOAuthCredentialStore::new().with_credential(
70            "present",
71            OAuthCredential {
72                client_id: "c".to_string(),
73                access_token: "t".to_string(),
74                refresh_token: None,
75                expires_at: None,
76            },
77        );
78
79        assert!(store.has_credential("present"));
80        assert!(!store.has_credential("absent"));
81    }
82}