Skip to main content

synaptic_mcp/
oauth.rs

1//! OAuth 2.1 + PKCE support for MCP server connections.
2//!
3//! Provides [`McpOAuthConfig`] for configuring OAuth client credentials flow
4//! and [`OAuthTokenManager`] for automatic token acquisition, caching, and
5//! refresh with PKCE (S256) support.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13use tokio::sync::Mutex;
14
15use synaptic_core::SynapticError;
16
17// ---------------------------------------------------------------------------
18// Config
19// ---------------------------------------------------------------------------
20
21/// OAuth 2.1 configuration for an MCP server connection.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct McpOAuthConfig {
24    /// OAuth client identifier.
25    pub client_id: String,
26    /// OAuth client secret (omit for public clients).
27    #[serde(default)]
28    pub client_secret: Option<String>,
29    /// Token endpoint URL.
30    pub token_url: String,
31    /// Authorization endpoint URL (for authorization code flows).
32    #[serde(default)]
33    pub authorize_url: Option<String>,
34    /// Requested scopes.
35    #[serde(default)]
36    pub scopes: Vec<String>,
37    /// Whether to use PKCE (S256). Defaults to `true`.
38    #[serde(default = "default_pkce")]
39    pub pkce: bool,
40}
41
42fn default_pkce() -> bool {
43    true
44}
45
46// ---------------------------------------------------------------------------
47// Token response (from the OAuth server)
48// ---------------------------------------------------------------------------
49
50#[derive(Debug, Deserialize)]
51struct TokenResponse {
52    access_token: String,
53    #[serde(default)]
54    expires_in: Option<u64>,
55    #[serde(default)]
56    refresh_token: Option<String>,
57}
58
59// ---------------------------------------------------------------------------
60// Cached token
61// ---------------------------------------------------------------------------
62
63#[derive(Debug, Clone)]
64struct CachedToken {
65    access_token: String,
66    expires_at: Instant,
67    refresh_token: Option<String>,
68}
69
70// ---------------------------------------------------------------------------
71// PKCE helpers
72// ---------------------------------------------------------------------------
73
74/// Generate a code verifier using a deterministic hash-based approach.
75///
76/// Uses the current timestamp and token_url as entropy source, hashed through
77/// SHA-256, then URL-safe base64 encoded (no padding). The result is always
78/// 43 characters, satisfying the PKCE spec (43..128).
79pub fn generate_code_verifier(seed: &str) -> String {
80    let now = std::time::SystemTime::now()
81        .duration_since(std::time::UNIX_EPOCH)
82        .unwrap_or_default();
83    let input = format!("{}{}{}", seed, now.as_nanos(), std::process::id());
84    let hash = Sha256::digest(input.as_bytes());
85    base64_url_encode(&hash)
86}
87
88/// Compute the S256 code challenge from a code verifier.
89///
90/// `code_challenge = BASE64URL(SHA256(code_verifier))`
91pub fn generate_code_challenge(verifier: &str) -> String {
92    let hash = Sha256::digest(verifier.as_bytes());
93    base64_url_encode(&hash)
94}
95
96/// URL-safe base64 encoding without padding.
97fn base64_url_encode(data: &[u8]) -> String {
98    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
99    use base64::Engine;
100    URL_SAFE_NO_PAD.encode(data)
101}
102
103/// Minimal percent-encoding for application/x-www-form-urlencoded values.
104fn url_encode(s: &str) -> String {
105    let mut result = String::with_capacity(s.len());
106    for b in s.bytes() {
107        match b {
108            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
109                result.push(b as char);
110            }
111            b' ' => result.push('+'),
112            _ => {
113                result.push('%');
114                result.push_str(&format!("{:02X}", b));
115            }
116        }
117    }
118    result
119}
120
121// ---------------------------------------------------------------------------
122// OAuthTokenManager
123// ---------------------------------------------------------------------------
124
125/// Manages OAuth 2.1 token lifecycle: acquisition, caching, and refresh.
126///
127/// Thread-safe — the internal state is behind a `Mutex` so multiple concurrent
128/// tool calls share and reuse the same cached token.
129pub struct OAuthTokenManager {
130    config: McpOAuthConfig,
131    client: reqwest::Client,
132    cached: Arc<Mutex<Option<CachedToken>>>,
133}
134
135impl OAuthTokenManager {
136    /// Create a new token manager for the given OAuth configuration.
137    pub fn new(config: McpOAuthConfig) -> Self {
138        Self {
139            config,
140            client: reqwest::Client::new(),
141            cached: Arc::new(Mutex::new(None)),
142        }
143    }
144
145    /// Return a valid access token, refreshing or re-acquiring as needed.
146    pub async fn get_token(&self) -> Result<String, SynapticError> {
147        let mut guard = self.cached.lock().await;
148
149        // Return cached token if still valid.
150        if let Some(ref cached) = *guard {
151            if Instant::now() < cached.expires_at {
152                return Ok(cached.access_token.clone());
153            }
154
155            // Try refreshing if we have a refresh token.
156            if let Some(ref rt) = cached.refresh_token {
157                match self.refresh(rt).await {
158                    Ok(new_token) => {
159                        *guard = Some(new_token.clone());
160                        return Ok(new_token.access_token);
161                    }
162                    Err(e) => {
163                        tracing::warn!(
164                            "OAuth refresh failed, falling back to client_credentials: {}",
165                            e
166                        );
167                    }
168                }
169            }
170        }
171
172        // Fresh client_credentials grant.
173        let token = self.client_credentials().await?;
174        let access_token = token.access_token.clone();
175        *guard = Some(token);
176        Ok(access_token)
177    }
178
179    /// Perform a `client_credentials` grant.
180    async fn client_credentials(&self) -> Result<CachedToken, SynapticError> {
181        let mut params: HashMap<String, String> = HashMap::new();
182        params.insert("grant_type".to_string(), "client_credentials".to_string());
183        params.insert("client_id".to_string(), self.config.client_id.clone());
184
185        if let Some(ref secret) = self.config.client_secret {
186            params.insert("client_secret".to_string(), secret.clone());
187        }
188
189        if !self.config.scopes.is_empty() {
190            params.insert("scope".to_string(), self.config.scopes.join(" "));
191        }
192
193        // PKCE: include code_verifier and code_challenge for client_credentials.
194        if self.config.pkce {
195            let verifier = generate_code_verifier(&self.config.token_url);
196            let challenge = generate_code_challenge(&verifier);
197            params.insert("code_verifier".to_string(), verifier);
198            params.insert("code_challenge".to_string(), challenge);
199            params.insert("code_challenge_method".to_string(), "S256".to_string());
200        }
201
202        self.exchange_token(&params).await
203    }
204
205    /// Perform a `refresh_token` grant.
206    async fn refresh(&self, refresh_token: &str) -> Result<CachedToken, SynapticError> {
207        let mut params: HashMap<String, String> = HashMap::new();
208        params.insert("grant_type".to_string(), "refresh_token".to_string());
209        params.insert("refresh_token".to_string(), refresh_token.to_string());
210        params.insert("client_id".to_string(), self.config.client_id.clone());
211
212        if let Some(ref secret) = self.config.client_secret {
213            params.insert("client_secret".to_string(), secret.clone());
214        }
215
216        self.exchange_token(&params).await
217    }
218
219    /// Shared POST logic: sends form-encoded params to `token_url` and parses
220    /// the JSON response into a [`CachedToken`].
221    async fn exchange_token(
222        &self,
223        params: &HashMap<String, String>,
224    ) -> Result<CachedToken, SynapticError> {
225        // Build URL-encoded form body manually to avoid reqwest `form` feature.
226        let body = params
227            .iter()
228            .map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
229            .collect::<Vec<_>>()
230            .join("&");
231
232        let resp = self
233            .client
234            .post(&self.config.token_url)
235            .header("Content-Type", "application/x-www-form-urlencoded")
236            .body(body)
237            .send()
238            .await
239            .map_err(|e| SynapticError::Mcp(format!("OAuth token request failed: {}", e)))?;
240
241        if !resp.status().is_success() {
242            let status = resp.status();
243            let body = resp.text().await.unwrap_or_default();
244            return Err(SynapticError::Mcp(format!(
245                "OAuth token endpoint returned {}: {}",
246                status, body
247            )));
248        }
249
250        let token_resp: TokenResponse = resp.json().await.map_err(|e| {
251            SynapticError::Mcp(format!("Failed to parse OAuth token response: {}", e))
252        })?;
253
254        // Default to 1 hour if expires_in is not provided, with 30s safety margin.
255        let expires_in_secs = token_resp.expires_in.unwrap_or(3600);
256        let safety_margin = 30;
257        let effective_ttl = expires_in_secs.saturating_sub(safety_margin);
258
259        Ok(CachedToken {
260            access_token: token_resp.access_token,
261            expires_at: Instant::now() + Duration::from_secs(effective_ttl),
262            refresh_token: token_resp.refresh_token,
263        })
264    }
265}
266
267// ---------------------------------------------------------------------------
268// Tests
269// ---------------------------------------------------------------------------
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn code_verifier_length() {
277        let verifier = generate_code_verifier("test-seed");
278        // SHA-256 output is 32 bytes; URL-safe base64 no-pad of 32 bytes = 43 chars.
279        assert!(
280            verifier.len() >= 43,
281            "code verifier must be >= 43 chars, got {}",
282            verifier.len()
283        );
284    }
285
286    #[test]
287    fn code_challenge_is_base64url() {
288        let verifier = generate_code_verifier("test-seed");
289        let challenge = generate_code_challenge(&verifier);
290
291        // Must not contain standard base64 chars that are NOT URL-safe.
292        assert!(!challenge.contains('+'), "challenge must not contain '+'");
293        assert!(!challenge.contains('/'), "challenge must not contain '/'");
294        assert!(!challenge.contains('='), "challenge must not contain '='");
295
296        // Must be non-empty.
297        assert!(!challenge.is_empty());
298    }
299
300    #[test]
301    fn oauth_config_default_pkce() {
302        let json = serde_json::json!({
303            "client_id": "my-client",
304            "token_url": "https://auth.example.com/token"
305        });
306        let config: McpOAuthConfig = serde_json::from_value(json).unwrap();
307        assert!(config.pkce, "pkce should default to true");
308        assert!(config.client_secret.is_none());
309        assert!(config.authorize_url.is_none());
310        assert!(config.scopes.is_empty());
311    }
312
313    #[test]
314    fn oauth_config_full_roundtrip() {
315        let config = McpOAuthConfig {
316            client_id: "cid".to_string(),
317            client_secret: Some("secret".to_string()),
318            token_url: "https://auth.example.com/token".to_string(),
319            authorize_url: Some("https://auth.example.com/authorize".to_string()),
320            scopes: vec!["read".to_string(), "write".to_string()],
321            pkce: false,
322        };
323        let json = serde_json::to_value(&config).unwrap();
324        let deserialized: McpOAuthConfig = serde_json::from_value(json).unwrap();
325        assert_eq!(deserialized.client_id, "cid");
326        assert_eq!(deserialized.client_secret.as_deref(), Some("secret"));
327        assert!(!deserialized.pkce);
328        assert_eq!(deserialized.scopes, vec!["read", "write"]);
329    }
330
331    #[test]
332    fn code_challenge_deterministic_for_same_input() {
333        let challenge1 = generate_code_challenge("same-verifier");
334        let challenge2 = generate_code_challenge("same-verifier");
335        assert_eq!(challenge1, challenge2);
336    }
337
338    #[test]
339    fn code_challenge_differs_for_different_input() {
340        let challenge1 = generate_code_challenge("verifier-a");
341        let challenge2 = generate_code_challenge("verifier-b");
342        assert_ne!(challenge1, challenge2);
343    }
344}