microsandbox_server/
middleware.rs

1//! Middleware components for the microsandbox server.
2//!
3//! This module handles:
4//! - Request/response middleware
5//! - Authentication and authorization
6//! - Request tracing and logging
7//! - Error handling
8//!
9//! The module provides:
10//! - Middleware components for common operations
11//! - Authentication middleware for API security
12//! - Logging and tracing middleware
13
14use axum::{
15    body::{to_bytes, Body},
16    extract::State,
17    http::{HeaderMap, Request, StatusCode, Uri},
18    middleware::Next,
19    response::IntoResponse,
20};
21use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
22use serde_json::Value;
23
24use crate::{
25    config::PROXY_AUTH_HEADER,
26    error::{AuthenticationError, ServerError, ValidationError},
27    management::API_KEY_PREFIX,
28    state::AppState,
29    Claims,
30};
31
32//--------------------------------------------------------------------------------------------------
33// Middleware Functions
34//--------------------------------------------------------------------------------------------------
35
36/// Proxy middleware for forwarding requests to a target service
37pub async fn proxy_middleware(
38    State(_state): State<AppState>,
39    req: Request<Body>,
40    next: Next,
41) -> impl IntoResponse {
42    // Default to passing the request to the next handler
43    // This middleware can be extended to implement actual proxying logic
44    next.run(req).await
45}
46
47/// Convert a URI to a proxied URI targeting a sandbox
48pub fn proxy_uri(original_uri: Uri, namespace: &str, sandbox_name: &str) -> Uri {
49    // In a real implementation, you would:
50    // 1. Look up the sandbox's address from a registry or state
51    // 2. Construct a new URI that points to the sandbox
52    // 3. Return the new URI for proxying
53
54    // For demonstration purposes, we'll construct a simple URI
55    // In production, you would get this from a sandbox registry
56    let target_host = format!("sandbox-{}.{}.internal", sandbox_name, namespace);
57
58    let uri_string = if let Some(path_and_query) = original_uri.path_and_query() {
59        format!("http://{}:{}{}", target_host, 8080, path_and_query)
60    } else {
61        format!("http://{}:{}/", target_host, 8080)
62    };
63
64    // Try to parse the string into a URI
65    // In case of errors, fallback to a default URI
66    uri_string
67        .parse()
68        .unwrap_or_else(|_| "http://localhost:8080/".parse().unwrap())
69}
70
71/// Log incoming requests
72pub async fn logging_middleware(
73    req: Request<Body>,
74    next: Next,
75) -> Result<impl IntoResponse, (StatusCode, String)> {
76    let method = req.method().clone();
77    let uri = req.uri().clone();
78
79    // Log the request
80    tracing::info!("Request: {} {}", method, uri);
81
82    // Process the request
83    let response = next.run(req).await;
84
85    // Log the response
86    tracing::info!("Response: {} {}: {}", method, uri, response.status());
87
88    Ok(response)
89}
90
91/// Authentication middleware for verifying API keys and namespace access
92pub async fn auth_middleware(
93    State(state): State<AppState>,
94    req: Request<Body>,
95    next: Next,
96) -> Result<impl IntoResponse, ServerError> {
97    // Skip auth in dev mode if configured
98    if *state.get_config().get_dev_mode() {
99        return Ok(next.run(req).await);
100    }
101
102    // Extract API key from authorization header
103    let api_key = extract_api_key_from_headers(req.headers())?;
104
105    // Validate the token and get its claims
106    let claims = validate_token(&api_key, &state)?;
107
108    // If token has wildcard namespace access, we can skip further namespace validation
109    if claims.namespace == "*" {
110        return Ok(next.run(req).await);
111    }
112
113    // For namespace-specific tokens, we need to ensure the token has access to the requested namespace
114    // We need to read the request body to extract the namespace
115    let (parts, body) = req.into_parts();
116
117    // Use axum's to_bytes to buffer the body
118    let bytes = to_bytes(body, usize::MAX)
119        .await
120        .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
121
122    // Parse the JSON-RPC request and extract the namespace
123    let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
124
125    // Validate that the token has access to the requested namespace
126    if claims.namespace != namespace_from_request {
127        return Err(ServerError::AuthorizationError(
128            crate::error::AuthorizationError::AccessDenied(format!(
129                "Token does not have access to namespace '{}'",
130                namespace_from_request
131            )),
132        ));
133    }
134
135    // Reconstruct the request with the original body
136    let body = Body::from(bytes);
137    let req = Request::from_parts(parts, body);
138
139    // If everything is valid, continue with the request
140    Ok(next.run(req).await)
141}
142
143/// Smart authentication middleware for MCP requests that handles protocol vs tool methods differently
144/// Protocol methods (initialize, tools/list, prompts/list, prompts/get) don't require namespace validation
145/// Tool methods (tools/call) require namespace validation
146pub async fn mcp_smart_auth_middleware(
147    State(state): State<AppState>,
148    req: Request<Body>,
149    next: Next,
150) -> Result<impl IntoResponse, ServerError> {
151    // Skip auth in dev mode if configured
152    if *state.get_config().get_dev_mode() {
153        return Ok(next.run(req).await);
154    }
155
156    // Extract API key from authorization header
157    let api_key = extract_api_key_from_headers(req.headers())?;
158
159    // Validate the token and get its claims
160    let claims = validate_token(&api_key, &state)?;
161
162    // If token has wildcard namespace access, we can skip further namespace validation
163    if claims.namespace == "*" {
164        return Ok(next.run(req).await);
165    }
166
167    // For namespace-specific tokens, we need to check if this is a tool execution method
168    // that requires namespace validation
169    let (parts, body) = req.into_parts();
170
171    // Use axum's to_bytes to buffer the body
172    let bytes = to_bytes(body, usize::MAX)
173        .await
174        .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
175
176    // Parse the JSON to check the method
177    let json_value: serde_json::Value = serde_json::from_slice(&bytes).map_err(|e| {
178        ServerError::ValidationError(crate::error::ValidationError::InvalidInput(format!(
179            "Invalid JSON-RPC request: {}",
180            e
181        )))
182    })?;
183
184    let method = json_value
185        .get("method")
186        .and_then(serde_json::Value::as_str)
187        .unwrap_or("unknown");
188
189    // Check if this is a tool execution method that requires namespace validation
190    let requires_namespace_validation = matches!(method, "tools/call");
191
192    if requires_namespace_validation {
193        // Extract namespace from params for tool execution methods
194        let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
195
196        // Validate that the token has access to the requested namespace
197        if claims.namespace != namespace_from_request {
198            return Err(ServerError::AuthorizationError(
199                crate::error::AuthorizationError::AccessDenied(format!(
200                    "Token does not have access to namespace '{}'",
201                    namespace_from_request
202                )),
203            ));
204        }
205    }
206
207    // Reconstruct the request with the original body
208    let body = Body::from(bytes);
209    let req = Request::from_parts(parts, body);
210
211    // If everything is valid, continue with the request
212    Ok(next.run(req).await)
213}
214
215//--------------------------------------------------------------------------------------------------
216// Helper Functions
217//--------------------------------------------------------------------------------------------------
218
219/// Extract the namespace from a JSON-RPC request body
220fn extract_namespace_from_json_rpc(bytes: &[u8]) -> Result<String, ServerError> {
221    // Parse the request body as JSON
222    let json_value: Value = serde_json::from_slice(bytes).map_err(|e| {
223        ServerError::ValidationError(ValidationError::InvalidInput(format!(
224            "Invalid JSON-RPC request: {}",
225            e
226        )))
227    })?;
228
229    // Extract the method for logging purposes
230    let method = json_value
231        .get("method")
232        .and_then(Value::as_str)
233        .unwrap_or("unknown");
234
235    // Extract params object
236    let params = json_value.get("params").ok_or_else(|| {
237        ServerError::ValidationError(ValidationError::InvalidInput(
238            "Missing 'params' field in JSON-RPC request".to_string(),
239        ))
240    })?;
241
242    // Extract namespace from params for any method
243    params
244        .get("namespace")
245        .and_then(Value::as_str)
246        .map(String::from)
247        .ok_or_else(|| {
248            ServerError::ValidationError(ValidationError::InvalidInput(format!(
249                "Missing or invalid 'namespace' in params for method '{}'",
250                method
251            )))
252        })
253}
254
255/// Extract API key from request headers
256fn extract_api_key_from_headers(headers: &HeaderMap) -> Result<String, ServerError> {
257    // First check the Proxy-Authorization header
258    if let Some(auth_header) = headers.get(PROXY_AUTH_HEADER) {
259        let auth_value = auth_header.to_str().map_err(|_| {
260            ServerError::Authentication(AuthenticationError::InvalidCredentials(
261                "Invalid authorization header format".to_string(),
262            ))
263        })?;
264
265        // Check if it has the Bearer prefix
266        if let Some(token) = auth_value.strip_prefix("Bearer ") {
267            return Ok(token.to_string());
268        }
269
270        // Or if it's just the raw token
271        return Ok(auth_value.to_string());
272    }
273
274    // Then check standard Authorization header
275    if let Some(auth_header) = headers.get("Authorization") {
276        let auth_value = auth_header.to_str().map_err(|_| {
277            ServerError::Authentication(AuthenticationError::InvalidCredentials(
278                "Invalid authorization header format".to_string(),
279            ))
280        })?;
281
282        // Check if it has the Bearer prefix
283        if let Some(token) = auth_value.strip_prefix("Bearer ") {
284            return Ok(token.to_string());
285        }
286
287        // Or if it's just the raw token
288        return Ok(auth_value.to_string());
289    }
290
291    Err(ServerError::Authentication(
292        AuthenticationError::InvalidCredentials("Missing authorization header".to_string()),
293    ))
294}
295
296/// Convert a custom API key back to a standard JWT format
297fn convert_api_key_to_jwt(api_key: &str) -> Result<String, ServerError> {
298    // Check if the API key has the expected prefix
299    if !api_key.starts_with(API_KEY_PREFIX) {
300        return Err(ServerError::Authentication(
301            AuthenticationError::InvalidCredentials(
302                "Invalid API key format: missing prefix".to_string(),
303            ),
304        ));
305    }
306
307    // Remove the prefix and return the original JWT
308    Ok(api_key[API_KEY_PREFIX.len()..].to_string())
309}
310
311/// Get the server key from the AppState config
312fn get_server_key(state: &AppState) -> Result<String, ServerError> {
313    // Get the key from the config - we already know we're not in dev mode
314    // by the time this function is called
315    match state.get_config().get_key() {
316        Some(key) => Ok(key.clone()),
317        None => Err(ServerError::Authentication(
318            AuthenticationError::InvalidCredentials(
319                "Server key not found in configuration".to_string(),
320            ),
321        )),
322    }
323}
324
325/// Validate the token
326fn validate_token(api_key: &str, state: &AppState) -> Result<Claims, ServerError> {
327    // Convert API key back to JWT format
328    let jwt = convert_api_key_to_jwt(api_key)?;
329
330    // Get server key for validation
331    let server_key = get_server_key(state)?;
332
333    // Decode and validate the JWT
334    let token_data = decode::<Claims>(
335        &jwt,
336        &DecodingKey::from_secret(server_key.as_bytes()),
337        &Validation::new(Algorithm::HS256),
338    )
339    .map_err(|e| {
340        let error_message = match e.kind() {
341            jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired".to_string(),
342            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
343                "Invalid token signature".to_string()
344            }
345            _ => format!("Token validation error: {}", e),
346        };
347        ServerError::Authentication(AuthenticationError::InvalidToken(error_message))
348    })?;
349
350    Ok(token_data.claims)
351}
352
353/// Validate the token and check namespace access
354pub fn validate_token_and_namespace(
355    api_key: &str,
356    requested_namespace: &str,
357    state: &AppState,
358) -> Result<Claims, ServerError> {
359    // Validate token
360    let claims = validate_token(api_key, state)?;
361
362    // Check if the token's namespace matches the requested namespace
363    if claims.namespace != requested_namespace && claims.namespace != "*" {
364        return Err(ServerError::Authentication(
365            AuthenticationError::InvalidCredentials(format!(
366                "Token does not have access to namespace '{}'",
367                requested_namespace
368            )),
369        ));
370    }
371
372    Ok(claims)
373}