use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone, Default)]
pub enum McpAuth {
#[default]
None,
Bearer(String),
ApiKey { header: String, key: String },
OAuth2(Arc<OAuth2Config>),
}
impl std::fmt::Debug for McpAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpAuth::None => write!(f, "McpAuth::None"),
McpAuth::Bearer(_) => write!(f, "McpAuth::Bearer([REDACTED])"),
McpAuth::ApiKey { header, .. } => write!(f, "McpAuth::ApiKey {{ header: {} }}", header),
McpAuth::OAuth2(_) => write!(f, "McpAuth::OAuth2([CONFIG])"),
}
}
}
impl McpAuth {
pub fn bearer(token: impl Into<String>) -> Self {
McpAuth::Bearer(token.into())
}
pub fn api_key(header: impl Into<String>, key: impl Into<String>) -> Self {
McpAuth::ApiKey { header: header.into(), key: key.into() }
}
pub fn oauth2(config: OAuth2Config) -> Self {
McpAuth::OAuth2(Arc::new(config))
}
pub async fn get_headers(&self) -> Result<HashMap<String, String>, AuthError> {
let mut headers = HashMap::new();
match self {
McpAuth::None => {}
McpAuth::Bearer(token) => {
headers.insert("Authorization".to_string(), format!("Bearer {}", token));
}
McpAuth::ApiKey { header, key } => {
headers.insert(header.clone(), key.clone());
}
McpAuth::OAuth2(config) => {
let token = config.get_or_refresh_token().await?;
headers.insert("Authorization".to_string(), format!("Bearer {}", token));
}
}
Ok(headers)
}
pub fn is_configured(&self) -> bool {
!matches!(self, McpAuth::None)
}
}
pub struct OAuth2Config {
pub client_id: String,
pub client_secret: Option<String>,
pub token_url: String,
pub scopes: Vec<String>,
token_cache: RwLock<Option<CachedToken>>,
}
impl OAuth2Config {
pub fn new(client_id: impl Into<String>, token_url: impl Into<String>) -> Self {
Self {
client_id: client_id.into(),
client_secret: None,
token_url: token_url.into(),
scopes: Vec::new(),
token_cache: RwLock::new(None),
}
}
pub fn with_secret(mut self, secret: impl Into<String>) -> Self {
self.client_secret = Some(secret.into());
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub async fn get_or_refresh_token(&self) -> Result<String, AuthError> {
{
let cache = self.token_cache.read().await;
if let Some(ref cached) = *cache {
if !cached.is_expired() {
return Ok(cached.access_token.clone());
}
}
}
let token = self.fetch_token().await?;
{
let mut cache = self.token_cache.write().await;
*cache = Some(token.clone());
}
Ok(token.access_token)
}
async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
let mut params = vec![
("grant_type", "client_credentials".to_string()),
("client_id", self.client_id.clone()),
];
if let Some(ref secret) = self.client_secret {
params.push(("client_secret", secret.clone()));
}
if !self.scopes.is_empty() {
params.push(("scope", self.scopes.join(" ")));
}
#[cfg(feature = "http-transport")]
{
let client = reqwest::Client::new();
let response = client
.post(&self.token_url)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::TokenFetch(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AuthError::TokenFetch(format!(
"Token request failed: {} - {}",
status, body
)));
}
let token_response: TokenResponse =
response.json().await.map_err(|e| AuthError::TokenParse(e.to_string()))?;
Ok(CachedToken::from_response(token_response))
}
#[cfg(not(feature = "http-transport"))]
{
Err(AuthError::NotSupported("OAuth2 requires the 'http-transport' feature".to_string()))
}
}
pub async fn clear_cache(&self) {
let mut cache = self.token_cache.write().await;
*cache = None;
}
}
#[derive(Clone)]
#[allow(dead_code)] struct CachedToken {
access_token: String,
expires_at: Option<std::time::Instant>,
refresh_token: Option<String>,
}
#[allow(dead_code)] impl CachedToken {
fn from_response(response: TokenResponse) -> Self {
let expires_at = response.expires_in.map(|secs| {
std::time::Instant::now() + std::time::Duration::from_secs(secs.saturating_sub(60))
});
Self {
access_token: response.access_token,
expires_at,
refresh_token: response.refresh_token,
}
}
fn is_expired(&self) -> bool {
match self.expires_at {
Some(expires_at) => std::time::Instant::now() >= expires_at,
None => false, }
}
}
#[derive(serde::Deserialize)]
#[allow(dead_code)] struct TokenResponse {
access_token: String,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
token_type: Option<String>,
}
#[derive(Debug, Clone)]
pub enum AuthError {
TokenFetch(String),
TokenParse(String),
TokenExpired(String),
NotSupported(String),
}
impl std::fmt::Display for AuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthError::TokenFetch(msg) => write!(f, "Token fetch failed: {}", msg),
AuthError::TokenParse(msg) => write!(f, "Token parse failed: {}", msg),
AuthError::TokenExpired(msg) => write!(f, "Token expired: {}", msg),
AuthError::NotSupported(msg) => write!(f, "Not supported: {}", msg),
}
}
}
impl std::error::Error for AuthError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mcp_auth_none() {
let auth = McpAuth::None;
assert!(!auth.is_configured());
}
#[test]
fn test_mcp_auth_bearer() {
let auth = McpAuth::bearer("test-token");
assert!(auth.is_configured());
}
#[test]
fn test_mcp_auth_api_key() {
let auth = McpAuth::api_key("X-API-Key", "secret-key");
assert!(auth.is_configured());
}
#[tokio::test]
async fn test_bearer_headers() {
let auth = McpAuth::bearer("my-token");
let headers = auth.get_headers().await.unwrap();
assert_eq!(headers.get("Authorization"), Some(&"Bearer my-token".to_string()));
}
#[tokio::test]
async fn test_api_key_headers() {
let auth = McpAuth::api_key("X-API-Key", "secret");
let headers = auth.get_headers().await.unwrap();
assert_eq!(headers.get("X-API-Key"), Some(&"secret".to_string()));
}
#[test]
fn test_oauth2_config() {
let config = OAuth2Config::new("client-id", "https://auth.example.com/token")
.with_secret("client-secret")
.with_scopes(vec!["read".to_string(), "write".to_string()]);
assert_eq!(config.client_id, "client-id");
assert_eq!(config.client_secret, Some("client-secret".to_string()));
assert_eq!(config.scopes, vec!["read", "write"]);
}
}