use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::Response,
};
use crate::api::jwt::{Claims, JwtManager};
use crate::Error;
#[derive(Clone)]
pub enum McpAuth {
Disabled,
Jwt(Arc<JwtManager>),
}
impl std::fmt::Debug for McpAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McpAuth::Disabled => write!(f, "McpAuth::Disabled"),
McpAuth::Jwt(_) => write!(f, "McpAuth::Jwt(<manager>)"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Scope {
Read,
Write,
}
impl Scope {
pub fn token(self) -> &'static str {
match self {
Scope::Read => "mcp:read",
Scope::Write => "mcp:write",
}
}
pub fn for_method(method: &str) -> Self {
match method {
"initialize"
| "ping"
| "tools/list"
| "resources/list"
| "resources/read" => Scope::Read,
_ => Scope::Write,
}
}
}
impl McpAuth {
pub fn check(&self, header_value: Option<&str>, scope: Scope) -> Result<Option<Claims>, AuthError> {
match self {
McpAuth::Disabled => Ok(None),
McpAuth::Jwt(mgr) => {
let token = header_value
.and_then(|v| v.strip_prefix("Bearer ").or_else(|| v.strip_prefix("bearer ")))
.ok_or(AuthError::MissingToken)?;
let claims = mgr
.validate_token(token)
.map_err(AuthError::InvalidToken)?;
let needed = scope.token();
if !claims.scopes.iter().any(|s| s == needed) {
return Err(AuthError::InsufficientScope(needed));
}
Ok(Some(claims))
}
}
}
}
#[derive(Debug)]
pub enum AuthError {
MissingToken,
InvalidToken(Error),
InsufficientScope(&'static str),
}
impl AuthError {
pub fn status(&self) -> StatusCode {
match self {
AuthError::MissingToken | AuthError::InvalidToken(_) => StatusCode::UNAUTHORIZED,
AuthError::InsufficientScope(_) => StatusCode::FORBIDDEN,
}
}
pub fn message(&self) -> String {
match self {
AuthError::MissingToken => "missing bearer token".to_string(),
AuthError::InvalidToken(e) => format!("invalid token: {e}"),
AuthError::InsufficientScope(s) => format!("missing scope: {s}"),
}
}
}
pub fn bind_safety_check(addr: SocketAddr, auth: &McpAuth) -> Result<(), String> {
let is_loopback = match addr.ip() {
IpAddr::V4(v4) => v4.is_loopback(),
IpAddr::V6(v6) => v6.is_loopback(),
};
if is_loopback {
return Ok(());
}
match auth {
McpAuth::Disabled => Err(format!(
"refusing to bind MCP endpoint on non-loopback address {addr}: \
authentication is disabled. Configure McpAuth::Jwt(...) or bind \
to 127.0.0.1 / [::1] / a Unix socket instead."
)),
McpAuth::Jwt(_) => Ok(()),
}
}
pub async fn require_read_scope(
State(auth): State<McpAuth>,
req: Request,
next: Next,
) -> Result<Response, (StatusCode, String)> {
enforce(&auth, Scope::Read, req, next).await
}
pub async fn require_write_scope(
State(auth): State<McpAuth>,
req: Request,
next: Next,
) -> Result<Response, (StatusCode, String)> {
enforce(&auth, Scope::Write, req, next).await
}
async fn enforce(
auth: &McpAuth,
scope: Scope,
req: Request,
next: Next,
) -> Result<Response, (StatusCode, String)> {
let header = req
.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
match auth.check(header, scope) {
Ok(_) => Ok(next.run(req).await),
Err(e) => Err((e.status(), e.message())),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
use uuid::Uuid;
#[test]
fn disabled_check_passes() {
let auth = McpAuth::Disabled;
assert!(auth.check(None, Scope::Write).unwrap().is_none());
}
fn jwt_with_scopes(scopes: &[&str]) -> (Arc<JwtManager>, String) {
let mgr = Arc::new(JwtManager::new(b"test-secret"));
let token = mgr
.generate_token("u".into(), "t".into(), Uuid::new_v4())
.unwrap();
let mut claims = mgr.validate_token(&token).unwrap();
claims.scopes = scopes.iter().map(|s| (*s).to_string()).collect();
use jsonwebtoken::{encode, EncodingKey, Header};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(b"test-secret"),
)
.unwrap();
(mgr, token)
}
#[test]
fn jwt_read_scope_passes_for_read() {
let (mgr, token) = jwt_with_scopes(&["mcp:read"]);
let auth = McpAuth::Jwt(mgr);
let header = format!("Bearer {token}");
assert!(auth.check(Some(&header), Scope::Read).is_ok());
}
#[test]
fn jwt_read_scope_fails_for_write() {
let (mgr, token) = jwt_with_scopes(&["mcp:read"]);
let auth = McpAuth::Jwt(mgr);
let header = format!("Bearer {token}");
let err = auth.check(Some(&header), Scope::Write).unwrap_err();
match err {
AuthError::InsufficientScope("mcp:write") => {}
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn jwt_missing_token_errors() {
let auth = McpAuth::Jwt(Arc::new(JwtManager::new(b"test-secret")));
let err = auth.check(None, Scope::Read).unwrap_err();
assert!(matches!(err, AuthError::MissingToken));
}
#[test]
fn jwt_invalid_token_errors() {
let auth = McpAuth::Jwt(Arc::new(JwtManager::new(b"test-secret")));
let err = auth.check(Some("Bearer not-a-token"), Scope::Read).unwrap_err();
assert!(matches!(err, AuthError::InvalidToken(_)));
}
#[test]
fn loopback_bind_always_ok() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9000);
assert!(bind_safety_check(addr, &McpAuth::Disabled).is_ok());
}
#[test]
fn public_bind_disabled_auth_refused() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9000);
let err = bind_safety_check(addr, &McpAuth::Disabled).unwrap_err();
assert!(err.contains("non-loopback"), "got: {err}");
}
#[test]
fn public_bind_jwt_ok() {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 9000);
let auth = McpAuth::Jwt(Arc::new(JwtManager::new(b"strong-secret")));
assert!(bind_safety_check(addr, &auth).is_ok());
}
#[test]
fn scope_for_method_routing() {
assert_eq!(Scope::for_method("initialize"), Scope::Read);
assert_eq!(Scope::for_method("tools/list"), Scope::Read);
assert_eq!(Scope::for_method("tools/call"), Scope::Write);
assert_eq!(Scope::for_method("custom/whatever"), Scope::Write);
}
}