use anyhow::{Context, Result};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use colored::*;
use pmcp::client::auth::{OidcDiscoveryClient, TokenExchangeClient};
use pmcp::client::http_middleware::HttpMiddlewareChain;
use pmcp::client::oauth_middleware::{BearerToken, OAuthClientMiddleware};
use pmcp::server::auth::oauth2::OidcDiscoveryMetadata;
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::time::sleep;
use url::Url;
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub issuer: Option<String>,
pub mcp_server_url: Option<String>,
pub client_id: String,
pub scopes: Vec<String>,
pub cache_file: Option<PathBuf>,
pub redirect_port: u16,
}
#[derive(Debug, Serialize, Deserialize)]
struct TokenCache {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
scopes: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceAuthResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: Option<u64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
token_type: String,
}
pub struct OAuthHelper {
config: OAuthConfig,
client: reqwest::Client,
}
impl OAuthHelper {
pub fn new(config: OAuthConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.context("Failed to create HTTP client")?;
Ok(Self { config, client })
}
fn extract_base_url(mcp_url: &str) -> Result<String> {
let parsed = Url::parse(mcp_url).context("Invalid MCP server URL")?;
let mut base = format!("{}://{}", parsed.scheme(), parsed.host_str().unwrap_or(""));
if let Some(port) = parsed.port() {
let is_default_port = (parsed.scheme() == "https" && port == 443)
|| (parsed.scheme() == "http" && port == 80);
if !is_default_port {
base.push_str(&format!(":{}", port));
}
}
Ok(base)
}
async fn discover_metadata(&self, mcp_url: &str) -> Result<OidcDiscoveryMetadata> {
let base_url = Self::extract_base_url(mcp_url)?;
eprintln!(
"{}",
format!("Discovering OAuth configuration from {}...", base_url).cyan()
);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(&base_url).await {
Ok(metadata) => {
eprintln!("{}", "✓ OAuth discovery successful!".green());
eprintln!(" Issuer: {}", metadata.issuer.dimmed());
if let Some(ref device_endpoint) = metadata.device_authorization_endpoint {
eprintln!(" Device endpoint: {}", device_endpoint.dimmed());
}
Ok(metadata)
},
Err(e) => {
anyhow::bail!(
"Failed to discover OAuth configuration at {}: {}\n\
\n\
Please provide --oauth-issuer explicitly, or ensure the server\n\
exposes OAuth metadata at {}/.well-known/openid-configuration",
base_url,
e,
base_url
)
},
}
}
async fn get_metadata(&self) -> Result<OidcDiscoveryMetadata> {
if let Some(ref mcp_url) = self.config.mcp_server_url {
self.discover_metadata(mcp_url).await
} else if let Some(ref issuer) = self.config.issuer {
eprintln!(
"{}",
format!("Discovering OAuth configuration from {}...", issuer).cyan()
);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(issuer).await {
Ok(metadata) => {
eprintln!("{}", "✓ OAuth discovery successful!".green());
Ok(metadata)
},
Err(e) => {
anyhow::bail!(
"Failed to discover OAuth configuration from issuer {}: {}\n\
\n\
Please ensure the issuer URL exposes OAuth metadata at\n\
{}/.well-known/openid-configuration",
issuer,
e,
issuer
)
},
}
} else {
anyhow::bail!(
"Either oauth_issuer or mcp_server_url must be provided for OAuth authentication"
)
}
}
pub async fn get_access_token(&self) -> Result<String> {
if let Some(ref cache_file) = self.config.cache_file {
if let Ok(cached) = self.load_cached_token(cache_file).await {
if let Some(expires_at) = cached.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if now < expires_at {
eprintln!("{}", "✓ Using cached OAuth token".green());
return Ok(cached.access_token);
}
}
if let Some(refresh_token) = cached.refresh_token {
eprintln!("{}", "Refreshing OAuth token...".yellow());
if let Ok(new_token) = self.refresh_token(&refresh_token).await {
self.cache_token(&new_token, cache_file).await?;
return Ok(new_token.access_token);
}
}
}
}
eprintln!("{}", "No cached token found. Starting OAuth flow...".cyan());
let metadata = self.get_metadata().await?;
match self.authorization_code_flow(&metadata).await {
Ok(token) => return Ok(token),
Err(e) => {
eprintln!(
"{}",
format!("Authorization code flow failed: {}", e).yellow()
);
if metadata.device_authorization_endpoint.is_some() {
eprintln!("{}", "Trying device code flow...".cyan());
return self.device_code_flow_with_metadata(&metadata).await;
} else {
anyhow::bail!(
"No supported OAuth flow available.\n\
\n\
The server must support either:\n\
- Authorization code flow (authorization_endpoint), or\n\
- Device code flow (device_authorization_endpoint)"
);
}
},
}
}
fn generate_code_verifier() -> String {
let random_bytes: [u8; 32] = rand::rng().random();
URL_SAFE_NO_PAD.encode(random_bytes)
}
fn generate_code_challenge(verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
URL_SAFE_NO_PAD.encode(hash)
}
async fn authorization_code_flow(&self, metadata: &OidcDiscoveryMetadata) -> Result<String> {
eprintln!(
"{}",
"Starting OAuth authorization code flow...".cyan().bold()
);
eprintln!();
let code_verifier = Self::generate_code_verifier();
let code_challenge = Self::generate_code_challenge(&code_verifier);
let redirect_port = self.config.redirect_port;
let redirect_uri = format!("http://localhost:{}/callback", redirect_port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", redirect_port))
.await
.with_context(|| {
format!(
"Failed to bind to localhost:{}.\n\
\n\
This port may already be in use. Try a different port with:\n\
--oauth-redirect-port PORT",
redirect_port
)
})?;
eprintln!(
"{}",
format!("Local callback server listening on port {}", redirect_port).dimmed()
);
eprintln!();
eprintln!(
"{}",
format!("IMPORTANT: Ensure the redirect URI is registered in your OAuth provider:")
.yellow()
.bold()
);
eprintln!("{}", format!(" {}", redirect_uri).yellow());
eprintln!();
let mut auth_url = Url::parse(&metadata.authorization_endpoint)
.context("Invalid authorization endpoint")?;
auth_url
.query_pairs_mut()
.append_pair("client_id", &self.config.client_id)
.append_pair("response_type", "code")
.append_pair("redirect_uri", &redirect_uri)
.append_pair("scope", &self.config.scopes.join(" "))
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256")
.append_pair("state", &Self::generate_code_verifier());
eprintln!();
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!("{}", " OAuth Authentication Required".cyan().bold());
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!();
eprintln!(" {}", "Opening browser for authentication...".bold());
eprintln!();
eprintln!(
" {}",
"If the browser doesn't open automatically, visit:".dimmed()
);
eprintln!(" {}", auth_url.as_str().yellow());
eprintln!();
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!();
if let Err(e) = webbrowser::open(auth_url.as_str()) {
eprintln!(
"{}",
format!("Warning: Failed to open browser: {}", e).yellow()
);
eprintln!("{}", "Please open the URL above manually.".yellow());
}
let (tx, rx) = oneshot::channel();
let callback_task = tokio::spawn(async move {
if let Ok((mut stream, _)) = listener.accept().await {
let mut reader = BufReader::new(&mut stream);
let mut request_line = String::new();
if reader.read_line(&mut request_line).await.is_ok() {
if let Some(path) = request_line.split_whitespace().nth(1) {
if let Ok(callback_url) = Url::parse(&format!("http://localhost{}", path)) {
let code = callback_url
.query_pairs()
.find(|(key, _)| key == "code")
.map(|(_, value)| value.to_string());
let response = if code.is_some() {
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: green;'>✓ Authentication Successful!</h1>\
<p>You can close this window and return to the terminal.</p>\
</body></html>"
} else {
"HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: red;'>✗ Authentication Failed</h1>\
<p>No authorization code received. Please try again.</p>\
</body></html>"
};
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
if let Some(code) = code {
let _ = tx.send(code);
}
}
}
}
}
});
print!(" Waiting for authorization...");
std::io::Write::flush(&mut std::io::stdout()).unwrap();
let authorization_code = tokio::time::timeout(Duration::from_secs(300), rx)
.await
.context("Timeout waiting for OAuth callback (5 minutes)")??;
callback_task.abort();
eprintln!("\r{}", "✓ Authorization code received! ".green());
eprintln!();
eprintln!(
"{}",
"Exchanging authorization code for access token...".dimmed()
);
let token_exchange = TokenExchangeClient::new();
let token_response = token_exchange
.exchange_code(
&metadata.token_endpoint,
&authorization_code,
&self.config.client_id,
None, &redirect_uri,
Some(&code_verifier), )
.await
.context("Failed to exchange authorization code for token")?;
eprintln!("{}", "✓ Authentication successful!".green().bold());
eprintln!();
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token_from_response(&token_response, cache_file)
.await?;
}
Ok(token_response.access_token)
}
async fn device_code_flow_with_metadata(
&self,
metadata: &OidcDiscoveryMetadata,
) -> Result<String> {
eprintln!("{}", "Starting OAuth device code flow...".cyan().bold());
eprintln!();
let device_auth_endpoint =
metadata
.device_authorization_endpoint
.as_ref()
.ok_or_else(|| {
anyhow::anyhow!(
"Device authorization endpoint not found in OAuth metadata.\n\
\n\
The OAuth server does not support device code flow (RFC 8628)."
)
})?;
self.device_code_flow_internal(metadata, device_auth_endpoint)
.await
}
async fn device_code_flow_internal(
&self,
metadata: &OidcDiscoveryMetadata,
device_auth_endpoint: &str,
) -> Result<String> {
let scope = self.config.scopes.join(" ");
let response = self
.client
.post(device_auth_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("scope", &scope),
])
.send()
.await
.context("Failed to request device code")?;
if !response.status().is_success() {
anyhow::bail!(
"Device authorization failed: {}",
response.text().await.unwrap_or_default()
);
}
let device_auth: DeviceAuthResponse = response
.json()
.await
.context("Failed to parse device authorization response")?;
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!("{}", " OAuth Authentication Required".cyan().bold());
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!();
eprintln!(
" {} {}",
"1. Visit:".bold(),
device_auth.verification_uri.yellow()
);
eprintln!(
" {} {}",
"2. Enter code:".bold(),
device_auth.user_code.green().bold()
);
if let Some(complete_uri) = &device_auth.verification_uri_complete {
eprintln!();
eprintln!(" {} Or scan this URL:", "Shortcut:".bold());
eprintln!(" {}", complete_uri.yellow());
}
eprintln!();
eprintln!(
"{}",
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━".cyan()
);
eprintln!();
let poll_interval = Duration::from_secs(device_auth.interval.unwrap_or(5));
let token_endpoint = &metadata.token_endpoint;
let expires_at = SystemTime::now() + Duration::from_secs(device_auth.expires_in);
loop {
if SystemTime::now() > expires_at {
anyhow::bail!("Device code expired. Please try again.");
}
sleep(poll_interval).await;
print!(" Waiting for authorization...\r");
let _ = std::io::Write::flush(&mut std::io::stdout());
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("device_code", &device_auth.device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])
.send()
.await
.context("Failed to poll for token")?;
let status = response.status();
let body = response.text().await?;
if status.is_success() {
let token_response: TokenResponse =
serde_json::from_str(&body).context("Failed to parse token response")?;
eprintln!("{}", "✓ Authentication successful!".green().bold());
eprintln!();
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token(&token_response, cache_file).await?;
}
return Ok(token_response.access_token);
}
if let Ok(error) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(error_code) = error.get("error").and_then(|e| e.as_str()) {
match error_code {
"authorization_pending" => continue,
"slow_down" => {
sleep(poll_interval).await;
continue;
},
"access_denied" => {
anyhow::bail!("User denied authorization");
},
"expired_token" => {
anyhow::bail!("Device code expired");
},
_ => {
anyhow::bail!("OAuth error: {}", error_code);
},
}
}
}
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let metadata = self.get_metadata().await?;
let token_endpoint = &metadata.token_endpoint;
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.await
.context("Failed to refresh token")?;
if !response.status().is_success() {
anyhow::bail!(
"Token refresh failed: {}",
response.text().await.unwrap_or_default()
);
}
response
.json()
.await
.context("Failed to parse token response")
}
async fn load_cached_token(&self, cache_file: &PathBuf) -> Result<TokenCache> {
let content = tokio::fs::read_to_string(cache_file)
.await
.context("Failed to read token cache")?;
serde_json::from_str(&content).context("Failed to parse token cache")
}
async fn cache_token(&self, token: &TokenResponse, cache_file: &PathBuf) -> Result<()> {
let expires_at = token.expires_in.map(|secs| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ secs
});
let cache = TokenCache {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_at,
scopes: self.config.scopes.clone(),
};
if let Some(parent) = cache_file.parent() {
tokio::fs::create_dir_all(parent)
.await
.context("Failed to create cache directory")?;
}
let json = serde_json::to_string_pretty(&cache).context("Failed to serialize cache")?;
tokio::fs::write(cache_file, json)
.await
.context("Failed to write token cache")?;
eprintln!(
"{}",
format!("Token cached to: {}", cache_file.display()).dimmed()
);
Ok(())
}
async fn cache_token_from_response(
&self,
token: &pmcp::client::auth::TokenResponse,
cache_file: &PathBuf,
) -> Result<()> {
let internal_token = TokenResponse {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_in: token.expires_in,
token_type: token.token_type.clone(),
};
self.cache_token(&internal_token, cache_file).await
}
pub async fn create_middleware_chain(&self) -> Result<Arc<HttpMiddlewareChain>> {
let access_token = self.get_access_token().await?;
eprintln!(
"{}",
format!(
"Creating OAuth middleware with token: {}...",
&access_token[..20]
)
.dimmed()
);
let bearer_token = BearerToken::new(access_token);
let oauth_middleware = OAuthClientMiddleware::new(bearer_token);
let mut chain = HttpMiddlewareChain::new();
chain.add(Arc::new(oauth_middleware));
eprintln!("{}", "OAuth middleware added to chain".green());
Ok(Arc::new(chain))
}
}
pub fn default_cache_path() -> PathBuf {
let mut path = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
path.push(".mcp-tester");
path.push("tokens.json");
path
}