use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::time::sleep;
use url::Url;
use crate::client::auth::{OidcDiscoveryClient, TokenExchangeClient};
use crate::client::http_middleware::HttpMiddlewareChain;
use crate::client::oauth_middleware::{BearerToken, OAuthClientMiddleware};
use crate::error::{Error, Result};
use crate::server::auth::oauth2::OidcDiscoveryMetadata;
pub use crate::server::auth::provider::{DcrRequest, DcrResponse};
#[derive(Debug, Clone)]
pub struct OAuthConfig {
pub issuer: Option<String>,
pub mcp_server_url: Option<String>,
pub client_id: Option<String>,
pub client_name: Option<String>,
pub dcr_enabled: bool,
pub scopes: Vec<String>,
pub cache_file: Option<PathBuf>,
pub redirect_port: u16,
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
issuer: None,
mcp_server_url: None,
client_id: None,
client_name: None,
dcr_enabled: true,
scopes: Vec::new(),
cache_file: None,
redirect_port: 8080,
}
}
}
#[derive(Debug, Clone)]
pub struct AuthorizationResult {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
pub scopes: Vec<String>,
pub issuer: Option<String>,
pub client_id: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct TokenCache {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
scopes: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceAuthResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: Option<u64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
token_type: String,
}
#[derive(Debug)]
pub struct OAuthHelper {
config: OAuthConfig,
client: reqwest::Client,
}
impl OAuthHelper {
pub fn new(config: OAuthConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::internal(format!("Failed to create HTTP client: {e}")))?;
Ok(Self { config, client })
}
async fn do_dynamic_client_registration(
&self,
registration_endpoint: &str,
) -> Result<crate::server::auth::provider::DcrResponse> {
let parsed = Url::parse(registration_endpoint)
.map_err(|e| Error::internal(format!("Invalid registration_endpoint URL: {e}")))?;
let scheme_ok = parsed.scheme() == "https"
|| (parsed.scheme() == "http"
&& matches!(
parsed.host_str(),
Some("localhost") | Some("127.0.0.1") | Some("::1") | Some("[::1]")
));
if !scheme_ok {
return Err(Error::internal(format!(
"registration_endpoint must be https:// (or http://localhost, \
http://127.0.0.1, http://[::1]) — got {}",
registration_endpoint
)));
}
let client_name = self
.config
.client_name
.clone()
.unwrap_or_else(|| "pmcp-sdk".to_string());
let redirect_uri = format!("http://127.0.0.1:{}/callback", self.config.redirect_port);
let request = crate::server::auth::provider::DcrRequest {
redirect_uris: vec![redirect_uri],
client_name: Some(client_name),
client_uri: None,
logo_uri: None,
contacts: vec![],
token_endpoint_auth_method: Some("none".to_string()),
grant_types: vec!["authorization_code".to_string()],
response_types: vec!["code".to_string()],
scope: None,
software_id: None,
software_version: None,
extra: Default::default(),
};
let response = self
.client
.post(registration_endpoint)
.json(&request)
.send()
.await
.map_err(|e| Error::internal(format!("DCR request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Error::internal(format!(
"DCR failed ({}): {}\n\n\
The server rejected dynamic client registration. Pass a \
pre-registered client_id to skip DCR.",
status, body
)));
}
const MAX_DCR_RESPONSE_BYTES: usize = 1_048_576; let bytes = response
.bytes()
.await
.map_err(|e| Error::internal(format!("Failed to read DCR response body: {e}")))?;
if bytes.len() > MAX_DCR_RESPONSE_BYTES {
return Err(Error::internal(format!(
"DCR response exceeds {} byte cap (got {} bytes) — refusing to parse",
MAX_DCR_RESPONSE_BYTES,
bytes.len()
)));
}
serde_json::from_slice::<crate::server::auth::provider::DcrResponse>(&bytes)
.map_err(|e| Error::internal(format!("Failed to parse DCR response: {e}")))
}
async fn resolve_client_id_for_flow(&self, metadata: &OidcDiscoveryMetadata) -> Result<String> {
if let Some(ref id) = self.config.client_id {
return Ok(id.clone());
}
if !self.config.dcr_enabled {
return Err(Error::internal(
"no client_id configured and dcr_enabled is false — \
provide OAuthConfig::client_id or enable dcr_enabled"
.to_string(),
));
}
match metadata.registration_endpoint.as_ref() {
Some(endpoint) => {
tracing::info!("Performing Dynamic Client Registration at {}", endpoint);
let response = self.do_dynamic_client_registration(endpoint).await?;
tracing::info!("DCR succeeded — issued client_id");
Ok(response.client_id)
},
None => Err(Error::internal(
"server does not support DCR — pass a pre-registered client_id".to_string(),
)),
}
}
#[doc(hidden)]
#[cfg(any(test, feature = "oauth"))]
pub async fn test_resolve_client_id_from_discovery(&self) -> Result<String> {
let metadata = self.get_metadata().await?;
self.resolve_client_id_for_flow(&metadata).await
}
fn extract_base_url(mcp_url: &str) -> Result<String> {
let parsed = Url::parse(mcp_url)
.map_err(|e| Error::internal(format!("Invalid MCP server URL: {e}")))?;
let mut base = format!("{}://{}", parsed.scheme(), parsed.host_str().unwrap_or(""));
if let Some(port) = parsed.port() {
let is_default_port = (parsed.scheme() == "https" && port == 443)
|| (parsed.scheme() == "http" && port == 80);
if !is_default_port {
base.push_str(&format!(":{}", port));
}
}
Ok(base)
}
async fn discover_metadata(&self, mcp_url: &str) -> Result<OidcDiscoveryMetadata> {
let base_url = Self::extract_base_url(mcp_url)?;
tracing::info!("Discovering OAuth configuration from {}...", base_url);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(&base_url).await {
Ok(metadata) => {
tracing::info!("OAuth discovery successful");
tracing::debug!("Issuer: {}", metadata.issuer);
if let Some(ref device_endpoint) = metadata.device_authorization_endpoint {
tracing::debug!("Device endpoint: {}", device_endpoint);
}
Ok(metadata)
},
Err(e) => Err(Error::internal(format!(
"Failed to discover OAuth configuration at {}: {}\n\
\n\
Please provide --oauth-issuer explicitly, or ensure the server\n\
exposes OAuth metadata at {}/.well-known/openid-configuration",
base_url, e, base_url
))),
}
}
async fn get_metadata(&self) -> Result<OidcDiscoveryMetadata> {
if let Some(ref mcp_url) = self.config.mcp_server_url {
self.discover_metadata(mcp_url).await
} else if let Some(ref issuer) = self.config.issuer {
tracing::info!("Discovering OAuth configuration from {}...", issuer);
let discovery_client = OidcDiscoveryClient::new();
match discovery_client.discover(issuer).await {
Ok(metadata) => {
tracing::info!("OAuth discovery successful");
Ok(metadata)
},
Err(e) => Err(Error::internal(format!(
"Failed to discover OAuth configuration from issuer {}: {}\n\
\n\
Please ensure the issuer URL exposes OAuth metadata at\n\
{}/.well-known/openid-configuration",
issuer, e, issuer
))),
}
} else {
Err(Error::internal(
"Either oauth_issuer or mcp_server_url must be provided for OAuth authentication"
.to_string(),
))
}
}
pub async fn get_access_token(&self) -> Result<String> {
if let Some(ref cache_file) = self.config.cache_file {
if let Ok(cached) = self.load_cached_token(cache_file).await {
if let Some(expires_at) = cached.expires_at {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if now < expires_at {
tracing::info!("Using cached OAuth token");
return Ok(cached.access_token);
}
}
if let Some(refresh_token) = cached.refresh_token {
tracing::warn!("OAuth token expired, refreshing...");
if let Ok(new_token) = self.refresh_token(&refresh_token).await {
self.cache_token(&new_token, cache_file).await?;
return Ok(new_token.access_token);
}
}
}
}
tracing::info!("No cached token found, starting OAuth flow...");
let metadata = self.get_metadata().await?;
match self.authorization_code_flow(&metadata).await {
Ok(token) => Ok(token),
Err(e) => {
tracing::warn!("Authorization code flow failed: {}", e);
if metadata.device_authorization_endpoint.is_some() {
tracing::info!("Trying device code flow...");
return self.device_code_flow_with_metadata(&metadata).await;
}
Err(Error::internal(
"No supported OAuth flow available.\n\
\n\
The server must support either:\n\
- Authorization code flow (authorization_endpoint), or\n\
- Device code flow (device_authorization_endpoint)"
.to_string(),
))
},
}
}
pub async fn authorize_with_details(&self) -> Result<AuthorizationResult> {
let metadata = self.get_metadata().await?;
let effective_issuer = self
.config
.issuer
.clone()
.or_else(|| Some(metadata.issuer.clone()));
match self.authorization_code_flow_inner(&metadata).await {
Ok((token_response, resolved_client_id)) => Ok(Self::build_auth_result(
token_response,
resolved_client_id,
effective_issuer,
&self.config.scopes,
)),
Err(e) => {
tracing::warn!("Authorization code flow failed: {}", e);
if metadata.device_authorization_endpoint.is_some() {
tracing::info!(
"Trying device code flow (refresh_token may be None per RFC 8628)..."
);
let resolved_client_id = self.resolve_client_id_for_flow(&metadata).await?;
let access_token = self.device_code_flow_with_metadata(&metadata).await?;
return Ok(AuthorizationResult {
access_token,
refresh_token: None,
expires_at: None,
scopes: self.config.scopes.clone(),
issuer: effective_issuer,
client_id: resolved_client_id,
});
}
Err(Error::internal(
"No supported OAuth flow available.\n\
\n\
The server must support either:\n\
- Authorization code flow (authorization_endpoint), or\n\
- Device code flow (device_authorization_endpoint)"
.to_string(),
))
},
}
}
fn build_auth_result(
token_response: crate::client::auth::TokenResponse,
client_id: String,
effective_issuer: Option<String>,
requested_scopes: &[String],
) -> AuthorizationResult {
let expires_at = token_response.expires_in.map(|ttl| {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now + ttl
});
let granted_scopes = token_response
.scope
.as_deref()
.map(|s| s.split_whitespace().map(String::from).collect::<Vec<_>>())
.unwrap_or_else(|| requested_scopes.to_vec());
AuthorizationResult {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at,
scopes: granted_scopes,
issuer: effective_issuer,
client_id,
}
}
fn generate_code_verifier() -> String {
let random_bytes: [u8; 32] = rand::rng().random();
URL_SAFE_NO_PAD.encode(random_bytes)
}
fn generate_code_challenge(verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
URL_SAFE_NO_PAD.encode(hash)
}
async fn authorization_code_flow(&self, metadata: &OidcDiscoveryMetadata) -> Result<String> {
let (token_response, _client_id) = self.authorization_code_flow_inner(metadata).await?;
Ok(token_response.access_token)
}
async fn authorization_code_flow_inner(
&self,
metadata: &OidcDiscoveryMetadata,
) -> Result<(crate::client::auth::TokenResponse, String)> {
tracing::info!("Starting OAuth authorization code flow...");
let resolved_client_id = self.resolve_client_id_for_flow(metadata).await?;
let code_verifier = Self::generate_code_verifier();
let code_challenge = Self::generate_code_challenge(&code_verifier);
let redirect_port = self.config.redirect_port;
let redirect_uri = format!("http://127.0.0.1:{}/callback", redirect_port);
let listener = TcpListener::bind(format!("127.0.0.1:{}", redirect_port))
.await
.map_err(|e| {
Error::internal(format!(
"Failed to bind to 127.0.0.1:{}.\n\
\n\
This port may already be in use. Try a different port with:\n\
--oauth-redirect-port PORT\n\
\n\
Error: {e}",
redirect_port
))
})?;
tracing::debug!("Local callback server listening on port {}", redirect_port);
tracing::warn!(
"Ensure the redirect URI is registered in your OAuth provider: {}",
redirect_uri
);
let mut auth_url = Url::parse(&metadata.authorization_endpoint)
.map_err(|e| Error::internal(format!("Invalid authorization endpoint: {e}")))?;
auth_url
.query_pairs_mut()
.append_pair("client_id", &resolved_client_id)
.append_pair("response_type", "code")
.append_pair("redirect_uri", &redirect_uri)
.append_pair("scope", &self.config.scopes.join(" "))
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256")
.append_pair("state", &Self::generate_code_verifier());
tracing::info!("OAuth Authentication Required");
tracing::info!("Opening browser for authentication...");
tracing::info!("If the browser doesn't open, visit: {}", auth_url.as_str());
if let Err(e) = webbrowser::open(auth_url.as_str()) {
tracing::warn!(
"Failed to open browser: {}. Please open the URL manually.",
e
);
}
let (tx, rx) = oneshot::channel();
let callback_task = tokio::spawn(async move {
if let Ok((mut stream, _)) = listener.accept().await {
let mut reader = BufReader::new(&mut stream);
let mut request_line = String::new();
if reader.read_line(&mut request_line).await.is_ok() {
if let Some(path) = request_line.split_whitespace().nth(1) {
if let Ok(callback_url) = Url::parse(&format!("http://localhost{}", path)) {
let code = callback_url
.query_pairs()
.find(|(key, _)| key == "code")
.map(|(_, value)| value.to_string());
let response = if code.is_some() {
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: green;'>Authentication Successful!</h1>\
<p>You can close this window and return to the terminal.</p>\
</body></html>"
} else {
"HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
<html><body style='font-family: sans-serif; text-align: center; padding: 50px;'>\
<h1 style='color: red;'>Authentication Failed</h1>\
<p>No authorization code received. Please try again.</p>\
</body></html>"
};
let _ = stream.write_all(response.as_bytes()).await;
let _ = stream.flush().await;
if let Some(code) = code {
let _ = tx.send(code);
}
}
}
}
}
});
tracing::info!("Waiting for authorization...");
let authorization_code = tokio::time::timeout(Duration::from_secs(300), rx)
.await
.map_err(|_| {
Error::internal("Timeout waiting for OAuth callback (5 minutes)".to_string())
})?
.map_err(|e| Error::internal(format!("OAuth callback channel error: {e}")))?;
callback_task.abort();
tracing::info!("Authorization code received");
tracing::debug!("Exchanging authorization code for access token...");
let token_exchange = TokenExchangeClient::new();
let token_response = token_exchange
.exchange_code(
&metadata.token_endpoint,
&authorization_code,
&resolved_client_id,
None, &redirect_uri,
Some(&code_verifier), )
.await
.map_err(|e| {
Error::internal(format!(
"Failed to exchange authorization code for token: {e}"
))
})?;
tracing::info!("Authentication successful");
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token_from_response(&token_response, cache_file)
.await?;
}
Ok((token_response, resolved_client_id))
}
async fn device_code_flow_with_metadata(
&self,
metadata: &OidcDiscoveryMetadata,
) -> Result<String> {
tracing::info!("Starting OAuth device code flow...");
let device_auth_endpoint =
metadata
.device_authorization_endpoint
.as_ref()
.ok_or_else(|| {
Error::internal(
"Device authorization endpoint not found in OAuth metadata.\n\
\n\
The OAuth server does not support device code flow (RFC 8628)."
.to_string(),
)
})?;
self.device_code_flow_internal(metadata, device_auth_endpoint)
.await
}
async fn device_code_flow_internal(
&self,
metadata: &OidcDiscoveryMetadata,
device_auth_endpoint: &str,
) -> Result<String> {
let resolved_client_id = self.resolve_client_id_for_flow(metadata).await?;
let scope = self.config.scopes.join(" ");
let response = self
.client
.post(device_auth_endpoint)
.form(&[
("client_id", resolved_client_id.as_str()),
("scope", &scope),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to request device code: {e}")))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"Device authorization failed: {}",
response.text().await.unwrap_or_default()
)));
}
let device_auth: DeviceAuthResponse = response.json().await.map_err(|e| {
Error::internal(format!(
"Failed to parse device authorization response: {e}"
))
})?;
tracing::info!("OAuth device code flow");
tracing::info!("1. Visit: {}", device_auth.verification_uri);
tracing::info!("2. Enter code: {}", device_auth.user_code);
if let Some(complete_uri) = &device_auth.verification_uri_complete {
tracing::info!("Or visit directly: {}", complete_uri);
}
let poll_interval = Duration::from_secs(device_auth.interval.unwrap_or(5));
let token_endpoint = &metadata.token_endpoint;
let expires_at = SystemTime::now() + Duration::from_secs(device_auth.expires_in);
loop {
if SystemTime::now() > expires_at {
return Err(Error::internal(
"Device code expired. Please try again.".to_string(),
));
}
sleep(poll_interval).await;
tracing::debug!("Polling for authorization...");
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", resolved_client_id.as_str()),
("device_code", &device_auth.device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to poll for token: {e}")))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| Error::internal(format!("Failed to read token response body: {e}")))?;
if status.is_success() {
let token_response: TokenResponse = serde_json::from_str(&body)
.map_err(|e| Error::internal(format!("Failed to parse token response: {e}")))?;
tracing::info!("Authentication successful");
if let Some(ref cache_file) = self.config.cache_file {
self.cache_token(&token_response, cache_file).await?;
}
return Ok(token_response.access_token);
}
if let Ok(error) = serde_json::from_str::<serde_json::Value>(&body) {
if let Some(error_code) = error.get("error").and_then(|e| e.as_str()) {
match error_code {
"authorization_pending" => continue,
"slow_down" => {
sleep(poll_interval).await;
continue;
},
"access_denied" => {
return Err(Error::internal("User denied authorization".to_string()));
},
"expired_token" => {
return Err(Error::internal("Device code expired".to_string()));
},
_ => {
return Err(Error::internal(format!("OAuth error: {}", error_code)));
},
}
}
}
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let metadata = self.get_metadata().await?;
let token_endpoint = &metadata.token_endpoint;
let client_id = self.config.client_id.as_deref().ok_or_else(|| {
Error::internal("cannot refresh token without a cached client_id".to_string())
})?;
let response = self
.client
.post(token_endpoint)
.form(&[
("client_id", client_id),
("refresh_token", refresh_token),
("grant_type", "refresh_token"),
])
.send()
.await
.map_err(|e| Error::internal(format!("Failed to refresh token: {e}")))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"Token refresh failed: {}",
response.text().await.unwrap_or_default()
)));
}
response
.json()
.await
.map_err(|e| Error::internal(format!("Failed to parse token response: {e}")))
}
async fn load_cached_token(&self, cache_file: &PathBuf) -> Result<TokenCache> {
let content = tokio::fs::read_to_string(cache_file)
.await
.map_err(|e| Error::internal(format!("Failed to read token cache: {e}")))?;
serde_json::from_str(&content)
.map_err(|e| Error::internal(format!("Failed to parse token cache: {e}")))
}
async fn cache_token(&self, token: &TokenResponse, cache_file: &PathBuf) -> Result<()> {
let expires_at = token.expires_in.map(|secs| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
+ secs
});
let cache = TokenCache {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_at,
scopes: self.config.scopes.clone(),
};
if let Some(parent) = cache_file.parent() {
tokio::fs::create_dir_all(parent)
.await
.map_err(|e| Error::internal(format!("Failed to create cache directory: {e}")))?;
}
let json = serde_json::to_string_pretty(&cache)
.map_err(|e| Error::internal(format!("Failed to serialize cache: {e}")))?;
tokio::fs::write(cache_file, json)
.await
.map_err(|e| Error::internal(format!("Failed to write token cache: {e}")))?;
tracing::debug!("Token cached to: {}", cache_file.display());
Ok(())
}
async fn cache_token_from_response(
&self,
token: &crate::client::auth::TokenResponse,
cache_file: &PathBuf,
) -> Result<()> {
let internal_token = TokenResponse {
access_token: token.access_token.clone(),
refresh_token: token.refresh_token.clone(),
expires_in: token.expires_in,
token_type: token.token_type.clone(),
};
self.cache_token(&internal_token, cache_file).await
}
pub async fn create_middleware_chain(&self) -> Result<Arc<HttpMiddlewareChain>> {
let access_token = self.get_access_token().await?;
tracing::debug!(
"Creating OAuth middleware with token: {}...",
&access_token[..access_token.len().min(20)]
);
let bearer_token = BearerToken::new(access_token);
let oauth_middleware = OAuthClientMiddleware::new(bearer_token);
let mut chain = HttpMiddlewareChain::new();
chain.add(Arc::new(oauth_middleware));
tracing::info!("OAuth middleware added to chain");
Ok(Arc::new(chain))
}
}
pub fn default_cache_path() -> PathBuf {
let mut path = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
path.push(".pmcp");
path.push("oauth-tokens.json");
path
}
pub async fn create_oauth_middleware(config: OAuthConfig) -> Result<Arc<HttpMiddlewareChain>> {
let helper = OAuthHelper::new(config)?;
helper.create_middleware_chain().await
}
#[cfg(test)]
mod oauth_config_tests {
use super::*;
#[test]
fn oauth_config_default_has_dcr_enabled_and_none_client_id() {
let c = OAuthConfig::default();
assert!(
c.client_id.is_none(),
"default client_id must be None for DCR auto-fire"
);
assert!(c.dcr_enabled, "default dcr_enabled must be true");
assert!(c.client_name.is_none(), "default client_name is None");
}
#[test]
fn oauth_config_struct_literal_with_some_client_id_compiles() {
let _c = OAuthConfig {
issuer: None,
mcp_server_url: Some("https://x.example".into()),
client_id: Some("my-client".into()),
client_name: None,
dcr_enabled: false,
scopes: vec![],
cache_file: None,
redirect_port: 8080,
};
}
#[test]
fn dcr_types_are_reexported() {
let _r: super::DcrRequest = super::DcrRequest {
redirect_uris: vec!["http://localhost:8080/callback".into()],
client_name: Some("test".into()),
client_uri: None,
logo_uri: None,
contacts: vec![],
token_endpoint_auth_method: Some("none".into()),
grant_types: vec!["authorization_code".into()],
response_types: vec![],
scope: None,
software_id: None,
software_version: None,
extra: Default::default(),
};
let _rsp = super::DcrResponse {
client_id: "x".into(),
client_secret: None,
client_secret_expires_at: None,
registration_access_token: None,
registration_client_uri: None,
token_endpoint_auth_method: None,
extra: Default::default(),
};
}
}
#[cfg(test)]
mod dcr_tests {
use super::*;
use crate::server::auth::oauth2::OidcDiscoveryMetadata;
fn metadata(reg: Option<&str>) -> OidcDiscoveryMetadata {
OidcDiscoveryMetadata {
issuer: "https://issuer.example".into(),
authorization_endpoint: "https://issuer.example/auth".into(),
token_endpoint: "https://issuer.example/token".into(),
jwks_uri: None,
userinfo_endpoint: None,
registration_endpoint: reg.map(String::from),
revocation_endpoint: None,
introspection_endpoint: None,
device_authorization_endpoint: None,
response_types_supported: vec![],
grant_types_supported: vec![],
scopes_supported: vec![],
token_endpoint_auth_methods_supported: vec![],
code_challenge_methods_supported: vec![],
}
}
#[tokio::test]
async fn dcr_skipped_when_client_id_provided() {
let cfg = OAuthConfig {
client_id: Some("preset".into()),
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let resolved = helper
.resolve_client_id_for_flow(&metadata(Some("https://x/register")))
.await
.unwrap();
assert_eq!(resolved, "preset");
}
#[tokio::test]
async fn dcr_skipped_when_dcr_disabled_with_client_id() {
let cfg = OAuthConfig {
client_id: Some("preset".into()),
dcr_enabled: false,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let resolved = helper
.resolve_client_id_for_flow(&metadata(None))
.await
.unwrap();
assert_eq!(resolved, "preset");
}
#[tokio::test]
async fn dcr_needed_but_unsupported_errors_with_actionable_message() {
let cfg = OAuthConfig {
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.resolve_client_id_for_flow(&metadata(None))
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("server does not support DCR"),
"expected actionable DCR-missing message, got: {msg}"
);
}
#[tokio::test]
async fn dcr_needed_but_disabled_errors_when_client_id_none() {
let cfg = OAuthConfig {
client_id: None,
dcr_enabled: false,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.resolve_client_id_for_flow(&metadata(Some("https://x/register")))
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("dcr_enabled is false"),
"expected dcr_enabled=false error, got: {msg}"
);
}
#[tokio::test]
async fn dcr_rejects_http_non_localhost_endpoint() {
let cfg = OAuthConfig {
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.do_dynamic_client_registration("http://attacker.example/register")
.await
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("must be https"), "got: {msg}");
}
#[test]
fn dcr_request_body_matches_rfc7591_public_pkce_shape() {
let req = crate::server::auth::provider::DcrRequest {
redirect_uris: vec!["http://localhost:8080/callback".into()],
client_name: Some("pmcp-sdk".into()),
client_uri: None,
logo_uri: None,
contacts: vec![],
token_endpoint_auth_method: Some("none".into()),
grant_types: vec!["authorization_code".into()],
response_types: vec![],
scope: None,
software_id: None,
software_version: None,
extra: Default::default(),
};
let v: serde_json::Value = serde_json::to_value(&req).unwrap();
assert_eq!(v["client_name"], "pmcp-sdk");
assert_eq!(
v["redirect_uris"],
serde_json::json!(["http://localhost:8080/callback"])
);
assert_eq!(v["grant_types"], serde_json::json!(["authorization_code"]));
assert_eq!(v["token_endpoint_auth_method"], "none");
}
#[test]
fn dcr_request_body_contains_response_types_code() {
let req = crate::server::auth::provider::DcrRequest {
redirect_uris: vec!["http://localhost:8080/callback".into()],
client_name: Some("pmcp-sdk".into()),
client_uri: None,
logo_uri: None,
contacts: vec![],
token_endpoint_auth_method: Some("none".into()),
grant_types: vec!["authorization_code".into()],
response_types: vec!["code".into()],
scope: None,
software_id: None,
software_version: None,
extra: Default::default(),
};
let s = serde_json::to_string(&req).unwrap();
assert!(
s.contains(r#""response_types":["code"]"#),
"RFC 7591 §3.1 response_types missing from wire body: {s}"
);
}
#[tokio::test]
async fn dcr_advertises_127_0_0_1_redirect_not_localhost() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/register")
.match_body(mockito::Matcher::PartialJsonString(
serde_json::json!({
"redirect_uris": ["http://127.0.0.1:8080/callback"]
})
.to_string(),
))
.with_status(201)
.with_body(r#"{"client_id":"ok"}"#)
.create_async()
.await;
let helper = OAuthHelper::new(OAuthConfig {
dcr_enabled: true,
redirect_port: 8080,
..OAuthConfig::default()
})
.unwrap();
let result = helper
.do_dynamic_client_registration(&format!("{}/register", server.url()))
.await;
assert!(
result.is_ok(),
"DCR body did not pin 127.0.0.1 redirect_uri"
);
mock.assert_async().await;
}
#[tokio::test]
async fn dcr_accepts_ipv6_loopback_registration_endpoint() {
let cfg = OAuthConfig {
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.do_dynamic_client_registration("http://[::1]:9/register")
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
!msg.contains("must be https"),
"scheme guard should accept http://[::1] but rejected: {msg}"
);
}
#[tokio::test]
async fn dcr_accepts_http_localhost_registration_endpoint() {
let cfg = OAuthConfig {
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.do_dynamic_client_registration("http://localhost:9/register")
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
!msg.contains("must be https"),
"scheme guard should accept http://localhost but rejected: {msg}"
);
}
#[tokio::test]
async fn dcr_accepts_http_ipv4_loopback_registration_endpoint() {
let cfg = OAuthConfig {
dcr_enabled: true,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper
.do_dynamic_client_registration("http://127.0.0.1:9/register")
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
!msg.contains("must be https"),
"scheme guard should accept http://127.0.0.1 but rejected: {msg}"
);
}
#[tokio::test]
async fn authorize_with_details_fails_cleanly_without_server() {
let cfg = OAuthConfig {
mcp_server_url: Some("http://localhost:1/nonexistent".into()),
client_id: Some("x".into()),
dcr_enabled: false,
..OAuthConfig::default()
};
let helper = OAuthHelper::new(cfg).unwrap();
let err = helper.authorize_with_details().await.unwrap_err();
let _ = format!("{err}");
}
#[test]
fn authorization_result_struct_has_expected_fields() {
let _r = AuthorizationResult {
access_token: "a".into(),
refresh_token: Some("r".into()),
expires_at: Some(1),
scopes: vec!["openid".into()],
issuer: Some("https://i.example".into()),
client_id: "c".into(),
};
}
#[test]
fn build_auth_result_converts_expires_in_to_expires_at() {
let token = crate::client::auth::TokenResponse {
access_token: "a".into(),
token_type: "Bearer".into(),
expires_in: Some(3600),
refresh_token: Some("r".into()),
scope: Some("openid profile".into()),
};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let r = OAuthHelper::build_auth_result(
token,
"c1".into(),
Some("https://i.example".into()),
&["openid".into()],
);
assert_eq!(r.client_id, "c1");
assert_eq!(r.refresh_token.as_deref(), Some("r"));
assert_eq!(r.issuer.as_deref(), Some("https://i.example"));
assert_eq!(r.scopes, vec!["openid".to_string(), "profile".into()]);
let expires_at = r.expires_at.expect("expires_at populated");
assert!(
expires_at >= now + 3599 && expires_at <= now + 3601,
"expires_at ({}) should be approximately now+3600 ({})",
expires_at,
now + 3600
);
}
#[test]
fn build_auth_result_falls_back_to_requested_scopes_when_no_grant() {
let token = crate::client::auth::TokenResponse {
access_token: "a".into(),
token_type: "Bearer".into(),
expires_in: None,
refresh_token: None,
scope: None,
};
let requested = vec!["openid".to_string(), "email".to_string()];
let r = OAuthHelper::build_auth_result(token, "c".into(), None, &requested);
assert_eq!(r.scopes, requested);
assert!(r.expires_at.is_none());
assert!(r.refresh_token.is_none());
}
}
#[cfg(test)]
mod dcr_proptest {
use super::*;
use proptest::prelude::*;
fn arb_dcr_request() -> impl Strategy<Value = crate::server::auth::provider::DcrRequest> {
(
prop::collection::vec("[a-z][a-z0-9-]{2,30}", 1..3),
prop::option::of("[a-zA-Z][a-zA-Z0-9 _-]{1,40}"),
prop::option::of(
prop::string::string_regex("(none|client_secret_basic|client_secret_post)")
.unwrap(),
),
)
.prop_map(|(uris, name, auth_method)| {
let redirect_uris = uris
.into_iter()
.map(|u| format!("http://localhost:8080/{u}"))
.collect();
crate::server::auth::provider::DcrRequest {
redirect_uris,
client_name: name,
client_uri: None,
logo_uri: None,
contacts: vec![],
token_endpoint_auth_method: auth_method,
grant_types: vec!["authorization_code".into()],
response_types: vec!["code".into()],
scope: None,
software_id: None,
software_version: None,
extra: Default::default(),
}
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(64))]
#[test]
fn dcr_request_serde_roundtrip(req in arb_dcr_request()) {
let v = serde_json::to_value(&req).unwrap();
let back: crate::server::auth::provider::DcrRequest =
serde_json::from_value(v).unwrap();
prop_assert_eq!(req.redirect_uris, back.redirect_uris);
prop_assert_eq!(req.client_name, back.client_name);
prop_assert_eq!(req.token_endpoint_auth_method, back.token_endpoint_auth_method);
}
#[test]
fn oauth_config_builder_allows_all_combinations(
has_id in any::<bool>(),
has_name in any::<bool>(),
dcr in any::<bool>(),
) {
let cfg = OAuthConfig {
client_id: has_id.then(|| "id".into()),
client_name: has_name.then(|| "name".into()),
dcr_enabled: dcr,
mcp_server_url: Some("https://x.example".into()),
..OAuthConfig::default()
};
OAuthHelper::new(cfg).unwrap();
}
}
}
#[cfg(test)]
mod dcr_parser_fuzz {
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn parser_never_panics(bytes in prop::collection::vec(any::<u8>(), 0..4096)) {
let _ = serde_json::from_slice::<
crate::server::auth::provider::DcrResponse
>(&bytes);
}
#[test]
fn parser_accepts_minimal_valid_response(
id in "[a-zA-Z0-9-]{8,40}",
has_secret in any::<bool>(),
) {
let mut v = serde_json::json!({"client_id": id});
if has_secret {
v["client_secret"] = serde_json::json!("s3cret");
}
let parsed: crate::server::auth::provider::DcrResponse =
serde_json::from_value(v).unwrap();
prop_assert_eq!(parsed.client_id, id);
prop_assert_eq!(parsed.client_secret.is_some(), has_secret);
}
}
}