Skip to main content

forge_runtime/gateway/
multipart.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_UPLOAD_FIELDS: usize = 20;
15const MAX_FIELD_NAME_LENGTH: usize = 255;
16const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
17/// Per-file limit: 10 MiB. Total multipart body can be larger (default 20 MiB)
18/// but no single file may exceed this.
19const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
20const JSON_FIELD_NAME: &str = "_json";
21
22/// Configurable limits for multipart uploads, injected via Axum extension.
23#[derive(Debug, Clone)]
24pub struct MultipartConfig {
25    pub max_body_size_bytes: usize,
26}
27
28/// Create a multipart error response.
29fn multipart_error(
30    status: StatusCode,
31    code: &str,
32    message: impl Into<String>,
33) -> (StatusCode, axum::Json<serde_json::Value>) {
34    (
35        status,
36        axum::Json(serde_json::json!({
37            "success": false,
38            "error": {
39                "code": code,
40                "message": message.into()
41            }
42        })),
43    )
44}
45
46/// Handle multipart form data for RPC calls with file uploads.
47pub async fn rpc_multipart_handler(
48    State(handler): State<Arc<RpcHandler>>,
49    Extension(auth): Extension<AuthContext>,
50    Extension(mp_config): Extension<MultipartConfig>,
51    Path(function): Path<String>,
52    mut multipart: Multipart,
53) -> impl IntoResponse {
54    // Validate function name to prevent log injection and path traversal
55    if function.is_empty()
56        || function.len() > 256
57        || !function
58            .chars()
59            .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
60    {
61        return multipart_error(
62            StatusCode::BAD_REQUEST,
63            "INVALID_FUNCTION",
64            "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
65        );
66    }
67
68    // Per-function limit takes priority, then global config.
69    let max_total = handler
70        .function_info(&function)
71        .and_then(|info| info.max_upload_size_bytes)
72        .unwrap_or(mp_config.max_body_size_bytes);
73    let max_file = max_total.min(DEFAULT_MAX_FILE_SIZE);
74
75    let mut json_args: Option<serde_json::Value> = None;
76    let mut uploads: HashMap<String, Upload> = HashMap::new();
77    let mut total_read: usize = 0;
78
79    // Parse multipart fields
80    loop {
81        let field = match multipart.next_field().await {
82            Ok(Some(f)) => f,
83            Ok(None) => break,
84            Err(e) => {
85                return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
86            }
87        };
88
89        let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
90            Some(n) => n,
91            None => {
92                return multipart_error(
93                    StatusCode::BAD_REQUEST,
94                    "INVALID_FIELD",
95                    "Field name is required",
96                );
97            }
98        };
99
100        // Validate field name length
101        if name.len() > MAX_FIELD_NAME_LENGTH {
102            return multipart_error(
103                StatusCode::BAD_REQUEST,
104                "INVALID_FIELD",
105                format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
106            );
107        }
108
109        if name.contains("..")
110            || name.contains('/')
111            || name.contains('\\')
112            || name.contains(|c: char| c.is_control())
113        {
114            return multipart_error(
115                StatusCode::BAD_REQUEST,
116                "INVALID_FIELD",
117                "Field name contains invalid characters",
118            );
119        }
120
121        // Check upload count before processing to prevent bypass via _json field ordering
122        if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
123            return multipart_error(
124                StatusCode::BAD_REQUEST,
125                "TOO_MANY_FIELDS",
126                format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
127            );
128        }
129
130        if name == JSON_FIELD_NAME {
131            let mut buffer = BytesMut::new();
132            let mut json_field = field;
133
134            loop {
135                match json_field.chunk().await {
136                    Ok(Some(chunk)) => {
137                        if total_read + chunk.len() > max_total {
138                            return multipart_error(
139                                StatusCode::PAYLOAD_TOO_LARGE,
140                                "PAYLOAD_TOO_LARGE",
141                                format!(
142                                    "Multipart payload exceeds maximum size of {} bytes",
143                                    max_total
144                                ),
145                            );
146                        }
147                        if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
148                            return multipart_error(
149                                StatusCode::PAYLOAD_TOO_LARGE,
150                                "JSON_TOO_LARGE",
151                                format!(
152                                    "_json field exceeds maximum size of {} bytes",
153                                    MAX_JSON_FIELD_SIZE
154                                ),
155                            );
156                        }
157                        total_read += chunk.len();
158                        buffer.extend_from_slice(&chunk);
159                    }
160                    Ok(None) => break,
161                    Err(e) => {
162                        return multipart_error(
163                            StatusCode::BAD_REQUEST,
164                            "READ_ERROR",
165                            format!("Failed to read _json field: {}", e),
166                        );
167                    }
168                }
169            }
170
171            let text = match std::str::from_utf8(&buffer) {
172                Ok(s) => s,
173                Err(_) => {
174                    return multipart_error(
175                        StatusCode::BAD_REQUEST,
176                        "INVALID_JSON",
177                        "Invalid UTF-8 in _json field",
178                    );
179                }
180            };
181
182            match serde_json::from_str(text) {
183                Ok(value) => json_args = Some(value),
184                Err(e) => {
185                    return multipart_error(
186                        StatusCode::BAD_REQUEST,
187                        "INVALID_JSON",
188                        format!("Invalid JSON in _json field: {}", e),
189                    );
190                }
191            }
192        } else {
193            let raw_filename = field
194                .file_name()
195                .map(String::from)
196                .unwrap_or_else(|| name.clone());
197            // Sanitize filename: strip path components to prevent path traversal
198            let filename = raw_filename
199                .rsplit(['/', '\\'])
200                .next()
201                .unwrap_or(&raw_filename)
202                .replace("..", "_")
203                .to_string();
204            if filename.is_empty() {
205                return multipart_error(
206                    StatusCode::BAD_REQUEST,
207                    "INVALID_FILENAME",
208                    "Filename is empty after sanitization",
209                );
210            }
211            let content_type = field
212                .content_type()
213                .map(String::from)
214                .unwrap_or_else(|| "application/octet-stream".to_string());
215
216            let mut buffer = BytesMut::new();
217            let mut field = field;
218
219            loop {
220                match field.chunk().await {
221                    Ok(Some(chunk)) => {
222                        if total_read + chunk.len() > max_total {
223                            return multipart_error(
224                                StatusCode::PAYLOAD_TOO_LARGE,
225                                "PAYLOAD_TOO_LARGE",
226                                format!(
227                                    "Multipart payload exceeds maximum size of {} bytes",
228                                    max_total
229                                ),
230                            );
231                        }
232                        if buffer.len() + chunk.len() > max_file {
233                            return multipart_error(
234                                StatusCode::PAYLOAD_TOO_LARGE,
235                                "FILE_TOO_LARGE",
236                                format!(
237                                    "File '{}' exceeds maximum size of {} bytes",
238                                    filename, max_file
239                                ),
240                            );
241                        }
242                        total_read += chunk.len();
243                        buffer.extend_from_slice(&chunk);
244                    }
245                    Ok(None) => break,
246                    Err(e) => {
247                        return multipart_error(
248                            StatusCode::BAD_REQUEST,
249                            "READ_ERROR",
250                            format!("Failed to read file field: {}", e),
251                        );
252                    }
253                }
254            }
255
256            let upload = Upload::new(filename, content_type, buffer.freeze());
257            uploads.insert(name, upload);
258        }
259    }
260
261    let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
262
263    if let serde_json::Value::Object(ref mut map) = args {
264        for (name, upload) in uploads {
265            // Prevent upload fields from overwriting JSON args (parameter tampering)
266            if map.contains_key(&name) {
267                return multipart_error(
268                    StatusCode::BAD_REQUEST,
269                    "DUPLICATE_FIELD",
270                    format!("Upload field '{}' conflicts with JSON argument", name),
271                );
272            }
273            match serde_json::to_value(&upload) {
274                Ok(value) => {
275                    map.insert(name, value);
276                }
277                Err(e) => {
278                    return multipart_error(
279                        StatusCode::INTERNAL_SERVER_ERROR,
280                        "SERIALIZE_ERROR",
281                        format!("Failed to serialize upload: {}", e),
282                    );
283                }
284            }
285        }
286    }
287
288    let request = super::request::RpcRequest::new(function, args);
289    let metadata = forge_core::function::RequestMetadata::new();
290
291    let response = handler.handle(request, auth, metadata).await;
292
293    // Use RpcResponse's IntoResponse to preserve correct HTTP status codes
294    let status = if response.success {
295        StatusCode::OK
296    } else {
297        response
298            .error
299            .as_ref()
300            .map(|e| e.status_code())
301            .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
302    };
303    match serde_json::to_value(&response) {
304        Ok(value) => (status, axum::Json(value)),
305        Err(e) => multipart_error(
306            StatusCode::INTERNAL_SERVER_ERROR,
307            "SERIALIZE_ERROR",
308            format!("Failed to serialize response: {}", e),
309        ),
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_json_field_name_constant() {
319        assert_eq!(JSON_FIELD_NAME, "_json");
320    }
321}