llm/testing/
fake_credential_store.rs1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use crate::oauth::OAuthError;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage};
6
7#[derive(Default)]
8pub struct FakeOAuthCredentialStore {
9 credentials: Mutex<HashMap<String, OAuthCredential>>,
10}
11
12impl FakeOAuthCredentialStore {
13 pub fn new() -> Self {
14 Self { credentials: Mutex::new(HashMap::new()) }
15 }
16
17 pub fn with_credential(self, server_id: &str, credential: OAuthCredential) -> Self {
18 self.credentials.lock().unwrap().insert(server_id.to_string(), credential);
19 self
20 }
21}
22
23impl OAuthCredentialStorage for FakeOAuthCredentialStore {
24 async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
25 Ok(self.credentials.lock().unwrap().get(server_id).cloned())
26 }
27
28 async fn save_credential(&self, server_id: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
29 self.credentials.lock().unwrap().insert(server_id.to_string(), credential);
30 Ok(())
31 }
32
33 fn has_credential(&self, server_id: &str) -> bool {
34 self.credentials.lock().unwrap().contains_key(server_id)
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41
42 #[test]
43 fn load_returns_none_when_empty() {
44 let store = FakeOAuthCredentialStore::new();
45 let result = tokio::runtime::Runtime::new().unwrap().block_on(store.load_credential("unknown"));
46 assert!(result.unwrap().is_none());
47 }
48
49 #[tokio::test]
50 async fn save_then_load_round_trips() {
51 let store = FakeOAuthCredentialStore::new();
52 let cred = OAuthCredential {
53 client_id: "client_1".to_string(),
54 access_token: "tok_abc".to_string(),
55 refresh_token: Some("ref_xyz".to_string()),
56 expires_at: Some(9_999_999_999_999),
57 };
58
59 store.save_credential("my-server", cred.clone()).await.unwrap();
60
61 let loaded = store.load_credential("my-server").await.unwrap().expect("should find saved credential");
62 assert_eq!(loaded.client_id, "client_1");
63 assert_eq!(loaded.access_token, "tok_abc");
64 assert_eq!(loaded.refresh_token.as_deref(), Some("ref_xyz"));
65 }
66
67 #[test]
68 fn has_credential_reflects_state() {
69 let store = FakeOAuthCredentialStore::new().with_credential(
70 "present",
71 OAuthCredential {
72 client_id: "c".to_string(),
73 access_token: "t".to_string(),
74 refresh_token: None,
75 expires_at: None,
76 },
77 );
78
79 assert!(store.has_credential("present"));
80 assert!(!store.has_credential("absent"));
81 }
82}