use std::{
collections::HashMap,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait;
use oauth2::{
AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
EmptyExtraTokenFields, ExtraTokenFields, HttpClientError, HttpRequest, HttpResponse,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicTokenType,
};
use reqwest::{
Client as HttpClient, IntoUrl, StatusCode, Url,
header::{AUTHORIZATION, WWW_AUTHENTICATE},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, warn};
use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION;
struct OAuthReqwestClient(HttpClient);
impl<'c> AsyncHttpClient<'c> for OAuthReqwestClient {
type Error = HttpClientError<reqwest::Error>;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>,
>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self
.0
.execute(request.try_into().map_err(Box::new)?)
.await
.map_err(Box::new)?;
let mut builder = oauth2::http::Response::builder()
.status(response.status())
.version(response.version());
for (name, value) in response.headers().iter() {
builder = builder.header(name, value);
}
builder
.body(response.bytes().await.map_err(Box::new)?.to_vec())
.map_err(HttpClientError::Http)
})
}
}
const DEFAULT_EXCHANGE_URL: &str = "http://localhost";
#[derive(Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StoredCredentials {
pub client_id: String,
pub token_response: Option<OAuthTokenResponse>,
#[serde(default)]
pub granted_scopes: Vec<String>,
#[serde(default)]
pub token_received_at: Option<u64>,
}
impl std::fmt::Debug for StoredCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredCredentials")
.field("client_id", &self.client_id)
.field(
"token_response",
&self.token_response.as_ref().map(|_| "[REDACTED]"),
)
.field("granted_scopes", &self.granted_scopes)
.field("token_received_at", &self.token_received_at)
.finish()
}
}
impl StoredCredentials {
pub fn new(
client_id: String,
token_response: Option<OAuthTokenResponse>,
granted_scopes: Vec<String>,
token_received_at: Option<u64>,
) -> Self {
Self {
client_id,
token_response,
granted_scopes,
token_received_at,
}
}
}
#[async_trait]
pub trait CredentialStore: Send + Sync {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError>;
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError>;
async fn clear(&self) -> Result<(), AuthError>;
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryCredentialStore {
credentials: Arc<RwLock<Option<StoredCredentials>>>,
}
impl InMemoryCredentialStore {
pub fn new() -> Self {
Self {
credentials: Arc::new(RwLock::new(None)),
}
}
}
#[async_trait::async_trait]
impl CredentialStore for InMemoryCredentialStore {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
Ok(self.credentials.read().await.clone())
}
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
*self.credentials.write().await = Some(credentials);
Ok(())
}
async fn clear(&self) -> Result<(), AuthError> {
*self.credentials.write().await = None;
Ok(())
}
}
#[derive(Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StoredAuthorizationState {
pub pkce_verifier: String,
pub csrf_token: String,
pub created_at: u64,
}
impl std::fmt::Debug for StoredAuthorizationState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StoredAuthorizationState")
.field("pkce_verifier", &"[REDACTED]")
.field("csrf_token", &"[REDACTED]")
.field("created_at", &self.created_at)
.finish()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct VendorExtraTokenFields(pub HashMap<String, Value>);
impl ExtraTokenFields for VendorExtraTokenFields {}
impl StoredAuthorizationState {
pub fn new(pkce_verifier: &PkceCodeVerifier, csrf_token: &CsrfToken) -> Self {
Self {
pkce_verifier: pkce_verifier.secret().to_string(),
csrf_token: csrf_token.secret().to_string(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
}
}
pub fn into_pkce_verifier(self) -> PkceCodeVerifier {
PkceCodeVerifier::new(self.pkce_verifier)
}
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError>;
async fn load(&self, csrf_token: &str) -> Result<Option<StoredAuthorizationState>, AuthError>;
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError>;
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryStateStore {
states: Arc<RwLock<HashMap<String, StoredAuthorizationState>>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError> {
self.states
.write()
.await
.insert(csrf_token.to_string(), state);
Ok(())
}
async fn load(&self, csrf_token: &str) -> Result<Option<StoredAuthorizationState>, AuthError> {
Ok(self.states.read().await.get(csrf_token).cloned())
}
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> {
self.states.write().await.remove(csrf_token);
Ok(())
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct AuthClient<C> {
pub http_client: C,
pub auth_manager: Arc<Mutex<AuthorizationManager>>,
}
impl<C: std::fmt::Debug> std::fmt::Debug for AuthClient<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthorizedClient")
.field("http_client", &self.http_client)
.field("auth_manager", &"...")
.finish()
}
}
impl<C> AuthClient<C> {
pub fn new(http_client: C, auth_manager: AuthorizationManager) -> Self {
Self {
http_client,
auth_manager: Arc::new(Mutex::new(auth_manager)),
}
}
}
impl<C> AuthClient<C> {
pub fn get_access_token(&self) -> impl Future<Output = Result<String, AuthError>> + Send {
let auth_manager = self.auth_manager.clone();
async move { auth_manager.lock().await.get_access_token().await }
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AuthError {
#[error("OAuth authorization required")]
AuthorizationRequired,
#[error("OAuth authorization failed: {0}")]
AuthorizationFailed(String),
#[error("OAuth token exchange failed: {0}")]
TokenExchangeFailed(String),
#[error("OAuth token refresh failed: {0}")]
TokenRefreshFailed(String),
#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),
#[error("OAuth error: {0}")]
OAuthError(String),
#[error("Metadata error: {0}")]
MetadataError(String),
#[error("URL parse error: {0}")]
UrlError(#[from] url::ParseError),
#[error("No authorization support detected")]
NoAuthorizationSupport,
#[error("Internal error: {0}")]
InternalError(String),
#[error("Invalid token type: {0}")]
InvalidTokenType(String),
#[error("Token expired")]
TokenExpired,
#[error("Invalid scope: {0}")]
InvalidScope(String),
#[error("Registration failed: {0}")]
RegistrationFailed(String),
#[error("Insufficient scope: {required_scope}")]
InsufficientScope {
required_scope: String,
upgrade_url: Option<String>,
},
#[error("Client credentials error: {0}")]
ClientCredentialsError(String),
#[cfg(feature = "auth-client-credentials-jwt")]
#[error("JWT signing error: {0}")]
JwtSigningError(String),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[non_exhaustive]
pub struct AuthorizationMetadata {
pub authorization_endpoint: String,
pub token_endpoint: String,
pub registration_endpoint: Option<String>,
pub issuer: Option<String>,
pub jwks_uri: Option<String>,
pub scopes_supported: Option<Vec<String>>,
pub response_types_supported: Option<Vec<String>>,
pub code_challenge_methods_supported: Option<Vec<String>>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
struct ResourceServerMetadata {
authorization_server: Option<String>,
authorization_servers: Option<Vec<String>>,
scopes_supported: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct WWWAuthenticateParams {
pub resource_metadata_url: Option<Url>,
pub scope: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
impl WWWAuthenticateParams {
pub fn is_insufficient_scope(&self) -> bool {
self.error.as_deref() == Some("insufficient_scope")
}
pub fn is_invalid_token(&self) -> bool {
self.error.as_deref() == Some("invalid_token")
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct OAuthClientConfig {
pub client_id: String,
pub client_secret: Option<String>,
pub scopes: Vec<String>,
pub redirect_uri: String,
}
impl OAuthClientConfig {
pub fn new(client_id: impl Into<String>, redirect_uri: impl Into<String>) -> Self {
Self {
client_id: client_id.into(),
client_secret: None,
scopes: Vec::new(),
redirect_uri: redirect_uri.into(),
}
}
pub fn with_client_secret(mut self, secret: impl Into<String>) -> Self {
self.client_secret = Some(secret.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
}
type OAuthErrorResponse = oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>;
pub type OAuthTokenResponse = StandardTokenResponse<VendorExtraTokenFields, BasicTokenType>;
type OAuthTokenIntrospection =
oauth2::StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>;
type OAuthRevocableToken = oauth2::StandardRevocableToken;
type OAuthRevocationError = oauth2::StandardErrorResponse<oauth2::RevocationErrorResponseType>;
type OAuthClient = oauth2::Client<
OAuthErrorResponse,
OAuthTokenResponse,
OAuthTokenIntrospection,
OAuthRevocableToken,
OAuthRevocationError,
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>;
type Credentials = (String, Option<OAuthTokenResponse>);
pub const EXTENSION_OAUTH_CLIENT_CREDENTIALS: &str =
"io.modelcontextprotocol/oauth-client-credentials";
#[cfg(feature = "auth-client-credentials-jwt")]
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum JwtSigningAlgorithm {
RS256,
RS384,
RS512,
ES256,
ES384,
}
#[cfg(feature = "auth-client-credentials-jwt")]
impl JwtSigningAlgorithm {
fn to_jsonwebtoken_algorithm(self) -> jsonwebtoken::Algorithm {
match self {
JwtSigningAlgorithm::RS256 => jsonwebtoken::Algorithm::RS256,
JwtSigningAlgorithm::RS384 => jsonwebtoken::Algorithm::RS384,
JwtSigningAlgorithm::RS512 => jsonwebtoken::Algorithm::RS512,
JwtSigningAlgorithm::ES256 => jsonwebtoken::Algorithm::ES256,
JwtSigningAlgorithm::ES384 => jsonwebtoken::Algorithm::ES384,
}
}
fn as_str(self) -> &'static str {
match self {
JwtSigningAlgorithm::RS256 => "RS256",
JwtSigningAlgorithm::RS384 => "RS384",
JwtSigningAlgorithm::RS512 => "RS512",
JwtSigningAlgorithm::ES256 => "ES256",
JwtSigningAlgorithm::ES384 => "ES384",
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum ClientCredentialsConfig {
ClientSecret {
client_id: String,
client_secret: String,
scopes: Vec<String>,
resource: Option<String>,
},
#[cfg(feature = "auth-client-credentials-jwt")]
PrivateKeyJwt {
client_id: String,
signing_key: Vec<u8>,
signing_algorithm: JwtSigningAlgorithm,
token_endpoint_audience: Option<String>,
scopes: Vec<String>,
resource: Option<String>,
},
}
impl ClientCredentialsConfig {
fn client_id(&self) -> &str {
match self {
ClientCredentialsConfig::ClientSecret { client_id, .. } => client_id,
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { client_id, .. } => client_id,
}
}
fn scopes(&self) -> &[String] {
match self {
ClientCredentialsConfig::ClientSecret { scopes, .. } => scopes,
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { scopes, .. } => scopes,
}
}
fn resource(&self) -> Option<&str> {
match self {
ClientCredentialsConfig::ClientSecret { resource, .. } => resource.as_deref(),
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { resource, .. } => resource.as_deref(),
}
}
fn auth_method(&self) -> &str {
match self {
ClientCredentialsConfig::ClientSecret { .. } => "client_secret_post",
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { .. } => "private_key_jwt",
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ScopeUpgradeConfig {
pub max_upgrade_attempts: u32,
pub auto_upgrade: bool,
}
impl Default for ScopeUpgradeConfig {
fn default() -> Self {
Self {
max_upgrade_attempts: 3,
auto_upgrade: true,
}
}
}
pub struct AuthorizationManager {
http_client: HttpClient,
metadata: Option<AuthorizationMetadata>,
oauth_client: Option<OAuthClient>,
credential_store: Arc<dyn CredentialStore>,
state_store: Arc<dyn StateStore>,
base_url: Url,
current_scopes: RwLock<Vec<String>>,
scope_upgrade_attempts: RwLock<u32>,
scope_upgrade_config: ScopeUpgradeConfig,
www_auth_scopes: RwLock<Vec<String>>,
resource_scopes: RwLock<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ClientRegistrationRequest {
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub token_endpoint_auth_method: String,
pub response_types: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ClientRegistrationResponse {
pub client_id: String,
pub client_secret: Option<String>,
pub client_name: Option<String>,
pub redirect_uris: Vec<String>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
impl ClientRegistrationResponse {
pub fn new(client_id: impl Into<String>, redirect_uris: Vec<String>) -> Self {
Self {
client_id: client_id.into(),
client_secret: None,
client_name: None,
redirect_uris,
additional_fields: HashMap::new(),
}
}
}
fn is_https_url(value: &str) -> bool {
Url::parse(value)
.ok()
.map(|url| url.scheme() == "https" && url.path() != "/" && url.host_str().is_some())
.unwrap_or(false)
}
impl AuthorizationManager {
fn well_known_paths(base_path: &str, resource: &str) -> Vec<String> {
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
let mut candidates = Vec::new();
let mut push_candidate = |candidate: String| {
if !candidates.contains(&candidate) {
candidates.push(candidate);
}
};
let canonical = format!("/.well-known/{resource}");
if trimmed.is_empty() {
push_candidate(canonical);
return candidates;
}
push_candidate(format!("{canonical}/{trimmed}"));
push_candidate(format!("/{trimmed}/.well-known/{resource}"));
push_candidate(canonical);
candidates
}
pub async fn new<U: IntoUrl>(base_url: U) -> Result<Self, AuthError> {
let base_url = base_url.into_url()?;
let http_client = HttpClient::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
let manager = Self {
http_client,
metadata: None,
oauth_client: None,
credential_store: Arc::new(InMemoryCredentialStore::new()),
state_store: Arc::new(InMemoryStateStore::new()),
base_url,
current_scopes: RwLock::new(Vec::new()),
scope_upgrade_attempts: RwLock::new(0),
scope_upgrade_config: ScopeUpgradeConfig::default(),
www_auth_scopes: RwLock::new(Vec::new()),
resource_scopes: RwLock::new(Vec::new()),
};
Ok(manager)
}
pub fn set_scope_upgrade_config(&mut self, config: ScopeUpgradeConfig) {
self.scope_upgrade_config = config;
}
pub fn set_credential_store<S: CredentialStore + 'static>(&mut self, store: S) {
self.credential_store = Arc::new(store);
}
pub fn set_state_store<S: StateStore + 'static>(&mut self, store: S) {
self.state_store = Arc::new(store);
}
pub fn set_metadata(&mut self, metadata: AuthorizationMetadata) {
self.metadata = Some(metadata);
}
pub async fn initialize_from_store(&mut self) -> Result<bool, AuthError> {
if let Some(stored) = self.credential_store.load().await? {
if stored.token_response.is_some() {
if self.metadata.is_none() {
let metadata = self.discover_metadata().await?;
self.metadata = Some(metadata);
}
self.configure_client_id(&stored.client_id)?;
return Ok(true);
}
}
Ok(false)
}
pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> {
self.http_client = http_client;
Ok(())
}
pub async fn discover_metadata(&self) -> Result<AuthorizationMetadata, AuthError> {
if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? {
return Ok(metadata);
}
if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? {
return Ok(metadata);
}
Err(AuthError::NoAuthorizationSupport)
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
let client_id = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
.client_id();
let stored = self.credential_store.load().await?;
let token_response = stored.and_then(|s| s.token_response);
Ok((client_id.to_string(), token_response))
}
pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> {
if self.metadata.is_none() {
return Err(AuthError::NoAuthorizationSupport);
}
let metadata = self.metadata.as_ref().unwrap();
let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?;
let token_url = TokenUrl::new(metadata.token_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?;
let client_id = ClientId::new(config.client_id);
let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?;
let mut client_builder: OAuthClient = oauth2::Client::new(client_id.clone())
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url);
if let Some(secret) = config.client_secret {
client_builder = client_builder.set_client_secret(ClientSecret::new(secret));
}
let uses_secret_post = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
.map(|arr| {
let has_basic = arr
.iter()
.any(|m| m.as_str() == Some("client_secret_basic"));
let has_post = arr.iter().any(|m| m.as_str() == Some("client_secret_post"));
has_post && !has_basic
})
.unwrap_or(false);
if uses_secret_post {
client_builder = client_builder.set_auth_type(AuthType::RequestBody);
}
self.oauth_client = Some(client_builder);
Ok(())
}
fn validate_server_metadata(&self, response_type: &str) -> Result<(), AuthError> {
let Some(metadata) = self.metadata.as_ref() else {
return Ok(());
};
if let Some(response_types_supported) = metadata.response_types_supported.as_ref() {
if !response_types_supported.contains(&response_type.to_string()) {
return Err(AuthError::InvalidScope(response_type.to_string()));
}
}
match &metadata.code_challenge_methods_supported {
Some(methods) if !methods.iter().any(|m| m == "S256") => {
warn!(
?methods,
"server does not advertise S256 in code_challenge_methods_supported, \
proceeding with S256 anyway as oauth 2.1 requires it. \
The server is not compliant with the specification!"
);
}
_ => {}
}
Ok(())
}
pub async fn register_client(
&mut self,
name: &str,
redirect_uri: &str,
scopes: &[&str],
) -> Result<OAuthClientConfig, AuthError> {
if self.metadata.is_none() {
return Err(AuthError::NoAuthorizationSupport);
}
let metadata = self.metadata.as_ref().unwrap();
let Some(registration_url) = metadata.registration_endpoint.as_ref() else {
return Err(AuthError::RegistrationFailed(
"Dynamic client registration not supported".to_string(),
));
};
self.validate_server_metadata("code")?;
let registration_request = ClientRegistrationRequest {
client_name: name.to_string(),
redirect_uris: vec![redirect_uri.to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
token_endpoint_auth_method: "none".to_string(), response_types: vec!["code".to_string()],
scope: if scopes.is_empty() {
None
} else {
Some(scopes.join(" "))
},
};
let response = match self
.http_client
.post(registration_url)
.json(®istration_request)
.send()
.await
{
Ok(response) => response,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"HTTP request error: {}",
e
)));
}
};
if !response.status().is_success() {
let status = response.status();
let error_text = match response.text().await {
Ok(text) => text,
Err(_) => "cannot get error details".to_string(),
};
return Err(AuthError::RegistrationFailed(format!(
"HTTP {}: {}",
status, error_text
)));
}
debug!("registration response: {:?}", response);
let reg_response = match response.json::<ClientRegistrationResponse>().await {
Ok(response) => response,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"analyze response error: {}",
e
)));
}
};
let config = OAuthClientConfig {
client_id: reg_response.client_id,
client_secret: reg_response.client_secret.filter(|s| !s.is_empty()),
redirect_uri: redirect_uri.to_string(),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
};
self.configure_client(config.clone())?;
Ok(config)
}
pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> {
let config = OAuthClientConfig {
client_id: client_id.to_string(),
client_secret: None,
scopes: vec![],
redirect_uri: self.base_url.to_string(),
};
self.configure_client(config)
}
pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result<String, AuthError> {
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
self.validate_server_metadata("code")?;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_request = oauth_client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge)
.add_extra_param("resource", self.base_url.to_string());
for scope in scopes {
auth_request = auth_request.add_scope(Scope::new(scope.to_string()));
}
let (auth_url, csrf_token) = auth_request.url();
let stored_state = StoredAuthorizationState::new(&pkce_verifier, &csrf_token);
self.state_store
.save(csrf_token.secret(), stored_state)
.await?;
Ok(auth_url.to_string())
}
pub async fn get_current_scopes(&self) -> Vec<String> {
self.current_scopes.read().await.clone()
}
fn compute_scope_union(current: &[String], required: &str) -> Vec<String> {
let mut scope_set: std::collections::HashSet<String> = current.iter().cloned().collect();
for scope in required.split_whitespace() {
scope_set.insert(scope.to_string());
}
scope_set.into_iter().collect()
}
pub async fn can_attempt_scope_upgrade(&self) -> bool {
if !self.scope_upgrade_config.auto_upgrade {
return false;
}
let attempts = *self.scope_upgrade_attempts.read().await;
attempts < self.scope_upgrade_config.max_upgrade_attempts
}
pub fn select_scopes(
&self,
www_authenticate_scope: Option<&str>,
default_scopes: &[&str],
) -> Vec<String> {
let mut scopes = self.select_base_scopes(www_authenticate_scope, default_scopes);
self.add_offline_access_if_supported(&mut scopes);
scopes
}
fn select_base_scopes(
&self,
www_authenticate_scope: Option<&str>,
default_scopes: &[&str],
) -> Vec<String> {
if let Some(scope) = www_authenticate_scope {
return scope.split_whitespace().map(|s| s.to_string()).collect();
}
if let Ok(guard) = self.www_auth_scopes.try_read() {
if !guard.is_empty() {
return guard.clone();
}
}
if let Ok(guard) = self.resource_scopes.try_read() {
if !guard.is_empty() {
return guard.clone();
}
}
if let Some(metadata) = &self.metadata {
if let Some(scopes_supported) = &metadata.scopes_supported {
if !scopes_supported.is_empty() {
return scopes_supported.clone();
}
}
}
default_scopes.iter().map(|s| s.to_string()).collect()
}
fn add_offline_access_if_supported(&self, scopes: &mut Vec<String>) {
if scopes.is_empty() || scopes.iter().any(|s| s == "offline_access") {
return;
}
if let Some(metadata) = &self.metadata {
if let Some(supported) = &metadata.scopes_supported {
if supported.iter().any(|s| s == "offline_access") {
scopes.push("offline_access".to_string());
}
}
}
}
pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result<String, AuthError> {
if !self.scope_upgrade_config.auto_upgrade {
return Err(AuthError::InvalidScope(
"Scope upgrade is disabled".to_string(),
));
}
let mut attempts = self.scope_upgrade_attempts.write().await;
if *attempts >= self.scope_upgrade_config.max_upgrade_attempts {
return Err(AuthError::InvalidScope(format!(
"Maximum scope upgrade attempts ({}) exceeded",
self.scope_upgrade_config.max_upgrade_attempts
)));
}
*attempts += 1;
drop(attempts);
let current_scopes = self.current_scopes.read().await.clone();
let upgraded_scopes = Self::compute_scope_union(¤t_scopes, required_scope);
debug!(
"Requesting scope upgrade: current={:?}, required={}, union={:?}",
current_scopes, required_scope, upgraded_scopes
);
let scope_refs: Vec<&str> = upgraded_scopes.iter().map(|s| s.as_str()).collect();
self.get_authorization_url(&scope_refs).await
}
pub async fn reset_scope_upgrade_attempts(&self) {
*self.scope_upgrade_attempts.write().await = 0;
}
pub async fn get_scope_upgrade_attempts(&self) -> u32 {
*self.scope_upgrade_attempts.read().await
}
pub async fn exchange_code_for_token(
&self,
code: &str,
csrf_token: &str,
) -> Result<OAuthTokenResponse, AuthError> {
debug!("start exchange code for token: {:?}", code);
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
let stored_state =
self.state_store.load(csrf_token).await?.ok_or_else(|| {
AuthError::InternalError("Authorization state not found".to_string())
})?;
self.state_store.delete(csrf_token).await?;
let pkce_verifier = stored_state.into_pkce_verifier();
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
debug!("client_id: {:?}", oauth_client.client_id());
let token_result = match oauth_client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(pkce_verifier)
.add_extra_param("resource", self.base_url.to_string())
.request_async(&OAuthReqwestClient(http_client))
.await
{
Ok(token) => token,
Err(RequestTokenError::Parse(_, body)) => {
match serde_json::from_slice::<OAuthTokenResponse>(&body) {
Ok(parsed) => {
warn!(
"token exchange failed to parse completely but included a valid token response. Accepting it."
);
parsed
}
Err(parse_err) => {
return Err(AuthError::TokenExchangeFailed(parse_err.to_string()));
}
}
}
Err(e) => {
return Err(AuthError::TokenExchangeFailed(e.to_string()));
}
};
debug!("exchange token result: {:?}", token_result);
let granted_scopes: Vec<String> = token_result
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*self.current_scopes.write().await = granted_scopes.clone();
*self.scope_upgrade_attempts.write().await = 0;
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
const REFRESH_BUFFER_SECS: u64 = 30;
pub async fn get_access_token(&self) -> Result<String, AuthError> {
let stored = self.credential_store.load().await?;
let Some(stored_creds) = stored else {
return Err(AuthError::AuthorizationRequired);
};
let Some(creds) = stored_creds.token_response.as_ref() else {
return Err(AuthError::AuthorizationRequired);
};
if let (Some(expires_in), Some(received_at)) =
(creds.expires_in(), stored_creds.token_received_at)
{
let elapsed = Self::now_epoch_secs().saturating_sub(received_at);
let remaining = expires_in.as_secs().saturating_sub(elapsed);
if remaining < Self::REFRESH_BUFFER_SECS {
tracing::info!(
remaining_secs = remaining,
"Access token expired or nearly expired, refreshing."
);
return self.try_refresh_or_reauth().await;
}
}
Ok(creds.access_token().secret().to_string())
}
async fn try_refresh_or_reauth(&self) -> Result<String, AuthError> {
match self.refresh_token().await {
Ok(new_creds) => {
tracing::info!("Refreshed access token.");
Ok(new_creds.access_token().secret().to_string())
}
Err(e @ (AuthError::AuthorizationRequired | AuthError::TokenRefreshFailed(_))) => {
tracing::warn!(error = %e, "Token refresh not possible, re-authorization required.");
Err(AuthError::AuthorizationRequired)
}
Err(e) => Err(e),
}
}
pub async fn refresh_token(&self) -> Result<OAuthTokenResponse, AuthError> {
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
let stored = self.credential_store.load().await?;
let stored_credentials = stored.ok_or(AuthError::AuthorizationRequired)?;
let current_credentials = stored_credentials
.token_response
.ok_or(AuthError::AuthorizationRequired)?;
let refresh_token = current_credentials.refresh_token().ok_or_else(|| {
AuthError::TokenRefreshFailed("No refresh token available".to_string())
})?;
debug!("refresh token present, attempting refresh");
let refresh_token_value = RefreshToken::new(refresh_token.secret().to_string());
let mut refresh_request = oauth_client.exchange_refresh_token(&refresh_token_value);
for scope in &stored_credentials.granted_scopes {
refresh_request = refresh_request.add_scope(Scope::new(scope.clone()));
}
let token_result = refresh_request
.request_async(&OAuthReqwestClient(self.http_client.clone()))
.await
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;
let granted_scopes: Vec<String> = match token_result.scopes() {
Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(),
None => self.current_scopes.read().await.clone(),
};
*self.current_scopes.write().await = granted_scopes.clone();
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
pub async fn prepare_request(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::RequestBuilder, AuthError> {
let token = self.get_access_token().await?;
Ok(request.header(AUTHORIZATION, format!("Bearer {}", token)))
}
pub async fn handle_response(
&self,
response: reqwest::Response,
) -> Result<reqwest::Response, AuthError> {
if response.status() == StatusCode::UNAUTHORIZED {
return Err(AuthError::AuthorizationRequired);
}
if response.status() == StatusCode::FORBIDDEN {
for value in response.headers().get_all(WWW_AUTHENTICATE).iter() {
let Ok(value_str) = value.to_str() else {
continue;
};
let params = Self::extract_www_authenticate_params(value_str, &self.base_url);
if params.is_insufficient_scope() {
let required_scope = params.scope.unwrap_or_default();
return Err(AuthError::InsufficientScope {
required_scope,
upgrade_url: None,
});
}
}
return Err(AuthError::AuthorizationFailed("Forbidden".to_string()));
}
Ok(response)
}
fn generate_discovery_urls(base_url: &Url) -> Vec<Url> {
let mut candidates = Vec::new();
let path = base_url.path();
let trimmed = path.trim_start_matches('/').trim_end_matches('/');
let mut push_candidate = |discovery_path: String| {
let mut discovery_url = base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&discovery_path);
candidates.push(discovery_url);
};
if trimmed.is_empty() {
push_candidate("/.well-known/oauth-authorization-server".to_string());
push_candidate("/.well-known/openid-configuration".to_string());
} else {
push_candidate(format!("/.well-known/oauth-authorization-server/{trimmed}"));
push_candidate(format!("/.well-known/openid-configuration/{trimmed}"));
push_candidate(format!("/{trimmed}/.well-known/openid-configuration"));
push_candidate("/.well-known/oauth-authorization-server".to_string());
}
candidates
}
async fn try_discover_oauth_server(
&self,
base_url: &Url,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
for discovery_url in Self::generate_discovery_urls(base_url) {
if let Some(metadata) = self.fetch_authorization_metadata(&discovery_url).await? {
return Ok(Some(metadata));
}
}
Ok(None)
}
async fn fetch_authorization_metadata(
&self,
discovery_url: &Url,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
debug!("discovery url: {:?}", discovery_url);
let response = match self
.http_client
.get(discovery_url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("discovery request failed: {}", e);
return Ok(None);
}
};
if response.status() != StatusCode::OK {
debug!("discovery returned non-200: {}", response.status());
return Ok(None);
}
let body = response.text().await?;
match serde_json::from_str::<AuthorizationMetadata>(&body) {
Ok(metadata) => Ok(Some(metadata)),
Err(err) => {
debug!("Failed to parse metadata for {}: {}", discovery_url, err);
Ok(None) }
}
}
async fn discover_oauth_server_via_resource_metadata(
&self,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
let Some(resource_metadata_url) = self.discover_resource_metadata_url().await? else {
return Ok(None);
};
let Some(resource_metadata) = self
.fetch_resource_metadata_from_url(&resource_metadata_url)
.await?
else {
return Ok(None);
};
if let Some(scopes) = resource_metadata.scopes_supported {
if !scopes.is_empty() {
*self.resource_scopes.write().await = scopes;
}
}
let mut candidates = Vec::new();
if let Some(single) = resource_metadata.authorization_server {
candidates.push(single);
}
if let Some(list) = resource_metadata.authorization_servers {
candidates.extend(list);
}
for candidate in candidates {
let candidate = candidate.trim();
if candidate.is_empty() {
continue;
}
let candidate_url = match Url::parse(candidate) {
Ok(url) => url,
Err(_) => match resource_metadata_url.join(candidate) {
Ok(url) => url,
Err(e) => {
debug!("Failed to resolve authorization server URL `{candidate}`: {e}");
continue;
}
},
};
if candidate_url.path().contains("/.well-known/") {
if let Some(metadata) = self.fetch_authorization_metadata(&candidate_url).await? {
return Ok(Some(metadata));
}
continue;
}
if let Some(metadata) = self.try_discover_oauth_server(&candidate_url).await? {
return Ok(Some(metadata));
}
}
Ok(None)
}
async fn discover_resource_metadata_url(&self) -> Result<Option<Url>, AuthError> {
if let Ok(Some(resource_metadata_url)) =
self.fetch_resource_metadata_url(&self.base_url).await
{
return Ok(Some(resource_metadata_url));
}
for candidate_path in
Self::well_known_paths(self.base_url.path(), "oauth-protected-resource")
{
let mut discovery_url = self.base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&candidate_path);
if let Ok(Some(resource_metadata_url)) =
self.fetch_resource_metadata_url(&discovery_url).await
{
return Ok(Some(resource_metadata_url));
}
}
Ok(None)
}
async fn fetch_resource_metadata_url(&self, url: &Url) -> Result<Option<Url>, AuthError> {
let response = match self
.http_client
.get(url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("resource metadata probe failed: {}", e);
return Ok(None);
}
};
if response.status() == StatusCode::OK {
return Ok(Some(url.clone()));
} else if response.status() != StatusCode::UNAUTHORIZED {
debug!(
"resource metadata probe returned unexpected status: {}",
response.status()
);
return Ok(None);
}
let mut parsed_url = None;
for value in response.headers().get_all(WWW_AUTHENTICATE).iter() {
let Ok(value_str) = value.to_str() else {
continue;
};
let params = Self::extract_www_authenticate_params(value_str, &self.base_url);
if let Some(url) = params.resource_metadata_url {
if let Some(scope) = ¶ms.scope {
debug!("WWW-Authenticate header contains scope: {}", scope);
let scopes: Vec<String> =
scope.split_whitespace().map(|s| s.to_string()).collect();
*self.www_auth_scopes.write().await = scopes;
}
parsed_url = Some(url);
break;
}
}
Ok(parsed_url)
}
async fn fetch_resource_metadata_from_url(
&self,
resource_metadata_url: &Url,
) -> Result<Option<ResourceServerMetadata>, AuthError> {
debug!(
"resource metadata discovery url: {:?}",
resource_metadata_url
);
let response = match self
.http_client
.get(resource_metadata_url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("resource metadata request failed: {}", e);
return Ok(None);
}
};
if response.status() != StatusCode::OK {
debug!(
"resource metadata request returned non-200: {}",
response.status()
);
return Ok(None);
}
let metadata = match response.json::<ResourceServerMetadata>().await {
Ok(metadata) => metadata,
Err(e) => {
debug!("failed to parse resource metadata as JSON: {}", e);
return Ok(None);
}
};
Ok(Some(metadata))
}
fn extract_www_authenticate_params(header: &str, base_url: &Url) -> WWWAuthenticateParams {
let mut params = WWWAuthenticateParams::default();
let header_lowercase = header.to_ascii_lowercase();
let mut search_offset = 0;
let resource_key = "resource_metadata=";
while let Some(pos) = header_lowercase[search_offset..].find(resource_key) {
let global_pos = search_offset + pos + resource_key.len();
let value_slice = &header[global_pos..];
if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) {
if let Ok(url) = Url::parse(&value) {
params.resource_metadata_url = Some(url);
break;
}
if let Ok(url) = base_url.join(&value) {
params.resource_metadata_url = Some(url);
break;
}
debug!("failed to parse resource metadata value `{value}` as URL");
search_offset = global_pos + consumed;
continue;
} else {
break;
}
}
let scope_key = "scope=";
if let Some(pos) = header_lowercase.find(scope_key) {
let global_pos = pos + scope_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.scope = Some(value);
}
}
let error_key = "error=";
if let Some(pos) = header_lowercase.find(error_key) {
let global_pos = pos + error_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.error = Some(value);
}
}
let desc_key = "error_description=";
if let Some(pos) = header_lowercase.find(desc_key) {
let global_pos = pos + desc_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.error_description = Some(value);
}
}
params
}
fn parse_next_header_value(header_fragment: &str) -> Option<(String, usize)> {
let trimmed = header_fragment.trim_start();
let leading_ws = header_fragment.len() - trimmed.len();
if let Some(stripped) = trimmed.strip_prefix('"') {
let mut escaped = false;
let mut result = String::new();
#[allow(clippy::manual_strip)]
for (idx, ch) in stripped.char_indices() {
if escaped {
result.push(ch);
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => return Some((result, leading_ws + idx + 2)),
_ => result.push(ch),
}
}
None
} else {
let end = trimmed
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
.unwrap_or(trimmed.len());
Some((trimmed[..end].to_string(), leading_ws + end))
}
}
pub fn validate_client_credentials_metadata(
&self,
config: &ClientCredentialsConfig,
) -> Result<(), AuthError> {
let Some(metadata) = self.metadata.as_ref() else {
return Ok(());
};
if let Some(methods) = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
{
let is_supported = match config {
ClientCredentialsConfig::ClientSecret { .. } => {
methods.iter().any(|m| {
matches!(
m.as_str(),
Some("client_secret_post") | Some("client_secret_basic")
)
})
}
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { .. } => methods
.iter()
.any(|m| m.as_str() == Some("private_key_jwt")),
};
if !is_supported {
let required_method = config.auth_method();
let supported: Vec<&str> = methods.iter().filter_map(|m| m.as_str()).collect();
return Err(AuthError::ClientCredentialsError(format!(
"Authorization server does not support auth method '{}'. Supported: {:?}",
required_method, supported
)));
}
}
#[cfg(feature = "auth-client-credentials-jwt")]
if let ClientCredentialsConfig::PrivateKeyJwt {
signing_algorithm, ..
} = config
{
if let Some(algs) = metadata
.additional_fields
.get("token_endpoint_auth_signing_alg_values_supported")
.and_then(|v| v.as_array())
{
let alg_str = signing_algorithm.as_str();
if !algs.iter().any(|a| a.as_str() == Some(alg_str)) {
let supported: Vec<&str> = algs.iter().filter_map(|a| a.as_str()).collect();
return Err(AuthError::ClientCredentialsError(format!(
"Authorization server does not support signing algorithm '{}'. Supported: {:?}",
alg_str, supported
)));
}
}
}
Ok(())
}
pub fn configure_client_credentials(
&mut self,
config: &ClientCredentialsConfig,
) -> Result<(), AuthError> {
let metadata = self
.metadata
.as_ref()
.ok_or(AuthError::NoAuthorizationSupport)?;
let token_url = TokenUrl::new(metadata.token_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?;
let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?;
let client_id = ClientId::new(config.client_id().to_string());
let mut client_builder: OAuthClient = oauth2::Client::new(client_id)
.set_auth_uri(auth_url)
.set_token_uri(token_url);
match config {
ClientCredentialsConfig::ClientSecret { client_secret, .. } => {
client_builder =
client_builder.set_client_secret(ClientSecret::new(client_secret.clone()));
let only_basic = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
.map(|arr| {
let (has_basic, has_post) =
arr.iter()
.fold((false, false), |(b, p), m| match m.as_str() {
Some("client_secret_basic") => (true, p),
Some("client_secret_post") => (b, true),
_ => (b, p),
});
has_basic && !has_post
})
.unwrap_or_default();
if !only_basic {
client_builder = client_builder.set_auth_type(AuthType::RequestBody);
}
}
#[cfg(feature = "auth-client-credentials-jwt")]
ClientCredentialsConfig::PrivateKeyJwt { .. } => {
}
}
self.oauth_client = Some(client_builder);
Ok(())
}
pub async fn exchange_client_credentials(
&self,
config: &ClientCredentialsConfig,
) -> Result<OAuthTokenResponse, AuthError> {
if config.resource().is_none() {
return Err(AuthError::ClientCredentialsError(
"resource parameter is required by the MCP auth spec".to_string(),
));
}
#[cfg(feature = "auth-client-credentials-jwt")]
if matches!(config, ClientCredentialsConfig::PrivateKeyJwt { .. }) {
return self.exchange_client_credentials_jwt(config).await;
}
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
let mut request = oauth_client.exchange_client_credentials();
for scope in config.scopes() {
request = request.add_scope(Scope::new(scope.clone()));
}
if let Some(resource) = config.resource() {
request = request.add_extra_param("resource", resource);
}
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
let token_result = match request
.request_async(&OAuthReqwestClient(http_client))
.await
{
Ok(token) => token,
Err(RequestTokenError::Parse(_, body)) => {
match serde_json::from_slice::<OAuthTokenResponse>(&body) {
Ok(parsed) => {
warn!(
"client credentials token exchange failed to parse completely but included a valid token response. Accepting it."
);
parsed
}
Err(parse_err) => {
return Err(AuthError::ClientCredentialsError(format!(
"Token exchange parse error: {}",
parse_err
)));
}
}
}
Err(e) => {
return Err(AuthError::ClientCredentialsError(format!(
"Token exchange failed: {}",
e
)));
}
};
debug!("client credentials token result: {:?}", token_result);
let granted_scopes: Vec<String> = token_result
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*self.current_scopes.write().await = granted_scopes.clone();
let client_id = config.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
#[cfg(feature = "auth-client-credentials-jwt")]
async fn exchange_client_credentials_jwt(
&self,
config: &ClientCredentialsConfig,
) -> Result<OAuthTokenResponse, AuthError> {
let ClientCredentialsConfig::PrivateKeyJwt {
client_id,
signing_key,
signing_algorithm,
token_endpoint_audience,
scopes,
resource,
} = config
else {
return Err(AuthError::InternalError(
"expected PrivateKeyJwt config".to_string(),
));
};
let metadata = self
.metadata
.as_ref()
.ok_or(AuthError::NoAuthorizationSupport)?;
let token_endpoint_url = url::Url::parse(&metadata.token_endpoint).map_err(|e| {
AuthError::ClientCredentialsError(format!(
"Invalid token endpoint URL in authorization metadata: {e}"
))
})?;
if token_endpoint_url.scheme() != "https" {
return Err(AuthError::ClientCredentialsError(
"Insecure token endpoint URL: HTTPS is required for client credentials flow"
.to_string(),
));
}
let audience = token_endpoint_audience
.as_deref()
.unwrap_or(&metadata.token_endpoint);
let assertion =
Self::build_jwt_assertion(client_id, audience, signing_key, *signing_algorithm)?;
let scope_str = scopes.join(" ");
let mut serializer = url::form_urlencoded::Serializer::new(String::new());
serializer.append_pair("grant_type", "client_credentials");
serializer.append_pair(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
);
serializer.append_pair("client_assertion", &assertion);
if !scope_str.is_empty() {
serializer.append_pair("scope", &scope_str);
}
if let Some(res) = resource.as_deref() {
serializer.append_pair("resource", res);
}
let body_str = serializer.finish();
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
let response = http_client
.post(token_endpoint_url.as_str())
.header("content-type", "application/x-www-form-urlencoded")
.body(body_str)
.send()
.await
.map_err(|e| {
AuthError::ClientCredentialsError(format!("Token exchange request failed: {e}"))
})?;
let status = response.status();
let body = response.bytes().await.map_err(|e| {
AuthError::ClientCredentialsError(format!("Failed to read token response: {e}"))
})?;
if !status.is_success() {
let msg = if let Ok(v) = serde_json::from_slice::<serde_json::Value>(&body) {
let error = v.get("error").and_then(|e| e.as_str()).unwrap_or("unknown");
let desc = v
.get("error_description")
.and_then(|d| d.as_str())
.unwrap_or("");
format!("Token exchange failed: {error}: {desc}")
} else {
format!("Token exchange failed: HTTP {status}")
};
return Err(AuthError::ClientCredentialsError(msg));
}
let token_result = serde_json::from_slice::<OAuthTokenResponse>(&body).map_err(|e| {
AuthError::ClientCredentialsError(format!("Failed to parse token response: {e}"))
})?;
debug!("client credentials JWT token result: {:?}", token_result);
let granted_scopes: Vec<String> = token_result
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*self.current_scopes.write().await = granted_scopes.clone();
let stored = StoredCredentials {
client_id: client_id.clone(),
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
#[cfg(feature = "auth-client-credentials-jwt")]
fn build_jwt_assertion(
client_id: &str,
audience: &str,
signing_key: &[u8],
algorithm: JwtSigningAlgorithm,
) -> Result<String, AuthError> {
use serde_json::json;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let jti = uuid::Uuid::new_v4().to_string();
let claims = json!({
"iss": client_id,
"sub": client_id,
"aud": audience,
"iat": now,
"exp": now + 300, "jti": jti,
});
let header = jsonwebtoken::Header::new(algorithm.to_jsonwebtoken_algorithm());
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(signing_key).or_else(|_| {
jsonwebtoken::EncodingKey::from_ec_pem(signing_key).map_err(|e| {
AuthError::JwtSigningError(format!("Failed to parse signing key: {}", e))
})
})?;
jsonwebtoken::encode(&header, &claims, &encoding_key)
.map_err(|e| AuthError::JwtSigningError(format!("Failed to sign JWT: {}", e)))
}
}
#[non_exhaustive]
pub struct AuthorizationSession {
pub auth_manager: AuthorizationManager,
pub auth_url: String,
pub redirect_uri: String,
}
impl AuthorizationSession {
pub async fn new(
mut auth_manager: AuthorizationManager,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
client_metadata_url: Option<&str>,
) -> Result<Self, AuthError> {
let metadata = auth_manager.metadata.as_ref();
let supports_url_based_client_id = metadata
.and_then(|m| {
m.additional_fields
.get("client_id_metadata_document_supported")
})
.and_then(|v| v.as_bool())
.unwrap_or(false);
let config = if supports_url_based_client_id {
if let Some(client_metadata_url) = client_metadata_url {
if !is_https_url(client_metadata_url) {
return Err(AuthError::RegistrationFailed(format!(
"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {}",
client_metadata_url
)));
}
OAuthClientConfig {
client_id: client_metadata_url.to_string(),
client_secret: None,
scopes: scopes.iter().map(|s| s.to_string()).collect(),
redirect_uri: redirect_uri.to_string(),
}
} else {
auth_manager
.register_client(client_name.unwrap_or("MCP Client"), redirect_uri, scopes)
.await
.map_err(|e| {
AuthError::RegistrationFailed(format!("Dynamic registration failed: {}", e))
})?
}
} else {
match auth_manager
.register_client(client_name.unwrap_or("MCP Client"), redirect_uri, scopes)
.await
{
Ok(config) => config,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"Dynamic registration failed: {}",
e
)));
}
}
};
auth_manager.configure_client(config)?;
let auth_url = auth_manager.get_authorization_url(scopes).await?;
Ok(Self {
auth_manager,
auth_url,
redirect_uri: redirect_uri.to_string(),
})
}
pub fn for_scope_upgrade(
auth_manager: AuthorizationManager,
auth_url: String,
redirect_uri: &str,
) -> Self {
Self {
auth_manager,
auth_url,
redirect_uri: redirect_uri.to_string(),
}
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
self.auth_manager.get_credentials().await
}
pub fn get_authorization_url(&self) -> &str {
&self.auth_url
}
pub async fn handle_callback(
&self,
code: &str,
csrf_token: &str,
) -> Result<OAuthTokenResponse, AuthError> {
self.auth_manager
.exchange_code_for_token(code, csrf_token)
.await
}
}
pub struct AuthorizedHttpClient {
auth_manager: Arc<AuthorizationManager>,
inner_client: HttpClient,
}
impl AuthorizedHttpClient {
pub fn new(auth_manager: Arc<AuthorizationManager>, client: Option<HttpClient>) -> Self {
let inner_client = client.unwrap_or_default();
Self {
auth_manager,
inner_client,
}
}
pub async fn request<U: IntoUrl>(
&self,
method: reqwest::Method,
url: U,
) -> Result<reqwest::RequestBuilder, AuthError> {
let request = self.inner_client.request(method, url);
self.auth_manager.prepare_request(request).await
}
pub async fn get<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response, AuthError> {
let request = self.request(reqwest::Method::GET, url).await?;
let response = request.send().await?;
self.auth_manager.handle_response(response).await
}
pub async fn post<U: IntoUrl>(&self, url: U) -> Result<reqwest::RequestBuilder, AuthError> {
self.request(reqwest::Method::POST, url).await
}
}
#[non_exhaustive]
pub enum OAuthState {
Unauthorized(AuthorizationManager),
Session(AuthorizationSession),
Authorized(AuthorizationManager),
AuthorizedHttpClient(AuthorizedHttpClient),
}
impl OAuthState {
pub async fn new<U: IntoUrl>(
base_url: U,
client: Option<HttpClient>,
) -> Result<Self, AuthError> {
let mut manager = AuthorizationManager::new(base_url).await?;
if let Some(client) = client {
manager.with_client(client)?;
}
Ok(OAuthState::Unauthorized(manager))
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
match self {
OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => {
manager.get_credentials().await
}
OAuthState::Session(session) => session.get_credentials().await,
OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await,
}
}
pub async fn set_credentials(
&mut self,
client_id: &str,
credentials: OAuthTokenResponse,
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(manager) = self {
let mut manager = std::mem::replace(
manager,
AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?,
);
let granted_scopes: Vec<String> = credentials
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*manager.current_scopes.write().await = granted_scopes.clone();
let stored = StoredCredentials {
client_id: client_id.to_string(),
token_response: Some(credentials),
granted_scopes,
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await?;
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);
manager.configure_client_id(client_id)?;
*self = OAuthState::Authorized(manager);
Ok(())
} else {
Err(AuthError::InternalError(
"Cannot set credentials in this state".to_string(),
))
}
}
pub async fn start_authorization(
&mut self,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
) -> Result<(), AuthError> {
self.start_authorization_with_metadata_url(scopes, redirect_uri, client_name, None)
.await
}
pub async fn start_authorization_with_metadata_url(
&mut self,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
client_metadata_url: Option<&str>,
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(mut manager) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
debug!("start discovery");
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);
let selected_scopes: Vec<String> = if scopes.is_empty() {
manager.select_scopes(None, &[])
} else {
let mut s: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();
manager.add_offline_access_if_supported(&mut s);
s
};
let scope_refs: Vec<&str> = selected_scopes.iter().map(|s| s.as_str()).collect();
debug!("start session");
let session = AuthorizationSession::new(
manager,
&scope_refs,
redirect_uri,
client_name,
client_metadata_url,
)
.await?;
*self = OAuthState::Session(session);
Ok(())
} else {
Err(AuthError::InternalError(
"Already in session state".to_string(),
))
}
}
pub async fn complete_authorization(&mut self) -> Result<(), AuthError> {
if let OAuthState::Session(session) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::Authorized(session.auth_manager);
Ok(())
} else {
Err(AuthError::InternalError("Not in session state".to_string()))
}
}
pub async fn to_authorized_http_client(&mut self) -> Result<(), AuthError> {
if let OAuthState::Authorized(manager) = std::mem::replace(
self,
OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::AuthorizedHttpClient(AuthorizedHttpClient::new(
Arc::new(manager),
None,
));
Ok(())
} else {
Err(AuthError::InternalError(
"Not in authorized state".to_string(),
))
}
}
pub async fn request_scope_upgrade(
&mut self,
required_scope: &str,
redirect_uri: &str,
) -> Result<String, AuthError> {
let placeholder =
OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?);
let old = std::mem::replace(self, placeholder);
let OAuthState::Authorized(manager) = old else {
*self = old;
return Err(AuthError::InternalError(
"Not in authorized state".to_string(),
));
};
let auth_url = match manager.request_scope_upgrade(required_scope).await {
Ok(url) => url,
Err(e) => {
*self = OAuthState::Authorized(manager);
return Err(e);
}
};
let session =
AuthorizationSession::for_scope_upgrade(manager, auth_url.clone(), redirect_uri);
*self = OAuthState::Session(session);
Ok(auth_url)
}
pub async fn get_authorization_url(&self) -> Result<String, AuthError> {
match self {
OAuthState::Session(session) => Ok(session.get_authorization_url().to_string()),
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in session state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn handle_callback(&mut self, code: &str, csrf_token: &str) -> Result<(), AuthError> {
match self {
OAuthState::Session(session) => {
session.handle_callback(code, csrf_token).await?;
self.complete_authorization().await
}
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in session state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn get_access_token(&self) -> Result<String, AuthError> {
match self {
OAuthState::Unauthorized(manager) => manager.get_access_token().await,
OAuthState::Session(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn refresh_token(&self) -> Result<(), AuthError> {
match self {
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Session(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Authorized(manager) => {
manager.refresh_token().await?;
Ok(())
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub fn into_authorization_manager(self) -> Option<AuthorizationManager> {
match self {
OAuthState::Authorized(manager) => Some(manager),
_ => None,
}
}
pub async fn authenticate_client_credentials(
&mut self,
config: ClientCredentialsConfig,
) -> Result<(), AuthError> {
let OAuthState::Unauthorized(mut manager) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) else {
return Err(AuthError::InternalError(
"Client credentials flow requires Unauthorized state".to_string(),
));
};
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);
manager.validate_client_credentials_metadata(&config)?;
manager.configure_client_credentials(&config)?;
manager.exchange_client_credentials(&config).await?;
*self = OAuthState::Authorized(manager);
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use oauth2::{AuthType, CsrfToken, PkceCodeVerifier};
use url::Url;
use super::{
AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore,
OAuthClientConfig, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url,
};
use crate::transport::auth::VendorExtraTokenFields;
#[test]
fn test_is_https_url_scenarios() {
assert!(is_https_url("https://example.com/client-metadata.json"));
assert!(is_https_url("https://example.com/metadata?version=1"));
assert!(!is_https_url("https://example.com"));
assert!(!is_https_url("https://example.com/"));
assert!(!is_https_url("https://"));
assert!(!is_https_url("http://example.com/metadata"));
assert!(!is_https_url("not a url"));
assert!(!is_https_url(""));
assert!(!is_https_url("javascript:alert(1)"));
assert!(!is_https_url("data:text/html,<script>alert(1)</script>"));
}
#[test]
fn well_known_paths_root() {
let paths = AuthorizationManager::well_known_paths("/", "oauth-authorization-server");
assert_eq!(
paths,
vec!["/.well-known/oauth-authorization-server".to_string()]
);
}
#[test]
fn well_known_paths_with_suffix() {
let paths = AuthorizationManager::well_known_paths("/mcp", "oauth-authorization-server");
assert_eq!(
paths,
vec![
"/.well-known/oauth-authorization-server/mcp".to_string(),
"/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
]
);
}
#[test]
fn well_known_paths_trailing_slash() {
let paths =
AuthorizationManager::well_known_paths("/v1/mcp/", "oauth-authorization-server");
assert_eq!(
paths,
vec![
"/.well-known/oauth-authorization-server/v1/mcp".to_string(),
"/v1/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
]
);
}
#[test]
fn test_protected_resource_metadata_paths() {
let paths =
AuthorizationManager::well_known_paths("/mcp/example", "oauth-protected-resource");
assert!(paths.contains(&"/.well-known/oauth-protected-resource/mcp/example".to_string()));
assert!(paths.contains(&"/.well-known/oauth-protected-resource".to_string()));
}
#[test]
fn generate_discovery_urls() {
let base_url = Url::parse("https://auth.example.com").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 2);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration"
);
let base_url = Url::parse("https://auth.example.com/tenant1").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/tenant1"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/tenant1/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
let base_url = Url::parse("https://auth.example.com/v1/mcp/").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/v1/mcp"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/v1/mcp/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
let base_url = Url::parse("https://auth.example.com/tenant1/subtenant").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/tenant1/subtenant"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/tenant1/subtenant"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/tenant1/subtenant/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
}
#[test]
fn test_discovery_urls_with_path_suffix() {
let base_url = Url::parse("https://mcp.example.com/mcp").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
let canonical_oauth_fallback =
"https://mcp.example.com/.well-known/oauth-authorization-server";
assert!(
urls.iter().any(|u| u.as_str() == canonical_oauth_fallback),
"Expected discovery URLs to include canonical OAuth fallback '{}', but got: {:?}",
canonical_oauth_fallback,
urls.iter().map(|u| u.as_str()).collect::<Vec<_>>()
);
}
#[test]
fn parse_auth_param_value_handles_quoted_string() {
let fragment = r#""example", realm="foo""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "example");
assert_eq!(parsed.1, 9);
}
#[test]
fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() {
let fragment = r#" "a\"b\\c" ,next=value"#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, r#"a"b\c"#);
assert_eq!(parsed.1, 12);
}
#[test]
fn parse_auth_param_value_handles_token_values() {
let fragment = " token,next";
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "token");
assert_eq!(parsed.1, 7);
}
#[test]
fn parse_auth_param_value_handles_semicolon_separated_tokens() {
let fragment = r#" https://example.com/meta; error="invalid_token""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "https://example.com/meta");
assert_eq!(&fragment[..parsed.1], " https://example.com/meta");
}
#[test]
fn parse_auth_param_value_handles_semicolon_after_quoted_value() {
let fragment = r#" "https://example.com/meta"; error="invalid_token""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "https://example.com/meta");
assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#);
}
#[test]
fn parse_auth_param_value_returns_none_for_unterminated_quotes() {
let fragment = r#""unterminated,value"#;
assert!(AuthorizationManager::parse_next_header_value(fragment).is_none());
}
#[test]
fn parses_resource_metadata_parameter() {
let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource/api"
);
}
#[test]
fn parses_relative_resource_metadata_parameter() {
let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource/api"
);
}
#[test]
fn extract_www_authenticate_params_with_all_fields() {
let header = r#"Bearer error="invalid_token", resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read:data write:data", error_description="token expired""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource"
);
assert_eq!(params.scope.unwrap(), "read:data write:data");
assert_eq!(params.error.unwrap(), "invalid_token");
assert_eq!(params.error_description.unwrap(), "token expired");
}
#[test]
fn extract_www_authenticate_params_insufficient_scope() {
let header = r#"Bearer error="insufficient_scope", scope="admin:write", error_description="Additional file write permission required""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.is_insufficient_scope());
assert!(!params.is_invalid_token());
assert_eq!(params.scope.unwrap(), "admin:write");
assert_eq!(
params.error_description.unwrap(),
"Additional file write permission required"
);
}
#[test]
fn extract_www_authenticate_params_with_only_resource_metadata() {
let header = r#"Bearer resource_metadata="/.well-known/oauth-protected-resource""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource"
);
assert!(params.scope.is_none());
}
#[test]
fn extract_www_authenticate_params_bare_bearer() {
let header = "Bearer";
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.scope.is_none());
assert!(params.error.is_none());
assert!(params.error_description.is_none());
}
#[test]
fn extract_www_authenticate_params_error_only() {
let header = r#"Bearer error="invalid_token""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.scope.is_none());
assert!(params.is_invalid_token());
assert!(!params.is_insufficient_scope());
assert!(params.error_description.is_none());
}
#[test]
fn extract_www_authenticate_params_with_unquoted_scope() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(params.scope.unwrap(), "read:data");
}
#[test]
fn test_stored_authorization_state_serialization() {
let pkce = PkceCodeVerifier::new("my-verifier".to_string());
let csrf = CsrfToken::new("my-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let json = serde_json::to_string(&state).unwrap();
let deserialized: StoredAuthorizationState = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.pkce_verifier, "my-verifier");
assert_eq!(deserialized.csrf_token, "my-csrf");
}
#[test]
fn test_stored_authorization_state_debug_redacts_secrets() {
let pkce = PkceCodeVerifier::new("super-secret-verifier".to_string());
let csrf = CsrfToken::new("super-secret-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let debug_output = format!("{:?}", state);
assert!(!debug_output.contains("super-secret-verifier"));
assert!(!debug_output.contains("super-secret-csrf"));
assert!(debug_output.contains("[REDACTED]"));
assert!(debug_output.contains("created_at"));
assert!(debug_output.contains("created_at"));
}
#[test]
fn test_stored_credentials_debug_redacts_token_response() {
use oauth2::{AccessToken, basic::BasicTokenType};
use super::{OAuthTokenResponse, StoredCredentials};
let token_response = OAuthTokenResponse::new(
AccessToken::new("super-secret-access-token".to_string()),
BasicTokenType::Bearer,
VendorExtraTokenFields::default(),
);
let creds = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(token_response),
granted_scopes: vec![],
token_received_at: None,
};
let debug_output = format!("{:?}", creds);
assert!(!debug_output.contains("super-secret-access-token"));
assert!(debug_output.contains("[REDACTED]"));
assert!(debug_output.contains("my-client"));
}
#[test]
fn test_stored_authorization_state_into_pkce_verifier() {
let pkce = PkceCodeVerifier::new("original-verifier".to_string());
let csrf = CsrfToken::new("csrf-token".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let recovered = state.into_pkce_verifier();
assert_eq!(recovered.secret(), "original-verifier");
}
#[test]
fn test_stored_authorization_state_created_at() {
let pkce = PkceCodeVerifier::new("verifier".to_string());
let csrf = CsrfToken::new("csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
assert!(state.created_at > 1577836800); }
#[tokio::test]
async fn test_in_memory_state_store_save_and_load() {
let store = InMemoryStateStore::new();
let pkce = PkceCodeVerifier::new("test-verifier".to_string());
let csrf = CsrfToken::new("test-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("test-csrf", state).await.unwrap();
let loaded = store.load("test-csrf").await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.csrf_token, "test-csrf");
assert_eq!(loaded.pkce_verifier, "test-verifier");
}
#[tokio::test]
async fn test_in_memory_state_store_load_nonexistent() {
let store = InMemoryStateStore::new();
let result = store.load("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_state_store_delete() {
let store = InMemoryStateStore::new();
let pkce = PkceCodeVerifier::new("verifier".to_string());
let csrf = CsrfToken::new("csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("csrf", state).await.unwrap();
store.delete("csrf").await.unwrap();
let result = store.load("csrf").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_state_store_overwrite() {
let store = InMemoryStateStore::new();
let csrf_key = "same-csrf";
let pkce1 = PkceCodeVerifier::new("verifier-1".to_string());
let csrf1 = CsrfToken::new(csrf_key.to_string());
let state1 = StoredAuthorizationState::new(&pkce1, &csrf1);
store.save(csrf_key, state1).await.unwrap();
let pkce2 = PkceCodeVerifier::new("verifier-2".to_string());
let csrf2 = CsrfToken::new(csrf_key.to_string());
let state2 = StoredAuthorizationState::new(&pkce2, &csrf2);
store.save(csrf_key, state2).await.unwrap();
let loaded = store.load(csrf_key).await.unwrap().unwrap();
assert_eq!(loaded.pkce_verifier, "verifier-2");
}
#[tokio::test]
async fn test_in_memory_state_store_concurrent_access() {
let store = Arc::new(InMemoryStateStore::new());
let mut handles = vec![];
for i in 0..10 {
let store = Arc::clone(&store);
let handle = tokio::spawn(async move {
let csrf_key = format!("csrf-{}", i);
let verifier = format!("verifier-{}", i);
let pkce = PkceCodeVerifier::new(verifier.clone());
let csrf = CsrfToken::new(csrf_key.clone());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save(&csrf_key, state).await.unwrap();
let loaded = store.load(&csrf_key).await.unwrap().unwrap();
assert_eq!(loaded.pkce_verifier, verifier);
store.delete(&csrf_key).await.unwrap();
let deleted = store.load(&csrf_key).await.unwrap();
assert!(deleted.is_none());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
#[tokio::test]
async fn test_custom_state_store_with_authorization_manager() {
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Default)]
struct TrackingStateStore {
inner: InMemoryStateStore,
save_count: AtomicUsize,
load_count: AtomicUsize,
delete_count: AtomicUsize,
}
#[async_trait::async_trait]
impl StateStore for TrackingStateStore {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError> {
self.save_count.fetch_add(1, Ordering::SeqCst);
self.inner.save(csrf_token, state).await
}
async fn load(
&self,
csrf_token: &str,
) -> Result<Option<StoredAuthorizationState>, AuthError> {
self.load_count.fetch_add(1, Ordering::SeqCst);
self.inner.load(csrf_token).await
}
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> {
self.delete_count.fetch_add(1, Ordering::SeqCst);
self.inner.delete(csrf_token).await
}
}
let store = TrackingStateStore::default();
let pkce = PkceCodeVerifier::new("test-verifier".to_string());
let csrf = CsrfToken::new("test-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("test-csrf", state).await.unwrap();
assert_eq!(store.save_count.load(Ordering::SeqCst), 1);
let _ = store.load("test-csrf").await.unwrap();
assert_eq!(store.load_count.load(Ordering::SeqCst), 1);
store.delete("test-csrf").await.unwrap();
assert_eq!(store.delete_count.load(Ordering::SeqCst), 1);
let mut manager = AuthorizationManager::new("http://localhost").await.unwrap();
manager.set_state_store(TrackingStateStore::default());
}
async fn manager_with_metadata(
metadata_override: Option<AuthorizationMetadata>,
) -> AuthorizationManager {
let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap();
mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
..Default::default()
}));
mgr
}
fn test_client_config() -> OAuthClientConfig {
OAuthClientConfig {
client_id: "my-client".to_string(),
client_secret: Some("my-secret".to_string()),
scopes: vec![],
redirect_uri: "http://localhost/callback".to_string(),
}
}
#[tokio::test]
async fn test_configure_client_uses_client_secret_post_from_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::RequestBody
));
}
#[tokio::test]
async fn test_configure_client_defaults_to_basic_auth() {
let mut mgr = manager_with_metadata(None).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_with_explicit_basic_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_ignores_unsupported_auth_methods_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["private_key_jwt"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_prefers_basic_when_both_methods_supported() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post", "client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[test]
fn test_code_challenge_methods_supported_deserialization() {
let json = r#"{
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"code_challenge_methods_supported": ["S256", "plain"]
}"#;
let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap();
let methods = metadata.code_challenge_methods_supported.unwrap();
assert!(methods.contains(&"S256".to_string()));
assert!(methods.contains(&"plain".to_string()));
}
#[test]
fn test_code_challenge_methods_supported_missing_from_json() {
let json = r#"{
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token"
}"#;
let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap();
assert!(metadata.code_challenge_methods_supported.is_none());
}
#[tokio::test]
async fn test_validate_as_metadata_rejects_unsupported_response_type() {
let mut manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: "https://auth.example.com/authorize".to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
response_types_supported: Some(vec!["token".to_string()]),
..Default::default()
};
manager.set_metadata(metadata);
assert!(manager.validate_server_metadata("code").is_err());
}
#[tokio::test]
async fn test_validate_as_metadata_passes_without_pkce_s256() {
let mut manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: "https://auth.example.com/authorize".to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
response_types_supported: Some(vec!["code".to_string()]),
code_challenge_methods_supported: Some(vec!["plain".to_string()]),
..Default::default()
};
manager.set_metadata(metadata);
assert!(manager.validate_server_metadata("code").is_ok());
}
#[tokio::test]
async fn test_validate_as_metadata_passes_without_metadata() {
let manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
assert!(manager.validate_server_metadata("code").is_ok());
}
#[tokio::test]
async fn test_authorization_url_is_valid() {
let base_url = "https://mcp.example.com/api";
let auth_endpoint = "https://auth.example.com/authorize";
let mut manager = AuthorizationManager::new(base_url).await.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: auth_endpoint.to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
registration_endpoint: None,
issuer: None,
jwks_uri: None,
scopes_supported: None,
response_types_supported: Some(vec!["code".to_string()]),
code_challenge_methods_supported: Some(vec!["S256".to_string()]),
additional_fields: std::collections::HashMap::new(),
};
manager.set_metadata(metadata);
manager.configure_client_id("test-client-id").unwrap();
let auth_url = manager
.get_authorization_url(&["read", "write"])
.await
.unwrap();
let parsed = Url::parse(&auth_url).unwrap();
assert!(auth_url.starts_with(auth_endpoint));
let params: std::collections::HashMap<_, _> = parsed.query_pairs().collect();
assert_eq!(
params.get("response_type").map(|v| v.as_ref()),
Some("code")
);
assert_eq!(
params.get("client_id").map(|v| v.as_ref()),
Some("test-client-id")
);
assert!(params.contains_key("state"));
assert_eq!(
params.get("redirect_uri").map(|v| v.as_ref()),
Some(base_url)
);
assert!(params.contains_key("code_challenge"));
assert_eq!(
params.get("code_challenge_method").map(|v| v.as_ref()),
Some("S256")
);
assert_eq!(params.get("resource").map(|v| v.as_ref()), Some(base_url));
let scope = params
.get("scope")
.map(|v| v.to_string())
.unwrap_or_default();
assert!(scope.contains("read"));
assert!(scope.contains("write"));
}
#[test]
fn compute_scope_union_adds_new_scopes() {
let current = vec!["read".to_string(), "write".to_string()];
let result = AuthorizationManager::compute_scope_union(¤t, "admin delete");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert!(result.contains(&"admin".to_string()));
assert!(result.contains(&"delete".to_string()));
assert_eq!(result.len(), 4);
}
#[test]
fn compute_scope_union_deduplicates() {
let current = vec!["read".to_string(), "write".to_string()];
let result = AuthorizationManager::compute_scope_union(¤t, "read admin");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert!(result.contains(&"admin".to_string()));
assert_eq!(result.len(), 3);
}
#[test]
fn compute_scope_union_handles_empty_current() {
let current: Vec<String> = vec![];
let result = AuthorizationManager::compute_scope_union(¤t, "read write");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert_eq!(result.len(), 2);
}
#[tokio::test]
async fn select_scopes_adds_offline_access_when_as_supports_it() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
*mgr.resource_scopes.write().await = vec!["profile".to_string()];
let scopes = mgr.select_scopes(None, &[]);
assert!(
scopes.contains(&"offline_access".to_string()),
"offline_access should be added when AS supports it"
);
assert!(scopes.contains(&"profile".to_string()));
}
#[tokio::test]
async fn select_scopes_does_not_add_offline_access_when_as_does_not_support_it() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]),
..Default::default()
}))
.await;
*mgr.resource_scopes.write().await = vec!["profile".to_string()];
let scopes = mgr.select_scopes(None, &[]);
assert!(
!scopes.contains(&"offline_access".to_string()),
"offline_access should not be added when AS does not support it"
);
}
#[tokio::test]
async fn select_scopes_falls_back_to_defaults() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: None,
..Default::default()
}))
.await;
let scopes = mgr.select_scopes(None, &["default_scope"]);
assert_eq!(scopes, vec!["default_scope".to_string()]);
}
#[tokio::test]
async fn select_scopes_does_not_duplicate_offline_access() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
let scopes = mgr.select_scopes(None, &[]);
let count = scopes.iter().filter(|s| *s == "offline_access").count();
assert_eq!(count, 1, "offline_access should not be duplicated");
}
#[tokio::test]
async fn select_scopes_adds_offline_access_to_www_authenticate_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
*mgr.www_auth_scopes.write().await = vec!["profile".to_string()];
let scopes = mgr.select_scopes(None, &[]);
assert!(scopes.contains(&"offline_access".to_string()));
assert!(scopes.contains(&"profile".to_string()));
}
#[tokio::test]
async fn select_scopes_adds_offline_access_to_www_authenticate_argument() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
let scopes = mgr.select_scopes(Some("profile email"), &[]);
assert!(scopes.contains(&"offline_access".to_string()));
assert!(scopes.contains(&"profile".to_string()));
assert!(scopes.contains(&"email".to_string()));
}
#[tokio::test]
async fn add_offline_access_if_supported_works_with_explicit_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
let mut explicit = vec!["read".to_string(), "write".to_string()];
mgr.add_offline_access_if_supported(&mut explicit);
assert!(explicit.contains(&"offline_access".to_string()));
}
#[tokio::test]
async fn add_offline_access_if_supported_skips_empty_scopes() {
let mgr = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
scopes_supported: Some(vec!["profile".to_string(), "offline_access".to_string()]),
..Default::default()
}))
.await;
let mut empty: Vec<String> = vec![];
mgr.add_offline_access_if_supported(&mut empty);
assert!(
empty.is_empty(),
"offline_access should not be the only scope"
);
}
#[test]
fn scope_upgrade_config_default_values() {
let config = ScopeUpgradeConfig::default();
assert_eq!(config.max_upgrade_attempts, 3);
assert!(config.auto_upgrade);
}
#[tokio::test]
async fn authorization_manager_tracks_scope_upgrade_attempts() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
assert_eq!(manager.get_scope_upgrade_attempts().await, 0);
*manager.scope_upgrade_attempts.write().await = 2;
assert_eq!(manager.get_scope_upgrade_attempts().await, 2);
manager.reset_scope_upgrade_attempts().await;
assert_eq!(manager.get_scope_upgrade_attempts().await, 0);
}
#[tokio::test]
async fn authorization_manager_can_attempt_scope_upgrade_respects_config() {
let mut manager = AuthorizationManager::new("http://localhost").await.unwrap();
assert!(manager.can_attempt_scope_upgrade().await);
manager.set_scope_upgrade_config(ScopeUpgradeConfig {
max_upgrade_attempts: 3,
auto_upgrade: false,
});
assert!(!manager.can_attempt_scope_upgrade().await);
manager.set_scope_upgrade_config(ScopeUpgradeConfig {
max_upgrade_attempts: 2,
auto_upgrade: true,
});
*manager.scope_upgrade_attempts.write().await = 2;
assert!(!manager.can_attempt_scope_upgrade().await);
*manager.scope_upgrade_attempts.write().await = 1;
assert!(manager.can_attempt_scope_upgrade().await);
}
use super::{OAuthTokenResponse, StoredCredentials};
fn make_token_response(access_token: &str, expires_in_secs: Option<u64>) -> OAuthTokenResponse {
use oauth2::{AccessToken, basic::BasicTokenType};
let mut resp = OAuthTokenResponse::new(
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
VendorExtraTokenFields {
..Default::default()
},
);
if let Some(secs) = expires_in_secs {
resp.set_expires_in(Some(&std::time::Duration::from_secs(secs)));
}
resp
}
#[tokio::test]
async fn get_access_token_returns_error_when_no_credentials() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(matches!(err, AuthError::AuthorizationRequired));
}
#[tokio::test]
async fn get_access_token_returns_token_when_not_expired() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("my-access-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();
let token = manager.get_access_token().await.unwrap();
assert_eq!(token, "my-access-token");
}
#[tokio::test]
async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response("stale-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}"
);
}
#[tokio::test]
async fn get_access_token_returns_token_without_expiry_info() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("no-expiry-token", None)),
granted_scopes: vec![],
token_received_at: None,
};
manager.credential_store.save(stored).await.unwrap();
let token = manager.get_access_token().await.unwrap();
assert_eq!(token, "no-expiry-token");
}
#[tokio::test]
async fn get_access_token_requires_reauth_when_within_refresh_buffer() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response("almost-expired", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 3590),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when token is within refresh buffer, got: {err:?}"
);
}
#[tokio::test]
async fn get_access_token_propagates_internal_errors() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("stale-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::InternalError(_)),
"expected InternalError when OAuth client is not configured, got: {err:?}"
);
}
#[test]
fn client_registration_request_includes_scope_when_present() {
let req = super::ClientRegistrationRequest {
client_name: "test".to_string(),
redirect_uris: vec!["http://localhost/callback".to_string()],
grant_types: vec!["authorization_code".to_string()],
token_endpoint_auth_method: "none".to_string(),
response_types: vec!["code".to_string()],
scope: Some("read write".to_string()),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["scope"], "read write");
}
#[test]
fn client_registration_request_omits_scope_when_none() {
let req = super::ClientRegistrationRequest {
client_name: "test".to_string(),
redirect_uris: vec!["http://localhost/callback".to_string()],
grant_types: vec!["authorization_code".to_string()],
token_endpoint_auth_method: "none".to_string(),
response_types: vec!["code".to_string()],
scope: None,
};
let json = serde_json::to_value(&req).unwrap();
assert!(!json.as_object().unwrap().contains_key("scope"));
}
#[tokio::test]
async fn configure_client_credentials_uses_request_body_auth_for_client_secret() {
let mut mgr = manager_with_metadata(None).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "my-m2m-client".to_string(),
client_secret: "super-secret".to_string(),
scopes: vec!["read".to_string()],
resource: None,
};
mgr.configure_client_credentials(&config).unwrap();
let oauth_client = mgr.oauth_client.as_ref().unwrap();
assert!(matches!(oauth_client.auth_type(), AuthType::RequestBody));
}
#[tokio::test]
async fn configure_client_credentials_sets_correct_client_id() {
let mut mgr = manager_with_metadata(None).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "my-m2m-client".to_string(),
client_secret: "super-secret".to_string(),
scopes: vec!["read".to_string()],
resource: None,
};
mgr.configure_client_credentials(&config).unwrap();
let oauth_client = mgr.oauth_client.as_ref().unwrap();
assert_eq!(oauth_client.client_id().as_str(), "my-m2m-client");
}
#[tokio::test]
async fn configure_client_credentials_returns_error_without_metadata() {
let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap();
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
let err = mgr.configure_client_credentials(&config).unwrap_err();
assert!(matches!(err, AuthError::NoAuthorizationSupport));
}
#[tokio::test]
async fn validate_client_credentials_metadata_rejects_unsupported_method() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["tls_client_auth", "private_key_jwt"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mgr = manager_with_metadata(Some(meta)).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
let err = mgr
.validate_client_credentials_metadata(&config)
.unwrap_err();
assert!(
err.to_string().contains("tls_client_auth"),
"expected error to mention unsupported method, got: {err}"
);
}
#[tokio::test]
async fn validate_client_credentials_metadata_accepts_supported_method() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post", "client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mgr = manager_with_metadata(Some(meta)).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
mgr.validate_client_credentials_metadata(&config).unwrap();
}
#[tokio::test]
async fn validate_client_credentials_metadata_permits_when_field_absent() {
let mgr = manager_with_metadata(None).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
mgr.validate_client_credentials_metadata(&config).unwrap();
}
#[tokio::test]
async fn validate_client_credentials_metadata_accepts_client_secret_basic_only() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mgr = manager_with_metadata(Some(meta)).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
mgr.validate_client_credentials_metadata(&config).unwrap();
}
#[tokio::test]
async fn configure_client_credentials_uses_basic_auth_when_server_only_supports_basic() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "id".to_string(),
client_secret: "secret".to_string(),
scopes: vec![],
resource: None,
};
mgr.configure_client_credentials(&config).unwrap();
let oauth_client = mgr.oauth_client.as_ref().unwrap();
assert!(
!matches!(oauth_client.auth_type(), AuthType::RequestBody),
"expected HTTP Basic auth when server only supports client_secret_basic"
);
}
#[test]
fn client_credentials_config_returns_correct_accessor_values() {
let config = super::ClientCredentialsConfig::ClientSecret {
client_id: "test-id".to_string(),
client_secret: "test-secret".to_string(),
scopes: vec!["scope1".to_string(), "scope2".to_string()],
resource: Some("https://example.com".to_string()),
};
assert_eq!(config.client_id(), "test-id");
assert_eq!(config.scopes(), &["scope1", "scope2"]);
assert_eq!(config.resource(), Some("https://example.com"));
assert_eq!(config.auth_method(), "client_secret_post");
}
#[test]
fn extension_constant_matches_spec() {
assert_eq!(
super::EXTENSION_OAUTH_CLIENT_CREDENTIALS,
"io.modelcontextprotocol/oauth-client-credentials"
);
}
fn make_token_response_with_refresh(
access_token: &str,
refresh_token: &str,
) -> OAuthTokenResponse {
use oauth2::RefreshToken;
let mut resp = make_token_response(access_token, Some(3600));
resp.set_refresh_token(Some(RefreshToken::new(refresh_token.to_string())));
resp
}
#[tokio::test]
async fn refresh_token_returns_error_when_no_stored_credentials() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let err = manager.refresh_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when no credentials stored, got: {err:?}"
);
}
#[tokio::test]
async fn refresh_token_returns_error_when_no_token_response() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: None,
granted_scopes: vec![],
token_received_at: None,
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.refresh_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when token_response is None, got: {err:?}"
);
}
#[tokio::test]
async fn refresh_token_returns_error_when_no_refresh_token() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response("old-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.refresh_token().await.unwrap_err();
assert!(
matches!(err, AuthError::TokenRefreshFailed(_)),
"expected TokenRefreshFailed when no refresh token, got: {err:?}"
);
}
async fn start_token_server() -> (String, Arc<std::sync::Mutex<Option<String>>>) {
use axum::{Router, body::Body, http::Response, routing::post};
let captured: Arc<std::sync::Mutex<Option<String>>> = Arc::new(std::sync::Mutex::new(None));
let captured_clone = Arc::clone(&captured);
let app = Router::new().route(
"/token",
post(move |body: axum::body::Bytes| {
let cap = Arc::clone(&captured_clone);
async move {
*cap.lock().unwrap() =
Some(String::from_utf8(body.to_vec()).unwrap());
Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Body::from(
r#"{"access_token":"new-token","token_type":"Bearer","expires_in":3600}"#,
))
.unwrap()
}
}),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
(format!("http://{}", addr), captured)
}
#[tokio::test]
async fn refresh_token_sends_granted_scopes_in_request() {
let (base_url, captured) = start_token_server().await;
let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("{}/authorize", base_url),
token_endpoint: format!("{}/token", base_url),
..Default::default()
}))
.await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response_with_refresh(
"old-token",
"my-refresh-token",
)),
granted_scopes: vec!["read".to_string(), "write".to_string()],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();
manager.refresh_token().await.unwrap();
let body = captured.lock().unwrap().take().unwrap();
let params: std::collections::HashMap<_, _> = url::form_urlencoded::parse(body.as_bytes())
.into_owned()
.collect();
let scope = params
.get("scope")
.expect("scope should be present in refresh request");
let mut scope_parts: Vec<&str> = scope.split_whitespace().collect();
scope_parts.sort_unstable();
assert_eq!(scope_parts, vec!["read", "write"]);
}
#[tokio::test]
async fn refresh_token_omits_scope_when_granted_scopes_is_empty() {
let (base_url, captured) = start_token_server().await;
let mut manager = manager_with_metadata(Some(AuthorizationMetadata {
authorization_endpoint: format!("{}/authorize", base_url),
token_endpoint: format!("{}/token", base_url),
..Default::default()
}))
.await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response_with_refresh(
"old-token",
"my-refresh-token",
)),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();
manager.refresh_token().await.unwrap();
let body = captured.lock().unwrap().take().unwrap();
let params: std::collections::HashMap<_, _> = url::form_urlencoded::parse(body.as_bytes())
.into_owned()
.collect();
assert!(
!params.contains_key("scope"),
"scope should be absent when granted_scopes is empty, body: {body}"
);
}
}