use std::cell::RefCell;
use std::rc::Rc;
use turbomcp_core::auth::{
AuthError, Authenticator, CredentialExtractor, HeaderExtractor, Principal,
};
use worker::{Request, Response};
use super::server::McpServer;
use super::types::{JsonRpcResponse, error_codes};
pub struct WithAuth<A, E = HeaderExtractor>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
E: CredentialExtractor + 'static,
{
server: McpServer,
authenticator: A,
extractor: E,
current_principal: Rc<RefCell<Option<Principal>>>,
skip_auth_methods: Vec<String>,
}
impl<A> WithAuth<A, HeaderExtractor>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
{
pub fn new(server: McpServer, authenticator: A) -> Self {
Self {
server,
authenticator,
extractor: HeaderExtractor,
current_principal: Rc::new(RefCell::new(None)),
skip_auth_methods: vec![
"initialize".to_string(),
"notifications/initialized".to_string(),
"ping".to_string(),
],
}
}
}
impl<A, E> WithAuth<A, E>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
E: CredentialExtractor + 'static,
{
pub fn with_extractor(server: McpServer, authenticator: A, extractor: E) -> Self {
Self {
server,
authenticator,
extractor,
current_principal: Rc::new(RefCell::new(None)),
skip_auth_methods: vec![
"initialize".to_string(),
"notifications/initialized".to_string(),
"ping".to_string(),
],
}
}
pub fn skip_auth_for(mut self, methods: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.skip_auth_methods = methods.into_iter().map(Into::into).collect();
self
}
pub fn also_skip_auth_for(mut self, method: impl Into<String>) -> Self {
self.skip_auth_methods.push(method.into());
self
}
pub fn principal(&self) -> Option<Principal> {
self.current_principal.borrow().clone()
}
pub async fn handle(&self, req: Request) -> worker::Result<Response> {
let request_origin = req.headers().get("origin").ok().flatten();
let origin_ref = request_origin.as_deref();
if req.method() == worker::Method::Options {
return self.server.handle(req).await;
}
let credential = {
let headers = req.headers();
self.extractor
.extract(|name| headers.get(name).ok().flatten())
};
if let Some(cred) = credential {
match self.authenticator.authenticate(&cred).await {
Ok(principal) => {
*self.current_principal.borrow_mut() = Some(principal);
}
Err(e) => {
*self.current_principal.borrow_mut() = None;
return self.auth_error_response(&e, origin_ref);
}
}
}
let response = self.server.handle(req).await;
*self.current_principal.borrow_mut() = None;
response
}
fn auth_error_response(
&self,
error: &AuthError,
request_origin: Option<&str>,
) -> worker::Result<Response> {
let headers = worker::Headers::new();
let origin = request_origin.unwrap_or("*");
let _ = headers.set("Access-Control-Allow-Origin", origin);
if request_origin.is_some() {
let _ = headers.set("Vary", "Origin");
}
let _ = headers.set("Content-Type", "application/json");
let _ = headers.set("WWW-Authenticate", "Bearer");
let response = JsonRpcResponse::error(
None,
error_codes::INTERNAL_ERROR - 5, error.to_string(),
);
let json = serde_json::to_string(&response)
.unwrap_or_else(|_| r#"{"error":"Authentication failed"}"#.to_string());
Response::error(json, 401).map(|r| r.with_headers(headers))
}
}
pub trait AuthExt {
fn with_auth<A>(self, authenticator: A) -> WithAuth<A, HeaderExtractor>
where
A: Authenticator<Error = AuthError> + Clone + 'static;
fn with_auth_extractor<A, E>(self, authenticator: A, extractor: E) -> WithAuth<A, E>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
E: CredentialExtractor + 'static;
}
impl AuthExt for McpServer {
fn with_auth<A>(self, authenticator: A) -> WithAuth<A, HeaderExtractor>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
{
WithAuth::new(self, authenticator)
}
fn with_auth_extractor<A, E>(self, authenticator: A, extractor: E) -> WithAuth<A, E>
where
A: Authenticator<Error = AuthError> + Clone + 'static,
E: CredentialExtractor + 'static,
{
WithAuth::with_extractor(self, authenticator, extractor)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::extra_unused_type_parameters)]
fn _assert_with_auth_compiles<A: Authenticator<Error = AuthError> + Clone + 'static>() {
fn _needs_with_auth<
A: Authenticator<Error = AuthError> + Clone + 'static,
E: CredentialExtractor,
>(
_: WithAuth<A, E>,
) {
}
}
}