use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::time::sleep;
use url::Url;
use crate::client::auth::{OidcDiscoveryClient, TokenExchangeClient};
use crate::client::http_middleware::HttpMiddlewareChain;
use crate::client::oauth_middleware::{BearerToken, OAuthClientMiddleware};
use crate::error::{Error, Result};
use crate::server::auth::oauth2::OidcDiscoveryMetadata;
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub issuer: Option<String>,
pub mcp_server_url: Option<String>,
pub client_id: String,
pub scopes: Vec<String>,
pub cache_file: Option<PathBuf>,
pub redirect_port: u16,
}
#[derive(Debug, Serialize, Deserialize)]
struct TokenCache {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
scopes: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceAuthResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: Option<u64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
token_type: String,
}
#[derive(Debug)]
pub struct OAuthHelper {
config: OAuthConfig,
client: reqwest::Client,
}
impl OAuthHelper {
pub fn new(config: OAuthConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::internal(format!("Failed to create HTTP client: {e}")))?;
Ok(Self { config, client })
}
fn extract_base_url(mcp_url: &str) -> Result<String> {
let parsed = Url::parse(mcp_url)
.map_err(|e| Error::internal(format!("Invalid MCP server URL: {e}")))?;
let mut base = format!("{}://{}", parsed.scheme(), parsed.host_str().unwrap_or(""));
if let Some(port) = parsed.port() {
let is_default_port = (parsed.scheme() == "https" && port == 443)
|| (parsed.scheme() == "http" && port == 80);
if !is_default_port {
base.push_str(&format!(":{}", port));
}
}
Ok(base)
}
async fn discover_metadata(&self, mcp_url: &str) -> Result<OidcDiscoveryMetadata> {
let base_url = Self::extract_base_url(mcp_url)?;
tracing::info!("Discovering OAuth configuration from {}...", base_url);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(&base_url).await {
Ok(metadata) => {
tracing::info!("OAuth discovery successful");
tracing::debug!("Issuer: {}", metadata.issuer);
if let Some(ref device_endpoint) = metadata.device_authorization_endpoint {
tracing::debug!("Device endpoint: {}", device_endpoint);
}
Ok(metadata)
},
Err(e) => Err(Error::internal(format!(
"Failed to discover OAuth configuration at {}: {}\n\
\n\
Please provide --oauth-issuer explicitly, or ensure the server\n\
exposes OAuth metadata at {}/.well-known/openid-configuration",
base_url, e, base_url
))),
}
}
async fn get_metadata(&self) -> Result<OidcDiscoveryMetadata> {
if let Some(ref mcp_url) = self.config.mcp_server_url {
self.discover_metadata(mcp_url).await
} else if let Some(ref issuer) = self.config.issuer {
tracing::info!("Discovering OAuth configuration from {}...", issuer);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(issuer).await {
Ok(metadata) => {
tracing::info!("OAuth discovery successful");
Ok(metadata)
},
Err(e) => Err(Error::internal(format!(
"Failed to discover OAuth configuration from issuer {}: {}\n\
\n\
Please ensure the issuer URL exposes OAuth metadata at\n\
{}/.well-known/openid-configuration",
issuer, e, issuer
))),
}
} else {
Err(Error::internal(
"Either oauth_issuer or mcp_server_url must be provided for OAuth authentication"
.to_string(),
))
}
}
pub async fn get_access_token(&self) -> Result<String> {
if let Some(ref cache_file) = self.config.cache_file {
if let Ok(cached) = self.load_cached_token(cache_file).await {
if let Some(expires_at) = cached.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if now < expires_at {
tracing::info!("Using cached OAuth token");
return Ok(cached.access_token);
}
}
if let Some(refresh_token) = cached.refresh_token {
tracing::warn!("OAuth token expired, refreshing...");
if let Ok(new_token) = self.refresh_token(&refresh_token).await {
self.cache_token(&new_token, cache_file).await?;
return Ok(new_token.access_token);
}
}
}
}
tracing::info!("No cached token found, starting OAuth flow...");
let metadata = self.get_metadata().await?;
match self.authorization_code_flow(&metadata).await {
Ok(token) => Ok(token),
Err(e) => {
tracing::warn!("Authorization code flow failed: {}", e);
if metadata.device_authorization_endpoint.is_some() {
tracing::info!("Trying device code flow...");
return self.device_code_flow_with_metadata(&metadata).await;
}
Err(Error::internal(
"No supported OAuth flow available.\n\
\n\
The server must support either:\n\
- Authorization code flow (authorization_endpoint), or\n\
- Device code flow (device_authorization_endpoint)"
.to_string(),
))
},
}
}
fn generate_code_verifier() -> String {
let random_bytes: [u8; 32] = rand::rng().random();
URL_SAFE_NO_PAD.encode(random_bytes)
}
fn generate_code_challenge(verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
URL_SAFE_NO_PAD.encode(hash)
}
async fn authorization_code_flow(&self, metadata: &OidcDiscoveryMetadata) -> Result<String> {
tracing::info!("Starting OAuth authorization code flow...");
let code_verifier = Self::generate_code_verifier();
let code_challenge = Self::generate_code_challenge(&code_verifier);
let redirect_port = self.config.redirect_port;
let redirect_uri = format!("http://localhost:{}/callback", redirect_port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", redirect_port))
.await
.map_err(|e| {
Error::internal(format!(
"Failed to bind to localhost:{}.\n\
\n\
This port may already be in use. Try a different port with:\n\
--oauth-redirect-port PORT\n\
\n\
Error: {e}",
redirect_port
))
})?;
tracing::debug!("Local callback server listening on port {}", redirect_port);
tracing::warn!(
"Ensure the redirect URI is registered in your OAuth provider: {}",
redirect_uri
);
let mut auth_url = Url::parse(&metadata.authorization_endpoint)
.map_err(|e| Error::internal(format!("Invalid authorization endpoint: {e}")))?;
auth_url
.query_pairs_mut()
.append_pair("client_id", &self.config.client_id)
.append_pair("response_type", "code")
.append_pair("redirect_uri", &redirect_uri)
.append_pair("scope", &self.config.scopes.join(" "))
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256")
.append_pair("state", &Self::generate_code_verifier());
tracing::info!("OAuth Authentication Required");
tracing::info!("Opening browser for authentication...");
tracing::info!("If the browser doesn't open, visit: {}", auth_url.as_str());
if let Err(e) = webbrowser::open(auth_url.as_str()) {
tracing::warn!(
"Failed to open browser: {}. Please open the URL manually.",
e
);
}
let (tx, rx) = oneshot::channel();
let callback_task = tokio::spawn(async move {
if let Ok((mut stream, _)) = listener.accept().await {
let mut reader = BufReader::new(&mut stream);
let mut request_line = String::new();
if reader.read_line(&mut request_line).await.is_ok() {
if let Some(path) = request_line.split_whitespace().nth(1) {
if let Ok(callback_url) = Url::parse(&format!("http://localhost{}", path)) {
let code = callback_url
.query_pairs()
.find(|(key, _)| key == "code")
.map(|(_, value)| value.to_string());
let response = if code.is_some() {
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: green;'>Authentication Successful!</h1>\
<p>You can close this window and return to the terminal.</p>\
</body></html>"
} else {
"HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: red;'>Authentication Failed</h1>\
<p>No authorization code received. Please try again.</p>\
</body></html>"
};
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
if let Some(code) = code {
let _ = tx.send(code);
}
}
}
}
}
});
tracing::info!("Waiting for authorization...");
let authorization_code = tokio::time::timeout(Duration::from_secs(300), rx)
.await
.map_err(|_| {
Error::internal("Timeout waiting for OAuth callback (5 minutes)".to_string())
})?
.map_err(|e| Error::internal(format!("OAuth callback channel error: {e}")))?;
callback_task.abort();
tracing::info!("Authorization code received");
tracing::debug!("Exchanging authorization code for access token...");
let token_exchange = TokenExchangeClient::new();
let token_response = token_exchange
.exchange_code(
&metadata.token_endpoint,
&authorization_code,
&self.config.client_id,
None, &redirect_uri,
Some(&code_verifier), )
.await
.map_err(|e| {
Error::internal(format!(
"Failed to exchange authorization code for token: {e}"
))
})?;
tracing::info!("Authentication successful");
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token_from_response(&token_response, cache_file)
.await?;
}
Ok(token_response.access_token)
}
async fn device_code_flow_with_metadata(
&self,
metadata: &OidcDiscoveryMetadata,
) -> Result<String> {
tracing::info!("Starting OAuth device code flow...");
let device_auth_endpoint =
metadata
.device_authorization_endpoint
.as_ref()
.ok_or_else(|| {
Error::internal(
"Device authorization endpoint not found in OAuth metadata.\n\
\n\
The OAuth server does not support device code flow (RFC 8628)."
.to_string(),
)
})?;
self.device_code_flow_internal(metadata, device_auth_endpoint)
.await
}
async fn device_code_flow_internal(
&self,
metadata: &OidcDiscoveryMetadata,
device_auth_endpoint: &str,
) -> Result<String> {
let scope = self.config.scopes.join(" ");
let response = self
.client
.post(device_auth_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("scope", &scope),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to request device code: {e}")))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"Device authorization failed: {}",
response.text().await.unwrap_or_default()
)));
}
let device_auth: DeviceAuthResponse = response.json().await.map_err(|e| {
Error::internal(format!(
"Failed to parse device authorization response: {e}"
))
})?;
tracing::info!("OAuth device code flow");
tracing::info!("1. Visit: {}", device_auth.verification_uri);
tracing::info!("2. Enter code: {}", device_auth.user_code);
if let Some(complete_uri) = &device_auth.verification_uri_complete {
tracing::info!("Or visit directly: {}", complete_uri);
}
let poll_interval = Duration::from_secs(device_auth.interval.unwrap_or(5));
let token_endpoint = &metadata.token_endpoint;
let expires_at = SystemTime::now() + Duration::from_secs(device_auth.expires_in);
loop {
if SystemTime::now() > expires_at {
return Err(Error::internal(
"Device code expired. Please try again.".to_string(),
));
}
sleep(poll_interval).await;
tracing::debug!("Polling for authorization...");
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("device_code", &device_auth.device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to poll for token: {e}")))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| Error::internal(format!("Failed to read token response body: {e}")))?;
if status.is_success() {
let token_response: TokenResponse = serde_json::from_str(&body)
.map_err(|e| Error::internal(format!("Failed to parse token response: {e}")))?;
tracing::info!("Authentication successful");
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token(&token_response, cache_file).await?;
}
return Ok(token_response.access_token);
}
if let Ok(error) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(error_code) = error.get("error").and_then(|e| e.as_str()) {
match error_code {
"authorization_pending" => continue,
"slow_down" => {
sleep(poll_interval).await;
continue;
},
"access_denied" => {
return Err(Error::internal("User denied authorization".to_string()));
},
"expired_token" => {
return Err(Error::internal("Device code expired".to_string()));
},
_ => {
return Err(Error::internal(format!("OAuth error: {}", error_code)));
},
}
}
}
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let metadata = self.get_metadata().await?;
let token_endpoint = &metadata.token_endpoint;
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to refresh token: {e}")))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"Token refresh failed: {}",
response.text().await.unwrap_or_default()
)));
}
response
.json()
.await
.map_err(|e| Error::internal(format!("Failed to parse token response: {e}")))
}
async fn load_cached_token(&self, cache_file: &PathBuf) -> Result<TokenCache> {
let content = tokio::fs::read_to_string(cache_file)
.await
.map_err(|e| Error::internal(format!("Failed to read token cache: {e}")))?;
serde_json::from_str(&content)
.map_err(|e| Error::internal(format!("Failed to parse token cache: {e}")))
}
async fn cache_token(&self, token: &TokenResponse, cache_file: &PathBuf) -> Result<()> {
let expires_at = token.expires_in.map(|secs| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ secs
});
let cache = TokenCache {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_at,
scopes: self.config.scopes.clone(),
};
if let Some(parent) = cache_file.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| Error::internal(format!("Failed to create cache directory: {e}")))?;
}
let json = serde_json::to_string_pretty(&cache)
.map_err(|e| Error::internal(format!("Failed to serialize cache: {e}")))?;
tokio::fs::write(cache_file, json)
.await
.map_err(|e| Error::internal(format!("Failed to write token cache: {e}")))?;
tracing::debug!("Token cached to: {}", cache_file.display());
Ok(())
}
async fn cache_token_from_response(
&self,
token: &crate::client::auth::TokenResponse,
cache_file: &PathBuf,
) -> Result<()> {
let internal_token = TokenResponse {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_in: token.expires_in,
token_type: token.token_type.clone(),
};
self.cache_token(&internal_token, cache_file).await
}
pub async fn create_middleware_chain(&self) -> Result<Arc<HttpMiddlewareChain>> {
let access_token = self.get_access_token().await?;
tracing::debug!(
"Creating OAuth middleware with token: {}...",
&access_token[..access_token.len().min(20)]
);
let bearer_token = BearerToken::new(access_token);
let oauth_middleware = OAuthClientMiddleware::new(bearer_token);
let mut chain = HttpMiddlewareChain::new();
chain.add(Arc::new(oauth_middleware));
tracing::info!("OAuth middleware added to chain");
Ok(Arc::new(chain))
}
}
pub fn default_cache_path() -> PathBuf {
let mut path = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
path.push(".pmcp");
path.push("oauth-tokens.json");
path
}
pub async fn create_oauth_middleware(config: OAuthConfig) -> Result<Arc<HttpMiddlewareChain>> {
let helper = OAuthHelper::new(config)?;
helper.create_middleware_chain().await
}