mcp_kit/transport/
auth_layer.rs1use std::sync::Arc;
14
15use axum::{
16 extract::{Request, State as AxumState},
17 http::{header, HeaderMap, StatusCode},
18 middleware::Next,
19 response::{IntoResponse, Response},
20};
21
22use crate::auth::{AuthenticatedIdentity, Credentials, DynAuthProvider};
23
24#[cfg(feature = "auth-mtls")]
31#[derive(Clone)]
32pub struct PeerCertificate(pub Vec<u8>);
33
34#[derive(Clone)]
38pub struct AuthMiddlewareState {
39 pub provider: DynAuthProvider,
41 pub require_auth: bool,
44}
45
46pub async fn auth_middleware(
53 AxumState(auth): AxumState<AuthMiddlewareState>,
54 mut request: Request,
55 next: Next,
56) -> Response {
57 let credentials = extract_credentials(request.headers(), request.extensions());
58
59 if credentials.is_none() {
60 if auth.require_auth {
61 return unauthorized_response(&credentials);
62 }
63 return next.run(request).await;
65 }
66
67 if !auth.provider.accepts(&credentials) {
68 if auth.require_auth {
69 return unauthorized_response(&credentials);
70 }
71 return next.run(request).await;
72 }
73
74 match auth.provider.authenticate(&credentials).await {
75 Ok(identity) => {
76 request
77 .extensions_mut()
78 .insert(Arc::new(identity) as Arc<AuthenticatedIdentity>);
79 next.run(request).await
80 }
81 Err(_) => unauthorized_response(&credentials),
82 }
83}
84
85pub fn extract_credentials(
99 headers: &HeaderMap,
100 extensions: &axum::http::Extensions,
101) -> Credentials {
102 #[cfg(feature = "auth-mtls")]
104 if let Some(cert) = extensions.get::<PeerCertificate>() {
105 return Credentials::ClientCertificate {
106 der: cert.0.clone(),
107 };
108 }
109
110 if let Some(auth_value) = headers.get(header::AUTHORIZATION) {
111 if let Ok(auth_str) = auth_value.to_str() {
112 if let Some(token) = auth_str.strip_prefix("Bearer ") {
113 return Credentials::Bearer {
114 token: token.trim().to_owned(),
115 };
116 }
117 if let Some(encoded) = auth_str.strip_prefix("Basic ") {
118 if let Ok(decoded) = decode_basic(encoded.trim()) {
119 return decoded;
120 }
121 }
122 }
123 }
124
125 if let Some(key_value) = headers.get("x-api-key") {
126 if let Ok(key) = key_value.to_str() {
127 return Credentials::ApiKey {
128 key: key.trim().to_owned(),
129 };
130 }
131 }
132
133 Credentials::None
134}
135
136fn decode_basic(encoded: &str) -> Result<Credentials, ()> {
137 use std::str;
138
139 let bytes = BASE64_ENGINE.decode(encoded).map_err(|_| ())?;
140 let decoded = str::from_utf8(&bytes).map_err(|_| ())?;
141 let (username, password) = decoded.split_once(':').ok_or(())?;
142 Ok(Credentials::Basic {
143 username: username.to_owned(),
144 password: password.to_owned(),
145 })
146}
147
148use base64::engine::general_purpose::STANDARD as BASE64_ENGINE;
150use base64::Engine as _;
151
152fn unauthorized_response(credentials: &Credentials) -> Response {
155 let www_auth = match credentials {
156 Credentials::Bearer { .. } | Credentials::None => r#"Bearer realm="mcp""#,
157 Credentials::Basic { .. } => r#"Basic realm="mcp""#,
158 Credentials::ApiKey { .. } => r#"ApiKey realm="mcp""#,
159 _ => r#"Bearer realm="mcp""#,
160 };
161
162 (
163 StatusCode::UNAUTHORIZED,
164 [(header::WWW_AUTHENTICATE, www_auth)],
165 "Unauthorized",
166 )
167 .into_response()
168}