use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use super::config::AuthConfig;
use super::context::AuthContext as UnifiedAuthContext; use super::types::{AuthCredentials, AuthProvider};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
#[derive(Debug)]
pub struct AuthManager {
config: AuthConfig,
providers: Arc<RwLock<HashMap<String, Arc<dyn AuthProvider>>>>,
}
impl AuthManager {
#[must_use]
pub fn new(config: AuthConfig) -> Self {
Self {
config,
providers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
let name = provider.name().to_string();
self.providers.write().await.insert(name, provider);
}
pub async fn remove_provider(&self, name: &str) -> bool {
self.providers.write().await.remove(name).is_some()
}
pub async fn list_providers(&self) -> Vec<String> {
self.providers.read().await.keys().cloned().collect()
}
pub async fn authenticate(
&self,
provider_name: &str,
credentials: AuthCredentials,
) -> McpResult<UnifiedAuthContext> {
if !self.config.enabled {
return Err(McpError::internal("Authentication is disabled".to_string()));
}
let providers = self.providers.read().await;
let provider = providers
.get(provider_name)
.ok_or_else(|| McpError::internal(format!("Provider '{provider_name}' not found")))?;
let result = provider.authenticate(credentials).await;
crate::auth_metrics::record_auth_attempt(provider_name, result.is_ok());
let mut auth_context = result?;
if auth_context.roles.is_empty() {
auth_context.roles = self.config.authorization.default_roles.clone();
}
Ok(auth_context)
}
pub async fn validate_token(
&self,
token: &str,
provider_name: Option<&str>,
) -> McpResult<UnifiedAuthContext> {
if !self.config.enabled {
return Err(McpError::internal("Authentication is disabled".to_string()));
}
let providers = self.providers.read().await;
if let Some(provider_name) = provider_name {
let provider = providers.get(provider_name).ok_or_else(|| {
McpError::internal(format!("Provider '{provider_name}' not found"))
})?;
let start = std::time::Instant::now();
let result = provider.validate_token(token).await;
let duration = start.elapsed().as_secs_f64();
crate::auth_metrics::record_token_validation(provider_name, result.is_ok(), false);
crate::auth_metrics::record_token_validation_duration(duration);
result
} else {
let start = std::time::Instant::now();
for provider in providers.values() {
if let Ok(auth_context) = provider.validate_token(token).await {
let duration = start.elapsed().as_secs_f64();
let provider_name = provider.name();
crate::auth_metrics::record_token_validation(provider_name, true, false);
crate::auth_metrics::record_token_validation_duration(duration);
return Ok(auth_context);
}
}
let duration = start.elapsed().as_secs_f64();
crate::auth_metrics::record_token_validation("unknown", false, false);
crate::auth_metrics::record_token_validation_duration(duration);
Err(McpError::internal("Token validation failed".to_string()))
}
}
#[must_use]
pub fn check_permission(&self, context: &UnifiedAuthContext, permission: &str) -> bool {
context.permissions.contains(&permission.to_string())
|| context.roles.iter().any(|role| {
self.config
.authorization
.inheritance_rules
.get(role)
.is_some_and(|perms| perms.contains(&permission.to_string()))
})
}
#[must_use]
pub fn check_role(&self, context: &UnifiedAuthContext, role: &str) -> bool {
context.roles.contains(&role.to_string())
}
}
static GLOBAL_AUTH_MANAGER: std::sync::LazyLock<tokio::sync::RwLock<Option<Arc<AuthManager>>>> =
std::sync::LazyLock::new(|| tokio::sync::RwLock::new(None));
pub async fn set_global_auth_manager(manager: Arc<AuthManager>) {
*GLOBAL_AUTH_MANAGER.write().await = Some(manager);
}
pub async fn global_auth_manager() -> Option<Arc<AuthManager>> {
GLOBAL_AUTH_MANAGER.read().await.clone()
}
pub async fn check_auth(token: &str) -> McpResult<UnifiedAuthContext> {
if let Some(manager) = global_auth_manager().await {
manager.validate_token(token, None).await
} else {
Err(McpError::internal(
"Authentication manager not initialized".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
config::{AuthorizationConfig, OAuth2Config, OAuth2FlowType, SecurityLevel},
providers::ApiKeyProvider,
types::UserInfo,
};
use std::collections::HashMap;
#[test]
fn test_oauth2_config() {
let config = OAuth2Config {
client_id: "test_client".to_string(),
client_secret: "test_secret".to_string().into(),
auth_url: "https://auth.example.com/oauth/authorize".to_string(),
token_url: "https://auth.example.com/oauth/token".to_string(),
revocation_url: None,
redirect_uri: "http://localhost:8080/callback".to_string(),
scopes: vec!["read".to_string(), "write".to_string()],
flow_type: OAuth2FlowType::AuthorizationCode,
additional_params: HashMap::new(),
security_level: SecurityLevel::Standard,
mcp_resource_uri: None,
auto_resource_indicators: false,
#[cfg(feature = "dpop")]
dpop_config: None,
};
assert_eq!(config.client_id, "test_client");
assert_eq!(config.flow_type, OAuth2FlowType::AuthorizationCode);
}
#[test]
fn test_oauth2_pkce_integration() {
let (challenge1, _verifier1) = oauth2::PkceCodeChallenge::new_random_sha256();
let (challenge2, _verifier2) = oauth2::PkceCodeChallenge::new_random_sha256();
assert_ne!(challenge1.as_str(), challenge2.as_str());
assert!(!challenge1.as_str().is_empty());
assert!(!challenge2.as_str().is_empty());
}
#[tokio::test]
async fn test_api_key_provider() {
let provider = ApiKeyProvider::new("test_api".to_string());
let user_info = UserInfo {
id: "user123".to_string(),
username: "testuser".to_string(),
email: Some("test@example.com".to_string()),
display_name: Some("Test User".to_string()),
avatar_url: None,
metadata: HashMap::new(),
};
let test_key = "test_key_abcdefghijklmnopqrstuvwxyz12";
provider
.add_api_key(test_key.to_string(), user_info.clone())
.await;
let credentials = AuthCredentials::ApiKey {
key: test_key.to_string(),
};
let auth_result = provider.authenticate(credentials).await;
assert!(auth_result.is_ok());
let context = auth_result.unwrap();
assert_eq!(context.user.username, "testuser");
assert_eq!(context.provider, "test_api");
}
#[tokio::test]
async fn test_auth_manager() {
let config = AuthConfig {
enabled: true,
providers: vec![],
authorization: AuthorizationConfig {
rbac_enabled: true,
default_roles: vec!["user".to_string()],
inheritance_rules: HashMap::new(),
resource_permissions: HashMap::new(),
},
};
let manager = AuthManager::new(config);
let api_provider = Arc::new(ApiKeyProvider::new("api".to_string()));
manager.add_provider(api_provider.clone()).await;
let providers = manager.list_providers().await;
assert!(providers.contains(&"api".to_string()));
}
}