api_gateway/middleware/
mime_validation.rs1use axum::extract::Request;
3use axum::http::StatusCode;
4use axum::middleware::Next;
5use axum::response::{IntoResponse, Response};
6use dashmap::DashMap;
7use http::Method;
8use std::sync::Arc;
9
10use modkit::api::{OperationSpec, Problem};
11
12use crate::middleware::common;
13
14pub type MimeValidationMap = Arc<DashMap<(Method, String), Vec<&'static str>>>;
16
17#[must_use]
19pub fn build_mime_validation_map(specs: &[OperationSpec]) -> MimeValidationMap {
20 let map = DashMap::new();
21
22 for spec in specs {
23 if let Some(ref allowed) = spec.allowed_request_content_types {
24 let key = (spec.method.clone(), spec.path.clone());
25
26 map.insert(key, allowed.clone());
27 }
28 }
29
30 Arc::new(map)
31}
32
33fn extract_content_type(req: &Request) -> Option<String> {
38 let ct_header = req.headers().get(http::header::CONTENT_TYPE)?;
39 let ct_str = ct_header.to_str().ok()?;
40 let ct_main = ct_str.split(';').next().map_or(ct_str, str::trim);
41 Some(ct_main.to_owned())
42}
43
44fn create_unsupported_media_type_error(detail: String) -> Response {
46 Problem::new(
47 StatusCode::UNSUPPORTED_MEDIA_TYPE,
48 "Unsupported Media Type",
49 detail,
50 )
51 .into_response()
52}
53
54fn validate_content_type(
58 content_type: &str,
59 allowed_types: &[&str],
60 method: &Method,
61 path: &str,
62) -> Result<(), Box<Response>> {
63 if allowed_types.contains(&content_type) {
64 return Ok(());
65 }
66
67 tracing::warn!(
68 method = %method,
69 path = %path,
70 content_type = content_type,
71 allowed_types = ?allowed_types,
72 "MIME type not allowed for this endpoint"
73 );
74
75 let detail = format!(
76 "Content-Type '{}' is not allowed for this endpoint. Allowed types: {}",
77 content_type,
78 allowed_types.join(", ")
79 );
80
81 Err(Box::new(create_unsupported_media_type_error(detail)))
82}
83
84pub async fn mime_validation_middleware(
90 validation_map: MimeValidationMap,
91 req: Request,
92 next: Next,
93) -> Response {
94 let method = req.method().clone();
95 let path = req
97 .extensions()
98 .get::<axum::extract::MatchedPath>()
99 .map_or_else(|| req.uri().path().to_owned(), |p| p.as_str().to_owned());
100
101 let path = common::resolve_path(&req, path.as_str());
102
103 let Some(allowed_types) = validation_map.get(&(method.clone(), path.clone())) else {
105 return next.run(req).await;
107 };
108
109 let Some(content_type) = extract_content_type(&req) else {
111 tracing::warn!(
112 method = %method,
113 path = %path,
114 allowed_types = ?allowed_types.value(),
115 "Missing Content-Type header for endpoint with MIME validation"
116 );
117
118 let detail = format!(
119 "Missing Content-Type header. Allowed types: {}",
120 allowed_types.join(", ")
121 );
122 return create_unsupported_media_type_error(detail);
123 };
124
125 if let Err(error_response) =
127 validate_content_type(&content_type, &allowed_types, &method, &path)
128 {
129 return *error_response;
130 }
131
132 next.run(req).await
134}
135
136#[cfg(test)]
137#[cfg_attr(coverage_nightly, coverage(off))]
138mod tests {
139 use super::*;
140 use modkit::api::operation_builder::VendorExtensions;
141
142 #[test]
143 fn test_build_mime_validation_map() {
144 use modkit::api::operation_builder::{RequestBodySchema, RequestBodySpec};
145
146 let specs = vec![OperationSpec {
147 method: Method::POST,
148 path: "/files/v1/upload".to_owned(),
149 operation_id: None,
150 summary: None,
151 description: None,
152 tags: vec![],
153 params: vec![],
154 request_body: Some(RequestBodySpec {
155 content_type: "multipart/form-data",
156 description: None,
157 schema: RequestBodySchema::MultipartFile {
158 field_name: "file".to_owned(),
159 },
160 required: true,
161 }),
162 responses: vec![],
163 handler_id: "test".to_owned(),
164 authenticated: false,
165 is_public: false,
166 license_requirement: None,
167 rate_limit: None,
168 allowed_request_content_types: Some(vec!["multipart/form-data", "application/pdf"]),
169 vendor_extensions: VendorExtensions::default(),
170 }];
171
172 let map = build_mime_validation_map(&specs);
173
174 assert!(map.contains_key(&(Method::POST, "/files/v1/upload".to_owned())));
175 let allowed = map
176 .get(&(Method::POST, "/files/v1/upload".to_owned()))
177 .unwrap();
178 assert_eq!(allowed.len(), 2);
179 assert!(allowed.contains(&"multipart/form-data"));
180 assert!(allowed.contains(&"application/pdf"));
181 }
182
183 #[test]
184 fn test_content_type_parameter_stripping() {
185 let ct_with_charset = "application/json; charset=utf-8";
187 let ct_main = ct_with_charset
188 .split(';')
189 .next()
190 .map_or(ct_with_charset, str::trim);
191
192 assert_eq!(ct_main, "application/json");
193
194 let ct_complex = "multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW";
196 let ct_main2 = ct_complex.split(';').next().map_or(ct_complex, str::trim);
197
198 assert_eq!(ct_main2, "multipart/form-data");
199
200 let ct_simple = "application/pdf";
202 let ct_main3 = ct_simple.split(';').next().map_or(ct_simple, str::trim);
203
204 assert_eq!(ct_main3, "application/pdf");
205 }
206}