microsandbox_server/
middleware.rs

1//! Middleware components for the microsandbox server.
2//!
3//! This module handles:
4//! - Request/response middleware
5//! - Authentication and authorization
6//! - Request tracing and logging
7//! - Error handling
8//!
9//! The module provides:
10//! - Middleware components for common operations
11//! - Authentication middleware for API security
12//! - Logging and tracing middleware
13
14use axum::{
15    body::{to_bytes, Body},
16    extract::State,
17    http::{HeaderMap, Request, StatusCode, Uri},
18    middleware::Next,
19    response::IntoResponse,
20};
21use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
22use serde_json::Value;
23
24use crate::{
25    config::PROXY_AUTH_HEADER,
26    error::{AuthenticationError, ServerError, ValidationError},
27    management::API_KEY_PREFIX,
28    state::AppState,
29    Claims,
30};
31
32//--------------------------------------------------------------------------------------------------
33// Middleware Functions
34//--------------------------------------------------------------------------------------------------
35
36/// Proxy middleware for forwarding requests to a target service
37pub async fn proxy_middleware(
38    State(_state): State<AppState>,
39    req: Request<Body>,
40    next: Next,
41) -> impl IntoResponse {
42    // Default to passing the request to the next handler
43    // This middleware can be extended to implement actual proxying logic
44    next.run(req).await
45}
46
47/// Convert a URI to a proxied URI targeting a sandbox
48pub fn proxy_uri(original_uri: Uri, namespace: &str, sandbox_name: &str) -> Uri {
49    // In a real implementation, you would:
50    // 1. Look up the sandbox's address from a registry or state
51    // 2. Construct a new URI that points to the sandbox
52    // 3. Return the new URI for proxying
53
54    // For demonstration purposes, we'll construct a simple URI
55    // In production, you would get this from a sandbox registry
56    let target_host = format!("sandbox-{}.{}.internal", sandbox_name, namespace);
57
58    let uri_string = if let Some(path_and_query) = original_uri.path_and_query() {
59        format!("http://{}:{}{}", target_host, 8080, path_and_query)
60    } else {
61        format!("http://{}:{}/", target_host, 8080)
62    };
63
64    // Try to parse the string into a URI
65    // In case of errors, fallback to a default URI
66    uri_string
67        .parse()
68        .unwrap_or_else(|_| "http://localhost:8080/".parse().unwrap())
69}
70
71/// Log incoming requests
72pub async fn logging_middleware(
73    req: Request<Body>,
74    next: Next,
75) -> Result<impl IntoResponse, (StatusCode, String)> {
76    let method = req.method().clone();
77    let uri = req.uri().clone();
78
79    // Log the request
80    tracing::info!("Request: {} {}", method, uri);
81
82    // Process the request
83    let response = next.run(req).await;
84
85    // Log the response
86    tracing::info!("Response: {} {}: {}", method, uri, response.status());
87
88    Ok(response)
89}
90
91/// Authentication middleware for verifying API keys and namespace access
92pub async fn auth_middleware(
93    State(state): State<AppState>,
94    req: Request<Body>,
95    next: Next,
96) -> Result<impl IntoResponse, ServerError> {
97    // Skip auth in dev mode if configured
98    if *state.get_config().get_dev_mode() {
99        return Ok(next.run(req).await);
100    }
101
102    // Extract API key from authorization header
103    let api_key = extract_api_key_from_headers(req.headers())?;
104
105    // Validate the token and get its claims
106    let claims = validate_token(&api_key, &state)?;
107
108    // If token has wildcard namespace access, we can skip further namespace validation
109    if claims.namespace == "*" {
110        return Ok(next.run(req).await);
111    }
112
113    // For namespace-specific tokens, we need to ensure the token has access to the requested namespace
114    // We need to read the request body to extract the namespace
115    let (parts, body) = req.into_parts();
116
117    // Use axum's to_bytes to buffer the body
118    let bytes = to_bytes(body, usize::MAX)
119        .await
120        .map_err(|e| ServerError::InternalError(format!("Failed to read request body: {}", e)))?;
121
122    // Parse the JSON-RPC request and extract the namespace
123    let namespace_from_request = extract_namespace_from_json_rpc(&bytes)?;
124
125    // Validate that the token has access to the requested namespace
126    if claims.namespace != namespace_from_request {
127        return Err(ServerError::AuthorizationError(
128            crate::error::AuthorizationError::AccessDenied(format!(
129                "Token does not have access to namespace '{}'",
130                namespace_from_request
131            )),
132        ));
133    }
134
135    // Reconstruct the request with the original body
136    let body = Body::from(bytes);
137    let req = Request::from_parts(parts, body);
138
139    // If everything is valid, continue with the request
140    Ok(next.run(req).await)
141}
142
143//--------------------------------------------------------------------------------------------------
144// Helper Functions
145//--------------------------------------------------------------------------------------------------
146
147/// Extract the namespace from a JSON-RPC request body
148fn extract_namespace_from_json_rpc(bytes: &[u8]) -> Result<String, ServerError> {
149    // Parse the request body as JSON
150    let json_value: Value = serde_json::from_slice(bytes).map_err(|e| {
151        ServerError::ValidationError(ValidationError::InvalidInput(format!(
152            "Invalid JSON-RPC request: {}",
153            e
154        )))
155    })?;
156
157    // Extract the method for logging purposes
158    let method = json_value
159        .get("method")
160        .and_then(Value::as_str)
161        .unwrap_or("unknown");
162
163    // Extract params object
164    let params = json_value.get("params").ok_or_else(|| {
165        ServerError::ValidationError(ValidationError::InvalidInput(
166            "Missing 'params' field in JSON-RPC request".to_string(),
167        ))
168    })?;
169
170    // Extract namespace from params for any method
171    params
172        .get("namespace")
173        .and_then(Value::as_str)
174        .map(String::from)
175        .ok_or_else(|| {
176            ServerError::ValidationError(ValidationError::InvalidInput(format!(
177                "Missing or invalid 'namespace' in params for method '{}'",
178                method
179            )))
180        })
181}
182
183/// Extract API key from request headers
184fn extract_api_key_from_headers(headers: &HeaderMap) -> Result<String, ServerError> {
185    // First check the Proxy-Authorization header
186    if let Some(auth_header) = headers.get(PROXY_AUTH_HEADER) {
187        let auth_value = auth_header.to_str().map_err(|_| {
188            ServerError::Authentication(AuthenticationError::InvalidCredentials(
189                "Invalid authorization header format".to_string(),
190            ))
191        })?;
192
193        // Check if it has the Bearer prefix
194        if let Some(token) = auth_value.strip_prefix("Bearer ") {
195            return Ok(token.to_string());
196        }
197
198        // Or if it's just the raw token
199        return Ok(auth_value.to_string());
200    }
201
202    // Then check standard Authorization header
203    if let Some(auth_header) = headers.get("Authorization") {
204        let auth_value = auth_header.to_str().map_err(|_| {
205            ServerError::Authentication(AuthenticationError::InvalidCredentials(
206                "Invalid authorization header format".to_string(),
207            ))
208        })?;
209
210        // Check if it has the Bearer prefix
211        if let Some(token) = auth_value.strip_prefix("Bearer ") {
212            return Ok(token.to_string());
213        }
214
215        // Or if it's just the raw token
216        return Ok(auth_value.to_string());
217    }
218
219    Err(ServerError::Authentication(
220        AuthenticationError::InvalidCredentials("Missing authorization header".to_string()),
221    ))
222}
223
224/// Convert a custom API key back to a standard JWT format
225fn convert_api_key_to_jwt(api_key: &str) -> Result<String, ServerError> {
226    // Check if the API key has the expected prefix
227    if !api_key.starts_with(API_KEY_PREFIX) {
228        return Err(ServerError::Authentication(
229            AuthenticationError::InvalidCredentials(
230                "Invalid API key format: missing prefix".to_string(),
231            ),
232        ));
233    }
234
235    // Remove the prefix
236    let key_without_prefix = &api_key[API_KEY_PREFIX.len()..];
237
238    // Split into parts and validate
239    let parts: Vec<&str> = key_without_prefix.split('.').collect();
240    if parts.len() != 2 {
241        return Err(ServerError::Authentication(
242            AuthenticationError::InvalidCredentials("Invalid API key format".to_string()),
243        ));
244    }
245
246    // Reconstruct as standard JWT with header, payload, signature
247    // Fix the temporary value issue by storing the header in a variable
248    let header_value = crate::config::DEFAULT_JWT_HEADER.clone();
249    let jwt_header = header_value.as_str();
250    let payload = parts[0];
251    let signature = parts[1];
252
253    Ok(format!("{}.{}.{}", jwt_header, payload, signature))
254}
255
256/// Get the server key from the AppState config
257fn get_server_key(state: &AppState) -> Result<String, ServerError> {
258    // Get the key from the config - we already know we're not in dev mode
259    // by the time this function is called
260    match state.get_config().get_key() {
261        Some(key) => Ok(key.clone()),
262        None => Err(ServerError::Authentication(
263            AuthenticationError::InvalidCredentials(
264                "Server key not found in configuration".to_string(),
265            ),
266        )),
267    }
268}
269
270/// Validate the token
271fn validate_token(api_key: &str, state: &AppState) -> Result<Claims, ServerError> {
272    // Convert API key back to JWT format
273    let jwt = convert_api_key_to_jwt(api_key)?;
274
275    // Get server key for validation
276    let server_key = get_server_key(state)?;
277
278    // Decode and validate the JWT
279    let token_data = decode::<Claims>(
280        &jwt,
281        &DecodingKey::from_secret(server_key.as_bytes()),
282        &Validation::new(Algorithm::HS256),
283    )
284    .map_err(|e| {
285        let error_message = match e.kind() {
286            jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired".to_string(),
287            jsonwebtoken::errors::ErrorKind::InvalidSignature => {
288                "Invalid token signature".to_string()
289            }
290            _ => format!("Token validation error: {}", e),
291        };
292        ServerError::Authentication(AuthenticationError::InvalidToken(error_message))
293    })?;
294
295    Ok(token_data.claims)
296}
297
298/// Validate the token and check namespace access
299pub fn validate_token_and_namespace(
300    api_key: &str,
301    requested_namespace: &str,
302    state: &AppState,
303) -> Result<Claims, ServerError> {
304    // Validate token
305    let claims = validate_token(api_key, state)?;
306
307    // Check if the token's namespace matches the requested namespace
308    if claims.namespace != requested_namespace && claims.namespace != "*" {
309        return Err(ServerError::Authentication(
310            AuthenticationError::InvalidCredentials(format!(
311                "Token does not have access to namespace '{}'",
312                requested_namespace
313            )),
314        ));
315    }
316
317    Ok(claims)
318}