use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::future::BoxFuture;
use tower_service::Service;
use turbomcp_protocol::McpError;
use crate::AuthProvider;
use crate::context::AuthContext;
use super::AuthLayerConfig;
#[derive(Debug, Clone)]
pub struct AuthService<S, P> {
inner: S,
provider: Arc<P>,
config: AuthLayerConfig,
}
impl<S, P> AuthService<S, P>
where
P: AuthProvider,
{
pub fn new(inner: S, provider: Arc<P>, config: AuthLayerConfig) -> Self {
Self {
inner,
provider,
config,
}
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.inner
}
fn extract_token(&self, req: &http::Request<()>) -> Option<String> {
if let Some(auth_header) = req.headers().get(&self.config.auth_header)
&& let Ok(value) = auth_header.to_str()
{
if let Some(token) = value.strip_prefix("Bearer ") {
return Some(token.to_string());
}
if let Some(token) = value.strip_prefix("ApiKey ") {
return Some(token.to_string());
}
}
if let Some(api_key) = req.headers().get(&self.config.api_key_header)
&& let Ok(value) = api_key.to_str()
{
return Some(value.to_string());
}
None
}
}
#[derive(Debug)]
pub struct AuthenticatedRequest<B> {
pub body: B,
pub auth_context: Option<AuthContext>,
pub method: Option<String>,
}
impl<B> AuthenticatedRequest<B> {
pub fn new(body: B, auth_context: Option<AuthContext>, method: Option<String>) -> Self {
Self {
body,
auth_context,
method,
}
}
pub fn auth(&self) -> Option<&AuthContext> {
self.auth_context.as_ref()
}
pub fn is_authenticated(&self) -> bool {
self.auth_context.is_some()
}
pub fn into_body(self) -> B {
self.body
}
}
pub type AuthServiceFuture<T, E> = BoxFuture<'static, Result<T, E>>;
impl<S, P, B, ResBody> Service<http::Request<B>> for AuthService<S, P>
where
S: Service<http::Request<B>, Response = http::Response<ResBody>> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<McpError>,
P: AuthProvider + Send + Sync + 'static,
B: Send + 'static,
{
type Response = http::Response<ResBody>;
type Error = McpError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let path = req.uri().path();
let method = path
.strip_prefix("/")
.unwrap_or(path)
.split('/')
.collect::<Vec<_>>()
.join("/");
if self.config.should_bypass(&method) {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
return Box::pin(async move { inner.call(req).await.map_err(Into::into) });
}
let (parts, body) = req.into_parts();
let token_req = http::Request::from_parts(parts.clone(), ());
let token = self.extract_token(&token_req);
match token {
Some(token) => {
let provider = Arc::clone(&self.provider);
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
let allow_anonymous = self.config.allow_anonymous;
Box::pin(async move {
match provider.validate_token(&token).await {
Ok(auth_context) => {
let mut req = http::Request::from_parts(parts, body);
req.extensions_mut().insert(auth_context);
inner.call(req).await.map_err(Into::into)
}
Err(e) => {
if allow_anonymous {
let req = http::Request::from_parts(parts, body);
inner.call(req).await.map_err(Into::into)
} else {
Err(e)
}
}
}
})
}
None => {
if self.config.allow_anonymous {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
let req = http::Request::from_parts(parts, body);
Box::pin(async move { inner.call(req).await.map_err(Into::into) })
} else {
Box::pin(async move {
Err(McpError::authentication("No authentication token provided"))
})
}
}
}
}
}
impl<S, P> Service<AuthenticatedRequest<serde_json::Value>> for AuthService<S, P>
where
S: Service<serde_json::Value, Response = serde_json::Value> + Clone + Send + 'static,
S::Future: Send,
S::Error: Into<McpError>,
P: AuthProvider + Send + Sync + 'static,
{
type Response = serde_json::Value;
type Error = McpError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: AuthenticatedRequest<serde_json::Value>) -> Self::Future {
if let Some(ref method) = req.method
&& self.config.should_bypass(method)
{
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
return Box::pin(async move { inner.call(req.body).await.map_err(Into::into) });
}
if req.is_authenticated() {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
return Box::pin(async move { inner.call(req.body).await.map_err(Into::into) });
}
if self.config.allow_anonymous {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move { inner.call(req.body).await.map_err(Into::into) })
} else {
Box::pin(async move {
Err(McpError::authentication(
"Authentication required for this operation",
))
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::ApiKeyProvider;
#[test]
fn test_authenticated_request() {
let body = serde_json::json!({"test": "value"});
let req = AuthenticatedRequest::new(body, None, Some("test/method".to_string()));
assert!(!req.is_authenticated());
assert!(req.auth().is_none());
}
#[test]
fn test_authenticated_request_with_context() {
use crate::UserInfo;
use std::collections::HashMap;
let body = serde_json::json!({"test": "value"});
let user = UserInfo {
id: "test-user".to_string(),
username: "testuser".to_string(),
email: None,
display_name: None,
avatar_url: None,
metadata: HashMap::new(),
};
let auth_ctx = AuthContext::builder()
.subject("test-user")
.user(user)
.provider("test")
.build()
.unwrap();
let req = AuthenticatedRequest::new(body, Some(auth_ctx), None);
assert!(req.is_authenticated());
assert!(req.auth().is_some());
assert_eq!(req.auth().unwrap().sub, "test-user");
}
#[test]
fn test_auth_service_creation() {
let provider = Arc::new(ApiKeyProvider::new("test-provider".to_string()));
let config = AuthLayerConfig::default();
let mock_service = tower::service_fn(|_req: serde_json::Value| async move {
Ok::<_, McpError>(serde_json::json!({"result": "ok"}))
});
let _service = AuthService::new(mock_service, provider, config);
}
}