use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::net::TcpListener;
use crate::cli::oauth_defaults::{self, OAUTH_CALLBACK_PORT};
use crate::secrets::{CreateSecretParams, SecretsStore};
use crate::tools::mcp::config::McpServerConfig;
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("Server does not support OAuth authorization")]
NotSupported,
#[error("Failed to discover authorization endpoints: {0}")]
DiscoveryFailed(String),
#[error("Authorization denied by user")]
AuthorizationDenied,
#[error("Token exchange failed: {0}")]
TokenExchangeFailed(String),
#[error("Token expired and refresh failed: {0}")]
RefreshFailed(String),
#[error("No access token available")]
NoToken,
#[error("Timeout waiting for authorization callback")]
Timeout,
#[error("Could not bind to callback port")]
PortUnavailable,
#[error("HTTP error: {0}")]
Http(String),
#[error("Secrets error: {0}")]
Secrets(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtectedResourceMetadata {
pub resource: String,
#[serde(default)]
pub authorization_servers: Vec<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationServerMetadata {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
#[serde(default)]
pub registration_endpoint: Option<String>,
#[serde(default)]
pub response_types_supported: Vec<String>,
#[serde(default)]
pub grant_types_supported: Vec<String>,
#[serde(default)]
pub code_challenge_methods_supported: Vec<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ClientRegistrationRequest {
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub response_types: Vec<String>,
pub token_endpoint_auth_method: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(default)]
pub client_secret_expires_at: Option<u64>,
#[serde(default)]
pub registration_access_token: Option<String>,
#[serde(default)]
pub registration_client_uri: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AccessToken {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
scope: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PkceChallenge {
pub verifier: String,
pub challenge: String,
}
impl PkceChallenge {
pub fn generate() -> Self {
let mut verifier_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut verifier_bytes);
let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
Self {
verifier,
challenge,
}
}
}
pub async fn discover_protected_resource(
server_url: &str,
) -> Result<ProtectedResourceMetadata, AuthError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| AuthError::Http(e.to_string()))?;
let parsed = reqwest::Url::parse(server_url)
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid server URL: {}", e)))?;
let origin = parsed.origin().ascii_serialization();
let well_known_url = format!("{}/.well-known/oauth-protected-resource", origin);
let response = client
.get(&well_known_url)
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(AuthError::NotSupported);
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid metadata: {}", e)))
}
pub async fn discover_authorization_server(
auth_server_url: &str,
) -> Result<AuthorizationServerMetadata, AuthError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| AuthError::Http(e.to_string()))?;
let base_url = auth_server_url.trim_end_matches('/');
let well_known_url = format!("{}/.well-known/oauth-authorization-server", base_url);
let response = client
.get(&well_known_url)
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(AuthError::DiscoveryFailed(format!(
"HTTP {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid metadata: {}", e)))
}
pub async fn discover_oauth_endpoints(
server_config: &McpServerConfig,
) -> Result<(String, String), AuthError> {
let oauth = server_config
.oauth
.as_ref()
.ok_or(AuthError::NotSupported)?;
if let (Some(auth_url), Some(token_url)) = (&oauth.authorization_url, &oauth.token_url) {
return Ok((auth_url.clone(), token_url.clone()));
}
let resource_meta = discover_protected_resource(&server_config.url).await?;
let auth_server_url = resource_meta
.authorization_servers
.first()
.ok_or_else(|| AuthError::DiscoveryFailed("No authorization servers listed".to_string()))?;
let auth_meta = discover_authorization_server(auth_server_url).await?;
Ok((auth_meta.authorization_endpoint, auth_meta.token_endpoint))
}
pub async fn discover_full_oauth_metadata(
server_url: &str,
) -> Result<AuthorizationServerMetadata, AuthError> {
let resource_meta = discover_protected_resource(server_url).await?;
let auth_server_url = resource_meta
.authorization_servers
.first()
.ok_or_else(|| AuthError::DiscoveryFailed("No authorization servers listed".to_string()))?;
discover_authorization_server(auth_server_url).await
}
pub async fn register_client(
registration_endpoint: &str,
redirect_uri: &str,
) -> Result<ClientRegistrationResponse, AuthError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| AuthError::Http(e.to_string()))?;
let request = ClientRegistrationRequest {
client_name: "IronClaw".to_string(),
redirect_uris: vec![redirect_uri.to_string()],
grant_types: vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
],
response_types: vec!["code".to_string()],
token_endpoint_auth_method: "none".to_string(), };
let response = client
.post(registration_endpoint)
.json(&request)
.send()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("DCR request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::DiscoveryFailed(format!(
"DCR failed: HTTP {} - {}",
status, body
)));
}
response
.json()
.await
.map_err(|e| AuthError::DiscoveryFailed(format!("Invalid DCR response: {}", e)))
}
pub async fn authorize_mcp_server(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<AccessToken, AuthError> {
let (listener, port) = find_available_port().await?;
let redirect_uri = format!("http://localhost:{}/callback", port);
let (client_id, authorization_url, token_url, use_pkce, scopes, extra_params) =
if let Some(oauth) = &server_config.oauth {
let (auth_url, tok_url) = discover_oauth_endpoints(server_config).await?;
(
oauth.client_id.clone(),
auth_url,
tok_url,
oauth.use_pkce,
oauth.scopes.clone(),
oauth.extra_params.clone(),
)
} else {
println!(" Discovering OAuth endpoints...");
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
let registration_endpoint = auth_meta
.registration_endpoint
.ok_or(AuthError::NotSupported)?;
println!(" Registering client dynamically...");
let registration = register_client(®istration_endpoint, &redirect_uri).await?;
println!(" ✓ Client registered: {}", registration.client_id);
(
registration.client_id,
auth_meta.authorization_endpoint,
auth_meta.token_endpoint,
true, auth_meta.scopes_supported,
HashMap::new(),
)
};
let pkce = if use_pkce {
Some(PkceChallenge::generate())
} else {
None
};
let auth_url = build_authorization_url(
&authorization_url,
&client_id,
&redirect_uri,
&scopes,
pkce.as_ref(),
&extra_params,
);
println!(" Opening browser for {} login...", server_config.name);
if let Err(e) = open::that(&auth_url) {
println!(" Could not open browser: {}", e);
println!(" Please open this URL manually:");
println!(" {}", auth_url);
}
println!(" Waiting for authorization...");
let code = wait_for_authorization_callback(listener, &server_config.name).await?;
println!(" Exchanging code for token...");
let token =
exchange_code_for_token(&token_url, &client_id, &code, &redirect_uri, pkce.as_ref())
.await?;
store_tokens(secrets, user_id, server_config, &token).await?;
if server_config.oauth.is_none() {
store_client_id(secrets, user_id, server_config, &client_id).await?;
}
Ok(token)
}
pub async fn find_available_port() -> Result<(TcpListener, u16), AuthError> {
let listener = oauth_defaults::bind_callback_listener()
.await
.map_err(|_| AuthError::PortUnavailable)?;
Ok((listener, OAUTH_CALLBACK_PORT))
}
pub fn build_authorization_url(
base_url: &str,
client_id: &str,
redirect_uri: &str,
scopes: &[String],
pkce: Option<&PkceChallenge>,
extra_params: &HashMap<String, String>,
) -> String {
let mut url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}",
base_url,
urlencoding::encode(client_id),
urlencoding::encode(redirect_uri)
);
if !scopes.is_empty() {
url.push_str(&format!(
"&scope={}",
urlencoding::encode(&scopes.join(" "))
));
}
if let Some(pkce) = pkce {
url.push_str(&format!(
"&code_challenge={}&code_challenge_method=S256",
pkce.challenge
));
}
for (key, value) in extra_params {
url.push_str(&format!(
"&{}={}",
urlencoding::encode(key),
urlencoding::encode(value)
));
}
url
}
pub async fn wait_for_authorization_callback(
listener: TcpListener,
server_name: &str,
) -> Result<String, AuthError> {
oauth_defaults::wait_for_callback(listener, "/callback", "code", server_name)
.await
.map_err(|e| match e {
oauth_defaults::OAuthCallbackError::Denied => AuthError::AuthorizationDenied,
oauth_defaults::OAuthCallbackError::Timeout => AuthError::Timeout,
oauth_defaults::OAuthCallbackError::PortInUse(_, msg) => {
AuthError::Http(format!("Port error: {}", msg))
}
oauth_defaults::OAuthCallbackError::Io(msg) => AuthError::Http(msg),
})
}
pub async fn exchange_code_for_token(
token_url: &str,
client_id: &str,
code: &str,
redirect_uri: &str,
pkce: Option<&PkceChallenge>,
) -> Result<AccessToken, AuthError> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| AuthError::Http(e.to_string()))?;
let mut params = vec![
("grant_type", "authorization_code".to_string()),
("code", code.to_string()),
("redirect_uri", redirect_uri.to_string()),
("client_id", client_id.to_string()),
];
if let Some(pkce) = pkce {
params.push(("code_verifier", pkce.verifier.clone()));
}
let response = client
.post(token_url)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::TokenExchangeFailed(format!(
"HTTP {} - {}",
status, body
)));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| AuthError::TokenExchangeFailed(format!("Invalid response: {}", e)))?;
Ok(AccessToken {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
refresh_token: token_response.refresh_token,
scope: token_response.scope,
})
}
pub async fn store_tokens(
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
server_config: &McpServerConfig,
token: &AccessToken,
) -> Result<(), AuthError> {
let params = CreateSecretParams::new(server_config.token_secret_name(), &token.access_token)
.with_provider(format!("mcp:{}", server_config.name));
secrets
.create(user_id, params)
.await
.map_err(|e| AuthError::Secrets(e.to_string()))?;
if let Some(ref refresh_token) = token.refresh_token {
let params =
CreateSecretParams::new(server_config.refresh_token_secret_name(), refresh_token)
.with_provider(format!("mcp:{}", server_config.name));
secrets
.create(user_id, params)
.await
.map_err(|e| AuthError::Secrets(e.to_string()))?;
}
Ok(())
}
pub async fn store_client_id(
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
server_config: &McpServerConfig,
client_id: &str,
) -> Result<(), AuthError> {
let params = CreateSecretParams::new(server_config.client_id_secret_name(), client_id)
.with_provider(format!("mcp:{}", server_config.name));
secrets
.create(user_id, params)
.await
.map(|_| ())
.map_err(|e| AuthError::Secrets(e.to_string()))
}
async fn get_client_id(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<String, AuthError> {
if let Some(ref oauth) = server_config.oauth {
return Ok(oauth.client_id.clone());
}
match secrets
.get_decrypted(user_id, &server_config.client_id_secret_name())
.await
{
Ok(client_id) => Ok(client_id.expose().to_string()),
Err(crate::secrets::SecretError::NotFound(_)) => Err(AuthError::RefreshFailed(
"No client ID found. Please re-authenticate.".to_string(),
)),
Err(e) => Err(AuthError::Secrets(e.to_string())),
}
}
pub async fn get_access_token(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<Option<String>, AuthError> {
match secrets
.get_decrypted(user_id, &server_config.token_secret_name())
.await
{
Ok(token) => Ok(Some(token.expose().to_string())),
Err(crate::secrets::SecretError::NotFound(_)) => Ok(None),
Err(e) => Err(AuthError::Secrets(e.to_string())),
}
}
pub async fn is_authenticated(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> bool {
secrets
.exists(user_id, &server_config.token_secret_name())
.await
.unwrap_or(false)
}
pub async fn refresh_access_token(
server_config: &McpServerConfig,
secrets: &Arc<dyn SecretsStore + Send + Sync>,
user_id: &str,
) -> Result<AccessToken, AuthError> {
let client_id = get_client_id(server_config, secrets, user_id).await?;
let refresh_token = secrets
.get_decrypted(user_id, &server_config.refresh_token_secret_name())
.await
.map_err(|e| AuthError::RefreshFailed(format!("No refresh token: {}", e)))?;
let token_url = if let Some(ref oauth) = server_config.oauth {
if let Some(ref url) = oauth.token_url {
url.clone()
} else {
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
auth_meta.token_endpoint
}
} else {
let auth_meta = discover_full_oauth_metadata(&server_config.url).await?;
auth_meta.token_endpoint
};
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| AuthError::Http(e.to_string()))?;
let params = vec![
("grant_type", "refresh_token".to_string()),
("refresh_token", refresh_token.expose().to_string()),
("client_id", client_id),
];
let response = client
.post(&token_url)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::RefreshFailed(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::RefreshFailed(format!(
"HTTP {} - {}",
status, body
)));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| AuthError::RefreshFailed(format!("Invalid response: {}", e)))?;
let token = AccessToken {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
refresh_token: token_response.refresh_token,
scope: token_response.scope,
};
store_tokens(secrets, user_id, server_config, &token).await?;
Ok(token)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pkce_challenge_generation() {
let pkce = PkceChallenge::generate();
assert!(!pkce.verifier.is_empty());
assert!(!pkce.verifier.contains('+'));
assert!(!pkce.verifier.contains('/'));
assert!(!pkce.verifier.contains('='));
assert_ne!(pkce.verifier, pkce.challenge);
let pkce2 = PkceChallenge::generate();
assert_ne!(pkce.verifier, pkce2.verifier);
}
#[test]
fn test_build_authorization_url() {
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&["read".to_string(), "write".to_string()],
None,
&HashMap::new(),
);
assert!(url.starts_with("https://auth.example.com/authorize?"));
assert!(url.contains("client_id=client-123"));
assert!(url.contains("response_type=code"));
assert!(url.contains("redirect_uri="));
assert!(url.contains("scope=read%20write"));
}
#[test]
fn test_build_authorization_url_with_pkce() {
let pkce = PkceChallenge::generate();
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
Some(&pkce),
&HashMap::new(),
);
assert!(url.contains(&format!("code_challenge={}", pkce.challenge)));
assert!(url.contains("code_challenge_method=S256"));
}
#[test]
fn test_build_authorization_url_with_extra_params() {
let mut extra = HashMap::new();
extra.insert("owner".to_string(), "user".to_string());
extra.insert("state".to_string(), "abc123".to_string());
let url = build_authorization_url(
"https://auth.example.com/authorize",
"client-123",
"http://localhost:9876/callback",
&[],
None,
&extra,
);
assert!(url.contains("owner=user"));
assert!(url.contains("state=abc123"));
}
}