llm/oauth/
credential_store.rs1use async_trait::async_trait;
2use keyring::Entry;
3use oauth2::{AccessToken, RefreshToken, TokenResponse};
4use rmcp::transport::auth::{
5 AuthError, CredentialStore, OAuthTokenResponse, StoredCredentials, VendorExtraTokenFields,
6};
7use serde::{Deserialize, Serialize};
8use std::future::Future;
9use std::time::Duration;
10
11use super::OAuthError;
12
13const KEYCHAIN_SERVICE: &str = "aether-oauth-v1";
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OAuthCredential {
18 pub client_id: String,
19 pub access_token: String,
20 pub refresh_token: Option<String>,
21 pub expires_at: Option<u64>,
23}
24
25pub trait OAuthCredentialStorage: Send + Sync {
30 fn load_credential(
31 &self,
32 server_id: &str,
33 ) -> impl Future<Output = Result<Option<OAuthCredential>, OAuthError>> + Send;
34
35 fn save_credential(
36 &self,
37 server_id: &str,
38 credential: OAuthCredential,
39 ) -> impl Future<Output = Result<(), OAuthError>> + Send;
40
41 fn has_credential(&self, server_id: &str) -> bool;
42}
43
44#[derive(Clone, Default)]
49pub struct OAuthCredentialStore {
50 server_id: String,
51}
52
53impl OAuthCredentialStore {
54 pub fn new(server_id: &str) -> Self {
56 Self { server_id: server_id.to_string() }
57 }
58
59 pub async fn load_credential(&self) -> Result<Option<OAuthCredential>, OAuthError> {
61 let store = self.clone();
62 spawn_blocking(move || store.load_sync()).await
63 }
64
65 pub async fn save_credential(&self, credential: OAuthCredential) -> Result<(), OAuthError> {
67 let store = self.clone();
68 spawn_blocking(move || store.save_sync(&credential)).await
69 }
70
71 pub fn has_credential(server_id: &str) -> bool {
73 keychain_entry(server_id).ok().and_then(|e| e.get_secret().ok()).is_some()
74 }
75
76 fn load_sync(&self) -> Result<Option<OAuthCredential>, OAuthError> {
77 load_from_keychain(&self.server_id)
78 }
79
80 fn save_sync(&self, credential: &OAuthCredential) -> Result<(), OAuthError> {
81 save_to_keychain(&self.server_id, credential)
82 }
83
84 fn delete_sync(&self) -> Result<(), OAuthError> {
85 let entry = keychain_entry(&self.server_id)?;
86 match entry.delete_credential() {
87 Ok(()) | Err(keyring::Error::NoEntry) => Ok(()),
88 Err(err) => Err(err.into()),
89 }
90 }
91}
92
93impl OAuthCredentialStorage for OAuthCredentialStore {
94 async fn load_credential(&self, server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
95 let server_id = server_id.to_string();
96 spawn_blocking(move || load_from_keychain(&server_id)).await
97 }
98
99 async fn save_credential(&self, server_id: &str, credential: OAuthCredential) -> Result<(), OAuthError> {
100 let server_id = server_id.to_string();
101 spawn_blocking(move || save_to_keychain(&server_id, &credential)).await
102 }
103
104 fn has_credential(&self, server_id: &str) -> bool {
105 keychain_entry(server_id).ok().and_then(|e| e.get_secret().ok()).is_some()
106 }
107}
108
109#[async_trait]
110impl CredentialStore for OAuthCredentialStore {
111 async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
112 let cred = self.load_credential().await.map_err(|e| AuthError::InternalError(e.to_string()))?;
113
114 Ok(cred.map(|c| {
115 let token_response = build_token_response(&c);
116 build_stored_credentials(&c.client_id, Some(&token_response))
117 }))
118 }
119
120 async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
121 let token = credentials
122 .token_response
123 .ok_or_else(|| AuthError::InternalError("No token response to save".to_string()))?;
124
125 let expires_at = token.expires_in().map(|duration| {
126 let now_ms = u64::try_from(
127 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
128 )
129 .unwrap_or(u64::MAX);
130 let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
131 now_ms.saturating_add(duration_ms)
132 });
133
134 let credential = OAuthCredential {
135 client_id: credentials.client_id,
136 access_token: token.access_token().secret().clone(),
137 refresh_token: token.refresh_token().map(|t| t.secret().clone()),
138 expires_at,
139 };
140
141 self.save_credential(credential).await.map_err(|e| AuthError::InternalError(e.to_string()))
142 }
143
144 async fn clear(&self) -> Result<(), AuthError> {
145 let store = self.clone();
146 spawn_blocking(move || store.delete_sync()).await.map_err(|e| AuthError::InternalError(e.to_string()))
147 }
148}
149
150fn keychain_entry(server_id: &str) -> Result<Entry, OAuthError> {
151 Ok(Entry::new(KEYCHAIN_SERVICE, server_id)?)
152}
153
154fn load_from_keychain(server_id: &str) -> Result<Option<OAuthCredential>, OAuthError> {
155 let entry = keychain_entry(server_id)?;
156 match entry.get_secret() {
157 Ok(blob) => serde_json::from_slice(&blob)
158 .map(Some)
159 .map_err(|err| OAuthError::CredentialStore(format!("invalid credential: {err}"))),
160 Err(keyring::Error::NoEntry) => Ok(None),
161 Err(err) => Err(err.into()),
162 }
163}
164
165fn save_to_keychain(server_id: &str, credential: &OAuthCredential) -> Result<(), OAuthError> {
166 let entry = keychain_entry(server_id)?;
167 let blob = serde_json::to_vec(credential)
168 .map_err(|err| OAuthError::CredentialStore(format!("failed to serialize credential: {err}")))?;
169 entry.set_secret(&blob)?;
170 Ok(())
171}
172
173async fn spawn_blocking<T: Send + 'static>(
174 f: impl FnOnce() -> Result<T, OAuthError> + Send + 'static,
175) -> Result<T, OAuthError> {
176 tokio::task::spawn_blocking(f)
177 .await
178 .map_err(|err| OAuthError::CredentialStore(format!("credential task failed: {err}")))?
179}
180
181fn build_stored_credentials(client_id: &str, token_response: Option<&OAuthTokenResponse>) -> StoredCredentials {
186 serde_json::from_value(serde_json::json!({
188 "client_id": client_id,
189 "token_response": token_response,
190 }))
191 .expect("StoredCredentials deserialization from known-good fields cannot fail")
192}
193
194fn build_token_response(cred: &OAuthCredential) -> OAuthTokenResponse {
195 let mut response = OAuthTokenResponse::new(
196 AccessToken::new(cred.access_token.clone()),
197 oauth2::basic::BasicTokenType::Bearer,
198 VendorExtraTokenFields::default(),
199 );
200
201 if let Some(ref refresh) = cred.refresh_token {
202 response.set_refresh_token(Some(RefreshToken::new(refresh.clone())));
203 }
204
205 if let Some(expires_at_millis) = cred.expires_at {
206 let now_millis = u64::try_from(
207 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
208 )
209 .unwrap_or(u64::MAX);
210
211 if expires_at_millis > now_millis {
212 response.set_expires_in(Some(&Duration::from_millis(expires_at_millis - now_millis)));
213 }
214 }
215
216 response
217}