Skip to main content

api_gateway/middleware/
mime_validation.rs

1//! MIME type validation middleware for enforcing per-operation allowed Content-Type headers
2use axum::extract::Request;
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use dashmap::DashMap;
7use http::Method;
8use std::sync::Arc;
9
10use modkit::api::{OperationSpec, Problem};
11
12use crate::middleware::common;
13
14/// Map from (method, path) to allowed content types
15pub type MimeValidationMap = Arc<DashMap<(Method, String), Vec<&'static str>>>;
16
17/// Build MIME validation map from operation specs
18#[must_use]
19pub fn build_mime_validation_map(specs: &[OperationSpec]) -> MimeValidationMap {
20    let map = DashMap::new();
21
22    for spec in specs {
23        if let Some(ref allowed) = spec.allowed_request_content_types {
24            let key = (spec.method.clone(), spec.path.clone());
25
26            map.insert(key, allowed.clone());
27        }
28    }
29
30    Arc::new(map)
31}
32
33/// Extract and normalize the Content-Type header value.
34///
35/// Strips parameters like charset from "application/json; charset=utf-8"
36/// to just "application/json".
37fn extract_content_type(req: &Request) -> Option<String> {
38    let ct_header = req.headers().get(http::header::CONTENT_TYPE)?;
39    let ct_str = ct_header.to_str().ok()?;
40    let ct_main = ct_str.split(';').next().map_or(ct_str, str::trim);
41    Some(ct_main.to_owned())
42}
43
44/// Create an Unsupported Media Type error response.
45fn create_unsupported_media_type_error(detail: String) -> Response {
46    Problem::new(
47        StatusCode::UNSUPPORTED_MEDIA_TYPE,
48        "Unsupported Media Type",
49        detail,
50    )
51    .into_response()
52}
53
54/// Validate that the content type is in the allowed list.
55///
56/// Returns Ok(()) if allowed, Err(Response) with error details if not.
57fn validate_content_type(
58    content_type: &str,
59    allowed_types: &[&str],
60    method: &Method,
61    path: &str,
62) -> Result<(), Box<Response>> {
63    if allowed_types.contains(&content_type) {
64        return Ok(());
65    }
66
67    tracing::warn!(
68        method = %method,
69        path = %path,
70        content_type = content_type,
71        allowed_types = ?allowed_types,
72        "MIME type not allowed for this endpoint"
73    );
74
75    let detail = format!(
76        "Content-Type '{}' is not allowed for this endpoint. Allowed types: {}",
77        content_type,
78        allowed_types.join(", ")
79    );
80
81    Err(Box::new(create_unsupported_media_type_error(detail)))
82}
83
84/// MIME validation middleware
85///
86/// Checks the Content-Type header against the allowed types configured
87/// for the operation. Returns 415 Unsupported Media Type if the content
88/// type is not allowed.
89pub async fn mime_validation_middleware(
90    validation_map: MimeValidationMap,
91    req: Request,
92    next: Next,
93) -> Response {
94    let method = req.method().clone();
95    // Use MatchedPath extension (set by Axum router) for accurate route matching
96    let path = req
97        .extensions()
98        .get::<axum::extract::MatchedPath>()
99        .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
100
101    let path = common::resolve_path(&req, path.as_str());
102
103    // Check if this operation has MIME validation configured
104    let Some(allowed_types) = validation_map.get(&(method.clone(), path.clone())) else {
105        // No validation configured - proceed
106        return next.run(req).await;
107    };
108
109    // Extract and validate Content-Type header
110    let Some(content_type) = extract_content_type(&req) else {
111        tracing::warn!(
112            method = %method,
113            path = %path,
114            allowed_types = ?allowed_types.value(),
115            "Missing Content-Type header for endpoint with MIME validation"
116        );
117
118        let detail = format!(
119            "Missing Content-Type header. Allowed types: {}",
120            allowed_types.join(", ")
121        );
122        return create_unsupported_media_type_error(detail);
123    };
124
125    // Validate the content type
126    if let Err(error_response) =
127        validate_content_type(&content_type, &allowed_types, &method, &path)
128    {
129        return *error_response;
130    }
131
132    // Validation passed - proceed
133    next.run(req).await
134}
135
136#[cfg(test)]
137#[cfg_attr(coverage_nightly, coverage(off))]
138mod tests {
139    use super::*;
140    use modkit::api::operation_builder::VendorExtensions;
141
142    #[test]
143    fn test_build_mime_validation_map() {
144        use modkit::api::operation_builder::{RequestBodySchema, RequestBodySpec};
145
146        let specs = vec![OperationSpec {
147            method: Method::POST,
148            path: "/files/v1/upload".to_owned(),
149            operation_id: None,
150            summary: None,
151            description: None,
152            tags: vec![],
153            params: vec![],
154            request_body: Some(RequestBodySpec {
155                content_type: "multipart/form-data",
156                description: None,
157                schema: RequestBodySchema::MultipartFile {
158                    field_name: "file".to_owned(),
159                },
160                required: true,
161            }),
162            responses: vec![],
163            handler_id: "test".to_owned(),
164            authenticated: false,
165            is_public: false,
166            license_requirement: None,
167            rate_limit: None,
168            allowed_request_content_types: Some(vec!["multipart/form-data", "application/pdf"]),
169            vendor_extensions: VendorExtensions::default(),
170        }];
171
172        let map = build_mime_validation_map(&specs);
173
174        assert!(map.contains_key(&(Method::POST, "/files/v1/upload".to_owned())));
175        let allowed = map
176            .get(&(Method::POST, "/files/v1/upload".to_owned()))
177            .unwrap();
178        assert_eq!(allowed.len(), 2);
179        assert!(allowed.contains(&"multipart/form-data"));
180        assert!(allowed.contains(&"application/pdf"));
181    }
182
183    #[test]
184    fn test_content_type_parameter_stripping() {
185        // Test the logic for stripping parameters from Content-Type
186        let ct_with_charset = "application/json; charset=utf-8";
187        let ct_main = ct_with_charset
188            .split(';')
189            .next()
190            .map_or(ct_with_charset, str::trim);
191
192        assert_eq!(ct_main, "application/json");
193
194        // Test with multiple parameters
195        let ct_complex = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
196        let ct_main2 = ct_complex.split(';').next().map_or(ct_complex, str::trim);
197
198        assert_eq!(ct_main2, "multipart/form-data");
199
200        // Test without parameters
201        let ct_simple = "application/pdf";
202        let ct_main3 = ct_simple.split(';').next().map_or(ct_simple, str::trim);
203
204        assert_eq!(ct_main3, "application/pdf");
205    }
206}