use anyhow::{anyhow, Context, Result};
use regex::Regex;
use reqwest;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::Duration;
use super::cimd::resolve_client_id;
use super::OAuthConfig;
lazy_static::lazy_static! {
static ref DISCOVERED_OAUTH_CACHE: RwLock<HashMap<String, OAuthConfig>> = RwLock::new(HashMap::new());
}
pub fn has_cached_discovery(server_name: &str) -> bool {
DISCOVERED_OAUTH_CACHE
.read()
.map(|cache| cache.contains_key(server_name))
.unwrap_or(false)
}
#[derive(Debug, Deserialize)]
pub struct ProtectedResourceMetadata {
pub resource: String,
pub authorization_servers: Vec<String>,
#[serde(default)]
pub scopes_supported: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct AuthServerMetadata {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(default)]
pub scopes_supported: Option<Vec<String>>,
#[serde(default)]
pub code_challenge_methods_supported: Option<Vec<String>>,
#[serde(default)]
pub registration_endpoint: Option<String>,
#[serde(default)]
pub client_id_metadata_document_supported: Option<bool>,
}
pub fn parse_www_authenticate_header(header_value: &str) -> Result<String> {
let re = Regex::new(r#"resource_metadata="([^"]+)""#)
.context("Failed to compile regex for WWW-Authenticate parsing")?;
let captures = re.captures(header_value).ok_or_else(|| {
anyhow!(
"WWW-Authenticate header does not contain resource_metadata URL. Header: {}",
header_value
)
})?;
let url = captures
.get(1)
.ok_or_else(|| anyhow!("Failed to extract resource_metadata URL from captures"))?
.as_str()
.to_string();
crate::log_debug!("Extracted resource_metadata URL: {}", url);
Ok(url)
}
pub async fn fetch_protected_resource_metadata(
metadata_url: &str,
) -> Result<ProtectedResourceMetadata> {
crate::log_debug!(
"Fetching Protected Resource Metadata from: {}",
metadata_url
);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.context("Failed to create HTTP client")?;
let response = client.get(metadata_url).send().await.context(format!(
"Failed to fetch Protected Resource Metadata from {}",
metadata_url
))?;
if !response.status().is_success() {
return Err(anyhow!(
"Protected Resource Metadata request failed with status: {}",
response.status()
));
}
let metadata: ProtectedResourceMetadata = response
.json()
.await
.context("Failed to parse Protected Resource Metadata JSON")?;
crate::log_debug!(
"Protected Resource Metadata: resource={}, auth_servers={:?}",
metadata.resource,
metadata.authorization_servers
);
Ok(metadata)
}
pub async fn fetch_auth_server_metadata(issuer: &str) -> Result<AuthServerMetadata> {
let issuer_trimmed = issuer.trim_end_matches('/');
let metadata_url = format!("{}/.well-known/oauth-authorization-server", issuer_trimmed);
crate::log_debug!(
"Fetching Authorization Server Metadata from: {}",
metadata_url
);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.context("Failed to create HTTP client")?;
let response = client.get(&metadata_url).send().await.context(format!(
"Failed to fetch Authorization Server Metadata from {}",
metadata_url
))?;
if !response.status().is_success() {
return Err(anyhow!(
"Authorization Server Metadata request failed with status: {} (RFC 8414 discovery at {})",
response.status(),
metadata_url
));
}
let metadata: AuthServerMetadata = response
.json()
.await
.context("Failed to parse Authorization Server Metadata JSON")?;
crate::log_debug!(
"Authorization Server Metadata: issuer={}, auth_endpoint={}, token_endpoint={}",
metadata.issuer,
metadata.authorization_endpoint,
metadata.token_endpoint
);
Ok(metadata)
}
pub fn build_oauth_config_from_metadata(
auth_metadata: &AuthServerMetadata,
resource_metadata: &ProtectedResourceMetadata,
) -> OAuthConfig {
let scopes = resource_metadata
.scopes_supported
.as_ref()
.or(auth_metadata.scopes_supported.as_ref())
.cloned()
.unwrap_or_default();
crate::log_debug!("Building OAuthConfig: scopes={:?}", scopes);
OAuthConfig {
client_id: String::new(), client_secret: String::new(),
authorization_url: auth_metadata.authorization_endpoint.clone(),
token_url: auth_metadata.token_endpoint.clone(),
callback_url: "http://localhost:34567/oauth/callback".to_string(),
scopes,
state: None,
refresh_buffer_seconds: 300,
}
}
pub async fn discover_oauth_from_mcp_server(
server_url: &str,
server_name: &str,
) -> Result<OAuthConfig> {
{
let cache = DISCOVERED_OAUTH_CACHE.read().unwrap();
if let Some(cached_config) = cache.get(server_name) {
crate::log_debug!(
"Using cached OAuth config for server '{}' (skipping discovery)",
server_name
);
return Ok(cached_config.clone());
}
}
crate::log_debug!(
"Starting MCP Authorization discovery for server '{}' at {}",
server_name,
server_url
);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.context("Failed to create HTTP client for MCP discovery")?;
let server_url_trimmed = server_url.trim_end_matches('/');
let pre_discovery_url = format!(
"{}/.well-known/oauth-protected-resource",
server_url_trimmed
);
crate::log_debug!("Trying pre-discovery at: {}", pre_discovery_url);
let resource_metadata = match fetch_protected_resource_metadata(&pre_discovery_url).await {
Ok(metadata) => {
crate::log_debug!("Pre-discovery successful for server '{}'", server_name);
Some(metadata)
}
Err(e) => {
crate::log_debug!(
"Pre-discovery failed for server '{}': {}, falling back to 401 flow",
server_name,
e
);
None
}
};
let resource_metadata = match resource_metadata {
Some(m) => m,
None => {
crate::log_debug!("Making initial JSON-RPC request to MCP server (expecting 401)...");
let jsonrpc_request = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list",
"params": {}
});
let response = client
.post(server_url)
.header("Content-Type", "application/json")
.json(&jsonrpc_request)
.send()
.await
.context(format!("Failed to connect to MCP server at {}", server_url))?;
if response.status() != reqwest::StatusCode::UNAUTHORIZED {
return Err(anyhow!(
"MCP Authorization discovery requires 401 Unauthorized response, got: {}. \
Server may not support MCP Authorization (RFC 9728).",
response.status()
));
}
crate::log_debug!("Received 401 Unauthorized, proceeding with discovery...");
let www_auth_header = response
.headers()
.get("WWW-Authenticate")
.ok_or_else(|| {
anyhow!(
"MCP server returned 401 but missing WWW-Authenticate header. \
Server does not support MCP Authorization (RFC 9728)."
)
})?
.to_str()
.context("WWW-Authenticate header contains invalid UTF-8")?;
crate::log_debug!("WWW-Authenticate header: {}", www_auth_header);
let resource_metadata_url = parse_www_authenticate_header(www_auth_header)
.context("Failed to parse WWW-Authenticate header")?;
fetch_protected_resource_metadata(&resource_metadata_url)
.await
.context("Failed to fetch Protected Resource Metadata")?
}
};
let auth_server_issuer = resource_metadata
.authorization_servers
.first()
.ok_or_else(|| anyhow!("Protected Resource Metadata contains no authorization servers"))?;
crate::log_debug!("Using authorization server: {}", auth_server_issuer);
let auth_metadata = fetch_auth_server_metadata(auth_server_issuer)
.await
.context("Failed to fetch Authorization Server Metadata via RFC 8414")?;
let oauth_config = build_oauth_config_from_metadata(&auth_metadata, &resource_metadata);
let oauth_config = resolve_client_id(oauth_config, &auth_metadata)
.await
.context("Failed to resolve OAuth client_id via CIMD/DCR")?;
crate::log_debug!(
"MCP Authorization discovery completed successfully for '{}' (client_id: {})",
server_name,
if oauth_config.client_id.len() > 50 {
format!("{}...", &oauth_config.client_id[..50])
} else {
oauth_config.client_id.clone()
}
);
{
let mut cache = DISCOVERED_OAUTH_CACHE.write().unwrap();
cache.insert(server_name.to_string(), oauth_config.clone());
crate::log_debug!(
"Cached OAuth config for server '{}' to avoid repeated discovery",
server_name
);
}
Ok(oauth_config)
}
pub fn clear_discovered_oauth_cache(server_name: &str) {
let mut cache = DISCOVERED_OAUTH_CACHE.write().unwrap();
if cache.remove(server_name).is_some() {
crate::log_debug!("Cleared cached OAuth config for server '{}'", server_name);
}
}
pub fn clear_all_discovered_oauth_cache() {
let mut cache = DISCOVERED_OAUTH_CACHE.write().unwrap();
let count = cache.len();
cache.clear();
crate::log_debug!("Cleared all {} cached OAuth configs", count);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_www_authenticate_header() {
let header = r#"Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource""#;
let result = parse_www_authenticate_header(header).unwrap();
assert_eq!(
result,
"https://api.example.com/.well-known/oauth-protected-resource"
);
}
#[test]
fn test_parse_www_authenticate_header_invalid() {
let header = "Bearer realm=\"example\"";
let result = parse_www_authenticate_header(header);
assert!(result.is_err());
}
#[test]
fn test_build_oauth_config() {
let auth_metadata = AuthServerMetadata {
issuer: "https://api.example.com".to_string(),
authorization_endpoint: "https://api.example.com/oauth/authorize".to_string(),
token_endpoint: "https://api.example.com/oauth/token".to_string(),
scopes_supported: Some(vec!["read".to_string(), "write".to_string()]),
code_challenge_methods_supported: Some(vec!["S256".to_string()]),
registration_endpoint: None,
client_id_metadata_document_supported: None,
};
let resource_metadata = ProtectedResourceMetadata {
resource: "https://api.example.com".to_string(),
authorization_servers: vec!["https://api.example.com".to_string()],
scopes_supported: None,
};
let config = build_oauth_config_from_metadata(&auth_metadata, &resource_metadata);
assert!(config.client_id.is_empty());
assert_eq!(
config.authorization_url,
"https://api.example.com/oauth/authorize"
);
assert_eq!(config.token_url, "https://api.example.com/oauth/token");
assert_eq!(config.scopes, vec!["read", "write"]);
assert!(config.client_secret.is_empty());
}
}