use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "dpop")]
use std::time::Duration;
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use turbomcp_protocol::{Error as McpError, Result as McpResult};
#[cfg(feature = "dpop")]
use super::dpop::DpopAlgorithm;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub enabled: bool,
pub providers: Vec<AuthProviderConfig>,
pub authorization: AuthorizationConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthProviderConfig {
pub name: String,
pub provider_type: AuthProviderType,
pub settings: HashMap<String, serde_json::Value>,
pub enabled: bool,
pub priority: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum AuthProviderType {
OAuth2,
ApiKey,
Jwt,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum SecurityLevel {
#[default]
Standard,
Enhanced,
Maximum,
}
#[cfg(feature = "dpop")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DpopConfig {
pub key_algorithm: DpopAlgorithm,
#[serde(default = "default_proof_lifetime")]
pub proof_lifetime: Duration,
#[serde(default = "default_clock_skew")]
pub clock_skew_tolerance: Duration,
#[serde(default)]
pub key_storage: DpopKeyStorageConfig,
}
#[cfg(feature = "dpop")]
fn default_proof_lifetime() -> Duration {
Duration::from_secs(60)
}
#[cfg(feature = "dpop")]
fn default_clock_skew() -> Duration {
Duration::from_secs(300)
}
#[cfg(feature = "dpop")]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum DpopKeyStorageConfig {
#[default]
Memory,
Redis {
url: String,
},
Hsm {
config: serde_json::Value,
},
}
#[cfg(feature = "dpop")]
impl Default for DpopConfig {
fn default() -> Self {
Self {
key_algorithm: DpopAlgorithm::ES256,
proof_lifetime: default_proof_lifetime(),
clock_skew_tolerance: default_clock_skew(),
key_storage: DpopKeyStorageConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationConfig {
pub rbac_enabled: bool,
pub default_roles: Vec<String>,
pub inheritance_rules: HashMap<String, Vec<String>>,
pub resource_permissions: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Config {
pub client_id: String,
#[serde(
serialize_with = "serialize_secret",
deserialize_with = "deserialize_secret"
)]
pub client_secret: SecretString,
pub auth_url: String,
pub token_url: String,
#[serde(default)]
pub revocation_url: Option<String>,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub flow_type: OAuth2FlowType,
pub additional_params: HashMap<String, String>,
#[serde(default)]
pub security_level: SecurityLevel,
#[cfg(feature = "dpop")]
#[serde(default)]
pub dpop_config: Option<DpopConfig>,
#[serde(default)]
pub mcp_resource_uri: Option<String>,
#[serde(default = "default_auto_resource_indicators")]
pub auto_resource_indicators: bool,
}
fn serialize_secret<S>(secret: &SecretString, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let is_empty = secret.expose_secret().is_empty();
if is_empty {
serializer.serialize_str("")
} else {
serializer.serialize_str("[REDACTED]")
}
}
fn deserialize_secret<'de, D>(deserializer: D) -> Result<SecretString, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = serde::Deserialize::deserialize(deserializer)?;
Ok(SecretString::new(s.into()))
}
fn default_auto_resource_indicators() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum OAuth2FlowType {
AuthorizationCode,
ClientCredentials,
DeviceCode,
Implicit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2AuthResult {
pub auth_url: String,
pub state: String,
pub code_verifier: Option<String>,
pub device_code: Option<String>,
pub user_code: Option<String>,
pub verification_uri: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtectedResourceMetadata {
pub resource: String,
pub authorization_server: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub scopes_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bearer_methods_supported: Option<Vec<BearerTokenMethod>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource_documentation: Option<String>,
#[serde(flatten)]
pub additional_metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum BearerTokenMethod {
#[default]
Header,
Query,
Body,
}
#[derive(Debug, Clone)]
pub struct McpResourceRegistry {
resources: Arc<RwLock<HashMap<String, ProtectedResourceMetadata>>>,
default_auth_server: String,
base_resource_uri: String,
}
impl McpResourceRegistry {
#[must_use]
pub fn new(base_resource_uri: String, auth_server: String) -> Self {
Self {
resources: Arc::new(RwLock::new(HashMap::new())),
default_auth_server: auth_server,
base_resource_uri,
}
}
pub async fn register_resource(
&self,
resource_id: &str,
scopes: Vec<String>,
documentation: Option<String>,
) -> McpResult<()> {
let resource_uri = format!(
"{}/{}",
self.base_resource_uri.trim_end_matches('/'),
resource_id
);
let metadata = ProtectedResourceMetadata {
resource: resource_uri.clone(),
authorization_server: self.default_auth_server.clone(),
scopes_supported: Some(scopes),
bearer_methods_supported: Some(vec![
BearerTokenMethod::Header, BearerTokenMethod::Body, ]),
resource_documentation: documentation,
additional_metadata: HashMap::new(),
};
self.resources.write().await.insert(resource_uri, metadata);
Ok(())
}
pub async fn get_resource_metadata(
&self,
resource_uri: &str,
) -> Option<ProtectedResourceMetadata> {
self.resources.read().await.get(resource_uri).cloned()
}
pub async fn list_resources(&self) -> Vec<String> {
self.resources.read().await.keys().cloned().collect()
}
pub async fn generate_well_known_metadata(&self) -> HashMap<String, ProtectedResourceMetadata> {
self.resources.read().await.clone()
}
pub async fn validate_scope_for_resource(
&self,
resource_uri: &str,
token_scopes: &[String],
) -> McpResult<bool> {
if let Some(metadata) = self.get_resource_metadata(resource_uri).await {
if let Some(required_scopes) = metadata.scopes_supported {
let has_required_scope = required_scopes
.iter()
.any(|scope| token_scopes.contains(scope));
Ok(has_required_scope)
} else {
Ok(true)
}
} else {
Err(McpError::invalid_params(format!(
"Unknown resource: {}",
resource_uri
)))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientRegistrationRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uris: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub application_type: Option<ApplicationType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub contacts: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_version: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id_issued_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret_expires_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uris: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub application_type: Option<ApplicationType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum ApplicationType {
#[default]
Web,
Native,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientRegistrationError {
pub error: ClientRegistrationErrorCode,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ClientRegistrationErrorCode {
InvalidRedirectUri,
InvalidClientMetadata,
InvalidSoftwareStatement,
UnapprovedSoftwareStatement,
}
#[derive(Debug, Clone)]
pub struct DynamicClientRegistration {
registration_endpoint: String,
default_application_type: ApplicationType,
default_grant_types: Vec<String>,
default_response_types: Vec<String>,
client: reqwest::Client,
}
impl DynamicClientRegistration {
#[must_use]
pub fn new(registration_endpoint: String) -> Self {
Self {
registration_endpoint,
default_application_type: ApplicationType::Web,
default_grant_types: vec!["authorization_code".to_string()],
default_response_types: vec!["code".to_string()],
client: reqwest::Client::new(),
}
}
pub async fn register_client(
&self,
request: ClientRegistrationRequest,
) -> McpResult<ClientRegistrationResponse> {
let mut registration_request = request;
if registration_request.application_type.is_none() {
registration_request.application_type = Some(self.default_application_type.clone());
}
if registration_request.grant_types.is_none() {
registration_request.grant_types = Some(self.default_grant_types.clone());
}
if registration_request.response_types.is_none() {
registration_request.response_types = Some(self.default_response_types.clone());
}
let response = self
.client
.post(&self.registration_endpoint)
.header("Content-Type", "application/json")
.json(®istration_request)
.send()
.await
.map_err(|e| McpError::invalid_params(format!("Registration request failed: {}", e)))?;
if response.status().is_success() {
let registration_response: ClientRegistrationResponse =
response.json().await.map_err(|e| {
McpError::invalid_params(format!("Invalid registration response: {}", e))
})?;
Ok(registration_response)
} else {
let error_response: ClientRegistrationError = response
.json()
.await
.map_err(|e| McpError::invalid_params(format!("Invalid error response: {}", e)))?;
Err(McpError::invalid_params(format!(
"Client registration failed: {} - {}",
error_response.error as u32,
error_response.error_description.unwrap_or_default()
)))
}
}
#[must_use]
pub fn create_mcp_client_request(
client_name: &str,
redirect_uris: Vec<String>,
mcp_server_uri: &str,
) -> ClientRegistrationRequest {
ClientRegistrationRequest {
redirect_uris: Some(redirect_uris),
response_types: Some(vec!["code".to_string()]),
grant_types: Some(vec!["authorization_code".to_string()]),
application_type: Some(ApplicationType::Web),
client_name: Some(format!("MCP Client: {}", client_name)),
client_uri: Some(mcp_server_uri.to_string()),
scope: Some(
"mcp:tools:read mcp:tools:execute mcp:resources:read mcp:prompts:read".to_string(),
),
software_id: Some("turbomcp".to_string()),
software_version: Some(env!("CARGO_PKG_VERSION").to_string()),
logo_uri: None,
contacts: None,
tos_uri: None,
policy_uri: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider_type: ProviderType,
pub default_scopes: Vec<String>,
pub refresh_behavior: RefreshBehavior,
pub userinfo_endpoint: Option<String>,
pub additional_params: HashMap<String, String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProviderType {
Google,
Microsoft,
GitHub,
GitLab,
Apple,
Okta,
Auth0,
Keycloak,
Generic,
Custom(String),
}
#[derive(Debug, Clone)]
pub enum RefreshBehavior {
Proactive,
Reactive,
Custom,
}