1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13use tokio::sync::Mutex;
14
15use synaptic_core::SynapticError;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct McpOAuthConfig {
24 pub client_id: String,
26 #[serde(default)]
28 pub client_secret: Option<String>,
29 pub token_url: String,
31 #[serde(default)]
33 pub authorize_url: Option<String>,
34 #[serde(default)]
36 pub scopes: Vec<String>,
37 #[serde(default = "default_pkce")]
39 pub pkce: bool,
40}
41
42fn default_pkce() -> bool {
43 true
44}
45
46#[derive(Debug, Deserialize)]
51struct TokenResponse {
52 access_token: String,
53 #[serde(default)]
54 expires_in: Option<u64>,
55 #[serde(default)]
56 refresh_token: Option<String>,
57}
58
59#[derive(Debug, Clone)]
64struct CachedToken {
65 access_token: String,
66 expires_at: Instant,
67 refresh_token: Option<String>,
68}
69
70pub fn generate_code_verifier(seed: &str) -> String {
80 let now = std::time::SystemTime::now()
81 .duration_since(std::time::UNIX_EPOCH)
82 .unwrap_or_default();
83 let input = format!("{}{}{}", seed, now.as_nanos(), std::process::id());
84 let hash = Sha256::digest(input.as_bytes());
85 base64_url_encode(&hash)
86}
87
88pub fn generate_code_challenge(verifier: &str) -> String {
92 let hash = Sha256::digest(verifier.as_bytes());
93 base64_url_encode(&hash)
94}
95
96fn base64_url_encode(data: &[u8]) -> String {
98 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
99 use base64::Engine;
100 URL_SAFE_NO_PAD.encode(data)
101}
102
103fn url_encode(s: &str) -> String {
105 let mut result = String::with_capacity(s.len());
106 for b in s.bytes() {
107 match b {
108 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
109 result.push(b as char);
110 }
111 b' ' => result.push('+'),
112 _ => {
113 result.push('%');
114 result.push_str(&format!("{:02X}", b));
115 }
116 }
117 }
118 result
119}
120
121pub struct OAuthTokenManager {
130 config: McpOAuthConfig,
131 client: reqwest::Client,
132 cached: Arc<Mutex<Option<CachedToken>>>,
133}
134
135impl OAuthTokenManager {
136 pub fn new(config: McpOAuthConfig) -> Self {
138 Self {
139 config,
140 client: reqwest::Client::new(),
141 cached: Arc::new(Mutex::new(None)),
142 }
143 }
144
145 pub async fn get_token(&self) -> Result<String, SynapticError> {
147 let mut guard = self.cached.lock().await;
148
149 if let Some(ref cached) = *guard {
151 if Instant::now() < cached.expires_at {
152 return Ok(cached.access_token.clone());
153 }
154
155 if let Some(ref rt) = cached.refresh_token {
157 match self.refresh(rt).await {
158 Ok(new_token) => {
159 *guard = Some(new_token.clone());
160 return Ok(new_token.access_token);
161 }
162 Err(e) => {
163 tracing::warn!(
164 "OAuth refresh failed, falling back to client_credentials: {}",
165 e
166 );
167 }
168 }
169 }
170 }
171
172 let token = self.client_credentials().await?;
174 let access_token = token.access_token.clone();
175 *guard = Some(token);
176 Ok(access_token)
177 }
178
179 async fn client_credentials(&self) -> Result<CachedToken, SynapticError> {
181 let mut params: HashMap<String, String> = HashMap::new();
182 params.insert("grant_type".to_string(), "client_credentials".to_string());
183 params.insert("client_id".to_string(), self.config.client_id.clone());
184
185 if let Some(ref secret) = self.config.client_secret {
186 params.insert("client_secret".to_string(), secret.clone());
187 }
188
189 if !self.config.scopes.is_empty() {
190 params.insert("scope".to_string(), self.config.scopes.join(" "));
191 }
192
193 if self.config.pkce {
195 let verifier = generate_code_verifier(&self.config.token_url);
196 let challenge = generate_code_challenge(&verifier);
197 params.insert("code_verifier".to_string(), verifier);
198 params.insert("code_challenge".to_string(), challenge);
199 params.insert("code_challenge_method".to_string(), "S256".to_string());
200 }
201
202 self.exchange_token(¶ms).await
203 }
204
205 async fn refresh(&self, refresh_token: &str) -> Result<CachedToken, SynapticError> {
207 let mut params: HashMap<String, String> = HashMap::new();
208 params.insert("grant_type".to_string(), "refresh_token".to_string());
209 params.insert("refresh_token".to_string(), refresh_token.to_string());
210 params.insert("client_id".to_string(), self.config.client_id.clone());
211
212 if let Some(ref secret) = self.config.client_secret {
213 params.insert("client_secret".to_string(), secret.clone());
214 }
215
216 self.exchange_token(¶ms).await
217 }
218
219 async fn exchange_token(
222 &self,
223 params: &HashMap<String, String>,
224 ) -> Result<CachedToken, SynapticError> {
225 let body = params
227 .iter()
228 .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
229 .collect::<Vec<_>>()
230 .join("&");
231
232 let resp = self
233 .client
234 .post(&self.config.token_url)
235 .header("Content-Type", "application/x-www-form-urlencoded")
236 .body(body)
237 .send()
238 .await
239 .map_err(|e| SynapticError::Mcp(format!("OAuth token request failed: {}", e)))?;
240
241 if !resp.status().is_success() {
242 let status = resp.status();
243 let body = resp.text().await.unwrap_or_default();
244 return Err(SynapticError::Mcp(format!(
245 "OAuth token endpoint returned {}: {}",
246 status, body
247 )));
248 }
249
250 let token_resp: TokenResponse = resp.json().await.map_err(|e| {
251 SynapticError::Mcp(format!("Failed to parse OAuth token response: {}", e))
252 })?;
253
254 let expires_in_secs = token_resp.expires_in.unwrap_or(3600);
256 let safety_margin = 30;
257 let effective_ttl = expires_in_secs.saturating_sub(safety_margin);
258
259 Ok(CachedToken {
260 access_token: token_resp.access_token,
261 expires_at: Instant::now() + Duration::from_secs(effective_ttl),
262 refresh_token: token_resp.refresh_token,
263 })
264 }
265}
266
267#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn code_verifier_length() {
277 let verifier = generate_code_verifier("test-seed");
278 assert!(
280 verifier.len() >= 43,
281 "code verifier must be >= 43 chars, got {}",
282 verifier.len()
283 );
284 }
285
286 #[test]
287 fn code_challenge_is_base64url() {
288 let verifier = generate_code_verifier("test-seed");
289 let challenge = generate_code_challenge(&verifier);
290
291 assert!(!challenge.contains('+'), "challenge must not contain '+'");
293 assert!(!challenge.contains('/'), "challenge must not contain '/'");
294 assert!(!challenge.contains('='), "challenge must not contain '='");
295
296 assert!(!challenge.is_empty());
298 }
299
300 #[test]
301 fn oauth_config_default_pkce() {
302 let json = serde_json::json!({
303 "client_id": "my-client",
304 "token_url": "https://auth.example.com/token"
305 });
306 let config: McpOAuthConfig = serde_json::from_value(json).unwrap();
307 assert!(config.pkce, "pkce should default to true");
308 assert!(config.client_secret.is_none());
309 assert!(config.authorize_url.is_none());
310 assert!(config.scopes.is_empty());
311 }
312
313 #[test]
314 fn oauth_config_full_roundtrip() {
315 let config = McpOAuthConfig {
316 client_id: "cid".to_string(),
317 client_secret: Some("secret".to_string()),
318 token_url: "https://auth.example.com/token".to_string(),
319 authorize_url: Some("https://auth.example.com/authorize".to_string()),
320 scopes: vec!["read".to_string(), "write".to_string()],
321 pkce: false,
322 };
323 let json = serde_json::to_value(&config).unwrap();
324 let deserialized: McpOAuthConfig = serde_json::from_value(json).unwrap();
325 assert_eq!(deserialized.client_id, "cid");
326 assert_eq!(deserialized.client_secret.as_deref(), Some("secret"));
327 assert!(!deserialized.pkce);
328 assert_eq!(deserialized.scopes, vec!["read", "write"]);
329 }
330
331 #[test]
332 fn code_challenge_deterministic_for_same_input() {
333 let challenge1 = generate_code_challenge("same-verifier");
334 let challenge2 = generate_code_challenge("same-verifier");
335 assert_eq!(challenge1, challenge2);
336 }
337
338 #[test]
339 fn code_challenge_differs_for_different_input() {
340 let challenge1 = generate_code_challenge("verifier-a");
341 let challenge2 = generate_code_challenge("verifier-b");
342 assert_ne!(challenge1, challenge2);
343 }
344}