use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::time::SystemTime;
use oauth2::RefreshToken;
use serde::{Deserialize, Serialize};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
use super::config::AuthProviderType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
pub username: String,
pub email: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct TokenInfo {
pub access_token: String,
pub token_type: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub scope: Option<String>,
}
impl std::fmt::Debug for TokenInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenInfo")
.field("access_token", &"[REDACTED]")
.field("token_type", &self.token_type)
.field(
"refresh_token",
&self.refresh_token.as_ref().map(|_| "[REDACTED]"),
)
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.finish()
}
}
pub trait AuthProvider: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn provider_type(&self) -> AuthProviderType;
fn authenticate(
&self,
credentials: AuthCredentials,
) -> Pin<Box<dyn Future<Output = McpResult<crate::context::AuthContext>> + Send + '_>>;
fn validate_token(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<crate::context::AuthContext>> + Send + '_>>;
fn refresh_token(
&self,
refresh_token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<TokenInfo>> + Send + '_>>;
fn revoke_token(&self, token: &str)
-> Pin<Box<dyn Future<Output = McpResult<()>> + Send + '_>>;
fn get_user_info(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<UserInfo>> + Send + '_>>;
}
#[derive(Clone, Serialize, Deserialize)]
pub enum AuthCredentials {
UsernamePassword {
username: String,
password: String,
},
ApiKey {
key: String,
},
OAuth2Code {
code: String,
state: String,
},
JwtToken {
token: String,
},
Custom {
data: HashMap<String, serde_json::Value>,
},
}
impl std::fmt::Debug for AuthCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthCredentials::UsernamePassword { username, .. } => f
.debug_struct("AuthCredentials::UsernamePassword")
.field("username", username)
.field("password", &"[REDACTED]")
.finish(),
AuthCredentials::ApiKey { .. } => f
.debug_struct("AuthCredentials::ApiKey")
.field("key", &"[REDACTED]")
.finish(),
AuthCredentials::OAuth2Code { state, .. } => f
.debug_struct("AuthCredentials::OAuth2Code")
.field("code", &"[REDACTED]")
.field("state", state)
.finish(),
AuthCredentials::JwtToken { .. } => f
.debug_struct("AuthCredentials::JwtToken")
.field("token", &"[REDACTED]")
.finish(),
AuthCredentials::Custom { .. } => f
.debug_struct("AuthCredentials::Custom")
.field("data", &"[REDACTED]")
.finish(),
}
}
}
pub trait TokenStorage: Send + Sync + std::fmt::Debug {
fn store_access_token(
&self,
user_id: &str,
token: &AccessToken,
) -> impl Future<Output = McpResult<()>> + Send;
fn get_access_token(
&self,
user_id: &str,
) -> impl Future<Output = McpResult<Option<AccessToken>>> + Send;
fn store_refresh_token(
&self,
user_id: &str,
token: &RefreshToken,
) -> impl Future<Output = McpResult<()>> + Send;
fn get_refresh_token(
&self,
user_id: &str,
) -> impl Future<Output = McpResult<Option<RefreshToken>>> + Send;
fn revoke_tokens(&self, user_id: &str) -> impl Future<Output = McpResult<()>> + Send;
fn list_users(&self) -> impl Future<Output = McpResult<Vec<String>>> + Send;
}
#[derive(Clone)]
pub struct AccessToken {
pub(crate) token: String,
pub(crate) expires_at: Option<SystemTime>,
pub(crate) scopes: Vec<String>,
pub(crate) metadata: HashMap<String, serde_json::Value>,
}
impl std::fmt::Debug for AccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AccessToken")
.field("token", &"[REDACTED]")
.field("expires_at", &self.expires_at)
.field("scopes", &self.scopes)
.field("metadata", &self.metadata)
.finish()
}
}
impl AccessToken {
#[must_use]
pub fn new(
token: String,
expires_at: Option<SystemTime>,
scopes: Vec<String>,
metadata: HashMap<String, serde_json::Value>,
) -> Self {
Self {
token,
expires_at,
scopes,
metadata,
}
}
#[must_use]
pub fn token(&self) -> &str {
&self.token
}
#[must_use]
pub fn expires_at(&self) -> Option<SystemTime> {
self.expires_at
}
#[must_use]
pub fn scopes(&self) -> &[String] {
&self.scopes
}
#[must_use]
pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
&self.metadata
}
}
pub trait AuthMiddleware: Send + Sync {
fn extract_token(
&self,
headers: &HashMap<String, String>,
) -> impl Future<Output = Option<String>> + Send;
fn handle_auth_failure(&self, error: McpError) -> impl Future<Output = McpResult<()>> + Send;
}
#[derive(Debug, Clone)]
pub struct DefaultAuthMiddleware;
impl AuthMiddleware for DefaultAuthMiddleware {
fn extract_token(
&self,
headers: &HashMap<String, String>,
) -> impl Future<Output = Option<String>> + Send {
let headers = headers.clone();
async move {
if let Some(auth_header) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
if let Some(token) = auth_header.strip_prefix("Bearer ") {
return Some(token.to_string());
}
if let Some(token) = auth_header.strip_prefix("ApiKey ") {
return Some(token.to_string());
}
}
if let Some(api_key) = headers
.get("x-api-key")
.or_else(|| headers.get("X-API-Key"))
{
return Some(api_key.clone());
}
None
}
}
async fn handle_auth_failure(&self, error: McpError) -> McpResult<()> {
tracing::warn!("Authentication failed: {}", error);
Err(error)
}
}