Skip to main content

forge_runtime/gateway/
multipart.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
15const MAX_TOTAL_UPLOAD_SIZE: usize = 20 * 1024 * 1024;
16const MAX_UPLOAD_FIELDS: usize = 20;
17const MAX_FIELD_NAME_LENGTH: usize = 255;
18const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
19const JSON_FIELD_NAME: &str = "_json";
20
21/// Create a multipart error response.
22fn multipart_error(
23    status: StatusCode,
24    code: &str,
25    message: impl Into<String>,
26) -> (StatusCode, axum::Json<serde_json::Value>) {
27    (
28        status,
29        axum::Json(serde_json::json!({
30            "success": false,
31            "error": {
32                "code": code,
33                "message": message.into()
34            }
35        })),
36    )
37}
38
39/// Handle multipart form data for RPC calls with file uploads.
40pub async fn rpc_multipart_handler(
41    State(handler): State<Arc<RpcHandler>>,
42    Extension(auth): Extension<AuthContext>,
43    Path(function): Path<String>,
44    mut multipart: Multipart,
45) -> impl IntoResponse {
46    let mut json_args: Option<serde_json::Value> = None;
47    let mut uploads: HashMap<String, Upload> = HashMap::new();
48    let mut total_read: usize = 0;
49
50    // Parse multipart fields
51    loop {
52        let field = match multipart.next_field().await {
53            Ok(Some(f)) => f,
54            Ok(None) => break,
55            Err(e) => {
56                return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
57            }
58        };
59
60        let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
61            Some(n) => n,
62            None => {
63                return multipart_error(
64                    StatusCode::BAD_REQUEST,
65                    "INVALID_FIELD",
66                    "Field name is required",
67                );
68            }
69        };
70
71        // Validate field name length
72        if name.len() > MAX_FIELD_NAME_LENGTH {
73            return multipart_error(
74                StatusCode::BAD_REQUEST,
75                "INVALID_FIELD",
76                format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
77            );
78        }
79
80        if name.contains("..") || name.contains('/') || name.contains('\\') {
81            return multipart_error(
82                StatusCode::BAD_REQUEST,
83                "INVALID_FIELD",
84                "Field name contains invalid characters",
85            );
86        }
87
88        // Check upload count before processing to prevent bypass via _json field ordering
89        if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
90            return multipart_error(
91                StatusCode::BAD_REQUEST,
92                "TOO_MANY_FIELDS",
93                format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
94            );
95        }
96
97        if name == JSON_FIELD_NAME {
98            let mut buffer = BytesMut::new();
99            let mut json_field = field;
100
101            loop {
102                match json_field.chunk().await {
103                    Ok(Some(chunk)) => {
104                        if total_read + chunk.len() > MAX_TOTAL_UPLOAD_SIZE {
105                            return multipart_error(
106                                StatusCode::PAYLOAD_TOO_LARGE,
107                                "PAYLOAD_TOO_LARGE",
108                                format!(
109                                    "Multipart payload exceeds maximum size of {} bytes",
110                                    MAX_TOTAL_UPLOAD_SIZE
111                                ),
112                            );
113                        }
114                        if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
115                            return multipart_error(
116                                StatusCode::PAYLOAD_TOO_LARGE,
117                                "JSON_TOO_LARGE",
118                                format!(
119                                    "_json field exceeds maximum size of {} bytes",
120                                    MAX_JSON_FIELD_SIZE
121                                ),
122                            );
123                        }
124                        total_read += chunk.len();
125                        buffer.extend_from_slice(&chunk);
126                    }
127                    Ok(None) => break,
128                    Err(e) => {
129                        return multipart_error(
130                            StatusCode::BAD_REQUEST,
131                            "READ_ERROR",
132                            format!("Failed to read _json field: {}", e),
133                        );
134                    }
135                }
136            }
137
138            let text = match std::str::from_utf8(&buffer) {
139                Ok(s) => s,
140                Err(_) => {
141                    return multipart_error(
142                        StatusCode::BAD_REQUEST,
143                        "INVALID_JSON",
144                        "Invalid UTF-8 in _json field",
145                    );
146                }
147            };
148
149            match serde_json::from_str(text) {
150                Ok(value) => json_args = Some(value),
151                Err(e) => {
152                    return multipart_error(
153                        StatusCode::BAD_REQUEST,
154                        "INVALID_JSON",
155                        format!("Invalid JSON in _json field: {}", e),
156                    );
157                }
158            }
159        } else {
160            let filename = field
161                .file_name()
162                .map(String::from)
163                .unwrap_or_else(|| name.clone());
164            let content_type = field
165                .content_type()
166                .map(String::from)
167                .unwrap_or_else(|| "application/octet-stream".to_string());
168
169            let mut buffer = BytesMut::new();
170            let mut field = field;
171
172            loop {
173                match field.chunk().await {
174                    Ok(Some(chunk)) => {
175                        if total_read + chunk.len() > MAX_TOTAL_UPLOAD_SIZE {
176                            return multipart_error(
177                                StatusCode::PAYLOAD_TOO_LARGE,
178                                "PAYLOAD_TOO_LARGE",
179                                format!(
180                                    "Multipart payload exceeds maximum size of {} bytes",
181                                    MAX_TOTAL_UPLOAD_SIZE
182                                ),
183                            );
184                        }
185                        if buffer.len() + chunk.len() > MAX_FILE_SIZE {
186                            return multipart_error(
187                                StatusCode::PAYLOAD_TOO_LARGE,
188                                "FILE_TOO_LARGE",
189                                format!(
190                                    "File '{}' exceeds maximum size of {} bytes",
191                                    filename, MAX_FILE_SIZE
192                                ),
193                            );
194                        }
195                        total_read += chunk.len();
196                        buffer.extend_from_slice(&chunk);
197                    }
198                    Ok(None) => break,
199                    Err(e) => {
200                        return multipart_error(
201                            StatusCode::BAD_REQUEST,
202                            "READ_ERROR",
203                            format!("Failed to read file field: {}", e),
204                        );
205                    }
206                }
207            }
208
209            let upload = Upload::new(filename, content_type, buffer.freeze());
210            uploads.insert(name, upload);
211        }
212    }
213
214    let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
215
216    if let serde_json::Value::Object(ref mut map) = args {
217        for (name, upload) in uploads {
218            match serde_json::to_value(&upload) {
219                Ok(value) => {
220                    map.insert(name, value);
221                }
222                Err(e) => {
223                    return multipart_error(
224                        StatusCode::INTERNAL_SERVER_ERROR,
225                        "SERIALIZE_ERROR",
226                        format!("Failed to serialize upload: {}", e),
227                    );
228                }
229            }
230        }
231    }
232
233    let request = super::request::RpcRequest::new(function, args);
234    let metadata = forge_core::function::RequestMetadata::new();
235
236    let response = handler.handle(request, auth, metadata).await;
237
238    match serde_json::to_value(&response) {
239        Ok(value) => (StatusCode::OK, axum::Json(value)),
240        Err(e) => multipart_error(
241            StatusCode::INTERNAL_SERVER_ERROR,
242            "SERIALIZE_ERROR",
243            format!("Failed to serialize response: {}", e),
244        ),
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_json_field_name_constant() {
254        assert_eq!(JSON_FIELD_NAME, "_json");
255    }
256}