use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::{Arc, Weak};
use std::time::Duration;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::net::TcpListener;
use crate::cli::oauth_defaults::{self, OAUTH_CALLBACK_PORT};
use crate::secrets::{CreateSecretParams, SecretsStore};
use crate::tools::mcp::config::McpServerConfig;
fn oauth_http_client() -> Result<&'static reqwest::Client, AuthError> {
static CLIENT: std::sync::OnceLock<Result<reqwest::Client, AuthError>> =
std::sync::OnceLock::new();
CLIENT
.get_or_init(|| {
reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|e| AuthError::Http(e.to_string()))
})
.as_ref()
.map_err(Clone::clone)
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct RefreshLockKey {
server_name: String,
user_id: String,
}
fn refresh_lock_key(server_name: &str, user_id: &str) -> RefreshLockKey {
RefreshLockKey {
server_name: server_name.to_string(),
user_id: user_id.to_string(),
}
}
async fn refresh_lock(server_name: &str, user_id: &str) -> Arc<tokio::sync::Mutex<()>> {
static LOCKS: std::sync::OnceLock<
tokio::sync::Mutex<HashMap<RefreshLockKey, Weak<tokio::sync::Mutex<()>>>>,
> = std::sync::OnceLock::new();
let registry = LOCKS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()));
let mut locks = registry.lock().await;
locks.retain(|_, lock| lock.strong_count() > 0);
let key = refresh_lock_key(server_name, user_id);
if let Some(lock) = locks.get(&key).and_then(Weak::upgrade) {
return lock;
}
let lock = Arc::new(tokio::sync::Mutex::new(()));
locks.insert(key, Arc::downgrade(&lock));
lock
}
fn log_redirect_if_applicable(url: &str, response: &reqwest::Response) {
if response.status().is_redirection() {
let location = response
.headers()
.get("location")
.and_then(|v| v.to_str().ok());
tracing::debug!(
"OAuth request to '{}' returned redirect {} -> {:?} (redirects disabled for security)",
url,
response.status(),
location
);
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum AuthError {
#[error("Server does not support OAuth authorization")]
NotSupported,
#[error("Failed to discover authorization endpoints: {0}")]
DiscoveryFailed(String),
#[error("Authorization denied by user")]
AuthorizationDenied,
#[error("Token exchange failed: {0}")]
TokenExchangeFailed(String),
#[error("Token expired and refresh failed: {0}")]
RefreshFailed(String),
#[error("No access token available")]
NoToken,
#[error("Timeout waiting for authorization callback")]
Timeout,
#[error("Could not bind to callback port")]
PortUnavailable,
#[error("HTTP error: {0}")]
Http(String),
#[error("Secrets error: {0}")]
Secrets(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtectedResourceMetadata {
pub resource: String,
#[serde(default)]
pub authorization_servers: Vec<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationServerMetadata {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(default)]
pub registration_endpoint: Option<String>,
#[serde(default)]
pub response_types_supported: Vec<String>,
#[serde(default)]
pub grant_types_supported: Vec<String>,
#[serde(default)]
pub code_challenge_methods_supported: Vec<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ClientRegistrationRequest {
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub response_types: Vec<String>,
pub token_endpoint_auth_method: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(default)]
pub client_secret_expires_at: Option<u64>,
#[serde(default)]
pub registration_access_token: Option<String>,
#[serde(default)]
pub registration_client_uri: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AccessToken {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Clone)]
struct ClientCredentials {
client_id: String,
client_secret: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
scope: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PkceChallenge {
pub verifier: String,
pub challenge: String,
}
impl PkceChallenge {
pub fn generate() -> Self {
let mut verifier_bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut verifier_bytes);
let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
Self {
verifier,
challenge,
}
}
}
pub fn build_well_known_uri(base_url: &str, suffix: &str) -> Result<String, AuthError> {
let parsed = reqwest::Url::parse(base_url)
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid URL: {}", e)))?;
let origin = parsed.origin().ascii_serialization();
let path = parsed.path().trim_end_matches('/');
Ok(format!("{}/.well-known/{}{}", origin, suffix, path))
}
pub fn canonical_resource_uri(server_url: &str) -> String {
match reqwest::Url::parse(server_url) {
Ok(mut parsed) => {
parsed.set_fragment(None);
let s = parsed.to_string();
s.trim_end_matches('/').to_string()
}
Err(_) => server_url.trim_end_matches('/').to_string(),
}
}
fn is_dangerous_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_broadcast()
|| v4.is_unspecified()
|| (v4.octets()[0] == 169 && v4.octets()[1] == 254) || (v4.octets()[0] == 100 && (v4.octets()[1] & 0xC0) == 64) }
IpAddr::V6(v6) => {
let segs = v6.segments();
v6.is_loopback()
|| v6.is_unspecified()
|| (segs[0] & 0xffc0) == 0xfe80
|| (segs[0] & 0xffc0) == 0xfec0
|| (segs[0] & 0xfe00) == 0xfc00
|| (segs[0] == 0x2001 && segs[1] == 0x0db8)
|| v6
.to_ipv4_mapped()
.is_some_and(|v4| is_dangerous_ip(IpAddr::V4(v4)))
}
}
}
async fn validate_url_safe(url: &str) -> Result<(), AuthError> {
let parsed = reqwest::Url::parse(url)
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid URL: {}", e)))?;
let scheme = parsed.scheme();
if scheme != "https" && scheme != "http" {
return Err(AuthError::DiscoveryFailed(format!(
"Unsupported scheme: {}",
scheme
)));
}
if scheme == "http" {
if !crate::tools::mcp::config::is_localhost_url(url) {
let host = parsed.host_str().unwrap_or("");
return Err(AuthError::DiscoveryFailed(format!(
"HTTP is only allowed for localhost; use HTTPS for '{}'",
host
)));
}
return Ok(());
}
let host = parsed
.host_str()
.ok_or_else(|| AuthError::DiscoveryFailed("URL has no host".to_string()))?;
if let Ok(ip) = host.parse::<IpAddr>()
&& is_dangerous_ip(ip)
{
return Err(AuthError::DiscoveryFailed(format!(
"URL points to a restricted IP address: {}",
host
)));
}
if host.parse::<IpAddr>().is_err() {
let addr = format!("{}:{}", host, parsed.port_or_known_default().unwrap_or(443));
match tokio::net::lookup_host(&addr).await {
Ok(addrs) => {
for socket_addr in addrs {
if is_dangerous_ip(socket_addr.ip()) {
return Err(AuthError::DiscoveryFailed(format!(
"URL hostname '{}' resolves to restricted IP address: {}",
host,
socket_addr.ip()
)));
}
}
}
Err(e) => {
return Err(AuthError::DiscoveryFailed(format!(
"DNS resolution failed for '{}': {}",
host, e
)));
}
}
}
Ok(())
}
fn parse_resource_metadata_url(www_authenticate: &str) -> Option<String> {
for part in www_authenticate.split(',') {
let part = part.trim();
if let Some(rest) = part.strip_prefix("resource_metadata=\"") {
return rest.strip_suffix('"').map(|s| s.to_string());
}
if let Some(rest) = part.strip_prefix("resource_metadata=") {
let val = rest.trim_matches('"');
return Some(val.to_string());
}
}
for part in www_authenticate.split_whitespace() {
if let Some(rest) = part.strip_prefix("resource_metadata=\"") {
return rest
.trim_end_matches(',')
.strip_suffix('"')
.map(|s| s.to_string());
}
if let Some(rest) = part.strip_prefix("resource_metadata=") {
let val = rest.trim_matches('"').trim_end_matches(',');
return Some(val.to_string());
}
}
None
}
async fn fetch_resource_metadata(url: &str) -> Result<ProtectedResourceMetadata, AuthError> {
validate_url_safe(url).await?;
let client = oauth_http_client()?;
let response = client
.get(url)
.timeout(Duration::from_secs(10))
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
log_redirect_if_applicable(url, &response);
if !response.status().is_success() {
return Err(AuthError::DiscoveryFailed(format!(
"HTTP {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid metadata: {}", e)))
}
async fn discover_via_401(server_url: &str) -> Result<AuthorizationServerMetadata, AuthError> {
validate_url_safe(server_url).await?;
let client = oauth_http_client()?;
let response = client
.post(server_url)
.timeout(Duration::from_secs(10))
.header("Content-Type", "application/json")
.body("{}")
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
log_redirect_if_applicable(server_url, &response);
let status = response.status().as_u16();
if status != 401 && status != 400 {
return Err(AuthError::DiscoveryFailed(format!(
"Expected 401 or 400, got {}",
response.status()
)));
}
let www_auth = response
.headers()
.get("WWW-Authenticate")
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
AuthError::DiscoveryFailed(format!("No WWW-Authenticate header in {} response", status))
})?;
let resource_metadata_url = parse_resource_metadata_url(www_auth).ok_or_else(|| {
AuthError::DiscoveryFailed(
"No resource_metadata URL in WWW-Authenticate header".to_string(),
)
})?;
let resource_meta = fetch_resource_metadata(&resource_metadata_url).await?;
try_discover_from_auth_servers(&resource_meta).await
}
async fn try_discover_from_auth_servers(
resource_meta: &ProtectedResourceMetadata,
) -> Result<AuthorizationServerMetadata, AuthError> {
let auth_server_url = resource_meta
.authorization_servers
.first()
.ok_or_else(|| AuthError::DiscoveryFailed("No authorization servers listed".to_string()))?;
discover_authorization_server(auth_server_url).await
}
pub async fn discover_protected_resource(
server_url: &str,
) -> Result<ProtectedResourceMetadata, AuthError> {
validate_url_safe(server_url).await?;
let client = oauth_http_client()?;
let well_known_url = build_well_known_uri(server_url, "oauth-protected-resource")?;
let response = client
.get(&well_known_url)
.timeout(Duration::from_secs(10))
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
log_redirect_if_applicable(&well_known_url, &response);
if !response.status().is_success() {
return Err(AuthError::NotSupported);
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid metadata: {}", e)))
}
pub async fn discover_authorization_server(
auth_server_url: &str,
) -> Result<AuthorizationServerMetadata, AuthError> {
validate_url_safe(auth_server_url).await?;
let client = oauth_http_client()?;
let well_known_url = build_well_known_uri(auth_server_url, "oauth-authorization-server")?;
let response = client
.get(&well_known_url)
.timeout(Duration::from_secs(10))
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
log_redirect_if_applicable(&well_known_url, &response);
if !response.status().is_success() {
return Err(AuthError::DiscoveryFailed(format!(
"HTTP {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid metadata: {}", e)))
}
pub async fn discover_oauth_endpoints(
server_config: &McpServerConfig,
) -> Result<(String, String), AuthError> {
let oauth = server_config
.oauth
.as_ref()
.ok_or(AuthError::NotSupported)?;
if let (Some(auth_url), Some(token_url)) = (&oauth.authorization_url, &oauth.token_url) {
return Ok((auth_url.clone(), token_url.clone()));
}
let resource_meta = discover_protected_resource(&server_config.url).await?;
let auth_server_url = resource_meta
.authorization_servers
.first()
.ok_or_else(|| AuthError::DiscoveryFailed("No authorization servers listed".to_string()))?;
let auth_meta = discover_authorization_server(auth_server_url).await?;
Ok((auth_meta.authorization_endpoint, auth_meta.token_endpoint))
}
pub async fn discover_full_oauth_metadata(
server_url: &str,
) -> Result<AuthorizationServerMetadata, AuthError> {
if let Ok(meta) = discover_via_401(server_url).await {
return Ok(meta);
}
if let Ok(resource_meta) = discover_protected_resource(server_url).await
&& let Ok(meta) = try_discover_from_auth_servers(&resource_meta).await
{
return Ok(meta);
}
discover_authorization_server(server_url).await
}
pub async fn register_client(
registration_endpoint: &str,
redirect_uri: &str,
) -> Result<ClientRegistrationResponse, AuthError> {
validate_url_safe(registration_endpoint).await?;
let client = oauth_http_client()?;
let request = ClientRegistrationRequest {
client_name: "IronClaw".to_string(),
redirect_uris: vec![redirect_uri.to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
response_types: vec!["code".to_string()],
token_endpoint_auth_method: "none".to_string(), };
let response = client
.post(registration_endpoint)
.json(&request)
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("DCR request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::DiscoveryFailed(format!(
"DCR failed: HTTP {} - {}",
status, body
)));
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid DCR response: {}", e)))
}
pub async fn authorize_mcp_server(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<AccessToken, AuthError> {
let (listener, port) = find_available_port().await?;
let host = oauth_defaults::callback_host();
let redirect_uri = format!("http://{}:{}/callback", host, port);
if !oauth_defaults::is_loopback_host(&host) {
println!("Warning: MCP OAuth callback is using plain HTTP to a remote host ({host}).");
println!(" Authorization codes will be transmitted unencrypted.");
println!(" Consider SSH port forwarding instead:");
println!(" ssh -L {port}:127.0.0.1:{port} user@{host}");
}
let (
client_id,
client_secret,
client_secret_expires_at,
authorization_url,
token_url,
use_pkce,
scopes,
mut extra_params,
) = if let Some(oauth) = &server_config.oauth {
let (auth_url, tok_url) = discover_oauth_endpoints(server_config).await?;
(
oauth.client_id.clone(),
None,
None,
auth_url,
tok_url,
oauth.use_pkce,
oauth.scopes.clone(),
oauth.extra_params.clone(),
)
} else {
println!(" Discovering OAuth endpoints...");
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
let registration_endpoint = auth_meta
.registration_endpoint
.ok_or(AuthError::NotSupported)?;
println!(" Registering client dynamically...");
let registration = register_client(®istration_endpoint, &redirect_uri).await?;
println!(" Client registered: {}", registration.client_id);
(
registration.client_id,
registration.client_secret,
registration.client_secret_expires_at,
auth_meta.authorization_endpoint,
auth_meta.token_endpoint,
true, auth_meta.scopes_supported,
HashMap::new(),
)
};
let pkce = if use_pkce {
Some(PkceChallenge::generate())
} else {
None
};
let mut state_bytes = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut state_bytes);
let state = URL_SAFE_NO_PAD.encode(state_bytes);
extra_params.insert("state".to_string(), state);
let resource = canonical_resource_uri(&server_config.url);
validate_url_safe(&authorization_url)
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Unsafe authorization endpoint: {}", e)))?;
let auth_url = build_authorization_url(
&authorization_url,
&client_id,
&redirect_uri,
&scopes,
pkce.as_ref(),
&extra_params,
Some(&resource),
);
println!(" Opening browser for {} login...", server_config.name);
if let Err(e) = open::that(&auth_url) {
println!(" Could not open browser: {}", e);
println!(" Please open this URL manually:");
println!(" {}", auth_url);
}
println!(" Waiting for authorization...");
let code = wait_for_authorization_callback(listener, &server_config.name).await?;
println!(" Exchanging code for token...");
let token = exchange_code_for_token(
&token_url,
&client_id,
client_secret.as_deref(),
&code,
&redirect_uri,
pkce.as_ref(),
Some(&resource),
)
.await?;
store_tokens(secrets, user_id, server_config, &token).await?;
if server_config.oauth.is_none() {
store_client_id(secrets, user_id, server_config, &client_id).await?;
if let Some(ref client_secret) = client_secret {
store_client_secret(
secrets,
user_id,
server_config,
client_secret,
client_secret_expires_at,
)
.await?;
}
}
Ok(token)
}
pub async fn find_available_port() -> Result<(TcpListener, u16), AuthError> {
let listener = oauth_defaults::bind_callback_listener()
.await
.map_err(|_| AuthError::PortUnavailable)?;
Ok((listener, OAUTH_CALLBACK_PORT))
}
pub fn build_authorization_url(
base_url: &str,
client_id: &str,
redirect_uri: &str,
scopes: &[String],
pkce: Option<&PkceChallenge>,
extra_params: &HashMap<String, String>,
resource: Option<&str>,
) -> String {
let mut url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}",
base_url,
urlencoding::encode(client_id),
urlencoding::encode(redirect_uri)
);
if !scopes.is_empty() {
url.push_str(&format!(
"&scope={}",
urlencoding::encode(&scopes.join(" "))
));
}
if let Some(pkce) = pkce {
url.push_str(&format!(
"&code_challenge={}&code_challenge_method=S256",
urlencoding::encode(&pkce.challenge)
));
}
for (key, value) in extra_params {
url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
if let Some(resource) = resource {
url.push_str(&format!("&resource={}", urlencoding::encode(resource)));
}
url
}
pub async fn wait_for_authorization_callback(
listener: TcpListener,
server_name: &str,
) -> Result<String, AuthError> {
oauth_defaults::wait_for_callback(listener, "/callback", "code", server_name, None)
.await
.map_err(|e| match e {
oauth_defaults::OAuthCallbackError::Denied => AuthError::AuthorizationDenied,
oauth_defaults::OAuthCallbackError::Timeout => AuthError::Timeout,
oauth_defaults::OAuthCallbackError::PortInUse(_, msg) => {
AuthError::Http(format!("Port error: {}", msg))
}
oauth_defaults::OAuthCallbackError::StateMismatch { .. } => {
AuthError::Http("CSRF state mismatch in OAuth callback".to_string())
}
oauth_defaults::OAuthCallbackError::Io(msg) => AuthError::Http(msg),
})
}
pub async fn exchange_code_for_token(
token_url: &str,
client_id: &str,
client_secret: Option<&str>,
code: &str,
redirect_uri: &str,
pkce: Option<&PkceChallenge>,
resource: Option<&str>,
) -> Result<AccessToken, AuthError> {
validate_url_safe(token_url).await?;
let client = oauth_http_client()?;
let mut params = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("redirect_uri", redirect_uri.to_string()),
("client_id", client_id.to_string()),
];
if let Some(secret) = client_secret {
params.push(("client_secret", secret.to_string()));
}
if let Some(pkce) = pkce {
params.push(("code_verifier", pkce.verifier.clone()));
}
if let Some(resource) = resource {
params.push(("resource", resource.to_string()));
}
let response = client
.post(token_url)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::TokenExchangeFailed(format!(
"HTTP {} - {}",
status, body
)));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| AuthError::TokenExchangeFailed(format!("Invalid response: {}", e)))?;
Ok(AccessToken {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
refresh_token: token_response.refresh_token,
scope: token_response.scope,
})
}
pub async fn store_tokens(
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
server_config: &McpServerConfig,
token: &AccessToken,
) -> Result<(), AuthError> {
let mut params =
CreateSecretParams::new(server_config.token_secret_name(), &token.access_token)
.with_provider(format!("mcp:{}", server_config.name));
if let Some(secs) = token.expires_in {
let expires_at = chrono::Utc::now() + chrono::Duration::seconds(secs as i64);
params = params.with_expiry(expires_at);
}
secrets
.create(user_id, params)
.await
.map_err(|e| AuthError::Secrets(e.to_string()))?;
if let Some(ref refresh_token) = token.refresh_token {
let params =
CreateSecretParams::new(server_config.refresh_token_secret_name(), refresh_token)
.with_provider(format!("mcp:{}", server_config.name));
secrets
.create(user_id, params)
.await
.map_err(|e| AuthError::Secrets(e.to_string()))?;
}
Ok(())
}
pub async fn store_client_id(
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
server_config: &McpServerConfig,
client_id: &str,
) -> Result<(), AuthError> {
let params = CreateSecretParams::new(server_config.client_id_secret_name(), client_id)
.with_provider(format!("mcp:{}", server_config.name));
secrets
.create(user_id, params)
.await
.map(|_| ())
.map_err(|e| AuthError::Secrets(e.to_string()))
}
pub async fn store_client_secret(
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
server_config: &McpServerConfig,
client_secret: &str,
client_secret_expires_at: Option<u64>,
) -> Result<(), AuthError> {
let mut params =
CreateSecretParams::new(server_config.client_secret_secret_name(), client_secret)
.with_provider(format!("mcp:{}", server_config.name));
if let Some(expires_at) = client_secret_expires_at
&& let Some(dt) = chrono::DateTime::<chrono::Utc>::from_timestamp(expires_at as i64, 0)
{
params = params.with_expiry(dt);
}
secrets
.create(user_id, params)
.await
.map(|_| ())
.map_err(|e| AuthError::Secrets(e.to_string()))
}
async fn get_client_id(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<String, AuthError> {
if let Some(ref oauth) = server_config.oauth {
return Ok(oauth.client_id.clone());
}
match secrets
.get_decrypted(user_id, &server_config.client_id_secret_name())
.await
{
Ok(client_id) => Ok(client_id.expose().to_string()),
Err(crate::secrets::SecretError::NotFound(_)) => Err(AuthError::RefreshFailed(
"No client ID found. Please re-authenticate.".to_string(),
)),
Err(e) => Err(AuthError::Secrets(e.to_string())),
}
}
async fn get_client_credentials(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<ClientCredentials, AuthError> {
let client_id = get_client_id(server_config, secrets, user_id).await?;
let client_secret = match secrets
.get_decrypted(user_id, &server_config.client_secret_secret_name())
.await
{
Ok(secret) => Some(secret.expose().to_string()),
Err(crate::secrets::SecretError::NotFound(_) | crate::secrets::SecretError::Expired) => {
None
}
Err(e) => return Err(AuthError::Secrets(e.to_string())),
};
Ok(ClientCredentials {
client_id,
client_secret,
})
}
pub async fn get_access_token(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<Option<String>, AuthError> {
match secrets
.get_decrypted(user_id, &server_config.token_secret_name())
.await
{
Ok(token) => Ok(Some(token.expose().to_string())),
Err(crate::secrets::SecretError::NotFound(_)) => Ok(None),
Err(e) => Err(AuthError::Secrets(e.to_string())),
}
}
pub async fn is_authenticated(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> bool {
secrets
.exists(user_id, &server_config.token_secret_name())
.await
.unwrap_or(false)
}
pub async fn refresh_access_token(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<AccessToken, AuthError> {
let lock = refresh_lock(&server_config.name, user_id).await;
let _guard = lock.lock().await;
match secrets
.get_decrypted(user_id, &server_config.token_secret_name())
.await
{
Ok(token) => {
return Ok(AccessToken {
access_token: token.expose().to_string(),
token_type: "Bearer".to_string(),
expires_in: None,
refresh_token: None,
scope: None,
});
}
Err(crate::secrets::SecretError::Expired | crate::secrets::SecretError::NotFound(_)) => {}
Err(e) => return Err(AuthError::Secrets(e.to_string())),
}
let credentials = get_client_credentials(server_config, secrets, user_id).await?;
let refresh_token = match secrets
.get_decrypted(user_id, &server_config.refresh_token_secret_name())
.await
{
Ok(token) => token,
Err(crate::secrets::SecretError::NotFound(_) | crate::secrets::SecretError::Expired) => {
secrets
.get_decrypted(user_id, &server_config.legacy_refresh_token_secret_name())
.await
.map_err(|e| AuthError::RefreshFailed(format!("No refresh token: {}", e)))?
}
Err(e) => {
return Err(AuthError::RefreshFailed(format!(
"Failed to read refresh token: {e}"
)));
}
};
let token_url = if let Some(ref oauth) = server_config.oauth {
if let Some(ref url) = oauth.token_url {
url.clone()
} else {
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
auth_meta.token_endpoint
}
} else {
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
auth_meta.token_endpoint
};
validate_url_safe(&token_url).await?;
let token = if let Some(proxy_url) = oauth_defaults::exchange_proxy_url() {
let resource = canonical_resource_uri(&server_config.url);
let provider = format!("mcp:{}", server_config.name);
let gateway_token = oauth_defaults::oauth_proxy_auth_token().ok_or_else(|| {
AuthError::RefreshFailed(
"OAuth refresh proxy is configured but no proxy auth token is available"
.to_string(),
)
})?;
let token_response =
oauth_defaults::refresh_token_via_proxy(oauth_defaults::ProxyRefreshTokenRequest {
proxy_url: &proxy_url,
gateway_token: &gateway_token,
token_url: &token_url,
client_id: &credentials.client_id,
client_secret: credentials.client_secret.as_deref(),
refresh_token: refresh_token.expose(),
resource: Some(&resource),
provider: Some(provider.as_str()),
})
.await
.map_err(|e| AuthError::RefreshFailed(e.to_string()))?;
AccessToken {
access_token: token_response.access_token,
token_type: token_response
.token_type
.unwrap_or_else(|| "Bearer".to_string()),
expires_in: token_response.expires_in,
refresh_token: token_response.refresh_token,
scope: token_response.scope,
}
} else {
let client = oauth_http_client()?;
let resource = canonical_resource_uri(&server_config.url);
let mut params = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_token.expose().to_string()),
("client_id", credentials.client_id.clone()),
("resource", resource),
];
if let Some(client_secret) = credentials.client_secret.as_deref() {
params.push(("client_secret", client_secret.to_string()));
}
let response = client
.post(&token_url)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::RefreshFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::RefreshFailed(format!(
"HTTP {} - {}",
status, body
)));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| AuthError::RefreshFailed(format!("Invalid response: {}", e)))?;
AccessToken {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
refresh_token: token_response.refresh_token,
scope: token_response.scope,
}
};
store_tokens(secrets, user_id, server_config, &token).await?;
Ok(token)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use axum::{
Router,
extract::{Form, State},
routing::post,
};
use secrecy::SecretString;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use crate::config::helpers::lock_env;
use crate::secrets::{InMemorySecretsStore, SecretsCrypto};
use crate::testing::credentials::TEST_GATEWAY_CRYPTO_KEY;
#[derive(Clone, Debug, Default)]
struct RecordedRefreshRequest {
authorization: Option<String>,
form: HashMap<String, String>,
}
#[derive(Clone, Default)]
struct MockRefreshState {
requests: Arc<Mutex<Vec<RecordedRefreshRequest>>>,
}
impl MockRefreshState {
async fn requests(&self) -> Vec<RecordedRefreshRequest> {
self.requests.lock().await.clone()
}
}
fn test_secrets_store() -> Arc<dyn SecretsStore + Send + Sync> {
Arc::new(InMemorySecretsStore::new(Arc::new(
SecretsCrypto::new(SecretString::from(TEST_GATEWAY_CRYPTO_KEY.to_string()))
.expect("test crypto"),
)))
}
async fn start_refresh_server() -> Option<(String, MockRefreshState)> {
async fn token_handler(
State(state): State<MockRefreshState>,
headers: axum::http::HeaderMap,
Form(form): Form<HashMap<String, String>>,
) -> axum::Json<serde_json::Value> {
state.requests.lock().await.push(RecordedRefreshRequest {
authorization: headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.map(str::to_string),
form,
});
axum::Json(serde_json::json!({
"access_token": "refreshed-access-token",
"token_type": "Bearer",
"refresh_token": "rotated-refresh-token",
"expires_in": 3600
}))
}
let state = MockRefreshState::default();
let app = Router::new()
.route("/token", post(token_handler))
.route("/oauth/refresh", post(token_handler))
.with_state(state.clone());
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => listener,
Err(error) if error.kind() == std::io::ErrorKind::PermissionDenied => {
eprintln!("Skipping refresh server test: loopback bind denied by sandbox");
return None;
}
Err(error) => panic!("failed to bind refresh test server: {error}"),
};
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
Some((format!("http://127.0.0.1:{}", addr.port()), state))
}
#[test]
fn test_pkce_challenge_generation() {
let pkce = PkceChallenge::generate();
assert!(!pkce.verifier.is_empty());
assert!(!pkce.verifier.contains('+'));
assert!(!pkce.verifier.contains('/'));
assert!(!pkce.verifier.contains('='));
assert_ne!(pkce.verifier, pkce.challenge);
let pkce2 = PkceChallenge::generate();
assert_ne!(pkce.verifier, pkce2.verifier);
}
#[test]
fn test_build_authorization_url() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&["read".to_string(), "write".to_string()],
None,
&HashMap::new(),
None,
);
assert!(url.starts_with("https://auth.example.com/authorize?"));
assert!(url.contains("client_id=client-123"));
assert!(url.contains("response_type=code"));
assert!(url.contains("redirect_uri="));
assert!(url.contains("scope=read%20write"));
}
#[test]
fn test_build_authorization_url_with_pkce() {
let pkce = PkceChallenge::generate();
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
Some(&pkce),
&HashMap::new(),
None,
);
assert!(url.contains(&format!("code_challenge={}", pkce.challenge)));
assert!(url.contains("code_challenge_method=S256"));
}
#[test]
fn test_build_authorization_url_with_extra_params() {
let mut extra = HashMap::new();
extra.insert("owner".to_string(), "user".to_string());
extra.insert("state".to_string(), "abc123".to_string());
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
None,
&extra,
None,
);
assert!(url.contains("owner=user"));
assert!(url.contains("state=abc123"));
}
#[test]
fn test_pkce_challenge_s256_is_correct_sha256() {
let pkce = PkceChallenge::generate();
let mut hasher = Sha256::new();
hasher.update(pkce.verifier.as_bytes());
let expected = URL_SAFE_NO_PAD.encode(hasher.finalize());
assert_eq!(pkce.challenge, expected);
}
#[test]
fn test_build_authorization_url_empty_scopes_no_scope_param() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
None,
&HashMap::new(),
None,
);
assert!(!url.contains("scope="));
}
#[test]
fn test_build_authorization_url_special_characters_are_encoded() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client id&evil=true",
"http://localhost:9876/call back?x=1",
&[],
None,
&HashMap::new(),
None,
);
assert!(url.contains("client_id=client%20id%26evil%3Dtrue"));
assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A9876%2Fcall%20back%3Fx%3D1"));
}
#[test]
fn test_protected_resource_metadata_serde_roundtrip_full() {
let meta = ProtectedResourceMetadata {
resource: "https://mcp.example.com".to_string(),
authorization_servers: vec![
"https://auth1.example.com".to_string(),
"https://auth2.example.com".to_string(),
],
scopes_supported: vec!["read".to_string(), "write".to_string()],
};
let json = serde_json::to_string(&meta).unwrap();
let deserialized: ProtectedResourceMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.resource, meta.resource);
assert_eq!(
deserialized.authorization_servers,
meta.authorization_servers
);
assert_eq!(deserialized.scopes_supported, meta.scopes_supported);
}
#[test]
fn test_protected_resource_metadata_serde_roundtrip_minimal() {
let json = r#"{"resource": "https://mcp.example.com"}"#;
let meta: ProtectedResourceMetadata = serde_json::from_str(json).unwrap();
assert_eq!(meta.resource, "https://mcp.example.com");
assert!(meta.authorization_servers.is_empty());
assert!(meta.scopes_supported.is_empty());
}
#[test]
fn test_authorization_server_metadata_serde_roundtrip_all_fields() {
let meta = AuthorizationServerMetadata {
issuer: "https://auth.example.com".to_string(),
authorization_endpoint: "https://auth.example.com/authorize".to_string(),
token_endpoint: "https://auth.example.com/token".to_string(),
registration_endpoint: Some("https://auth.example.com/register".to_string()),
response_types_supported: vec!["code".to_string()],
grant_types_supported: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
code_challenge_methods_supported: vec!["S256".to_string()],
scopes_supported: vec!["openid".to_string(), "profile".to_string()],
};
let json = serde_json::to_string(&meta).unwrap();
let rt: AuthorizationServerMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(rt.issuer, meta.issuer);
assert_eq!(rt.authorization_endpoint, meta.authorization_endpoint);
assert_eq!(rt.token_endpoint, meta.token_endpoint);
assert_eq!(rt.registration_endpoint, meta.registration_endpoint);
assert_eq!(rt.response_types_supported, meta.response_types_supported);
assert_eq!(rt.grant_types_supported, meta.grant_types_supported);
assert_eq!(
rt.code_challenge_methods_supported,
meta.code_challenge_methods_supported
);
assert_eq!(rt.scopes_supported, meta.scopes_supported);
}
#[test]
fn test_authorization_server_metadata_serde_without_registration() {
let json = r#"{
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token"
}"#;
let meta: AuthorizationServerMetadata = serde_json::from_str(json).unwrap();
assert_eq!(meta.issuer, "https://auth.example.com");
assert!(meta.registration_endpoint.is_none());
assert!(meta.response_types_supported.is_empty());
assert!(meta.grant_types_supported.is_empty());
}
#[test]
fn test_client_registration_request_serialization() {
let req = ClientRegistrationRequest {
client_name: "IronClaw".to_string(),
redirect_uris: vec!["http://localhost:9876/callback".to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
response_types: vec!["code".to_string()],
token_endpoint_auth_method: "none".to_string(),
};
let value: serde_json::Value = serde_json::to_value(&req).unwrap();
assert_eq!(value["client_name"], "IronClaw");
assert_eq!(value["redirect_uris"][0], "http://localhost:9876/callback");
assert_eq!(value["grant_types"][0], "authorization_code");
assert_eq!(value["grant_types"][1], "refresh_token");
assert_eq!(value["response_types"][0], "code");
assert_eq!(value["token_endpoint_auth_method"], "none");
}
#[test]
fn test_client_registration_response_deserialization_full() {
let json = r#"{
"client_id": "abc-123",
"client_secret": "s3cret",
"client_secret_expires_at": 1700000000,
"registration_access_token": "reg-tok",
"registration_client_uri": "https://auth.example.com/register/abc-123"
}"#;
let resp: ClientRegistrationResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.client_id, "abc-123");
assert_eq!(resp.client_secret.as_deref(), Some("s3cret"));
assert_eq!(resp.client_secret_expires_at, Some(1700000000));
assert_eq!(resp.registration_access_token.as_deref(), Some("reg-tok"));
assert_eq!(
resp.registration_client_uri.as_deref(),
Some("https://auth.example.com/register/abc-123")
);
}
#[test]
fn test_client_registration_response_deserialization_minimal() {
let json = r#"{"client_id": "xyz-789"}"#;
let resp: ClientRegistrationResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.client_id, "xyz-789");
assert!(resp.client_secret.is_none());
assert!(resp.client_secret_expires_at.is_none());
assert!(resp.registration_access_token.is_none());
assert!(resp.registration_client_uri.is_none());
}
#[test]
fn test_access_token_construction() {
let token = AccessToken {
access_token: "at-abc".to_string(),
token_type: "Bearer".to_string(),
expires_in: Some(3600),
refresh_token: Some("rt-xyz".to_string()),
scope: Some("read write".to_string()),
};
assert_eq!(token.access_token, "at-abc");
assert_eq!(token.token_type, "Bearer");
assert_eq!(token.expires_in, Some(3600));
assert_eq!(token.refresh_token.as_deref(), Some("rt-xyz"));
assert_eq!(token.scope.as_deref(), Some("read write"));
let minimal = AccessToken {
access_token: "tok".to_string(),
token_type: "bearer".to_string(),
expires_in: None,
refresh_token: None,
scope: None,
};
assert!(minimal.expires_in.is_none());
assert!(minimal.refresh_token.is_none());
assert!(minimal.scope.is_none());
}
#[test]
fn test_token_response_to_access_token_pattern() {
let json = r#"{
"access_token": "eyJ-token",
"token_type": "Bearer",
"expires_in": 7200,
"refresh_token": "refresh-me",
"scope": "openid profile"
}"#;
let resp: serde_json::Value = serde_json::from_str(json).unwrap();
let token = AccessToken {
access_token: resp["access_token"].as_str().unwrap().to_string(),
token_type: resp["token_type"].as_str().unwrap().to_string(),
expires_in: resp["expires_in"].as_u64(),
refresh_token: resp["refresh_token"].as_str().map(String::from),
scope: resp["scope"].as_str().map(String::from),
};
assert_eq!(token.access_token, "eyJ-token");
assert_eq!(token.token_type, "Bearer");
assert_eq!(token.expires_in, Some(7200));
assert_eq!(token.refresh_token.as_deref(), Some("refresh-me"));
assert_eq!(token.scope.as_deref(), Some("openid profile"));
let minimal_json = r#"{"access_token": "tok", "token_type": "bearer"}"#;
let resp: serde_json::Value = serde_json::from_str(minimal_json).unwrap();
let token = AccessToken {
access_token: resp["access_token"].as_str().unwrap().to_string(),
token_type: resp["token_type"].as_str().unwrap().to_string(),
expires_in: resp["expires_in"].as_u64(),
refresh_token: resp["refresh_token"].as_str().map(String::from),
scope: resp["scope"].as_str().map(String::from),
};
assert!(token.expires_in.is_none());
assert!(token.refresh_token.is_none());
assert!(token.scope.is_none());
}
#[test]
fn test_auth_error_display_strings() {
let cases: Vec<(AuthError, &str)> = vec![
(
AuthError::NotSupported,
"Server does not support OAuth authorization",
),
(
AuthError::DiscoveryFailed("timeout".to_string()),
"Failed to discover authorization endpoints: timeout",
),
(
AuthError::AuthorizationDenied,
"Authorization denied by user",
),
(
AuthError::TokenExchangeFailed("bad code".to_string()),
"Token exchange failed: bad code",
),
(
AuthError::RefreshFailed("expired".to_string()),
"Token expired and refresh failed: expired",
),
(AuthError::NoToken, "No access token available"),
(
AuthError::Timeout,
"Timeout waiting for authorization callback",
),
(
AuthError::PortUnavailable,
"Could not bind to callback port",
),
(
AuthError::Http("connection refused".to_string()),
"HTTP error: connection refused",
),
(
AuthError::Secrets("decrypt failed".to_string()),
"Secrets error: decrypt failed",
),
];
for (error, expected) in cases {
let display = error.to_string();
assert_eq!(
display, expected,
"AuthError display mismatch for {:?}",
error
);
}
}
#[test]
fn test_auth_error_clone_preserves_http_variant_and_payload() {
let original = AuthError::Http("builder failed".to_string());
let cloned = original.clone();
match cloned {
AuthError::Http(message) => assert_eq!(message, "builder failed"), other => panic!("expected AuthError::Http variant, got {other:?}"),
}
}
#[test]
fn test_build_well_known_uri_no_path() {
let uri =
build_well_known_uri("https://example.com", "oauth-authorization-server").unwrap();
assert_eq!(
uri,
"https://example.com/.well-known/oauth-authorization-server"
);
}
#[test]
fn test_build_well_known_uri_with_path() {
let uri =
build_well_known_uri("https://example.com/path", "oauth-authorization-server").unwrap();
assert_eq!(
uri,
"https://example.com/.well-known/oauth-authorization-server/path"
);
}
#[test]
fn test_build_well_known_uri_with_trailing_slash() {
let uri =
build_well_known_uri("https://example.com/path/", "oauth-protected-resource").unwrap();
assert_eq!(
uri,
"https://example.com/.well-known/oauth-protected-resource/path"
);
}
#[test]
fn test_build_well_known_uri_root_trailing_slash() {
let uri =
build_well_known_uri("https://example.com/", "oauth-authorization-server").unwrap();
assert_eq!(
uri,
"https://example.com/.well-known/oauth-authorization-server"
);
}
#[test]
fn test_canonical_resource_uri_strips_fragment() {
assert_eq!(
canonical_resource_uri("https://mcp.example.com/v1#section"),
"https://mcp.example.com/v1"
);
}
#[test]
fn test_canonical_resource_uri_strips_trailing_slash() {
assert_eq!(
canonical_resource_uri("https://mcp.example.com/v1/"),
"https://mcp.example.com/v1"
);
}
#[test]
fn test_canonical_resource_uri_no_changes_needed() {
assert_eq!(
canonical_resource_uri("https://mcp.example.com/v1"),
"https://mcp.example.com/v1"
);
}
#[test]
fn test_is_dangerous_ip_loopback_v4() {
assert!(is_dangerous_ip("127.0.0.1".parse().unwrap()));
assert!(is_dangerous_ip("127.0.0.2".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_private_v4() {
assert!(is_dangerous_ip("10.0.0.1".parse().unwrap()));
assert!(is_dangerous_ip("172.16.0.1".parse().unwrap()));
assert!(is_dangerous_ip("192.168.1.1".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_link_local_v4() {
assert!(is_dangerous_ip("169.254.169.254".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_cgnat() {
assert!(is_dangerous_ip("100.64.0.1".parse().unwrap()));
assert!(is_dangerous_ip("100.127.255.254".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_safe_v4() {
assert!(!is_dangerous_ip("8.8.8.8".parse().unwrap()));
assert!(!is_dangerous_ip("1.1.1.1".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_ipv4_mapped_v6_loopback() {
let ip: IpAddr = "::ffff:127.0.0.1".parse().unwrap();
assert!(is_dangerous_ip(ip));
}
#[test]
fn test_is_dangerous_ip_ipv4_mapped_v6_link_local() {
let ip: IpAddr = "::ffff:169.254.169.254".parse().unwrap();
assert!(is_dangerous_ip(ip));
}
#[test]
fn test_is_dangerous_ip_unspecified() {
assert!(is_dangerous_ip("0.0.0.0".parse().unwrap()));
assert!(is_dangerous_ip("::".parse().unwrap()));
}
#[test]
fn test_is_dangerous_ip_v6_loopback() {
assert!(is_dangerous_ip("::1".parse().unwrap()));
}
#[tokio::test]
async fn test_validate_url_safe_https() {
assert!(validate_url_safe("https://example.com/path").await.is_ok());
}
#[tokio::test]
async fn test_validate_url_safe_http_localhost_allowed() {
assert!(validate_url_safe("http://localhost/path").await.is_ok());
assert!(
validate_url_safe("http://localhost:8080/path")
.await
.is_ok()
);
}
#[tokio::test]
async fn test_validate_url_safe_http_non_localhost_rejected() {
assert!(validate_url_safe("http://example.com/path").await.is_err());
}
#[tokio::test]
async fn test_validate_url_safe_bad_scheme() {
assert!(validate_url_safe("ftp://example.com/path").await.is_err());
assert!(validate_url_safe("file:///etc/passwd").await.is_err());
}
#[tokio::test]
async fn test_validate_url_safe_private_ip() {
assert!(validate_url_safe("http://127.0.0.1/path").await.is_ok());
assert!(validate_url_safe("https://10.0.0.1/path").await.is_err());
assert!(
validate_url_safe("https://169.254.169.254/latest/meta-data")
.await
.is_err()
);
assert!(validate_url_safe("http://10.0.0.1/path").await.is_err());
}
#[tokio::test]
async fn test_validate_url_safe_public_ip() {
assert!(validate_url_safe("https://8.8.8.8/dns").await.is_ok());
}
#[test]
fn test_parse_resource_metadata_url_bearer() {
let header = r#"Bearer resource_metadata="https://res.example.com/.well-known/oauth-protected-resource""#;
let url = parse_resource_metadata_url(header);
assert_eq!(
url.as_deref(),
Some("https://res.example.com/.well-known/oauth-protected-resource")
);
}
#[test]
fn test_parse_resource_metadata_url_with_other_params() {
let header = r#"Bearer realm="example", resource_metadata="https://res.example.com/meta""#;
let url = parse_resource_metadata_url(header);
assert_eq!(url.as_deref(), Some("https://res.example.com/meta"));
}
#[test]
fn test_parse_resource_metadata_url_missing() {
let header = r#"Bearer realm="example""#;
let url = parse_resource_metadata_url(header);
assert!(url.is_none());
}
#[test]
fn test_build_authorization_url_with_resource() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
None,
&HashMap::new(),
Some("https://mcp.example.com/v1"),
);
assert!(url.contains("resource=https%3A%2F%2Fmcp.example.com%2Fv1"));
}
#[test]
fn test_build_authorization_url_without_resource() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
None,
&HashMap::new(),
None,
);
assert!(!url.contains("resource="));
}
#[test]
fn test_authorization_url_includes_state_parameter() {
let mut extra_params = HashMap::new();
let mut state_bytes = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut state_bytes);
let state = URL_SAFE_NO_PAD.encode(state_bytes);
extra_params.insert("state".to_string(), state.clone());
let pkce = PkceChallenge::generate();
let url = build_authorization_url(
"https://app.attio.com/oidc/authorize",
"test-client",
"http://127.0.0.1:9876/callback",
&[
"mcp".to_string(),
"offline_access".to_string(),
"openid".to_string(),
],
Some(&pkce),
&extra_params,
Some("https://mcp.attio.com/mcp"),
);
assert!(
url.contains(&format!("state={}", state)),
"Authorization URL must include the state parameter, got: {}",
url,
);
assert!(!state.contains('+'), "State must be base64url-safe");
assert!(!state.contains('/'), "State must be base64url-safe");
assert!(!state.contains('='), "State must not have padding");
assert!(
state.len() >= 22,
"State must have at least 128 bits of entropy, got {} chars",
state.len(),
);
let mut state_bytes_2 = [0u8; 16];
rand::rngs::OsRng.fill_bytes(&mut state_bytes_2);
let state_2 = URL_SAFE_NO_PAD.encode(state_bytes_2);
assert_ne!(state, state_2, "State must be unique per request");
}
#[tokio::test]
async fn test_refresh_access_token_direct_includes_stored_client_secret() {
let secrets = test_secrets_store();
let user_id = "test-user";
let Some((base_url, state)) = start_refresh_server().await else {
return;
};
let server = McpServerConfig::new("notion", "https://mcp.notion.com/mcp").with_oauth(
crate::tools::mcp::config::OAuthConfig::new("configured-client")
.with_endpoints("http://127.0.0.1/authorize", format!("{base_url}/token")),
);
secrets
.create(
user_id,
CreateSecretParams::new(server.refresh_token_secret_name(), "refresh-token-123"),
)
.await
.unwrap();
store_client_secret(&secrets, user_id, &server, "stored-client-secret", None)
.await
.unwrap();
let token = refresh_access_token(&server, &secrets, user_id)
.await
.expect("refresh succeeds");
assert_eq!(token.access_token, "refreshed-access-token");
assert_eq!(
token.refresh_token.as_deref(),
Some("rotated-refresh-token")
);
let requests = state.requests().await;
assert_eq!(requests.len(), 1);
assert_eq!(
requests[0].form.get("client_id").map(String::as_str),
Some("configured-client")
);
assert_eq!(
requests[0].form.get("client_secret").map(String::as_str),
Some("stored-client-secret")
);
assert_eq!(
requests[0].form.get("refresh_token").map(String::as_str),
Some("refresh-token-123")
);
assert_eq!(
requests[0].form.get("resource").map(String::as_str),
Some("https://mcp.notion.com/mcp")
);
let stored_refresh = secrets
.get_decrypted(user_id, &server.refresh_token_secret_name())
.await
.unwrap();
assert_eq!(stored_refresh.expose(), "rotated-refresh-token");
}
#[allow(clippy::await_holding_lock)]
#[tokio::test]
async fn test_refresh_access_token_uses_proxy_when_configured() {
let _env_guard = lock_env();
let Some((base_url, state)) = start_refresh_server().await else {
return;
};
let _proxy_url_guard = set_env_var("IRONCLAW_OAUTH_EXCHANGE_URL", Some(&base_url));
let _proxy_token_guard = set_env_var(
"IRONCLAW_OAUTH_PROXY_AUTH_TOKEN",
Some("gateway-test-token"),
);
let expected_token_url = format!("{base_url}/token");
let secrets = test_secrets_store();
let user_id = "test-user";
let server = McpServerConfig::new("notion", "https://mcp.notion.com/mcp").with_oauth(
crate::tools::mcp::config::OAuthConfig::new("configured-client")
.with_endpoints("http://127.0.0.1/authorize", expected_token_url.clone()),
);
secrets
.create(
user_id,
CreateSecretParams::new(server.refresh_token_secret_name(), "refresh-token-123"),
)
.await
.unwrap();
store_client_secret(&secrets, user_id, &server, "stored-client-secret", None)
.await
.unwrap();
refresh_access_token(&server, &secrets, user_id)
.await
.expect("proxy refresh succeeds");
let requests = state.requests().await;
assert_eq!(requests.len(), 1);
assert_eq!(
requests[0].authorization.as_deref(),
Some("Bearer gateway-test-token")
);
assert_eq!(
requests[0].form.get("token_url").map(String::as_str),
Some(expected_token_url.as_str())
);
assert_eq!(
requests[0].form.get("provider").map(String::as_str),
Some("mcp:notion")
);
assert_eq!(
requests[0].form.get("client_secret").map(String::as_str),
Some("stored-client-secret")
);
assert_eq!(
requests[0].form.get("resource").map(String::as_str),
Some("https://mcp.notion.com/mcp")
);
}
#[allow(clippy::await_holding_lock)]
#[tokio::test]
async fn test_refresh_access_token_serializes_concurrent_refreshes() {
let _env_guard = lock_env();
let _proxy_url_guard = set_env_var("IRONCLAW_OAUTH_EXCHANGE_URL", None);
let _proxy_token_guard = set_env_var("IRONCLAW_OAUTH_PROXY_AUTH_TOKEN", None);
let secrets = test_secrets_store();
let user_id = "test-user";
let Some((base_url, state)) = start_refresh_server().await else {
return;
};
let server = McpServerConfig::new("notion", "https://mcp.notion.com/mcp").with_oauth(
crate::tools::mcp::config::OAuthConfig::new("configured-client")
.with_endpoints("http://127.0.0.1/authorize", format!("{base_url}/token")),
);
secrets
.create(
user_id,
CreateSecretParams::new(server.refresh_token_secret_name(), "refresh-token-123"),
)
.await
.unwrap();
let (first, second) = tokio::join!(
refresh_access_token(&server, &secrets, user_id),
refresh_access_token(&server, &secrets, user_id),
);
assert!(first.is_ok(), "first refresh should succeed: {first:?}");
assert!(second.is_ok(), "second refresh should succeed: {second:?}");
let requests = state.requests().await;
assert_eq!(
requests.len(),
1,
"only one outbound refresh should run for concurrent callers"
);
}
#[tokio::test]
async fn test_refresh_lock_reuses_same_key() {
let first = refresh_lock("notion", "user-a").await;
let second = refresh_lock("notion", "user-a").await;
let other_user = refresh_lock("notion", "user-b").await;
assert!(Arc::ptr_eq(&first, &second));
assert!(!Arc::ptr_eq(&first, &other_user));
}
#[tokio::test]
async fn test_refresh_lock_recreates_dropped_entry() {
let first = refresh_lock("notion-recreate", "user-recreate").await;
let first_weak = Arc::downgrade(&first);
drop(first);
assert!(first_weak.upgrade().is_none());
let second = refresh_lock("notion-recreate", "user-recreate").await;
let third = refresh_lock("notion-recreate", "user-recreate").await;
assert!(Arc::ptr_eq(&second, &third));
}
struct EnvVarGuard {
key: &'static str,
original: Option<String>,
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
if let Some(ref value) = self.original {
std::env::set_var(self.key, value);
} else {
std::env::remove_var(self.key);
}
}
}
}
fn set_env_var(key: &'static str, value: Option<&str>) -> EnvVarGuard {
let original = std::env::var(key).ok();
unsafe {
if let Some(value) = value {
std::env::set_var(key, value);
} else {
std::env::remove_var(key);
}
}
EnvVarGuard { key, original }
}
}