Skip to main content

llm/oauth/
credential_store.rs

1use async_trait::async_trait;
2use keyring::Entry;
3use oauth2::{AccessToken, RefreshToken, TokenResponse};
4use rmcp::transport::auth::{
5    AuthError, CredentialStore, OAuthTokenResponse, StoredCredentials, VendorExtraTokenFields,
6};
7use serde::{Deserialize, Serialize};
8use std::future::Future;
9use std::time::Duration;
10
11use super::OAuthError;
12
13const KEYCHAIN_SERVICE: &str = "aether-oauth-v1";
14
15/// Credential for an OAuth provider.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OAuthCredential {
18    pub client_id: String,
19    pub access_token: String,
20    pub refresh_token: Option<String>,
21    /// Unix timestamp in milliseconds when the token expires.
22    pub expires_at: Option<u64>,
23}
24
25/// Trait for loading and saving OAuth credentials, keyed by server/provider ID.
26///
27/// The default implementation (`OAuthCredentialStore`) uses the OS keychain.
28/// Tests can use an in-memory fake to avoid keychain popups.
29pub trait OAuthCredentialStorage: Send + Sync {
30    fn load_credential(
31        &self,
32        server_id: &str,
33    ) -> impl Future<Output = Result<Option<OAuthCredential>, OAuthError>> + Send;
34
35    fn save_credential(
36        &self,
37        server_id: &str,
38        credential: OAuthCredential,
39    ) -> impl Future<Output = Result<(), OAuthError>> + Send;
40
41    fn has_credential(&self, server_id: &str) -> bool;
42}
43
44/// OAuth credential store that persists credentials in the OS keychain
45/// and directly implements rmcp's `CredentialStore` trait.
46///
47/// Each server/provider ID maps to its own keychain entry.
48#[derive(Clone, Default)]
49pub struct OAuthCredentialStore {
50    server_id: String,
51}
52
53impl OAuthCredentialStore {
54    /// Create a new store for the given server/provider ID.
55    pub fn new(server_id: &str) -> Self {
56        Self { server_id: server_id.to_string() }
57    }
58
59    /// Load the raw `OAuthCredential` for this store's server ID.
60    pub async fn load_credential(&self) -> Result<Option<OAuthCredential>, OAuthError> {
61        let store = self.clone();
62        spawn_blocking(move || store.load_sync()).await
63    }
64
65    /// Save a raw `OAuthCredential` directly, keyed by this store's server ID.
66    pub async fn save_credential(&self, credential: OAuthCredential) -> Result<(), OAuthError> {
67        let store = self.clone();
68        spawn_blocking(move || store.save_sync(&credential)).await
69    }
70
71    /// Check synchronously whether credentials exist for a given server ID.
72    pub fn has_credential(server_id: &str) -> bool {
73        keychain_entry(server_id).ok().and_then(|e| e.get_secret().ok()).is_some()
74    }
75
76    fn load_sync(&self) -> Result<Option<OAuthCredential>, OAuthError> {
77        load_from_keychain(&self.server_id)
78    }
79
80    fn save_sync(&self, credential: &OAuthCredential) -> Result<(), OAuthError> {
81        save_to_keychain(&self.server_id, credential)
82    }
83
84    fn delete_sync(&self) -> Result<(), OAuthError> {
85        let entry = keychain_entry(&self.server_id)?;
86        match entry.delete_credential() {
87            Ok(()) | Err(keyring::Error::NoEntry) => Ok(()),
88            Err(err) => Err(err.into()),
89        }
90    }
91}
92
93impl OAuthCredentialStorage for OAuthCredentialStore {
94    async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
95        let server_id = server_id.to_string();
96        spawn_blocking(move || load_from_keychain(&server_id)).await
97    }
98
99    async fn save_credential(&self, server_id: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
100        let server_id = server_id.to_string();
101        spawn_blocking(move || save_to_keychain(&server_id, &credential)).await
102    }
103
104    fn has_credential(&self, server_id: &str) -> bool {
105        keychain_entry(server_id).ok().and_then(|e| e.get_secret().ok()).is_some()
106    }
107}
108
109#[async_trait]
110impl CredentialStore for OAuthCredentialStore {
111    async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
112        let cred = self.load_credential().await.map_err(|e| AuthError::InternalError(e.to_string()))?;
113
114        Ok(cred.map(|c| {
115            let token_response = build_token_response(&c);
116            build_stored_credentials(&c.client_id, Some(&token_response))
117        }))
118    }
119
120    async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
121        let token = credentials
122            .token_response
123            .ok_or_else(|| AuthError::InternalError("No token response to save".to_string()))?;
124
125        let expires_at = token.expires_in().map(|duration| {
126            let now_ms = u64::try_from(
127                std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
128            )
129            .unwrap_or(u64::MAX);
130            let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
131            now_ms.saturating_add(duration_ms)
132        });
133
134        let credential = OAuthCredential {
135            client_id: credentials.client_id,
136            access_token: token.access_token().secret().clone(),
137            refresh_token: token.refresh_token().map(|t| t.secret().clone()),
138            expires_at,
139        };
140
141        self.save_credential(credential).await.map_err(|e| AuthError::InternalError(e.to_string()))
142    }
143
144    async fn clear(&self) -> Result<(), AuthError> {
145        let store = self.clone();
146        spawn_blocking(move || store.delete_sync()).await.map_err(|e| AuthError::InternalError(e.to_string()))
147    }
148}
149
150fn keychain_entry(server_id: &str) -> Result<Entry, OAuthError> {
151    Ok(Entry::new(KEYCHAIN_SERVICE, server_id)?)
152}
153
154fn load_from_keychain(server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
155    let entry = keychain_entry(server_id)?;
156    match entry.get_secret() {
157        Ok(blob) => serde_json::from_slice(&blob)
158            .map(Some)
159            .map_err(|err| OAuthError::CredentialStore(format!("invalid credential: {err}"))),
160        Err(keyring::Error::NoEntry) => Ok(None),
161        Err(err) => Err(err.into()),
162    }
163}
164
165fn save_to_keychain(server_id: &str, credential: &OAuthCredential) -> Result<(), OAuthError> {
166    let entry = keychain_entry(server_id)?;
167    let blob = serde_json::to_vec(credential)
168        .map_err(|err| OAuthError::CredentialStore(format!("failed to serialize credential: {err}")))?;
169    entry.set_secret(&blob)?;
170    Ok(())
171}
172
173async fn spawn_blocking<T: Send + 'static>(
174    f: impl FnOnce() -> Result<T, OAuthError> + Send + 'static,
175) -> Result<T, OAuthError> {
176    tokio::task::spawn_blocking(f)
177        .await
178        .map_err(|err| OAuthError::CredentialStore(format!("credential task failed: {err}")))?
179}
180
181/// Construct a `StoredCredentials` via serde deserialization.
182///
183/// The upstream struct is `#[non_exhaustive]` with no constructor, so this is
184/// the only way to build one from outside the crate.
185fn build_stored_credentials(client_id: &str, token_response: Option<&OAuthTokenResponse>) -> StoredCredentials {
186    // granted_scopes and token_received_at have #[serde(default)] so we can omit them.
187    serde_json::from_value(serde_json::json!({
188        "client_id": client_id,
189        "token_response": token_response,
190    }))
191    .expect("StoredCredentials deserialization from known-good fields cannot fail")
192}
193
194fn build_token_response(cred: &OAuthCredential) -> OAuthTokenResponse {
195    let mut response = OAuthTokenResponse::new(
196        AccessToken::new(cred.access_token.clone()),
197        oauth2::basic::BasicTokenType::Bearer,
198        VendorExtraTokenFields::default(),
199    );
200
201    if let Some(ref refresh) = cred.refresh_token {
202        response.set_refresh_token(Some(RefreshToken::new(refresh.clone())));
203    }
204
205    if let Some(expires_at_millis) = cred.expires_at {
206        let now_millis = u64::try_from(
207            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
208        )
209        .unwrap_or(u64::MAX);
210
211        if expires_at_millis > now_millis {
212            response.set_expires_in(Some(&Duration::from_millis(expires_at_millis - now_millis)));
213        }
214    }
215
216    response
217}