1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
15const MAX_TOTAL_UPLOAD_SIZE: usize = 20 * 1024 * 1024;
16const MAX_UPLOAD_FIELDS: usize = 20;
17const MAX_FIELD_NAME_LENGTH: usize = 255;
18const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
19const JSON_FIELD_NAME: &str = "_json";
20
21fn multipart_error(
23 status: StatusCode,
24 code: &str,
25 message: impl Into<String>,
26) -> (StatusCode, axum::Json<serde_json::Value>) {
27 (
28 status,
29 axum::Json(serde_json::json!({
30 "success": false,
31 "error": {
32 "code": code,
33 "message": message.into()
34 }
35 })),
36 )
37}
38
39pub async fn rpc_multipart_handler(
41 State(handler): State<Arc<RpcHandler>>,
42 Extension(auth): Extension<AuthContext>,
43 Path(function): Path<String>,
44 mut multipart: Multipart,
45) -> impl IntoResponse {
46 let mut json_args: Option<serde_json::Value> = None;
47 let mut uploads: HashMap<String, Upload> = HashMap::new();
48 let mut total_read: usize = 0;
49
50 loop {
52 let field = match multipart.next_field().await {
53 Ok(Some(f)) => f,
54 Ok(None) => break,
55 Err(e) => {
56 return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
57 }
58 };
59
60 let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
61 Some(n) => n,
62 None => {
63 return multipart_error(
64 StatusCode::BAD_REQUEST,
65 "INVALID_FIELD",
66 "Field name is required",
67 );
68 }
69 };
70
71 if name.len() > MAX_FIELD_NAME_LENGTH {
73 return multipart_error(
74 StatusCode::BAD_REQUEST,
75 "INVALID_FIELD",
76 format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
77 );
78 }
79
80 if name.contains("..") || name.contains('/') || name.contains('\\') {
81 return multipart_error(
82 StatusCode::BAD_REQUEST,
83 "INVALID_FIELD",
84 "Field name contains invalid characters",
85 );
86 }
87
88 if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
90 return multipart_error(
91 StatusCode::BAD_REQUEST,
92 "TOO_MANY_FIELDS",
93 format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
94 );
95 }
96
97 if name == JSON_FIELD_NAME {
98 let mut buffer = BytesMut::new();
99 let mut json_field = field;
100
101 loop {
102 match json_field.chunk().await {
103 Ok(Some(chunk)) => {
104 if total_read + chunk.len() > MAX_TOTAL_UPLOAD_SIZE {
105 return multipart_error(
106 StatusCode::PAYLOAD_TOO_LARGE,
107 "PAYLOAD_TOO_LARGE",
108 format!(
109 "Multipart payload exceeds maximum size of {} bytes",
110 MAX_TOTAL_UPLOAD_SIZE
111 ),
112 );
113 }
114 if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
115 return multipart_error(
116 StatusCode::PAYLOAD_TOO_LARGE,
117 "JSON_TOO_LARGE",
118 format!(
119 "_json field exceeds maximum size of {} bytes",
120 MAX_JSON_FIELD_SIZE
121 ),
122 );
123 }
124 total_read += chunk.len();
125 buffer.extend_from_slice(&chunk);
126 }
127 Ok(None) => break,
128 Err(e) => {
129 return multipart_error(
130 StatusCode::BAD_REQUEST,
131 "READ_ERROR",
132 format!("Failed to read _json field: {}", e),
133 );
134 }
135 }
136 }
137
138 let text = match std::str::from_utf8(&buffer) {
139 Ok(s) => s,
140 Err(_) => {
141 return multipart_error(
142 StatusCode::BAD_REQUEST,
143 "INVALID_JSON",
144 "Invalid UTF-8 in _json field",
145 );
146 }
147 };
148
149 match serde_json::from_str(text) {
150 Ok(value) => json_args = Some(value),
151 Err(e) => {
152 return multipart_error(
153 StatusCode::BAD_REQUEST,
154 "INVALID_JSON",
155 format!("Invalid JSON in _json field: {}", e),
156 );
157 }
158 }
159 } else {
160 let filename = field
161 .file_name()
162 .map(String::from)
163 .unwrap_or_else(|| name.clone());
164 let content_type = field
165 .content_type()
166 .map(String::from)
167 .unwrap_or_else(|| "application/octet-stream".to_string());
168
169 let mut buffer = BytesMut::new();
170 let mut field = field;
171
172 loop {
173 match field.chunk().await {
174 Ok(Some(chunk)) => {
175 if total_read + chunk.len() > MAX_TOTAL_UPLOAD_SIZE {
176 return multipart_error(
177 StatusCode::PAYLOAD_TOO_LARGE,
178 "PAYLOAD_TOO_LARGE",
179 format!(
180 "Multipart payload exceeds maximum size of {} bytes",
181 MAX_TOTAL_UPLOAD_SIZE
182 ),
183 );
184 }
185 if buffer.len() + chunk.len() > MAX_FILE_SIZE {
186 return multipart_error(
187 StatusCode::PAYLOAD_TOO_LARGE,
188 "FILE_TOO_LARGE",
189 format!(
190 "File '{}' exceeds maximum size of {} bytes",
191 filename, MAX_FILE_SIZE
192 ),
193 );
194 }
195 total_read += chunk.len();
196 buffer.extend_from_slice(&chunk);
197 }
198 Ok(None) => break,
199 Err(e) => {
200 return multipart_error(
201 StatusCode::BAD_REQUEST,
202 "READ_ERROR",
203 format!("Failed to read file field: {}", e),
204 );
205 }
206 }
207 }
208
209 let upload = Upload::new(filename, content_type, buffer.freeze());
210 uploads.insert(name, upload);
211 }
212 }
213
214 let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
215
216 if let serde_json::Value::Object(ref mut map) = args {
217 for (name, upload) in uploads {
218 match serde_json::to_value(&upload) {
219 Ok(value) => {
220 map.insert(name, value);
221 }
222 Err(e) => {
223 return multipart_error(
224 StatusCode::INTERNAL_SERVER_ERROR,
225 "SERIALIZE_ERROR",
226 format!("Failed to serialize upload: {}", e),
227 );
228 }
229 }
230 }
231 }
232
233 let request = super::request::RpcRequest::new(function, args);
234 let metadata = forge_core::function::RequestMetadata::new();
235
236 let response = handler.handle(request, auth, metadata).await;
237
238 match serde_json::to_value(&response) {
239 Ok(value) => (StatusCode::OK, axum::Json(value)),
240 Err(e) => multipart_error(
241 StatusCode::INTERNAL_SERVER_ERROR,
242 "SERIALIZE_ERROR",
243 format!("Failed to serialize response: {}", e),
244 ),
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_json_field_name_constant() {
254 assert_eq!(JSON_FIELD_NAME, "_json");
255 }
256}