Skip to main content

mcp_kit/transport/
auth_layer.rs

1//! Axum middleware that enforces authentication on SSE/HTTP routes.
2//!
3//! When an [`AuthProvider`] is configured on the server, every incoming request
4//! passes through [`auth_middleware`] before reaching the route handlers.
5//! On success the [`AuthenticatedIdentity`] is inserted into the request's
6//! Axum [`Extensions`] so that `sse_handler` / `message_handler` can store it
7//! on the [`Session`].
8//!
9//! [`AuthProvider`]: crate::auth::AuthProvider
10//! [`AuthenticatedIdentity`]: crate::auth::AuthenticatedIdentity
11//! [`Session`]: crate::server::session::Session
12
13use std::sync::Arc;
14
15use axum::{
16    extract::{Request, State as AxumState},
17    http::{header, HeaderMap, StatusCode},
18    middleware::Next,
19    response::{IntoResponse, Response},
20};
21
22use crate::auth::{AuthenticatedIdentity, Credentials, DynAuthProvider};
23
24// ─── mTLS peer certificate extension ─────────────────────────────────────────
25
26/// Newtype wrapper inserted into Axum request extensions by the TLS transport
27/// when the client presents a certificate during the mTLS handshake.
28///
29/// The inner `Vec<u8>` contains the DER-encoded certificate bytes.
30#[cfg(feature = "auth-mtls")]
31#[derive(Clone)]
32pub struct PeerCertificate(pub Vec<u8>);
33
34// ─── Middleware state ──────────────────────────────────────────────────────────
35
36/// Configuration carried by the Axum middleware layer.
37#[derive(Clone)]
38pub struct AuthMiddlewareState {
39    /// The provider used to validate incoming credentials.
40    pub provider: DynAuthProvider,
41    /// When `true`, requests with no credentials are rejected with 401.
42    /// When `false`, unauthenticated requests proceed (identity stays `None`).
43    pub require_auth: bool,
44}
45
46// ─── Middleware function ───────────────────────────────────────────────────────
47
48/// Axum middleware that authenticates incoming requests.
49///
50/// Pass this to `axum::middleware::from_fn_with_state` with an
51/// [`AuthMiddlewareState`] as the state.
52pub async fn auth_middleware(
53    AxumState(auth): AxumState<AuthMiddlewareState>,
54    mut request: Request,
55    next: Next,
56) -> Response {
57    let credentials = extract_credentials(request.headers(), request.extensions());
58
59    if credentials.is_none() {
60        if auth.require_auth {
61            return unauthorized_response(&credentials);
62        }
63        // No credentials and auth is optional — proceed without an identity.
64        return next.run(request).await;
65    }
66
67    if !auth.provider.accepts(&credentials) {
68        if auth.require_auth {
69            return unauthorized_response(&credentials);
70        }
71        return next.run(request).await;
72    }
73
74    match auth.provider.authenticate(&credentials).await {
75        Ok(identity) => {
76            request
77                .extensions_mut()
78                .insert(Arc::new(identity) as Arc<AuthenticatedIdentity>);
79            next.run(request).await
80        }
81        Err(_) => unauthorized_response(&credentials),
82    }
83}
84
85// ─── Credential extraction ────────────────────────────────────────────────────
86
87/// Extract [`Credentials`] from request headers and extensions.
88///
89/// Precedence:
90/// 1. mTLS peer certificate (from TLS handshake, in request extensions)
91/// 2. `Authorization: Bearer <token>`
92/// 3. `Authorization: Basic <b64>`  → decoded to `(username, password)`
93/// 4. `X-Api-Key: <key>`
94/// 5. Falls back to [`Credentials::None`]
95///
96/// The query-param `?api_key=` fallback is handled separately in the SSE
97/// handler, because Axum exposes query params after routing.
98pub fn extract_credentials(
99    headers: &HeaderMap,
100    extensions: &axum::http::Extensions,
101) -> Credentials {
102    // mTLS peer certificate takes highest precedence.
103    #[cfg(feature = "auth-mtls")]
104    if let Some(cert) = extensions.get::<PeerCertificate>() {
105        return Credentials::ClientCertificate {
106            der: cert.0.clone(),
107        };
108    }
109
110    if let Some(auth_value) = headers.get(header::AUTHORIZATION) {
111        if let Ok(auth_str) = auth_value.to_str() {
112            if let Some(token) = auth_str.strip_prefix("Bearer ") {
113                return Credentials::Bearer {
114                    token: token.trim().to_owned(),
115                };
116            }
117            if let Some(encoded) = auth_str.strip_prefix("Basic ") {
118                if let Ok(decoded) = decode_basic(encoded.trim()) {
119                    return decoded;
120                }
121            }
122        }
123    }
124
125    if let Some(key_value) = headers.get("x-api-key") {
126        if let Ok(key) = key_value.to_str() {
127            return Credentials::ApiKey {
128                key: key.trim().to_owned(),
129            };
130        }
131    }
132
133    Credentials::None
134}
135
136fn decode_basic(encoded: &str) -> Result<Credentials, ()> {
137    use std::str;
138
139    let bytes = BASE64_ENGINE.decode(encoded).map_err(|_| ())?;
140    let decoded = str::from_utf8(&bytes).map_err(|_| ())?;
141    let (username, password) = decoded.split_once(':').ok_or(())?;
142    Ok(Credentials::Basic {
143        username: username.to_owned(),
144        password: password.to_owned(),
145    })
146}
147
148// base64 engine (standard alphabet, with padding)
149use base64::engine::general_purpose::STANDARD as BASE64_ENGINE;
150use base64::Engine as _;
151
152// ─── 401 response ─────────────────────────────────────────────────────────────
153
154fn unauthorized_response(credentials: &Credentials) -> Response {
155    let www_auth = match credentials {
156        Credentials::Bearer { .. } | Credentials::None => r#"Bearer realm="mcp""#,
157        Credentials::Basic { .. } => r#"Basic realm="mcp""#,
158        Credentials::ApiKey { .. } => r#"ApiKey realm="mcp""#,
159        _ => r#"Bearer realm="mcp""#,
160    };
161
162    (
163        StatusCode::UNAUTHORIZED,
164        [(header::WWW_AUTHENTICATE, www_auth)],
165        "Unauthorized",
166    )
167        .into_response()
168}