use std::{
collections::HashMap,
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait;
use oauth2::{
AsyncHttpClient, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
EmptyExtraTokenFields, ExtraTokenFields, HttpClientError, HttpRequest, HttpResponse,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope,
StandardTokenResponse, TokenResponse, TokenUrl, basic::BasicTokenType,
};
use reqwest::{
Client as HttpClient, IntoUrl, StatusCode, Url,
header::{AUTHORIZATION, WWW_AUTHENTICATE},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, warn};
use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION;
struct OAuthReqwestClient(HttpClient);
impl<'c> AsyncHttpClient<'c> for OAuthReqwestClient {
type Error = HttpClientError<reqwest::Error>;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>,
>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self
.0
.execute(request.try_into().map_err(Box::new)?)
.await
.map_err(Box::new)?;
let mut builder = oauth2::http::Response::builder()
.status(response.status())
.version(response.version());
for (name, value) in response.headers().iter() {
builder = builder.header(name, value);
}
builder
.body(response.bytes().await.map_err(Box::new)?.to_vec())
.map_err(HttpClientError::Http)
})
}
}
const DEFAULT_EXCHANGE_URL: &str = "http://localhost";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub client_id: String,
pub token_response: Option<OAuthTokenResponse>,
#[serde(default)]
pub granted_scopes: Vec<String>,
#[serde(default)]
pub token_received_at: Option<u64>,
}
#[async_trait]
pub trait CredentialStore: Send + Sync {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError>;
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError>;
async fn clear(&self) -> Result<(), AuthError>;
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryCredentialStore {
credentials: Arc<RwLock<Option<StoredCredentials>>>,
}
impl InMemoryCredentialStore {
pub fn new() -> Self {
Self {
credentials: Arc::new(RwLock::new(None)),
}
}
}
#[async_trait::async_trait]
impl CredentialStore for InMemoryCredentialStore {
async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
Ok(self.credentials.read().await.clone())
}
async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
*self.credentials.write().await = Some(credentials);
Ok(())
}
async fn clear(&self) -> Result<(), AuthError> {
*self.credentials.write().await = None;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredAuthorizationState {
pub pkce_verifier: String,
pub csrf_token: String,
pub created_at: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VendorExtraTokenFields(pub HashMap<String, Value>);
impl ExtraTokenFields for VendorExtraTokenFields {}
impl StoredAuthorizationState {
pub fn new(pkce_verifier: &PkceCodeVerifier, csrf_token: &CsrfToken) -> Self {
Self {
pkce_verifier: pkce_verifier.secret().to_string(),
csrf_token: csrf_token.secret().to_string(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0),
}
}
pub fn into_pkce_verifier(self) -> PkceCodeVerifier {
PkceCodeVerifier::new(self.pkce_verifier)
}
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError>;
async fn load(&self, csrf_token: &str) -> Result<Option<StoredAuthorizationState>, AuthError>;
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError>;
}
#[derive(Debug, Default, Clone)]
pub struct InMemoryStateStore {
states: Arc<RwLock<HashMap<String, StoredAuthorizationState>>>,
}
impl InMemoryStateStore {
pub fn new() -> Self {
Self {
states: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError> {
self.states
.write()
.await
.insert(csrf_token.to_string(), state);
Ok(())
}
async fn load(&self, csrf_token: &str) -> Result<Option<StoredAuthorizationState>, AuthError> {
Ok(self.states.read().await.get(csrf_token).cloned())
}
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> {
self.states.write().await.remove(csrf_token);
Ok(())
}
}
#[derive(Clone)]
pub struct AuthClient<C> {
pub http_client: C,
pub auth_manager: Arc<Mutex<AuthorizationManager>>,
}
impl<C: std::fmt::Debug> std::fmt::Debug for AuthClient<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthorizedClient")
.field("http_client", &self.http_client)
.field("auth_manager", &"...")
.finish()
}
}
impl<C> AuthClient<C> {
pub fn new(http_client: C, auth_manager: AuthorizationManager) -> Self {
Self {
http_client,
auth_manager: Arc::new(Mutex::new(auth_manager)),
}
}
}
impl<C> AuthClient<C> {
pub fn get_access_token(&self) -> impl Future<Output = Result<String, AuthError>> + Send {
let auth_manager = self.auth_manager.clone();
async move { auth_manager.lock().await.get_access_token().await }
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AuthError {
#[error("OAuth authorization required")]
AuthorizationRequired,
#[error("OAuth authorization failed: {0}")]
AuthorizationFailed(String),
#[error("OAuth token exchange failed: {0}")]
TokenExchangeFailed(String),
#[error("OAuth token refresh failed: {0}")]
TokenRefreshFailed(String),
#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),
#[error("OAuth error: {0}")]
OAuthError(String),
#[error("Metadata error: {0}")]
MetadataError(String),
#[error("URL parse error: {0}")]
UrlError(#[from] url::ParseError),
#[error("No authorization support detected")]
NoAuthorizationSupport,
#[error("Internal error: {0}")]
InternalError(String),
#[error("Invalid token type: {0}")]
InvalidTokenType(String),
#[error("Token expired")]
TokenExpired,
#[error("Invalid scope: {0}")]
InvalidScope(String),
#[error("Registration failed: {0}")]
RegistrationFailed(String),
#[error("Insufficient scope: {required_scope}")]
InsufficientScope {
required_scope: String,
upgrade_url: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct AuthorizationMetadata {
pub authorization_endpoint: String,
pub token_endpoint: String,
pub registration_endpoint: Option<String>,
pub issuer: Option<String>,
pub jwks_uri: Option<String>,
pub scopes_supported: Option<Vec<String>>,
pub response_types_supported: Option<Vec<String>>,
pub code_challenge_methods_supported: Option<Vec<String>>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
struct ResourceServerMetadata {
authorization_server: Option<String>,
authorization_servers: Option<Vec<String>>,
scopes_supported: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default)]
pub struct WWWAuthenticateParams {
pub resource_metadata_url: Option<Url>,
pub scope: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
impl WWWAuthenticateParams {
pub fn is_insufficient_scope(&self) -> bool {
self.error.as_deref() == Some("insufficient_scope")
}
pub fn is_invalid_token(&self) -> bool {
self.error.as_deref() == Some("invalid_token")
}
}
#[derive(Debug, Clone)]
pub struct OAuthClientConfig {
pub client_id: String,
pub client_secret: Option<String>,
pub scopes: Vec<String>,
pub redirect_uri: String,
}
type OAuthErrorResponse = oauth2::StandardErrorResponse<oauth2::basic::BasicErrorResponseType>;
pub type OAuthTokenResponse = StandardTokenResponse<VendorExtraTokenFields, BasicTokenType>;
type OAuthTokenIntrospection =
oauth2::StandardTokenIntrospectionResponse<EmptyExtraTokenFields, BasicTokenType>;
type OAuthRevocableToken = oauth2::StandardRevocableToken;
type OAuthRevocationError = oauth2::StandardErrorResponse<oauth2::RevocationErrorResponseType>;
type OAuthClient = oauth2::Client<
OAuthErrorResponse,
OAuthTokenResponse,
OAuthTokenIntrospection,
OAuthRevocableToken,
OAuthRevocationError,
oauth2::EndpointSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointNotSet,
oauth2::EndpointSet,
>;
type Credentials = (String, Option<OAuthTokenResponse>);
#[derive(Debug, Clone)]
pub struct ScopeUpgradeConfig {
pub max_upgrade_attempts: u32,
pub auto_upgrade: bool,
}
impl Default for ScopeUpgradeConfig {
fn default() -> Self {
Self {
max_upgrade_attempts: 3,
auto_upgrade: true,
}
}
}
pub struct AuthorizationManager {
http_client: HttpClient,
metadata: Option<AuthorizationMetadata>,
oauth_client: Option<OAuthClient>,
credential_store: Arc<dyn CredentialStore>,
state_store: Arc<dyn StateStore>,
base_url: Url,
current_scopes: RwLock<Vec<String>>,
scope_upgrade_attempts: RwLock<u32>,
scope_upgrade_config: ScopeUpgradeConfig,
www_auth_scopes: RwLock<Vec<String>>,
resource_scopes: RwLock<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ClientRegistrationRequest {
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub token_endpoint_auth_method: String,
pub response_types: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: String,
pub client_secret: Option<String>,
pub client_name: Option<String>,
pub redirect_uris: Vec<String>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
fn is_https_url(value: &str) -> bool {
Url::parse(value)
.ok()
.map(|url| url.scheme() == "https" && url.path() != "/" && url.host_str().is_some())
.unwrap_or(false)
}
impl AuthorizationManager {
fn well_known_paths(base_path: &str, resource: &str) -> Vec<String> {
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
let mut candidates = Vec::new();
let mut push_candidate = |candidate: String| {
if !candidates.contains(&candidate) {
candidates.push(candidate);
}
};
let canonical = format!("/.well-known/{resource}");
if trimmed.is_empty() {
push_candidate(canonical);
return candidates;
}
push_candidate(format!("{canonical}/{trimmed}"));
push_candidate(format!("/{trimmed}/.well-known/{resource}"));
push_candidate(canonical);
candidates
}
pub async fn new<U: IntoUrl>(base_url: U) -> Result<Self, AuthError> {
let base_url = base_url.into_url()?;
let http_client = HttpClient::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
let manager = Self {
http_client,
metadata: None,
oauth_client: None,
credential_store: Arc::new(InMemoryCredentialStore::new()),
state_store: Arc::new(InMemoryStateStore::new()),
base_url,
current_scopes: RwLock::new(Vec::new()),
scope_upgrade_attempts: RwLock::new(0),
scope_upgrade_config: ScopeUpgradeConfig::default(),
www_auth_scopes: RwLock::new(Vec::new()),
resource_scopes: RwLock::new(Vec::new()),
};
Ok(manager)
}
pub fn set_scope_upgrade_config(&mut self, config: ScopeUpgradeConfig) {
self.scope_upgrade_config = config;
}
pub fn set_credential_store<S: CredentialStore + 'static>(&mut self, store: S) {
self.credential_store = Arc::new(store);
}
pub fn set_state_store<S: StateStore + 'static>(&mut self, store: S) {
self.state_store = Arc::new(store);
}
pub fn set_metadata(&mut self, metadata: AuthorizationMetadata) {
self.metadata = Some(metadata);
}
pub async fn initialize_from_store(&mut self) -> Result<bool, AuthError> {
if let Some(stored) = self.credential_store.load().await? {
if stored.token_response.is_some() {
if self.metadata.is_none() {
let metadata = self.discover_metadata().await?;
self.metadata = Some(metadata);
}
self.configure_client_id(&stored.client_id)?;
return Ok(true);
}
}
Ok(false)
}
pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> {
self.http_client = http_client;
Ok(())
}
pub async fn discover_metadata(&self) -> Result<AuthorizationMetadata, AuthError> {
if let Some(metadata) = self.discover_oauth_server_via_resource_metadata().await? {
return Ok(metadata);
}
if let Some(metadata) = self.try_discover_oauth_server(&self.base_url).await? {
return Ok(metadata);
}
Err(AuthError::NoAuthorizationSupport)
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
let client_id = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
.client_id();
let stored = self.credential_store.load().await?;
let token_response = stored.and_then(|s| s.token_response);
Ok((client_id.to_string(), token_response))
}
pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> {
if self.metadata.is_none() {
return Err(AuthError::NoAuthorizationSupport);
}
let metadata = self.metadata.as_ref().unwrap();
let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?;
let token_url = TokenUrl::new(metadata.token_endpoint.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?;
let client_id = ClientId::new(config.client_id);
let redirect_url = RedirectUrl::new(config.redirect_uri.clone())
.map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?;
let mut client_builder: OAuthClient = oauth2::Client::new(client_id.clone())
.set_auth_uri(auth_url)
.set_token_uri(token_url)
.set_redirect_uri(redirect_url);
if let Some(secret) = config.client_secret {
client_builder = client_builder.set_client_secret(ClientSecret::new(secret));
}
let uses_secret_post = metadata
.additional_fields
.get("token_endpoint_auth_methods_supported")
.and_then(|v| v.as_array())
.map(|arr| {
let has_basic = arr
.iter()
.any(|m| m.as_str() == Some("client_secret_basic"));
let has_post = arr.iter().any(|m| m.as_str() == Some("client_secret_post"));
has_post && !has_basic
})
.unwrap_or(false);
if uses_secret_post {
client_builder = client_builder.set_auth_type(AuthType::RequestBody);
}
self.oauth_client = Some(client_builder);
Ok(())
}
fn validate_server_metadata(&self, response_type: &str) -> Result<(), AuthError> {
let Some(metadata) = self.metadata.as_ref() else {
return Ok(());
};
if let Some(response_types_supported) = metadata.response_types_supported.as_ref() {
if !response_types_supported.contains(&response_type.to_string()) {
return Err(AuthError::InvalidScope(response_type.to_string()));
}
}
match &metadata.code_challenge_methods_supported {
Some(methods) if !methods.iter().any(|m| m == "S256") => {
warn!(
?methods,
"server does not advertise S256 in code_challenge_methods_supported, \
proceeding with S256 anyway as oauth 2.1 requires it. \
The server is not compliant with the specification!"
);
}
_ => {}
}
Ok(())
}
pub async fn register_client(
&mut self,
name: &str,
redirect_uri: &str,
scopes: &[&str],
) -> Result<OAuthClientConfig, AuthError> {
if self.metadata.is_none() {
return Err(AuthError::NoAuthorizationSupport);
}
let metadata = self.metadata.as_ref().unwrap();
let Some(registration_url) = metadata.registration_endpoint.as_ref() else {
return Err(AuthError::RegistrationFailed(
"Dynamic client registration not supported".to_string(),
));
};
self.validate_server_metadata("code")?;
let registration_request = ClientRegistrationRequest {
client_name: name.to_string(),
redirect_uris: vec![redirect_uri.to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
token_endpoint_auth_method: "none".to_string(), response_types: vec!["code".to_string()],
scope: if scopes.is_empty() {
None
} else {
Some(scopes.join(" "))
},
};
let response = match self
.http_client
.post(registration_url)
.json(®istration_request)
.send()
.await
{
Ok(response) => response,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"HTTP request error: {}",
e
)));
}
};
if !response.status().is_success() {
let status = response.status();
let error_text = match response.text().await {
Ok(text) => text,
Err(_) => "cannot get error details".to_string(),
};
return Err(AuthError::RegistrationFailed(format!(
"HTTP {}: {}",
status, error_text
)));
}
debug!("registration response: {:?}", response);
let reg_response = match response.json::<ClientRegistrationResponse>().await {
Ok(response) => response,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"analyze response error: {}",
e
)));
}
};
let config = OAuthClientConfig {
client_id: reg_response.client_id,
client_secret: reg_response.client_secret.filter(|s| !s.is_empty()),
redirect_uri: redirect_uri.to_string(),
scopes: scopes.iter().map(|s| s.to_string()).collect(),
};
self.configure_client(config.clone())?;
Ok(config)
}
pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> {
let config = OAuthClientConfig {
client_id: client_id.to_string(),
client_secret: None,
scopes: vec![],
redirect_uri: self.base_url.to_string(),
};
self.configure_client(config)
}
pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result<String, AuthError> {
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
self.validate_server_metadata("code")?;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_request = oauth_client
.authorize_url(CsrfToken::new_random)
.set_pkce_challenge(pkce_challenge)
.add_extra_param("resource", self.base_url.to_string());
for scope in scopes {
auth_request = auth_request.add_scope(Scope::new(scope.to_string()));
}
let (auth_url, csrf_token) = auth_request.url();
let stored_state = StoredAuthorizationState::new(&pkce_verifier, &csrf_token);
self.state_store
.save(csrf_token.secret(), stored_state)
.await?;
Ok(auth_url.to_string())
}
pub async fn get_current_scopes(&self) -> Vec<String> {
self.current_scopes.read().await.clone()
}
fn compute_scope_union(current: &[String], required: &str) -> Vec<String> {
let mut scope_set: std::collections::HashSet<String> = current.iter().cloned().collect();
for scope in required.split_whitespace() {
scope_set.insert(scope.to_string());
}
scope_set.into_iter().collect()
}
pub async fn can_attempt_scope_upgrade(&self) -> bool {
if !self.scope_upgrade_config.auto_upgrade {
return false;
}
let attempts = *self.scope_upgrade_attempts.read().await;
attempts < self.scope_upgrade_config.max_upgrade_attempts
}
pub fn select_scopes(
&self,
www_authenticate_scope: Option<&str>,
default_scopes: &[&str],
) -> Vec<String> {
if let Some(scope) = www_authenticate_scope {
return scope.split_whitespace().map(|s| s.to_string()).collect();
}
if let Ok(guard) = self.www_auth_scopes.try_read() {
if !guard.is_empty() {
return guard.clone();
}
}
if let Ok(guard) = self.resource_scopes.try_read() {
if !guard.is_empty() {
return guard.clone();
}
}
if let Some(metadata) = &self.metadata {
if let Some(scopes_supported) = &metadata.scopes_supported {
if !scopes_supported.is_empty() {
return scopes_supported.clone();
}
}
}
default_scopes.iter().map(|s| s.to_string()).collect()
}
pub async fn request_scope_upgrade(&self, required_scope: &str) -> Result<String, AuthError> {
if !self.scope_upgrade_config.auto_upgrade {
return Err(AuthError::InvalidScope(
"Scope upgrade is disabled".to_string(),
));
}
let mut attempts = self.scope_upgrade_attempts.write().await;
if *attempts >= self.scope_upgrade_config.max_upgrade_attempts {
return Err(AuthError::InvalidScope(format!(
"Maximum scope upgrade attempts ({}) exceeded",
self.scope_upgrade_config.max_upgrade_attempts
)));
}
*attempts += 1;
drop(attempts);
let current_scopes = self.current_scopes.read().await.clone();
let upgraded_scopes = Self::compute_scope_union(¤t_scopes, required_scope);
debug!(
"Requesting scope upgrade: current={:?}, required={}, union={:?}",
current_scopes, required_scope, upgraded_scopes
);
let scope_refs: Vec<&str> = upgraded_scopes.iter().map(|s| s.as_str()).collect();
self.get_authorization_url(&scope_refs).await
}
pub async fn reset_scope_upgrade_attempts(&self) {
*self.scope_upgrade_attempts.write().await = 0;
}
pub async fn get_scope_upgrade_attempts(&self) -> u32 {
*self.scope_upgrade_attempts.read().await
}
pub async fn exchange_code_for_token(
&self,
code: &str,
csrf_token: &str,
) -> Result<OAuthTokenResponse, AuthError> {
debug!("start exchange code for token: {:?}", code);
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
let stored_state =
self.state_store.load(csrf_token).await?.ok_or_else(|| {
AuthError::InternalError("Authorization state not found".to_string())
})?;
self.state_store.delete(csrf_token).await?;
let pkce_verifier = stored_state.into_pkce_verifier();
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| AuthError::InternalError(e.to_string()))?;
debug!("client_id: {:?}", oauth_client.client_id());
let token_result = match oauth_client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(pkce_verifier)
.add_extra_param("resource", self.base_url.to_string())
.request_async(&OAuthReqwestClient(http_client))
.await
{
Ok(token) => token,
Err(RequestTokenError::Parse(_, body)) => {
match serde_json::from_slice::<OAuthTokenResponse>(&body) {
Ok(parsed) => {
warn!(
"token exchange failed to parse completely but included a valid token response. Accepting it."
);
parsed
}
Err(parse_err) => {
return Err(AuthError::TokenExchangeFailed(parse_err.to_string()));
}
}
}
Err(e) => {
return Err(AuthError::TokenExchangeFailed(e.to_string()));
}
};
debug!("exchange token result: {:?}", token_result);
let granted_scopes: Vec<String> = token_result
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*self.current_scopes.write().await = granted_scopes.clone();
*self.scope_upgrade_attempts.write().await = 0;
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
const REFRESH_BUFFER_SECS: u64 = 30;
pub async fn get_access_token(&self) -> Result<String, AuthError> {
let stored = self.credential_store.load().await?;
let Some(stored_creds) = stored else {
return Err(AuthError::AuthorizationRequired);
};
let Some(creds) = stored_creds.token_response.as_ref() else {
return Err(AuthError::AuthorizationRequired);
};
if let (Some(expires_in), Some(received_at)) =
(creds.expires_in(), stored_creds.token_received_at)
{
let elapsed = Self::now_epoch_secs().saturating_sub(received_at);
let remaining = expires_in.as_secs().saturating_sub(elapsed);
if remaining < Self::REFRESH_BUFFER_SECS {
tracing::info!(
remaining_secs = remaining,
"Access token expired or nearly expired, refreshing."
);
return self.try_refresh_or_reauth().await;
}
}
Ok(creds.access_token().secret().to_string())
}
async fn try_refresh_or_reauth(&self) -> Result<String, AuthError> {
match self.refresh_token().await {
Ok(new_creds) => {
tracing::info!("Refreshed access token.");
Ok(new_creds.access_token().secret().to_string())
}
Err(e @ (AuthError::AuthorizationRequired | AuthError::TokenRefreshFailed(_))) => {
tracing::warn!(error = %e, "Token refresh not possible, re-authorization required.");
Err(AuthError::AuthorizationRequired)
}
Err(e) => Err(e),
}
}
pub async fn refresh_token(&self) -> Result<OAuthTokenResponse, AuthError> {
let oauth_client = self
.oauth_client
.as_ref()
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?;
let stored = self.credential_store.load().await?;
let current_credentials = stored
.and_then(|s| s.token_response)
.ok_or_else(|| AuthError::AuthorizationRequired)?;
let refresh_token = current_credentials.refresh_token().ok_or_else(|| {
AuthError::TokenRefreshFailed("No refresh token available".to_string())
})?;
debug!("refresh token present, attempting refresh");
let token_result = oauth_client
.exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string()))
.request_async(&OAuthReqwestClient(self.http_client.clone()))
.await
.map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?;
let granted_scopes: Vec<String> = match token_result.scopes() {
Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(),
None => self.current_scopes.read().await.clone(),
};
*self.current_scopes.write().await = granted_scopes.clone();
let client_id = oauth_client.client_id().to_string();
let stored = StoredCredentials {
client_id,
token_response: Some(token_result.clone()),
granted_scopes,
token_received_at: Some(Self::now_epoch_secs()),
};
self.credential_store.save(stored).await?;
Ok(token_result)
}
pub async fn prepare_request(
&self,
request: reqwest::RequestBuilder,
) -> Result<reqwest::RequestBuilder, AuthError> {
let token = self.get_access_token().await?;
Ok(request.header(AUTHORIZATION, format!("Bearer {}", token)))
}
pub async fn handle_response(
&self,
response: reqwest::Response,
) -> Result<reqwest::Response, AuthError> {
if response.status() == StatusCode::UNAUTHORIZED {
return Err(AuthError::AuthorizationRequired);
}
if response.status() == StatusCode::FORBIDDEN {
for value in response.headers().get_all(WWW_AUTHENTICATE).iter() {
let Ok(value_str) = value.to_str() else {
continue;
};
let params = Self::extract_www_authenticate_params(value_str, &self.base_url);
if params.is_insufficient_scope() {
let required_scope = params.scope.unwrap_or_default();
return Err(AuthError::InsufficientScope {
required_scope,
upgrade_url: None,
});
}
}
return Err(AuthError::AuthorizationFailed("Forbidden".to_string()));
}
Ok(response)
}
fn generate_discovery_urls(base_url: &Url) -> Vec<Url> {
let mut candidates = Vec::new();
let path = base_url.path();
let trimmed = path.trim_start_matches('/').trim_end_matches('/');
let mut push_candidate = |discovery_path: String| {
let mut discovery_url = base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&discovery_path);
candidates.push(discovery_url);
};
if trimmed.is_empty() {
push_candidate("/.well-known/oauth-authorization-server".to_string());
push_candidate("/.well-known/openid-configuration".to_string());
} else {
push_candidate(format!("/.well-known/oauth-authorization-server/{trimmed}"));
push_candidate(format!("/.well-known/openid-configuration/{trimmed}"));
push_candidate(format!("/{trimmed}/.well-known/openid-configuration"));
push_candidate("/.well-known/oauth-authorization-server".to_string());
}
candidates
}
async fn try_discover_oauth_server(
&self,
base_url: &Url,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
for discovery_url in Self::generate_discovery_urls(base_url) {
if let Some(metadata) = self.fetch_authorization_metadata(&discovery_url).await? {
return Ok(Some(metadata));
}
}
Ok(None)
}
async fn fetch_authorization_metadata(
&self,
discovery_url: &Url,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
debug!("discovery url: {:?}", discovery_url);
let response = match self
.http_client
.get(discovery_url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("discovery request failed: {}", e);
return Ok(None);
}
};
if response.status() != StatusCode::OK {
debug!("discovery returned non-200: {}", response.status());
return Ok(None);
}
let body = response.text().await?;
match serde_json::from_str::<AuthorizationMetadata>(&body) {
Ok(metadata) => Ok(Some(metadata)),
Err(err) => {
debug!("Failed to parse metadata for {}: {}", discovery_url, err);
Ok(None) }
}
}
async fn discover_oauth_server_via_resource_metadata(
&self,
) -> Result<Option<AuthorizationMetadata>, AuthError> {
let Some(resource_metadata_url) = self.discover_resource_metadata_url().await? else {
return Ok(None);
};
let Some(resource_metadata) = self
.fetch_resource_metadata_from_url(&resource_metadata_url)
.await?
else {
return Ok(None);
};
if let Some(scopes) = resource_metadata.scopes_supported {
if !scopes.is_empty() {
*self.resource_scopes.write().await = scopes;
}
}
let mut candidates = Vec::new();
if let Some(single) = resource_metadata.authorization_server {
candidates.push(single);
}
if let Some(list) = resource_metadata.authorization_servers {
candidates.extend(list);
}
for candidate in candidates {
let candidate = candidate.trim();
if candidate.is_empty() {
continue;
}
let candidate_url = match Url::parse(candidate) {
Ok(url) => url,
Err(_) => match resource_metadata_url.join(candidate) {
Ok(url) => url,
Err(e) => {
debug!("Failed to resolve authorization server URL `{candidate}`: {e}");
continue;
}
},
};
if candidate_url.path().contains("/.well-known/") {
if let Some(metadata) = self.fetch_authorization_metadata(&candidate_url).await? {
return Ok(Some(metadata));
}
continue;
}
if let Some(metadata) = self.try_discover_oauth_server(&candidate_url).await? {
return Ok(Some(metadata));
}
}
Ok(None)
}
async fn discover_resource_metadata_url(&self) -> Result<Option<Url>, AuthError> {
if let Ok(Some(resource_metadata_url)) =
self.fetch_resource_metadata_url(&self.base_url).await
{
return Ok(Some(resource_metadata_url));
}
for candidate_path in
Self::well_known_paths(self.base_url.path(), "oauth-protected-resource")
{
let mut discovery_url = self.base_url.clone();
discovery_url.set_query(None);
discovery_url.set_fragment(None);
discovery_url.set_path(&candidate_path);
if let Ok(Some(resource_metadata_url)) =
self.fetch_resource_metadata_url(&discovery_url).await
{
return Ok(Some(resource_metadata_url));
}
}
Ok(None)
}
async fn fetch_resource_metadata_url(&self, url: &Url) -> Result<Option<Url>, AuthError> {
let response = match self
.http_client
.get(url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("resource metadata probe failed: {}", e);
return Ok(None);
}
};
if response.status() == StatusCode::OK {
return Ok(Some(url.clone()));
} else if response.status() != StatusCode::UNAUTHORIZED {
debug!(
"resource metadata probe returned unexpected status: {}",
response.status()
);
return Ok(None);
}
let mut parsed_url = None;
for value in response.headers().get_all(WWW_AUTHENTICATE).iter() {
let Ok(value_str) = value.to_str() else {
continue;
};
let params = Self::extract_www_authenticate_params(value_str, &self.base_url);
if let Some(url) = params.resource_metadata_url {
if let Some(scope) = ¶ms.scope {
debug!("WWW-Authenticate header contains scope: {}", scope);
let scopes: Vec<String> =
scope.split_whitespace().map(|s| s.to_string()).collect();
*self.www_auth_scopes.write().await = scopes;
}
parsed_url = Some(url);
break;
}
}
Ok(parsed_url)
}
async fn fetch_resource_metadata_from_url(
&self,
resource_metadata_url: &Url,
) -> Result<Option<ResourceServerMetadata>, AuthError> {
debug!(
"resource metadata discovery url: {:?}",
resource_metadata_url
);
let response = match self
.http_client
.get(resource_metadata_url.clone())
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Ok(r) => r,
Err(e) => {
debug!("resource metadata request failed: {}", e);
return Ok(None);
}
};
if response.status() != StatusCode::OK {
debug!(
"resource metadata request returned non-200: {}",
response.status()
);
return Ok(None);
}
let metadata = response
.json::<ResourceServerMetadata>()
.await
.map_err(|e| {
AuthError::MetadataError(format!("Failed to parse resource metadata: {}", e))
})?;
Ok(Some(metadata))
}
fn extract_www_authenticate_params(header: &str, base_url: &Url) -> WWWAuthenticateParams {
let mut params = WWWAuthenticateParams::default();
let header_lowercase = header.to_ascii_lowercase();
let mut search_offset = 0;
let resource_key = "resource_metadata=";
while let Some(pos) = header_lowercase[search_offset..].find(resource_key) {
let global_pos = search_offset + pos + resource_key.len();
let value_slice = &header[global_pos..];
if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) {
if let Ok(url) = Url::parse(&value) {
params.resource_metadata_url = Some(url);
break;
}
if let Ok(url) = base_url.join(&value) {
params.resource_metadata_url = Some(url);
break;
}
debug!("failed to parse resource metadata value `{value}` as URL");
search_offset = global_pos + consumed;
continue;
} else {
break;
}
}
let scope_key = "scope=";
if let Some(pos) = header_lowercase.find(scope_key) {
let global_pos = pos + scope_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.scope = Some(value);
}
}
let error_key = "error=";
if let Some(pos) = header_lowercase.find(error_key) {
let global_pos = pos + error_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.error = Some(value);
}
}
let desc_key = "error_description=";
if let Some(pos) = header_lowercase.find(desc_key) {
let global_pos = pos + desc_key.len();
let value_slice = &header[global_pos..];
if let Some((value, _consumed)) = Self::parse_next_header_value(value_slice) {
params.error_description = Some(value);
}
}
params
}
fn parse_next_header_value(header_fragment: &str) -> Option<(String, usize)> {
let trimmed = header_fragment.trim_start();
let leading_ws = header_fragment.len() - trimmed.len();
if let Some(stripped) = trimmed.strip_prefix('"') {
let mut escaped = false;
let mut result = String::new();
#[allow(clippy::manual_strip)]
for (idx, ch) in stripped.char_indices() {
if escaped {
result.push(ch);
escaped = false;
continue;
}
match ch {
'\\' => escaped = true,
'"' => return Some((result, leading_ws + idx + 2)),
_ => result.push(ch),
}
}
None
} else {
let end = trimmed
.find(|c: char| c == ',' || c == ';' || c.is_whitespace())
.unwrap_or(trimmed.len());
Some((trimmed[..end].to_string(), leading_ws + end))
}
}
}
pub struct AuthorizationSession {
pub auth_manager: AuthorizationManager,
pub auth_url: String,
pub redirect_uri: String,
}
impl AuthorizationSession {
pub async fn new(
mut auth_manager: AuthorizationManager,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
client_metadata_url: Option<&str>,
) -> Result<Self, AuthError> {
let metadata = auth_manager.metadata.as_ref();
let supports_url_based_client_id = metadata
.and_then(|m| {
m.additional_fields
.get("client_id_metadata_document_supported")
})
.and_then(|v| v.as_bool())
.unwrap_or(false);
let config = if supports_url_based_client_id {
if let Some(client_metadata_url) = client_metadata_url {
if !is_https_url(client_metadata_url) {
return Err(AuthError::RegistrationFailed(format!(
"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {}",
client_metadata_url
)));
}
OAuthClientConfig {
client_id: client_metadata_url.to_string(),
client_secret: None,
scopes: scopes.iter().map(|s| s.to_string()).collect(),
redirect_uri: redirect_uri.to_string(),
}
} else {
auth_manager
.register_client(client_name.unwrap_or("MCP Client"), redirect_uri, scopes)
.await
.map_err(|e| {
AuthError::RegistrationFailed(format!("Dynamic registration failed: {}", e))
})?
}
} else {
match auth_manager
.register_client(client_name.unwrap_or("MCP Client"), redirect_uri, scopes)
.await
{
Ok(config) => config,
Err(e) => {
return Err(AuthError::RegistrationFailed(format!(
"Dynamic registration failed: {}",
e
)));
}
}
};
auth_manager.configure_client(config)?;
let auth_url = auth_manager.get_authorization_url(scopes).await?;
Ok(Self {
auth_manager,
auth_url,
redirect_uri: redirect_uri.to_string(),
})
}
pub fn for_scope_upgrade(
auth_manager: AuthorizationManager,
auth_url: String,
redirect_uri: &str,
) -> Self {
Self {
auth_manager,
auth_url,
redirect_uri: redirect_uri.to_string(),
}
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
self.auth_manager.get_credentials().await
}
pub fn get_authorization_url(&self) -> &str {
&self.auth_url
}
pub async fn handle_callback(
&self,
code: &str,
csrf_token: &str,
) -> Result<OAuthTokenResponse, AuthError> {
self.auth_manager
.exchange_code_for_token(code, csrf_token)
.await
}
}
pub struct AuthorizedHttpClient {
auth_manager: Arc<AuthorizationManager>,
inner_client: HttpClient,
}
impl AuthorizedHttpClient {
pub fn new(auth_manager: Arc<AuthorizationManager>, client: Option<HttpClient>) -> Self {
let inner_client = client.unwrap_or_default();
Self {
auth_manager,
inner_client,
}
}
pub async fn request<U: IntoUrl>(
&self,
method: reqwest::Method,
url: U,
) -> Result<reqwest::RequestBuilder, AuthError> {
let request = self.inner_client.request(method, url);
self.auth_manager.prepare_request(request).await
}
pub async fn get<U: IntoUrl>(&self, url: U) -> Result<reqwest::Response, AuthError> {
let request = self.request(reqwest::Method::GET, url).await?;
let response = request.send().await?;
self.auth_manager.handle_response(response).await
}
pub async fn post<U: IntoUrl>(&self, url: U) -> Result<reqwest::RequestBuilder, AuthError> {
self.request(reqwest::Method::POST, url).await
}
}
pub enum OAuthState {
Unauthorized(AuthorizationManager),
Session(AuthorizationSession),
Authorized(AuthorizationManager),
AuthorizedHttpClient(AuthorizedHttpClient),
}
impl OAuthState {
pub async fn new<U: IntoUrl>(
base_url: U,
client: Option<HttpClient>,
) -> Result<Self, AuthError> {
let mut manager = AuthorizationManager::new(base_url).await?;
if let Some(client) = client {
manager.with_client(client)?;
}
Ok(OAuthState::Unauthorized(manager))
}
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
match self {
OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => {
manager.get_credentials().await
}
OAuthState::Session(session) => session.get_credentials().await,
OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await,
}
}
pub async fn set_credentials(
&mut self,
client_id: &str,
credentials: OAuthTokenResponse,
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(manager) = self {
let mut manager = std::mem::replace(
manager,
AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?,
);
let granted_scopes: Vec<String> = credentials
.scopes()
.map(|scopes| scopes.iter().map(|s| s.to_string()).collect())
.unwrap_or_default();
*manager.current_scopes.write().await = granted_scopes.clone();
let stored = StoredCredentials {
client_id: client_id.to_string(),
token_response: Some(credentials),
granted_scopes,
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await?;
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);
manager.configure_client_id(client_id)?;
*self = OAuthState::Authorized(manager);
Ok(())
} else {
Err(AuthError::InternalError(
"Cannot set credentials in this state".to_string(),
))
}
}
pub async fn start_authorization(
&mut self,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
) -> Result<(), AuthError> {
self.start_authorization_with_metadata_url(scopes, redirect_uri, client_name, None)
.await
}
pub async fn start_authorization_with_metadata_url(
&mut self,
scopes: &[&str],
redirect_uri: &str,
client_name: Option<&str>,
client_metadata_url: Option<&str>,
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(mut manager) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
debug!("start discovery");
let metadata = manager.discover_metadata().await?;
manager.metadata = Some(metadata);
let selected_scopes: Vec<String> = if scopes.is_empty() {
manager.select_scopes(None, &[])
} else {
scopes.iter().map(|s| s.to_string()).collect()
};
let scope_refs: Vec<&str> = selected_scopes.iter().map(|s| s.as_str()).collect();
debug!("start session");
let session = AuthorizationSession::new(
manager,
&scope_refs,
redirect_uri,
client_name,
client_metadata_url,
)
.await?;
*self = OAuthState::Session(session);
Ok(())
} else {
Err(AuthError::InternalError(
"Already in session state".to_string(),
))
}
}
pub async fn complete_authorization(&mut self) -> Result<(), AuthError> {
if let OAuthState::Session(session) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::Authorized(session.auth_manager);
Ok(())
} else {
Err(AuthError::InternalError("Not in session state".to_string()))
}
}
pub async fn to_authorized_http_client(&mut self) -> Result<(), AuthError> {
if let OAuthState::Authorized(manager) = std::mem::replace(
self,
OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::AuthorizedHttpClient(AuthorizedHttpClient::new(
Arc::new(manager),
None,
));
Ok(())
} else {
Err(AuthError::InternalError(
"Not in authorized state".to_string(),
))
}
}
pub async fn request_scope_upgrade(
&mut self,
required_scope: &str,
redirect_uri: &str,
) -> Result<String, AuthError> {
let placeholder =
OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?);
let old = std::mem::replace(self, placeholder);
let OAuthState::Authorized(manager) = old else {
*self = old;
return Err(AuthError::InternalError(
"Not in authorized state".to_string(),
));
};
let auth_url = match manager.request_scope_upgrade(required_scope).await {
Ok(url) => url,
Err(e) => {
*self = OAuthState::Authorized(manager);
return Err(e);
}
};
let session =
AuthorizationSession::for_scope_upgrade(manager, auth_url.clone(), redirect_uri);
*self = OAuthState::Session(session);
Ok(auth_url)
}
pub async fn get_authorization_url(&self) -> Result<String, AuthError> {
match self {
OAuthState::Session(session) => Ok(session.get_authorization_url().to_string()),
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in session state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn handle_callback(&mut self, code: &str, csrf_token: &str) -> Result<(), AuthError> {
match self {
OAuthState::Session(session) => {
session.handle_callback(code, csrf_token).await?;
self.complete_authorization().await
}
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in session state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn get_access_token(&self) -> Result<String, AuthError> {
match self {
OAuthState::Unauthorized(manager) => manager.get_access_token().await,
OAuthState::Session(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Authorized(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub async fn refresh_token(&self) -> Result<(), AuthError> {
match self {
OAuthState::Unauthorized(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Session(_) => {
Err(AuthError::InternalError("Not in manager state".to_string()))
}
OAuthState::Authorized(manager) => {
manager.refresh_token().await?;
Ok(())
}
OAuthState::AuthorizedHttpClient(_) => {
Err(AuthError::InternalError("Already authorized".to_string()))
}
}
}
pub fn into_authorization_manager(self) -> Option<AuthorizationManager> {
match self {
OAuthState::Authorized(manager) => Some(manager),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use oauth2::{AuthType, CsrfToken, PkceCodeVerifier};
use url::Url;
use super::{
AuthError, AuthorizationManager, AuthorizationMetadata, InMemoryStateStore,
OAuthClientConfig, ScopeUpgradeConfig, StateStore, StoredAuthorizationState, is_https_url,
};
use crate::transport::auth::VendorExtraTokenFields;
#[test]
fn test_is_https_url_scenarios() {
assert!(is_https_url("https://example.com/client-metadata.json"));
assert!(is_https_url("https://example.com/metadata?version=1"));
assert!(!is_https_url("https://example.com"));
assert!(!is_https_url("https://example.com/"));
assert!(!is_https_url("https://"));
assert!(!is_https_url("http://example.com/metadata"));
assert!(!is_https_url("not a url"));
assert!(!is_https_url(""));
assert!(!is_https_url("javascript:alert(1)"));
assert!(!is_https_url("data:text/html,<script>alert(1)</script>"));
}
#[test]
fn well_known_paths_root() {
let paths = AuthorizationManager::well_known_paths("/", "oauth-authorization-server");
assert_eq!(
paths,
vec!["/.well-known/oauth-authorization-server".to_string()]
);
}
#[test]
fn well_known_paths_with_suffix() {
let paths = AuthorizationManager::well_known_paths("/mcp", "oauth-authorization-server");
assert_eq!(
paths,
vec![
"/.well-known/oauth-authorization-server/mcp".to_string(),
"/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
]
);
}
#[test]
fn well_known_paths_trailing_slash() {
let paths =
AuthorizationManager::well_known_paths("/v1/mcp/", "oauth-authorization-server");
assert_eq!(
paths,
vec![
"/.well-known/oauth-authorization-server/v1/mcp".to_string(),
"/v1/mcp/.well-known/oauth-authorization-server".to_string(),
"/.well-known/oauth-authorization-server".to_string(),
]
);
}
#[test]
fn test_protected_resource_metadata_paths() {
let paths =
AuthorizationManager::well_known_paths("/mcp/example", "oauth-protected-resource");
assert!(paths.contains(&"/.well-known/oauth-protected-resource/mcp/example".to_string()));
assert!(paths.contains(&"/.well-known/oauth-protected-resource".to_string()));
}
#[test]
fn generate_discovery_urls() {
let base_url = Url::parse("https://auth.example.com").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 2);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration"
);
let base_url = Url::parse("https://auth.example.com/tenant1").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/tenant1"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/tenant1/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
let base_url = Url::parse("https://auth.example.com/v1/mcp/").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/v1/mcp"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/v1/mcp/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
let base_url = Url::parse("https://auth.example.com/tenant1/subtenant").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
assert_eq!(urls.len(), 4);
assert_eq!(
urls[0].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server/tenant1/subtenant"
);
assert_eq!(
urls[1].as_str(),
"https://auth.example.com/.well-known/openid-configuration/tenant1/subtenant"
);
assert_eq!(
urls[2].as_str(),
"https://auth.example.com/tenant1/subtenant/.well-known/openid-configuration"
);
assert_eq!(
urls[3].as_str(),
"https://auth.example.com/.well-known/oauth-authorization-server"
);
}
#[test]
fn test_discovery_urls_with_path_suffix() {
let base_url = Url::parse("https://mcp.example.com/mcp").unwrap();
let urls = AuthorizationManager::generate_discovery_urls(&base_url);
let canonical_oauth_fallback =
"https://mcp.example.com/.well-known/oauth-authorization-server";
assert!(
urls.iter().any(|u| u.as_str() == canonical_oauth_fallback),
"Expected discovery URLs to include canonical OAuth fallback '{}', but got: {:?}",
canonical_oauth_fallback,
urls.iter().map(|u| u.as_str()).collect::<Vec<_>>()
);
}
#[test]
fn parse_auth_param_value_handles_quoted_string() {
let fragment = r#""example", realm="foo""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "example");
assert_eq!(parsed.1, 9);
}
#[test]
fn parse_auth_param_value_handles_escaped_quotes_and_whitespace() {
let fragment = r#" "a\"b\\c" ,next=value"#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, r#"a"b\c"#);
assert_eq!(parsed.1, 12);
}
#[test]
fn parse_auth_param_value_handles_token_values() {
let fragment = " token,next";
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "token");
assert_eq!(parsed.1, 7);
}
#[test]
fn parse_auth_param_value_handles_semicolon_separated_tokens() {
let fragment = r#" https://example.com/meta; error="invalid_token""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "https://example.com/meta");
assert_eq!(&fragment[..parsed.1], " https://example.com/meta");
}
#[test]
fn parse_auth_param_value_handles_semicolon_after_quoted_value() {
let fragment = r#" "https://example.com/meta"; error="invalid_token""#;
let parsed = AuthorizationManager::parse_next_header_value(fragment).unwrap();
assert_eq!(parsed.0, "https://example.com/meta");
assert_eq!(&fragment[..parsed.1], r#" "https://example.com/meta""#);
}
#[test]
fn parse_auth_param_value_returns_none_for_unterminated_quotes() {
let fragment = r#""unterminated,value"#;
assert!(AuthorizationManager::parse_next_header_value(fragment).is_none());
}
#[test]
fn parses_resource_metadata_parameter() {
let header = r#"Bearer error="invalid_request", error_description="missing token", resource_metadata="https://example.com/.well-known/oauth-protected-resource/api""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource/api"
);
}
#[test]
fn parses_relative_resource_metadata_parameter() {
let header = r#"Bearer error="invalid_request", resource_metadata="/.well-known/oauth-protected-resource/api""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource/api"
);
}
#[test]
fn extract_www_authenticate_params_with_all_fields() {
let header = r#"Bearer error="invalid_token", resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read:data write:data", error_description="token expired""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource"
);
assert_eq!(params.scope.unwrap(), "read:data write:data");
assert_eq!(params.error.unwrap(), "invalid_token");
assert_eq!(params.error_description.unwrap(), "token expired");
}
#[test]
fn extract_www_authenticate_params_insufficient_scope() {
let header = r#"Bearer error="insufficient_scope", scope="admin:write", error_description="Additional file write permission required""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.is_insufficient_scope());
assert!(!params.is_invalid_token());
assert_eq!(params.scope.unwrap(), "admin:write");
assert_eq!(
params.error_description.unwrap(),
"Additional file write permission required"
);
}
#[test]
fn extract_www_authenticate_params_with_only_resource_metadata() {
let header = r#"Bearer resource_metadata="/.well-known/oauth-protected-resource""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(
params.resource_metadata_url.unwrap().as_str(),
"https://example.com/.well-known/oauth-protected-resource"
);
assert!(params.scope.is_none());
}
#[test]
fn extract_www_authenticate_params_bare_bearer() {
let header = "Bearer";
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.scope.is_none());
assert!(params.error.is_none());
assert!(params.error_description.is_none());
}
#[test]
fn extract_www_authenticate_params_error_only() {
let header = r#"Bearer error="invalid_token""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert!(params.resource_metadata_url.is_none());
assert!(params.scope.is_none());
assert!(params.is_invalid_token());
assert!(!params.is_insufficient_scope());
assert!(params.error_description.is_none());
}
#[test]
fn extract_www_authenticate_params_with_unquoted_scope() {
let header = r#"Bearer scope=read:data, error="insufficient_scope""#;
let base = Url::parse("https://example.com/api").unwrap();
let params = AuthorizationManager::extract_www_authenticate_params(header, &base);
assert_eq!(params.scope.unwrap(), "read:data");
}
#[test]
fn test_stored_authorization_state_serialization() {
let pkce = PkceCodeVerifier::new("my-verifier".to_string());
let csrf = CsrfToken::new("my-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let json = serde_json::to_string(&state).unwrap();
let deserialized: StoredAuthorizationState = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.pkce_verifier, "my-verifier");
assert_eq!(deserialized.csrf_token, "my-csrf");
}
#[test]
fn test_stored_authorization_state_into_pkce_verifier() {
let pkce = PkceCodeVerifier::new("original-verifier".to_string());
let csrf = CsrfToken::new("csrf-token".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
let recovered = state.into_pkce_verifier();
assert_eq!(recovered.secret(), "original-verifier");
}
#[test]
fn test_stored_authorization_state_created_at() {
let pkce = PkceCodeVerifier::new("verifier".to_string());
let csrf = CsrfToken::new("csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
assert!(state.created_at > 1577836800); }
#[tokio::test]
async fn test_in_memory_state_store_save_and_load() {
let store = InMemoryStateStore::new();
let pkce = PkceCodeVerifier::new("test-verifier".to_string());
let csrf = CsrfToken::new("test-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("test-csrf", state).await.unwrap();
let loaded = store.load("test-csrf").await.unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.csrf_token, "test-csrf");
assert_eq!(loaded.pkce_verifier, "test-verifier");
}
#[tokio::test]
async fn test_in_memory_state_store_load_nonexistent() {
let store = InMemoryStateStore::new();
let result = store.load("nonexistent").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_state_store_delete() {
let store = InMemoryStateStore::new();
let pkce = PkceCodeVerifier::new("verifier".to_string());
let csrf = CsrfToken::new("csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("csrf", state).await.unwrap();
store.delete("csrf").await.unwrap();
let result = store.load("csrf").await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_in_memory_state_store_overwrite() {
let store = InMemoryStateStore::new();
let csrf_key = "same-csrf";
let pkce1 = PkceCodeVerifier::new("verifier-1".to_string());
let csrf1 = CsrfToken::new(csrf_key.to_string());
let state1 = StoredAuthorizationState::new(&pkce1, &csrf1);
store.save(csrf_key, state1).await.unwrap();
let pkce2 = PkceCodeVerifier::new("verifier-2".to_string());
let csrf2 = CsrfToken::new(csrf_key.to_string());
let state2 = StoredAuthorizationState::new(&pkce2, &csrf2);
store.save(csrf_key, state2).await.unwrap();
let loaded = store.load(csrf_key).await.unwrap().unwrap();
assert_eq!(loaded.pkce_verifier, "verifier-2");
}
#[tokio::test]
async fn test_in_memory_state_store_concurrent_access() {
let store = Arc::new(InMemoryStateStore::new());
let mut handles = vec![];
for i in 0..10 {
let store = Arc::clone(&store);
let handle = tokio::spawn(async move {
let csrf_key = format!("csrf-{}", i);
let verifier = format!("verifier-{}", i);
let pkce = PkceCodeVerifier::new(verifier.clone());
let csrf = CsrfToken::new(csrf_key.clone());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save(&csrf_key, state).await.unwrap();
let loaded = store.load(&csrf_key).await.unwrap().unwrap();
assert_eq!(loaded.pkce_verifier, verifier);
store.delete(&csrf_key).await.unwrap();
let deleted = store.load(&csrf_key).await.unwrap();
assert!(deleted.is_none());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
#[tokio::test]
async fn test_custom_state_store_with_authorization_manager() {
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Default)]
struct TrackingStateStore {
inner: InMemoryStateStore,
save_count: AtomicUsize,
load_count: AtomicUsize,
delete_count: AtomicUsize,
}
#[async_trait::async_trait]
impl StateStore for TrackingStateStore {
async fn save(
&self,
csrf_token: &str,
state: StoredAuthorizationState,
) -> Result<(), AuthError> {
self.save_count.fetch_add(1, Ordering::SeqCst);
self.inner.save(csrf_token, state).await
}
async fn load(
&self,
csrf_token: &str,
) -> Result<Option<StoredAuthorizationState>, AuthError> {
self.load_count.fetch_add(1, Ordering::SeqCst);
self.inner.load(csrf_token).await
}
async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> {
self.delete_count.fetch_add(1, Ordering::SeqCst);
self.inner.delete(csrf_token).await
}
}
let store = TrackingStateStore::default();
let pkce = PkceCodeVerifier::new("test-verifier".to_string());
let csrf = CsrfToken::new("test-csrf".to_string());
let state = StoredAuthorizationState::new(&pkce, &csrf);
store.save("test-csrf", state).await.unwrap();
assert_eq!(store.save_count.load(Ordering::SeqCst), 1);
let _ = store.load("test-csrf").await.unwrap();
assert_eq!(store.load_count.load(Ordering::SeqCst), 1);
store.delete("test-csrf").await.unwrap();
assert_eq!(store.delete_count.load(Ordering::SeqCst), 1);
let mut manager = AuthorizationManager::new("http://localhost").await.unwrap();
manager.set_state_store(TrackingStateStore::default());
}
async fn manager_with_metadata(
metadata_override: Option<AuthorizationMetadata>,
) -> AuthorizationManager {
let mut mgr = AuthorizationManager::new("http://localhost").await.unwrap();
mgr.set_metadata(metadata_override.unwrap_or(AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
..Default::default()
}));
mgr
}
fn test_client_config() -> OAuthClientConfig {
OAuthClientConfig {
client_id: "my-client".to_string(),
client_secret: Some("my-secret".to_string()),
scopes: vec![],
redirect_uri: "http://localhost/callback".to_string(),
}
}
#[tokio::test]
async fn test_configure_client_uses_client_secret_post_from_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::RequestBody
));
}
#[tokio::test]
async fn test_configure_client_defaults_to_basic_auth() {
let mut mgr = manager_with_metadata(None).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_with_explicit_basic_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_ignores_unsupported_auth_methods_in_metadata() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["private_key_jwt"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[tokio::test]
async fn test_configure_client_prefers_basic_when_both_methods_supported() {
let mut additional_fields = HashMap::new();
additional_fields.insert(
"token_endpoint_auth_methods_supported".to_string(),
serde_json::json!(["client_secret_post", "client_secret_basic"]),
);
let meta = AuthorizationMetadata {
authorization_endpoint: "http://localhost/authorize".to_string(),
token_endpoint: "http://localhost/token".to_string(),
additional_fields,
..Default::default()
};
let mut mgr = manager_with_metadata(Some(meta)).await;
mgr.configure_client(test_client_config()).unwrap();
assert!(matches!(
mgr.oauth_client.as_ref().unwrap().auth_type(),
AuthType::BasicAuth
));
}
#[test]
fn test_code_challenge_methods_supported_deserialization() {
let json = r#"{
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"code_challenge_methods_supported": ["S256", "plain"]
}"#;
let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap();
let methods = metadata.code_challenge_methods_supported.unwrap();
assert!(methods.contains(&"S256".to_string()));
assert!(methods.contains(&"plain".to_string()));
}
#[test]
fn test_code_challenge_methods_supported_missing_from_json() {
let json = r#"{
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token"
}"#;
let metadata: AuthorizationMetadata = serde_json::from_str(json).unwrap();
assert!(metadata.code_challenge_methods_supported.is_none());
}
#[tokio::test]
async fn test_validate_as_metadata_rejects_unsupported_response_type() {
let mut manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: "https://auth.example.com/authorize".to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
response_types_supported: Some(vec!["token".to_string()]),
..Default::default()
};
manager.set_metadata(metadata);
assert!(manager.validate_server_metadata("code").is_err());
}
#[tokio::test]
async fn test_validate_as_metadata_passes_without_pkce_s256() {
let mut manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: "https://auth.example.com/authorize".to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
response_types_supported: Some(vec!["code".to_string()]),
code_challenge_methods_supported: Some(vec!["plain".to_string()]),
..Default::default()
};
manager.set_metadata(metadata);
assert!(manager.validate_server_metadata("code").is_ok());
}
#[tokio::test]
async fn test_validate_as_metadata_passes_without_metadata() {
let manager = AuthorizationManager::new("https://example.com")
.await
.unwrap();
assert!(manager.validate_server_metadata("code").is_ok());
}
#[tokio::test]
async fn test_authorization_url_is_valid() {
let base_url = "https://mcp.example.com/api";
let auth_endpoint = "https://auth.example.com/authorize";
let mut manager = AuthorizationManager::new(base_url).await.unwrap();
let metadata = AuthorizationMetadata {
authorization_endpoint: auth_endpoint.to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
registration_endpoint: None,
issuer: None,
jwks_uri: None,
scopes_supported: None,
response_types_supported: Some(vec!["code".to_string()]),
code_challenge_methods_supported: Some(vec!["S256".to_string()]),
additional_fields: std::collections::HashMap::new(),
};
manager.set_metadata(metadata);
manager.configure_client_id("test-client-id").unwrap();
let auth_url = manager
.get_authorization_url(&["read", "write"])
.await
.unwrap();
let parsed = Url::parse(&auth_url).unwrap();
assert!(auth_url.starts_with(auth_endpoint));
let params: std::collections::HashMap<_, _> = parsed.query_pairs().collect();
assert_eq!(
params.get("response_type").map(|v| v.as_ref()),
Some("code")
);
assert_eq!(
params.get("client_id").map(|v| v.as_ref()),
Some("test-client-id")
);
assert!(params.contains_key("state"));
assert_eq!(
params.get("redirect_uri").map(|v| v.as_ref()),
Some(base_url)
);
assert!(params.contains_key("code_challenge"));
assert_eq!(
params.get("code_challenge_method").map(|v| v.as_ref()),
Some("S256")
);
assert_eq!(params.get("resource").map(|v| v.as_ref()), Some(base_url));
let scope = params
.get("scope")
.map(|v| v.to_string())
.unwrap_or_default();
assert!(scope.contains("read"));
assert!(scope.contains("write"));
}
#[test]
fn compute_scope_union_adds_new_scopes() {
let current = vec!["read".to_string(), "write".to_string()];
let result = AuthorizationManager::compute_scope_union(¤t, "admin delete");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert!(result.contains(&"admin".to_string()));
assert!(result.contains(&"delete".to_string()));
assert_eq!(result.len(), 4);
}
#[test]
fn compute_scope_union_deduplicates() {
let current = vec!["read".to_string(), "write".to_string()];
let result = AuthorizationManager::compute_scope_union(¤t, "read admin");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert!(result.contains(&"admin".to_string()));
assert_eq!(result.len(), 3);
}
#[test]
fn compute_scope_union_handles_empty_current() {
let current: Vec<String> = vec![];
let result = AuthorizationManager::compute_scope_union(¤t, "read write");
assert!(result.contains(&"read".to_string()));
assert!(result.contains(&"write".to_string()));
assert_eq!(result.len(), 2);
}
#[test]
fn scope_upgrade_config_default_values() {
let config = ScopeUpgradeConfig::default();
assert_eq!(config.max_upgrade_attempts, 3);
assert!(config.auto_upgrade);
}
#[tokio::test]
async fn authorization_manager_tracks_scope_upgrade_attempts() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
assert_eq!(manager.get_scope_upgrade_attempts().await, 0);
*manager.scope_upgrade_attempts.write().await = 2;
assert_eq!(manager.get_scope_upgrade_attempts().await, 2);
manager.reset_scope_upgrade_attempts().await;
assert_eq!(manager.get_scope_upgrade_attempts().await, 0);
}
#[tokio::test]
async fn authorization_manager_can_attempt_scope_upgrade_respects_config() {
let mut manager = AuthorizationManager::new("http://localhost").await.unwrap();
assert!(manager.can_attempt_scope_upgrade().await);
manager.set_scope_upgrade_config(ScopeUpgradeConfig {
max_upgrade_attempts: 3,
auto_upgrade: false,
});
assert!(!manager.can_attempt_scope_upgrade().await);
manager.set_scope_upgrade_config(ScopeUpgradeConfig {
max_upgrade_attempts: 2,
auto_upgrade: true,
});
*manager.scope_upgrade_attempts.write().await = 2;
assert!(!manager.can_attempt_scope_upgrade().await);
*manager.scope_upgrade_attempts.write().await = 1;
assert!(manager.can_attempt_scope_upgrade().await);
}
use super::{OAuthTokenResponse, StoredCredentials};
fn make_token_response(access_token: &str, expires_in_secs: Option<u64>) -> OAuthTokenResponse {
use oauth2::{AccessToken, basic::BasicTokenType};
let mut resp = OAuthTokenResponse::new(
AccessToken::new(access_token.to_string()),
BasicTokenType::Bearer,
VendorExtraTokenFields {
..Default::default()
},
);
if let Some(secs) = expires_in_secs {
resp.set_expires_in(Some(&std::time::Duration::from_secs(secs)));
}
resp
}
#[tokio::test]
async fn get_access_token_returns_error_when_no_credentials() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(matches!(err, AuthError::AuthorizationRequired));
}
#[tokio::test]
async fn get_access_token_returns_token_when_not_expired() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("my-access-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs()),
};
manager.credential_store.save(stored).await.unwrap();
let token = manager.get_access_token().await.unwrap();
assert_eq!(token, "my-access-token");
}
#[tokio::test]
async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response("stale-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}"
);
}
#[tokio::test]
async fn get_access_token_returns_token_without_expiry_info() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("no-expiry-token", None)),
granted_scopes: vec![],
token_received_at: None,
};
manager.credential_store.save(stored).await.unwrap();
let token = manager.get_access_token().await.unwrap();
assert_eq!(token, "no-expiry-token");
}
#[tokio::test]
async fn get_access_token_requires_reauth_when_within_refresh_buffer() {
let mut manager = manager_with_metadata(None).await;
manager.configure_client(test_client_config()).unwrap();
let stored = StoredCredentials {
client_id: "my-client".to_string(),
token_response: Some(make_token_response("almost-expired", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 3590),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AuthorizationRequired),
"expected AuthorizationRequired when token is within refresh buffer, got: {err:?}"
);
}
#[tokio::test]
async fn get_access_token_propagates_internal_errors() {
let manager = AuthorizationManager::new("http://localhost").await.unwrap();
let stored = StoredCredentials {
client_id: "test".to_string(),
token_response: Some(make_token_response("stale-token", Some(3600))),
granted_scopes: vec![],
token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200),
};
manager.credential_store.save(stored).await.unwrap();
let err = manager.get_access_token().await.unwrap_err();
assert!(
matches!(err, AuthError::InternalError(_)),
"expected InternalError when OAuth client is not configured, got: {err:?}"
);
}
#[test]
fn client_registration_request_includes_scope_when_present() {
let req = super::ClientRegistrationRequest {
client_name: "test".to_string(),
redirect_uris: vec!["http://localhost/callback".to_string()],
grant_types: vec!["authorization_code".to_string()],
token_endpoint_auth_method: "none".to_string(),
response_types: vec!["code".to_string()],
scope: Some("read write".to_string()),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["scope"], "read write");
}
#[test]
fn client_registration_request_omits_scope_when_none() {
let req = super::ClientRegistrationRequest {
client_name: "test".to_string(),
redirect_uris: vec!["http://localhost/callback".to_string()],
grant_types: vec!["authorization_code".to_string()],
token_endpoint_auth_method: "none".to_string(),
response_types: vec!["code".to_string()],
scope: None,
};
let json = serde_json::to_value(&req).unwrap();
assert!(!json.as_object().unwrap().contains_key("scope"));
}
}