use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::time::SystemTime;
use oauth2::RefreshToken;
use serde::{Deserialize, Serialize};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
use super::config::AuthProviderType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
pub username: String,
pub email: Option<String>,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct TokenInfo {
pub access_token: String,
pub token_type: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub scope: Option<String>,
#[serde(default, with = "system_time_millis")]
pub issued_at: Option<SystemTime>,
}
impl TokenInfo {
#[must_use]
pub fn expires_at(&self) -> Option<SystemTime> {
let issued = self.issued_at?;
let lifetime = self.expires_in?;
issued.checked_add(std::time::Duration::from_secs(lifetime))
}
#[must_use]
pub fn is_expired_with_skew(&self, skew: std::time::Duration) -> bool {
match self.expires_at() {
Some(expiry) => match expiry.checked_sub(skew) {
Some(threshold) => SystemTime::now() >= threshold,
None => true,
},
None => false,
}
}
#[must_use]
pub fn is_expired(&self) -> bool {
self.is_expired_with_skew(std::time::Duration::from_secs(60))
}
}
impl std::fmt::Debug for TokenInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenInfo")
.field("access_token", &"[REDACTED]")
.field("token_type", &self.token_type)
.field(
"refresh_token",
&self.refresh_token.as_ref().map(|_| "[REDACTED]"),
)
.field("expires_in", &self.expires_in)
.field("issued_at", &self.issued_at)
.field("scope", &self.scope)
.finish()
}
}
mod system_time_millis {
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Deserializer, Serializer};
pub fn serialize<S: Serializer>(value: &Option<SystemTime>, ser: S) -> Result<S::Ok, S::Error> {
match value {
Some(t) => {
let millis = t
.duration_since(UNIX_EPOCH)
.map_err(serde::ser::Error::custom)?
.as_millis() as u64;
ser.serialize_some(&millis)
}
None => ser.serialize_none(),
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<Option<SystemTime>, D::Error> {
Option::<u64>::deserialize(de)
.map(|opt| opt.map(|millis| UNIX_EPOCH + Duration::from_millis(millis)))
}
}
pub trait AuthProvider: Send + Sync + std::fmt::Debug {
fn name(&self) -> &str;
fn provider_type(&self) -> AuthProviderType;
fn authenticate(
&self,
credentials: AuthCredentials,
) -> Pin<Box<dyn Future<Output = McpResult<crate::context::AuthContext>> + Send + '_>>;
fn validate_token(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<crate::context::AuthContext>> + Send + '_>>;
fn refresh_token(
&self,
refresh_token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<TokenInfo>> + Send + '_>>;
fn revoke_token(&self, token: &str)
-> Pin<Box<dyn Future<Output = McpResult<()>> + Send + '_>>;
fn get_user_info(
&self,
token: &str,
) -> Pin<Box<dyn Future<Output = McpResult<UserInfo>> + Send + '_>>;
}
#[derive(Clone, Serialize, Deserialize)]
pub enum AuthCredentials {
UsernamePassword {
username: String,
password: String,
},
ApiKey {
key: String,
},
OAuth2Code {
code: String,
state: String,
},
JwtToken {
token: String,
},
Custom {
data: HashMap<String, serde_json::Value>,
},
}
impl std::fmt::Debug for AuthCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AuthCredentials::UsernamePassword { username, .. } => f
.debug_struct("AuthCredentials::UsernamePassword")
.field("username", username)
.field("password", &"[REDACTED]")
.finish(),
AuthCredentials::ApiKey { .. } => f
.debug_struct("AuthCredentials::ApiKey")
.field("key", &"[REDACTED]")
.finish(),
AuthCredentials::OAuth2Code { state, .. } => f
.debug_struct("AuthCredentials::OAuth2Code")
.field("code", &"[REDACTED]")
.field("state", state)
.finish(),
AuthCredentials::JwtToken { .. } => f
.debug_struct("AuthCredentials::JwtToken")
.field("token", &"[REDACTED]")
.finish(),
AuthCredentials::Custom { .. } => f
.debug_struct("AuthCredentials::Custom")
.field("data", &"[REDACTED]")
.finish(),
}
}
}
pub trait TokenStorage: Send + Sync + std::fmt::Debug {
fn store_access_token(
&self,
user_id: &str,
token: &AccessToken,
) -> impl Future<Output = McpResult<()>> + Send;
fn get_access_token(
&self,
user_id: &str,
) -> impl Future<Output = McpResult<Option<AccessToken>>> + Send;
fn store_refresh_token(
&self,
user_id: &str,
token: &RefreshToken,
) -> impl Future<Output = McpResult<()>> + Send;
fn get_refresh_token(
&self,
user_id: &str,
) -> impl Future<Output = McpResult<Option<RefreshToken>>> + Send;
fn revoke_tokens(&self, user_id: &str) -> impl Future<Output = McpResult<()>> + Send;
fn list_users(&self) -> impl Future<Output = McpResult<Vec<String>>> + Send;
}
#[derive(Clone)]
pub struct AccessToken {
pub(crate) token: String,
pub(crate) expires_at: Option<SystemTime>,
pub(crate) scopes: Vec<String>,
pub(crate) metadata: HashMap<String, serde_json::Value>,
}
impl std::fmt::Debug for AccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AccessToken")
.field("token", &"[REDACTED]")
.field("expires_at", &self.expires_at)
.field("scopes", &self.scopes)
.field("metadata", &self.metadata)
.finish()
}
}
impl AccessToken {
#[must_use]
pub fn new(
token: String,
expires_at: Option<SystemTime>,
scopes: Vec<String>,
metadata: HashMap<String, serde_json::Value>,
) -> Self {
Self {
token,
expires_at,
scopes,
metadata,
}
}
#[must_use]
pub fn token(&self) -> &str {
&self.token
}
#[must_use]
pub fn expires_at(&self) -> Option<SystemTime> {
self.expires_at
}
#[must_use]
pub fn scopes(&self) -> &[String] {
&self.scopes
}
#[must_use]
pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
&self.metadata
}
}
pub trait AuthMiddleware: Send + Sync {
fn extract_token(
&self,
headers: &HashMap<String, String>,
) -> impl Future<Output = Option<String>> + Send;
fn handle_auth_failure(&self, error: McpError) -> impl Future<Output = McpResult<()>> + Send;
}
#[derive(Debug, Clone)]
pub struct DefaultAuthMiddleware;
impl AuthMiddleware for DefaultAuthMiddleware {
fn extract_token(
&self,
headers: &HashMap<String, String>,
) -> impl Future<Output = Option<String>> + Send {
let headers = headers.clone();
async move {
if let Some(auth_header) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
if let Some(token) = auth_header.strip_prefix("Bearer ") {
return Some(token.to_string());
}
if let Some(token) = auth_header.strip_prefix("ApiKey ") {
return Some(token.to_string());
}
}
if let Some(api_key) = headers
.get("x-api-key")
.or_else(|| headers.get("X-API-Key"))
{
return Some(api_key.clone());
}
None
}
}
async fn handle_auth_failure(&self, error: McpError) -> McpResult<()> {
tracing::warn!("Authentication failed: {}", error);
Err(error)
}
}