use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::Mutex;
use synaptic_core::SynapticError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpOAuthConfig {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
pub token_url: String,
#[serde(default)]
pub authorize_url: Option<String>,
#[serde(default)]
pub scopes: Vec<String>,
#[serde(default = "default_pkce")]
pub pkce: bool,
}
fn default_pkce() -> bool {
true
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
refresh_token: Option<String>,
}
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: Instant,
refresh_token: Option<String>,
}
pub fn generate_code_verifier(seed: &str) -> String {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let input = format!("{}{}{}", seed, now.as_nanos(), std::process::id());
let hash = Sha256::digest(input.as_bytes());
base64_url_encode(&hash)
}
pub fn generate_code_challenge(verifier: &str) -> String {
let hash = Sha256::digest(verifier.as_bytes());
base64_url_encode(&hash)
}
fn base64_url_encode(data: &[u8]) -> String {
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
URL_SAFE_NO_PAD.encode(data)
}
fn url_encode(s: &str) -> String {
let mut result = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
result.push(b as char);
}
b' ' => result.push('+'),
_ => {
result.push('%');
result.push_str(&format!("{:02X}", b));
}
}
}
result
}
pub struct OAuthTokenManager {
config: McpOAuthConfig,
client: reqwest::Client,
cached: Arc<Mutex<Option<CachedToken>>>,
}
impl OAuthTokenManager {
pub fn new(config: McpOAuthConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
cached: Arc::new(Mutex::new(None)),
}
}
pub async fn get_token(&self) -> Result<String, SynapticError> {
let mut guard = self.cached.lock().await;
if let Some(ref cached) = *guard {
if Instant::now() < cached.expires_at {
return Ok(cached.access_token.clone());
}
if let Some(ref rt) = cached.refresh_token {
match self.refresh(rt).await {
Ok(new_token) => {
*guard = Some(new_token.clone());
return Ok(new_token.access_token);
}
Err(e) => {
tracing::warn!(
"OAuth refresh failed, falling back to client_credentials: {}",
e
);
}
}
}
}
let token = self.client_credentials().await?;
let access_token = token.access_token.clone();
*guard = Some(token);
Ok(access_token)
}
async fn client_credentials(&self) -> Result<CachedToken, SynapticError> {
let mut params: HashMap<String, String> = HashMap::new();
params.insert("grant_type".to_string(), "client_credentials".to_string());
params.insert("client_id".to_string(), self.config.client_id.clone());
if let Some(ref secret) = self.config.client_secret {
params.insert("client_secret".to_string(), secret.clone());
}
if !self.config.scopes.is_empty() {
params.insert("scope".to_string(), self.config.scopes.join(" "));
}
if self.config.pkce {
let verifier = generate_code_verifier(&self.config.token_url);
let challenge = generate_code_challenge(&verifier);
params.insert("code_verifier".to_string(), verifier);
params.insert("code_challenge".to_string(), challenge);
params.insert("code_challenge_method".to_string(), "S256".to_string());
}
self.exchange_token(¶ms).await
}
async fn refresh(&self, refresh_token: &str) -> Result<CachedToken, SynapticError> {
let mut params: HashMap<String, String> = HashMap::new();
params.insert("grant_type".to_string(), "refresh_token".to_string());
params.insert("refresh_token".to_string(), refresh_token.to_string());
params.insert("client_id".to_string(), self.config.client_id.clone());
if let Some(ref secret) = self.config.client_secret {
params.insert("client_secret".to_string(), secret.clone());
}
self.exchange_token(¶ms).await
}
async fn exchange_token(
&self,
params: &HashMap<String, String>,
) -> Result<CachedToken, SynapticError> {
let body = params
.iter()
.map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
.collect::<Vec<_>>()
.join("&");
let resp = self
.client
.post(&self.config.token_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.map_err(|e| SynapticError::Mcp(format!("OAuth token request failed: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(SynapticError::Mcp(format!(
"OAuth token endpoint returned {}: {}",
status, body
)));
}
let token_resp: TokenResponse = resp.json().await.map_err(|e| {
SynapticError::Mcp(format!("Failed to parse OAuth token response: {}", e))
})?;
let expires_in_secs = token_resp.expires_in.unwrap_or(3600);
let safety_margin = 30;
let effective_ttl = expires_in_secs.saturating_sub(safety_margin);
Ok(CachedToken {
access_token: token_resp.access_token,
expires_at: Instant::now() + Duration::from_secs(effective_ttl),
refresh_token: token_resp.refresh_token,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn code_verifier_length() {
let verifier = generate_code_verifier("test-seed");
assert!(
verifier.len() >= 43,
"code verifier must be >= 43 chars, got {}",
verifier.len()
);
}
#[test]
fn code_challenge_is_base64url() {
let verifier = generate_code_verifier("test-seed");
let challenge = generate_code_challenge(&verifier);
assert!(!challenge.contains('+'), "challenge must not contain '+'");
assert!(!challenge.contains('/'), "challenge must not contain '/'");
assert!(!challenge.contains('='), "challenge must not contain '='");
assert!(!challenge.is_empty());
}
#[test]
fn oauth_config_default_pkce() {
let json = serde_json::json!({
"client_id": "my-client",
"token_url": "https://auth.example.com/token"
});
let config: McpOAuthConfig = serde_json::from_value(json).unwrap();
assert!(config.pkce, "pkce should default to true");
assert!(config.client_secret.is_none());
assert!(config.authorize_url.is_none());
assert!(config.scopes.is_empty());
}
#[test]
fn oauth_config_full_roundtrip() {
let config = McpOAuthConfig {
client_id: "cid".to_string(),
client_secret: Some("secret".to_string()),
token_url: "https://auth.example.com/token".to_string(),
authorize_url: Some("https://auth.example.com/authorize".to_string()),
scopes: vec!["read".to_string(), "write".to_string()],
pkce: false,
};
let json = serde_json::to_value(&config).unwrap();
let deserialized: McpOAuthConfig = serde_json::from_value(json).unwrap();
assert_eq!(deserialized.client_id, "cid");
assert_eq!(deserialized.client_secret.as_deref(), Some("secret"));
assert!(!deserialized.pkce);
assert_eq!(deserialized.scopes, vec!["read", "write"]);
}
#[test]
fn code_challenge_deterministic_for_same_input() {
let challenge1 = generate_code_challenge("same-verifier");
let challenge2 = generate_code_challenge("same-verifier");
assert_eq!(challenge1, challenge2);
}
#[test]
fn code_challenge_differs_for_different_input() {
let challenge1 = generate_code_challenge("verifier-a");
let challenge2 = generate_code_challenge("verifier-b");
assert_ne!(challenge1, challenge2);
}
}