use std::collections::HashMap;
use std::time::Duration;
use anyhow::{Context, Result};
use http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use serde::Deserialize;
#[derive(Debug, Clone)]
pub enum AuthRequirement {
None,
Required(OAuthDiscovery),
}
#[derive(Debug, Clone)]
pub struct OAuthDiscovery {
pub authorization_server: String,
pub scopes: Vec<String>,
#[allow(dead_code)]
pub resource: String,
}
#[derive(Debug, Deserialize)]
struct ProtectedResourceMetadata {
#[serde(default)]
authorization_servers: Vec<String>,
#[serde(default)]
scopes_supported: Vec<String>,
}
pub async fn discover(
http_client: &reqwest::Client,
server_url: &str,
headers: &HashMap<HeaderName, HeaderValue>,
resource_override: Option<&str>,
) -> Result<AuthRequirement> {
tracing::debug!(server_url, "probing MCP server for auth requirements");
let resp = http_client
.get(server_url)
.headers(to_header_map(headers))
.header(http::header::ACCEPT, "application/json, text/event-stream")
.timeout(Duration::from_secs(10))
.send()
.await
.with_context(|| format!("probe request to {server_url} failed"))?;
let status = resp.status();
if status.is_success()
|| status == StatusCode::METHOD_NOT_ALLOWED
|| status == StatusCode::BAD_REQUEST
{
tracing::debug!(%status, "server reachable without auth");
return Ok(AuthRequirement::None);
}
if status != StatusCode::UNAUTHORIZED {
anyhow::bail!(
"unexpected response from {server_url}: {status}; refusing to assume OAuth flow"
);
}
if let Some(name) = supplied_authz_header(headers) {
let www_auth_hint = resp
.headers()
.get(http::header::WWW_AUTHENTICATE)
.and_then(|v| v.to_str().ok())
.map(|s| format!("; WWW-Authenticate: {s}"))
.unwrap_or_default();
anyhow::bail!(
"remote MCP server at {server_url} rejected the supplied --header '{name}: ...' with 401 Unauthorized{www_auth_hint}. \
Refusing to fall back to OAuth because a static credential was provided; \
check your token or omit the header to use OAuth"
);
}
let www_auth = resp
.headers()
.get(http::header::WWW_AUTHENTICATE)
.and_then(|v| v.to_str().ok())
.map(parse_www_authenticate)
.unwrap_or_default();
tracing::debug!(?www_auth, "parsed WWW-Authenticate header");
let resource = resource_override
.map(str::to_string)
.unwrap_or_else(|| server_url.to_string());
let mut prm_candidates: Vec<String> = Vec::new();
if let Some(url) = &www_auth.resource_metadata {
prm_candidates.push(url.clone());
}
prm_candidates.extend(well_known_prm_urls(server_url)?);
let mut prm: Option<ProtectedResourceMetadata> = None;
for url in &prm_candidates {
match fetch_prm(http_client, url).await {
Ok(Some(meta)) => {
tracing::debug!(prm_url = url, "fetched protected resource metadata");
prm = Some(meta);
break;
}
Ok(None) => continue,
Err(e) => {
tracing::debug!(prm_url = url, error = %e, "PRM fetch failed; trying next");
}
}
}
let mut scopes = Vec::new();
if let Some(s) = www_auth.scope.as_ref() {
scopes.extend(s.split_whitespace().map(str::to_string));
} else if let Some(meta) = &prm {
scopes.extend(meta.scopes_supported.iter().cloned());
}
let authorization_server = prm
.as_ref()
.and_then(|m| m.authorization_servers.first().cloned())
.unwrap_or_else(|| server_url.to_string());
Ok(AuthRequirement::Required(OAuthDiscovery {
authorization_server,
scopes,
resource,
}))
}
fn well_known_prm_urls(server_url: &str) -> Result<Vec<String>> {
let url = url::Url::parse(server_url).context("invalid server URL")?;
let origin = url.origin().ascii_serialization();
let path = url.path().trim_end_matches('/');
let mut out = Vec::with_capacity(2);
if !path.is_empty() && path != "/" {
out.push(format!(
"{origin}/.well-known/oauth-protected-resource{path}"
));
}
out.push(format!("{origin}/.well-known/oauth-protected-resource"));
Ok(out)
}
async fn fetch_prm(
http_client: &reqwest::Client,
url: &str,
) -> Result<Option<ProtectedResourceMetadata>> {
let resp = http_client
.get(url)
.header(http::header::ACCEPT, "application/json")
.timeout(Duration::from_secs(5))
.send()
.await?;
if resp.status().is_client_error() {
return Ok(None);
}
if !resp.status().is_success() {
anyhow::bail!("PRM fetch returned {}", resp.status());
}
let meta: ProtectedResourceMetadata = resp.json().await.context("decoding PRM JSON")?;
Ok(Some(meta))
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
struct WwwAuthenticate {
resource_metadata: Option<String>,
scope: Option<String>,
}
fn parse_www_authenticate(header: &str) -> WwwAuthenticate {
let mut out = WwwAuthenticate::default();
let body = header.trim_start();
let body = body.strip_prefix("Bearer ").unwrap_or(body);
let body = body.strip_prefix("bearer ").unwrap_or(body);
for part in split_top_level(body) {
let Some((k, v)) = part.split_once('=') else {
continue;
};
let key = k.trim().to_ascii_lowercase();
let value = v.trim().trim_matches('"').to_string();
match key.as_str() {
"resource_metadata" => out.resource_metadata = Some(value),
"scope" => out.scope = Some(value),
_ => {}
}
}
out
}
fn split_top_level(s: &str) -> Vec<&str> {
let mut out = Vec::new();
let mut start = 0;
let mut in_quotes = false;
for (i, ch) in s.char_indices() {
match ch {
'"' => in_quotes = !in_quotes,
',' if !in_quotes => {
out.push(&s[start..i]);
start = i + 1;
}
_ => {}
}
}
out.push(&s[start..]);
out
}
fn to_header_map(map: &HashMap<HeaderName, HeaderValue>) -> HeaderMap {
let mut hm = HeaderMap::with_capacity(map.len());
for (k, v) in map {
hm.insert(k.clone(), v.clone());
}
hm
}
fn supplied_authz_header(headers: &HashMap<HeaderName, HeaderValue>) -> Option<&'static str> {
if headers.contains_key(&http::header::AUTHORIZATION) {
Some("Authorization")
} else if headers.contains_key(&http::header::PROXY_AUTHORIZATION) {
Some("Proxy-Authorization")
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_basic_www_auth() {
let h = parse_www_authenticate(
r#"Bearer error="invalid_request", resource_metadata="https://x/.well-known/foo", scope="read write""#,
);
assert_eq!(
h.resource_metadata.as_deref(),
Some("https://x/.well-known/foo")
);
assert_eq!(h.scope.as_deref(), Some("read write"));
}
#[test]
fn parse_www_auth_without_bearer_prefix() {
let h = parse_www_authenticate(r#"resource_metadata="https://x/m""#);
assert_eq!(h.resource_metadata.as_deref(), Some("https://x/m"));
}
#[test]
fn well_known_paths_for_subpath() {
let urls = well_known_prm_urls("https://example.com/mcp/v1")
.expect("valid URL must yield PRM candidates");
assert_eq!(
urls,
vec![
"https://example.com/.well-known/oauth-protected-resource/mcp/v1".to_string(),
"https://example.com/.well-known/oauth-protected-resource".to_string()
]
);
}
#[test]
fn well_known_paths_for_root() {
let urls = well_known_prm_urls("https://example.com/")
.expect("valid URL must yield PRM candidates");
assert_eq!(
urls,
vec!["https://example.com/.well-known/oauth-protected-resource".to_string()]
);
}
#[test]
fn split_top_level_respects_quotes() {
let parts = split_top_level(r#"a=1, b="x,y", c=3"#);
assert_eq!(parts, vec!["a=1", r#" b="x,y""#, " c=3"]);
}
#[test]
fn detects_static_authorization_header() {
let mut h = HashMap::new();
h.insert(
http::header::AUTHORIZATION,
HeaderValue::from_static("Bearer xyz"),
);
assert_eq!(supplied_authz_header(&h), Some("Authorization"));
}
#[test]
fn detects_static_proxy_authorization_header() {
let mut h = HashMap::new();
h.insert(
http::header::PROXY_AUTHORIZATION,
HeaderValue::from_static("Bearer xyz"),
);
assert_eq!(supplied_authz_header(&h), Some("Proxy-Authorization"));
}
#[test]
fn no_authz_header_returns_none() {
let mut h = HashMap::new();
h.insert(
HeaderName::from_static("x-custom"),
HeaderValue::from_static("value"),
);
assert!(supplied_authz_header(&h).is_none());
}
use axum::Router;
use axum::extract::State;
use axum::http::{HeaderMap as AxumHeaderMap, StatusCode as AxumStatus};
use axum::response::IntoResponse;
use axum::routing::get;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
enum ProbeBehavior {
Ok,
MethodNotAllowed,
BadRequest,
Unauthorized { www_authenticate: Option<String> },
ServerError,
}
struct MockState {
probe: ProbeBehavior,
prm_body: Option<String>,
prm_hits: AtomicUsize,
}
async fn handle_probe(
State(state): State<Arc<MockState>>,
_headers: AxumHeaderMap,
) -> axum::response::Response {
match &state.probe {
ProbeBehavior::Ok => (AxumStatus::OK, "hello").into_response(),
ProbeBehavior::MethodNotAllowed => {
(AxumStatus::METHOD_NOT_ALLOWED, "nope").into_response()
}
ProbeBehavior::BadRequest => (AxumStatus::BAD_REQUEST, "bad").into_response(),
ProbeBehavior::ServerError => {
(AxumStatus::INTERNAL_SERVER_ERROR, "boom").into_response()
}
ProbeBehavior::Unauthorized { www_authenticate } => {
let mut headers = AxumHeaderMap::new();
if let Some(v) = www_authenticate {
headers.insert("WWW-Authenticate", v.parse().expect("valid header"));
}
(AxumStatus::UNAUTHORIZED, headers, "go away").into_response()
}
}
}
async fn handle_prm(State(state): State<Arc<MockState>>) -> axum::response::Response {
state.prm_hits.fetch_add(1, Ordering::SeqCst);
match &state.prm_body {
Some(body) => (
AxumStatus::OK,
[("content-type", "application/json")],
body.clone(),
)
.into_response(),
None => (AxumStatus::NOT_FOUND, "no prm").into_response(),
}
}
async fn spawn_mock(state: Arc<MockState>) -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new()
.route("/mcp", get(handle_probe))
.route("/.well-known/oauth-protected-resource/mcp", get(handle_prm))
.route("/.well-known/oauth-protected-resource", get(handle_prm))
.with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local_addr");
let handle = tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
(format!("http://{addr}"), handle)
}
fn empty_headers() -> HashMap<HeaderName, HeaderValue> {
HashMap::new()
}
#[tokio::test]
async fn discover_returns_none_on_2xx() {
let state = Arc::new(MockState {
probe: ProbeBehavior::Ok,
prm_body: None,
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let out = discover(&client, &format!("{base}/mcp"), &empty_headers(), None)
.await
.expect("discover");
assert!(matches!(out, AuthRequirement::None));
}
#[tokio::test]
async fn discover_returns_none_on_405() {
let state = Arc::new(MockState {
probe: ProbeBehavior::MethodNotAllowed,
prm_body: None,
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let out = discover(&client, &format!("{base}/mcp"), &empty_headers(), None)
.await
.expect("discover");
assert!(matches!(out, AuthRequirement::None));
}
#[tokio::test]
async fn discover_returns_none_on_400() {
let state = Arc::new(MockState {
probe: ProbeBehavior::BadRequest,
prm_body: None,
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let out = discover(&client, &format!("{base}/mcp"), &empty_headers(), None)
.await
.expect("discover");
assert!(matches!(out, AuthRequirement::None));
}
#[tokio::test]
async fn discover_errors_on_5xx() {
let state = Arc::new(MockState {
probe: ProbeBehavior::ServerError,
prm_body: None,
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let err = discover(&client, &format!("{base}/mcp"), &empty_headers(), None)
.await
.expect_err("5xx must not be silently treated as anonymous-OK");
let msg = format!("{err:#}");
assert!(
msg.contains("unexpected response") || msg.contains("500"),
"got: {msg}"
);
}
#[tokio::test]
async fn discover_bails_when_static_authz_header_rejected() {
let state = Arc::new(MockState {
probe: ProbeBehavior::Unauthorized {
www_authenticate: Some("Bearer error=\"invalid_token\"".to_string()),
},
prm_body: None,
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let mut headers = HashMap::new();
headers.insert(
http::header::AUTHORIZATION,
HeaderValue::from_static("Bearer not-real"),
);
let client = reqwest::Client::new();
let err = discover(&client, &format!("{base}/mcp"), &headers, None)
.await
.expect_err("static-credential 401 must not silently switch to OAuth");
let msg = format!("{err:#}");
assert!(msg.contains("Authorization"), "got: {msg}");
assert!(msg.contains("401"), "got: {msg}");
}
#[tokio::test]
async fn discover_resolves_oauth_requirement_from_prm() {
let prm = serde_json::json!({
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["read", "write"],
})
.to_string();
let state = Arc::new(MockState {
probe: ProbeBehavior::Unauthorized {
www_authenticate: None,
},
prm_body: Some(prm),
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let out = discover(&client, &format!("{base}/mcp"), &empty_headers(), None)
.await
.expect("discover");
match out {
AuthRequirement::Required(d) => {
assert_eq!(d.authorization_server, "https://auth.example.com");
assert_eq!(d.scopes, vec!["read".to_string(), "write".to_string()]);
}
AuthRequirement::None => panic!("expected Required, got None"),
}
}
#[tokio::test]
async fn discover_prefers_www_authenticate_scope_over_prm() {
let prm = serde_json::json!({
"authorization_servers": ["https://auth.example.com"],
"scopes_supported": ["prm-only"],
})
.to_string();
let state = Arc::new(MockState {
probe: ProbeBehavior::Unauthorized {
www_authenticate: Some(r#"Bearer scope="header-a header-b""#.to_string()),
},
prm_body: Some(prm),
prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let client = reqwest::Client::new();
let out = discover(
&client,
&format!("{base}/mcp"),
&empty_headers(),
Some("custom-resource"),
)
.await
.expect("discover");
match out {
AuthRequirement::Required(d) => {
assert_eq!(
d.scopes,
vec!["header-a".to_string(), "header-b".to_string()],
"header scope must win over PRM scopes_supported"
);
assert_eq!(d.resource, "custom-resource");
}
AuthRequirement::None => panic!("expected Required"),
}
}
#[tokio::test]
async fn discover_falls_back_to_server_url_when_no_prm() {
let state = Arc::new(MockState {
probe: ProbeBehavior::Unauthorized {
www_authenticate: None,
},
prm_body: None, prm_hits: AtomicUsize::new(0),
});
let (base, _h) = spawn_mock(state).await;
let server_url = format!("{base}/mcp");
let client = reqwest::Client::new();
let out = discover(&client, &server_url, &empty_headers(), None)
.await
.expect("discover");
match out {
AuthRequirement::Required(d) => {
assert_eq!(
d.authorization_server, server_url,
"server URL must be the fallback authorization server"
);
assert!(d.scopes.is_empty(), "no scope information available");
}
AuthRequirement::None => panic!("expected Required"),
}
}
}