1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_UPLOAD_FIELDS: usize = 20;
15const MAX_FIELD_NAME_LENGTH: usize = 255;
16const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
17const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
20const JSON_FIELD_NAME: &str = "_json";
21
22#[derive(Debug, Clone)]
24pub struct MultipartConfig {
25 pub max_body_size_bytes: usize,
26}
27
28fn multipart_error(
30 status: StatusCode,
31 code: &str,
32 message: impl Into<String>,
33) -> (StatusCode, axum::Json<serde_json::Value>) {
34 (
35 status,
36 axum::Json(serde_json::json!({
37 "success": false,
38 "error": {
39 "code": code,
40 "message": message.into()
41 }
42 })),
43 )
44}
45
46pub async fn rpc_multipart_handler(
48 State(handler): State<Arc<RpcHandler>>,
49 Extension(auth): Extension<AuthContext>,
50 Extension(mp_config): Extension<MultipartConfig>,
51 Path(function): Path<String>,
52 mut multipart: Multipart,
53) -> impl IntoResponse {
54 if function.is_empty()
56 || function.len() > 256
57 || !function
58 .chars()
59 .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
60 {
61 return multipart_error(
62 StatusCode::BAD_REQUEST,
63 "INVALID_FUNCTION",
64 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
65 );
66 }
67
68 let max_total = handler
70 .function_info(&function)
71 .and_then(|info| info.max_upload_size_bytes)
72 .unwrap_or(mp_config.max_body_size_bytes);
73 let max_file = max_total.min(DEFAULT_MAX_FILE_SIZE);
74
75 let mut json_args: Option<serde_json::Value> = None;
76 let mut uploads: HashMap<String, Upload> = HashMap::new();
77 let mut total_read: usize = 0;
78
79 loop {
81 let field = match multipart.next_field().await {
82 Ok(Some(f)) => f,
83 Ok(None) => break,
84 Err(e) => {
85 return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
86 }
87 };
88
89 let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
90 Some(n) => n,
91 None => {
92 return multipart_error(
93 StatusCode::BAD_REQUEST,
94 "INVALID_FIELD",
95 "Field name is required",
96 );
97 }
98 };
99
100 if name.len() > MAX_FIELD_NAME_LENGTH {
102 return multipart_error(
103 StatusCode::BAD_REQUEST,
104 "INVALID_FIELD",
105 format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
106 );
107 }
108
109 if name.contains("..")
110 || name.contains('/')
111 || name.contains('\\')
112 || name.contains(|c: char| c.is_control())
113 {
114 return multipart_error(
115 StatusCode::BAD_REQUEST,
116 "INVALID_FIELD",
117 "Field name contains invalid characters",
118 );
119 }
120
121 if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
123 return multipart_error(
124 StatusCode::BAD_REQUEST,
125 "TOO_MANY_FIELDS",
126 format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
127 );
128 }
129
130 if name == JSON_FIELD_NAME {
131 let mut buffer = BytesMut::new();
132 let mut json_field = field;
133
134 loop {
135 match json_field.chunk().await {
136 Ok(Some(chunk)) => {
137 if total_read + chunk.len() > max_total {
138 return multipart_error(
139 StatusCode::PAYLOAD_TOO_LARGE,
140 "PAYLOAD_TOO_LARGE",
141 format!(
142 "Multipart payload exceeds maximum size of {} bytes",
143 max_total
144 ),
145 );
146 }
147 if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
148 return multipart_error(
149 StatusCode::PAYLOAD_TOO_LARGE,
150 "JSON_TOO_LARGE",
151 format!(
152 "_json field exceeds maximum size of {} bytes",
153 MAX_JSON_FIELD_SIZE
154 ),
155 );
156 }
157 total_read += chunk.len();
158 buffer.extend_from_slice(&chunk);
159 }
160 Ok(None) => break,
161 Err(e) => {
162 return multipart_error(
163 StatusCode::BAD_REQUEST,
164 "READ_ERROR",
165 format!("Failed to read _json field: {}", e),
166 );
167 }
168 }
169 }
170
171 let text = match std::str::from_utf8(&buffer) {
172 Ok(s) => s,
173 Err(_) => {
174 return multipart_error(
175 StatusCode::BAD_REQUEST,
176 "INVALID_JSON",
177 "Invalid UTF-8 in _json field",
178 );
179 }
180 };
181
182 match serde_json::from_str(text) {
183 Ok(value) => json_args = Some(value),
184 Err(e) => {
185 return multipart_error(
186 StatusCode::BAD_REQUEST,
187 "INVALID_JSON",
188 format!("Invalid JSON in _json field: {}", e),
189 );
190 }
191 }
192 } else {
193 let raw_filename = field
194 .file_name()
195 .map(String::from)
196 .unwrap_or_else(|| name.clone());
197 let filename = raw_filename
199 .rsplit(['/', '\\'])
200 .next()
201 .unwrap_or(&raw_filename)
202 .replace("..", "_")
203 .to_string();
204 if filename.is_empty() {
205 return multipart_error(
206 StatusCode::BAD_REQUEST,
207 "INVALID_FILENAME",
208 "Filename is empty after sanitization",
209 );
210 }
211 let content_type = field
212 .content_type()
213 .map(String::from)
214 .unwrap_or_else(|| "application/octet-stream".to_string());
215
216 let mut buffer = BytesMut::new();
217 let mut field = field;
218
219 loop {
220 match field.chunk().await {
221 Ok(Some(chunk)) => {
222 if total_read + chunk.len() > max_total {
223 return multipart_error(
224 StatusCode::PAYLOAD_TOO_LARGE,
225 "PAYLOAD_TOO_LARGE",
226 format!(
227 "Multipart payload exceeds maximum size of {} bytes",
228 max_total
229 ),
230 );
231 }
232 if buffer.len() + chunk.len() > max_file {
233 return multipart_error(
234 StatusCode::PAYLOAD_TOO_LARGE,
235 "FILE_TOO_LARGE",
236 format!(
237 "File '{}' exceeds maximum size of {} bytes",
238 filename, max_file
239 ),
240 );
241 }
242 total_read += chunk.len();
243 buffer.extend_from_slice(&chunk);
244 }
245 Ok(None) => break,
246 Err(e) => {
247 return multipart_error(
248 StatusCode::BAD_REQUEST,
249 "READ_ERROR",
250 format!("Failed to read file field: {}", e),
251 );
252 }
253 }
254 }
255
256 let upload = Upload::new(filename, content_type, buffer.freeze());
257 uploads.insert(name, upload);
258 }
259 }
260
261 let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
262
263 if let serde_json::Value::Object(ref mut map) = args {
264 for (name, upload) in uploads {
265 if map.contains_key(&name) {
267 return multipart_error(
268 StatusCode::BAD_REQUEST,
269 "DUPLICATE_FIELD",
270 format!("Upload field '{}' conflicts with JSON argument", name),
271 );
272 }
273 match serde_json::to_value(&upload) {
274 Ok(value) => {
275 map.insert(name, value);
276 }
277 Err(e) => {
278 return multipart_error(
279 StatusCode::INTERNAL_SERVER_ERROR,
280 "SERIALIZE_ERROR",
281 format!("Failed to serialize upload: {}", e),
282 );
283 }
284 }
285 }
286 }
287
288 let request = super::request::RpcRequest::new(function, args);
289 let metadata = forge_core::function::RequestMetadata::new();
290
291 let response = handler.handle(request, auth, metadata).await;
292
293 let status = if response.success {
295 StatusCode::OK
296 } else {
297 response
298 .error
299 .as_ref()
300 .map(|e| e.status_code())
301 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
302 };
303 match serde_json::to_value(&response) {
304 Ok(value) => (status, axum::Json(value)),
305 Err(e) => multipart_error(
306 StatusCode::INTERNAL_SERVER_ERROR,
307 "SERIALIZE_ERROR",
308 format!("Failed to serialize response: {}", e),
309 ),
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_json_field_name_constant() {
319 assert_eq!(JSON_FIELD_NAME, "_json");
320 }
321}