1use axum::{
15 body::{to_bytes, Body},
16 extract::State,
17 http::{HeaderMap, Request, StatusCode, Uri},
18 middleware::Next,
19 response::IntoResponse,
20};
21use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
22use serde_json::Value;
23
24use crate::{
25 config::PROXY_AUTH_HEADER,
26 error::{AuthenticationError, ServerError, ValidationError},
27 management::API_KEY_PREFIX,
28 state::AppState,
29 Claims,
30};
31
32pub async fn proxy_middleware(
38 State(_state): State<AppState>,
39 req: Request<Body>,
40 next: Next,
41) -> impl IntoResponse {
42 next.run(req).await
45}
46
47pub fn proxy_uri(original_uri: Uri, namespace: &str, sandbox_name: &str) -> Uri {
49 let target_host = format!("sandbox-{}.{}.internal", sandbox_name, namespace);
57
58 let uri_string = if let Some(path_and_query) = original_uri.path_and_query() {
59 format!("http://{}:{}{}", target_host, 8080, path_and_query)
60 } else {
61 format!("http://{}:{}/", target_host, 8080)
62 };
63
64 uri_string
67 .parse()
68 .unwrap_or_else(|_| "http://localhost:8080/".parse().unwrap())
69}
70
71pub async fn logging_middleware(
73 req: Request<Body>,
74 next: Next,
75) -> Result<impl IntoResponse, (StatusCode, String)> {
76 let method = req.method().clone();
77 let uri = req.uri().clone();
78
79 tracing::info!("Request: {} {}", method, uri);
81
82 let response = next.run(req).await;
84
85 tracing::info!("Response: {} {}: {}", method, uri, response.status());
87
88 Ok(response)
89}
90
91pub async fn auth_middleware(
93 State(state): State<AppState>,
94 req: Request<Body>,
95 next: Next,
96) -> Result<impl IntoResponse, ServerError> {
97 if *state.get_config().get_dev_mode() {
99 return Ok(next.run(req).await);
100 }
101
102 let api_key = extract_api_key_from_headers(req.headers())?;
104
105 let claims = validate_token(&api_key, &state)?;
107
108 if claims.namespace == "*" {
110 return Ok(next.run(req).await);
111 }
112
113 let (parts, body) = req.into_parts();
116
117 let bytes = to_bytes(body, usize::MAX)
119 .await
120 .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
121
122 let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
124
125 if claims.namespace != namespace_from_request {
127 return Err(ServerError::AuthorizationError(
128 crate::error::AuthorizationError::AccessDenied(format!(
129 "Token does not have access to namespace '{}'",
130 namespace_from_request
131 )),
132 ));
133 }
134
135 let body = Body::from(bytes);
137 let req = Request::from_parts(parts, body);
138
139 Ok(next.run(req).await)
141}
142
143pub async fn mcp_smart_auth_middleware(
147 State(state): State<AppState>,
148 req: Request<Body>,
149 next: Next,
150) -> Result<impl IntoResponse, ServerError> {
151 if *state.get_config().get_dev_mode() {
153 return Ok(next.run(req).await);
154 }
155
156 let api_key = extract_api_key_from_headers(req.headers())?;
158
159 let claims = validate_token(&api_key, &state)?;
161
162 if claims.namespace == "*" {
164 return Ok(next.run(req).await);
165 }
166
167 let (parts, body) = req.into_parts();
170
171 let bytes = to_bytes(body, usize::MAX)
173 .await
174 .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
175
176 let json_value: serde_json::Value = serde_json::from_slice(&bytes).map_err(|e| {
178 ServerError::ValidationError(crate::error::ValidationError::InvalidInput(format!(
179 "Invalid JSON-RPC request: {}",
180 e
181 )))
182 })?;
183
184 let method = json_value
185 .get("method")
186 .and_then(serde_json::Value::as_str)
187 .unwrap_or("unknown");
188
189 let requires_namespace_validation = matches!(method, "tools/call");
191
192 if requires_namespace_validation {
193 let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
195
196 if claims.namespace != namespace_from_request {
198 return Err(ServerError::AuthorizationError(
199 crate::error::AuthorizationError::AccessDenied(format!(
200 "Token does not have access to namespace '{}'",
201 namespace_from_request
202 )),
203 ));
204 }
205 }
206
207 let body = Body::from(bytes);
209 let req = Request::from_parts(parts, body);
210
211 Ok(next.run(req).await)
213}
214
215fn extract_namespace_from_json_rpc(bytes: &[u8]) -> Result<String, ServerError> {
221 let json_value: Value = serde_json::from_slice(bytes).map_err(|e| {
223 ServerError::ValidationError(ValidationError::InvalidInput(format!(
224 "Invalid JSON-RPC request: {}",
225 e
226 )))
227 })?;
228
229 let method = json_value
231 .get("method")
232 .and_then(Value::as_str)
233 .unwrap_or("unknown");
234
235 let params = json_value.get("params").ok_or_else(|| {
237 ServerError::ValidationError(ValidationError::InvalidInput(
238 "Missing 'params' field in JSON-RPC request".to_string(),
239 ))
240 })?;
241
242 params
244 .get("namespace")
245 .and_then(Value::as_str)
246 .map(String::from)
247 .ok_or_else(|| {
248 ServerError::ValidationError(ValidationError::InvalidInput(format!(
249 "Missing or invalid 'namespace' in params for method '{}'",
250 method
251 )))
252 })
253}
254
255fn extract_api_key_from_headers(headers: &HeaderMap) -> Result<String, ServerError> {
257 if let Some(auth_header) = headers.get(PROXY_AUTH_HEADER) {
259 let auth_value = auth_header.to_str().map_err(|_| {
260 ServerError::Authentication(AuthenticationError::InvalidCredentials(
261 "Invalid authorization header format".to_string(),
262 ))
263 })?;
264
265 if let Some(token) = auth_value.strip_prefix("Bearer ") {
267 return Ok(token.to_string());
268 }
269
270 return Ok(auth_value.to_string());
272 }
273
274 if let Some(auth_header) = headers.get("Authorization") {
276 let auth_value = auth_header.to_str().map_err(|_| {
277 ServerError::Authentication(AuthenticationError::InvalidCredentials(
278 "Invalid authorization header format".to_string(),
279 ))
280 })?;
281
282 if let Some(token) = auth_value.strip_prefix("Bearer ") {
284 return Ok(token.to_string());
285 }
286
287 return Ok(auth_value.to_string());
289 }
290
291 Err(ServerError::Authentication(
292 AuthenticationError::InvalidCredentials("Missing authorization header".to_string()),
293 ))
294}
295
296fn convert_api_key_to_jwt(api_key: &str) -> Result<String, ServerError> {
298 if !api_key.starts_with(API_KEY_PREFIX) {
300 return Err(ServerError::Authentication(
301 AuthenticationError::InvalidCredentials(
302 "Invalid API key format: missing prefix".to_string(),
303 ),
304 ));
305 }
306
307 Ok(api_key[API_KEY_PREFIX.len()..].to_string())
309}
310
311fn get_server_key(state: &AppState) -> Result<String, ServerError> {
313 match state.get_config().get_key() {
316 Some(key) => Ok(key.clone()),
317 None => Err(ServerError::Authentication(
318 AuthenticationError::InvalidCredentials(
319 "Server key not found in configuration".to_string(),
320 ),
321 )),
322 }
323}
324
325fn validate_token(api_key: &str, state: &AppState) -> Result<Claims, ServerError> {
327 let jwt = convert_api_key_to_jwt(api_key)?;
329
330 let server_key = get_server_key(state)?;
332
333 let token_data = decode::<Claims>(
335 &jwt,
336 &DecodingKey::from_secret(server_key.as_bytes()),
337 &Validation::new(Algorithm::HS256),
338 )
339 .map_err(|e| {
340 let error_message = match e.kind() {
341 jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired".to_string(),
342 jsonwebtoken::errors::ErrorKind::InvalidSignature => {
343 "Invalid token signature".to_string()
344 }
345 _ => format!("Token validation error: {}", e),
346 };
347 ServerError::Authentication(AuthenticationError::InvalidToken(error_message))
348 })?;
349
350 Ok(token_data.claims)
351}
352
353pub fn validate_token_and_namespace(
355 api_key: &str,
356 requested_namespace: &str,
357 state: &AppState,
358) -> Result<Claims, ServerError> {
359 let claims = validate_token(api_key, state)?;
361
362 if claims.namespace != requested_namespace && claims.namespace != "*" {
364 return Err(ServerError::Authentication(
365 AuthenticationError::InvalidCredentials(format!(
366 "Token does not have access to namespace '{}'",
367 requested_namespace
368 )),
369 ));
370 }
371
372 Ok(claims)
373}