use super::traits::{AuthContext, AuthProvider, TokenValidator};
use crate::error::{Error, ErrorCode, Result};
use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub type TokenValidatorFn =
Box<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<AuthContext>> + Send>> + Send + Sync>;
#[derive(Clone, Debug)]
pub struct ProxyProviderConfig {
pub upstream_url: String,
pub introspection_endpoint: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub enable_cache: bool,
pub cache_ttl: u64,
}
impl Default for ProxyProviderConfig {
fn default() -> Self {
Self {
upstream_url: String::new(),
introspection_endpoint: None,
client_id: None,
client_secret: None,
enable_cache: true,
cache_ttl: 300,
}
}
}
pub struct ProxyProvider {
config: ProxyProviderConfig,
token_validator: Option<TokenValidatorFn>,
validator: Option<Arc<dyn TokenValidator>>,
}
impl std::fmt::Debug for ProxyProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProxyProvider")
.field("config", &self.config)
.field("token_validator", &self.token_validator.is_some())
.field("validator", &self.validator.is_some())
.finish()
}
}
impl ProxyProvider {
pub fn new(config: ProxyProviderConfig) -> Self {
Self {
config,
token_validator: None,
validator: None,
}
}
pub fn with_upstream(upstream_url: impl Into<String>) -> Self {
Self::new(ProxyProviderConfig {
upstream_url: upstream_url.into(),
..Default::default()
})
}
pub fn with_validator_fn<F, Fut>(mut self, validator: F) -> Self
where
F: Fn(String) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<AuthContext>> + Send + 'static,
{
self.token_validator = Some(Box::new(move |token| Box::pin(validator(token))));
self
}
pub fn with_validator(mut self, validator: Arc<dyn TokenValidator>) -> Self {
self.validator = Some(validator);
self
}
pub fn introspection_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.config.introspection_endpoint = Some(endpoint.into());
self
}
pub fn client_credentials(
mut self,
client_id: impl Into<String>,
client_secret: impl Into<String>,
) -> Self {
self.config.client_id = Some(client_id.into());
self.config.client_secret = Some(client_secret.into());
self
}
pub fn cache(mut self, enable: bool) -> Self {
self.config.enable_cache = enable;
self
}
fn extract_bearer_token(authorization_header: Option<&str>) -> Option<String> {
authorization_header?
.strip_prefix("Bearer ")
.map(|s| s.to_string())
}
async fn validate_token_internal(&self, token: String) -> Result<AuthContext> {
if let Some(ref validator_fn) = self.token_validator {
return validator_fn(token).await;
}
if let Some(ref validator) = self.validator {
return validator.validate(&token).await;
}
self.introspect_token(token).await
}
async fn introspect_token(&self, _token: String) -> Result<AuthContext> {
Err(Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
"Token introspection not yet implemented. Please provide a custom validator.",
))
}
}
#[async_trait]
impl AuthProvider for ProxyProvider {
async fn validate_request(
&self,
authorization_header: Option<&str>,
) -> Result<Option<AuthContext>> {
let Some(token) = Self::extract_bearer_token(authorization_header) else {
return Ok(None); };
match self.validate_token_internal(token).await {
Ok(auth_context) => {
if auth_context.is_expired() {
return Err(Error::protocol(ErrorCode::INVALID_REQUEST, "Token expired"));
}
Ok(Some(auth_context))
},
Err(e) => Err(e),
}
}
fn auth_scheme(&self) -> &'static str {
"Bearer"
}
}
#[async_trait]
impl TokenValidator for ProxyProvider {
async fn validate(&self, token: &str) -> Result<AuthContext> {
self.validate_token_internal(token.to_string()).await
}
}
#[derive(Debug, Clone)]
pub struct NoOpAuthProvider;
#[async_trait]
impl AuthProvider for NoOpAuthProvider {
async fn validate_request(
&self,
_authorization_header: Option<&str>,
) -> Result<Option<AuthContext>> {
Ok(Some(AuthContext {
subject: "dev-user".to_string(),
scopes: vec![
"read".to_string(),
"write".to_string(),
"admin".to_string(),
"mcp:tools:use".to_string(),
],
claims: Default::default(),
token: None,
client_id: Some("dev-client".to_string()),
expires_at: None,
authenticated: true,
}))
}
fn is_required(&self) -> bool {
false }
}
#[derive(Debug)]
pub struct OptionalAuthProvider<P: AuthProvider> {
inner: P,
}
impl<P: AuthProvider> OptionalAuthProvider<P> {
pub fn new(provider: P) -> Self {
Self { inner: provider }
}
}
#[async_trait]
impl<P: AuthProvider> AuthProvider for OptionalAuthProvider<P> {
async fn validate_request(
&self,
authorization_header: Option<&str>,
) -> Result<Option<AuthContext>> {
self.inner.validate_request(authorization_header).await
}
fn auth_scheme(&self) -> &'static str {
self.inner.auth_scheme()
}
fn is_required(&self) -> bool {
false }
}