forge_runtime/gateway/
multipart.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
15const MAX_UPLOAD_FIELDS: usize = 20;
16const MAX_FIELD_NAME_LENGTH: usize = 255;
17const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
18const JSON_FIELD_NAME: &str = "_json";
19
20fn multipart_error(
22 status: StatusCode,
23 code: &str,
24 message: impl Into<String>,
25) -> (StatusCode, axum::Json<serde_json::Value>) {
26 (
27 status,
28 axum::Json(serde_json::json!({
29 "success": false,
30 "error": {
31 "code": code,
32 "message": message.into()
33 }
34 })),
35 )
36}
37
38pub async fn rpc_multipart_handler(
40 State(handler): State<Arc<RpcHandler>>,
41 Extension(auth): Extension<AuthContext>,
42 Path(function): Path<String>,
43 mut multipart: Multipart,
44) -> impl IntoResponse {
45 let mut json_args: Option<serde_json::Value> = None;
46 let mut uploads: HashMap<String, Upload> = HashMap::new();
47
48 loop {
50 let field = match multipart.next_field().await {
51 Ok(Some(f)) => f,
52 Ok(None) => break,
53 Err(e) => {
54 return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
55 }
56 };
57
58 let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
59 Some(n) => n,
60 None => {
61 return multipart_error(
62 StatusCode::BAD_REQUEST,
63 "INVALID_FIELD",
64 "Field name is required",
65 );
66 }
67 };
68
69 if name.len() > MAX_FIELD_NAME_LENGTH {
71 return multipart_error(
72 StatusCode::BAD_REQUEST,
73 "INVALID_FIELD",
74 format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
75 );
76 }
77
78 if name.contains("..") || name.contains('/') || name.contains('\\') {
79 return multipart_error(
80 StatusCode::BAD_REQUEST,
81 "INVALID_FIELD",
82 "Field name contains invalid characters",
83 );
84 }
85
86 if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
88 return multipart_error(
89 StatusCode::BAD_REQUEST,
90 "TOO_MANY_FIELDS",
91 format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
92 );
93 }
94
95 if name == JSON_FIELD_NAME {
96 let mut buffer = BytesMut::new();
97 let mut json_field = field;
98
99 loop {
100 match json_field.chunk().await {
101 Ok(Some(chunk)) => {
102 if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
103 return multipart_error(
104 StatusCode::PAYLOAD_TOO_LARGE,
105 "JSON_TOO_LARGE",
106 format!(
107 "_json field exceeds maximum size of {} bytes",
108 MAX_JSON_FIELD_SIZE
109 ),
110 );
111 }
112 buffer.extend_from_slice(&chunk);
113 }
114 Ok(None) => break,
115 Err(e) => {
116 return multipart_error(
117 StatusCode::BAD_REQUEST,
118 "READ_ERROR",
119 format!("Failed to read _json field: {}", e),
120 );
121 }
122 }
123 }
124
125 let text = match std::str::from_utf8(&buffer) {
126 Ok(s) => s,
127 Err(_) => {
128 return multipart_error(
129 StatusCode::BAD_REQUEST,
130 "INVALID_JSON",
131 "Invalid UTF-8 in _json field",
132 );
133 }
134 };
135
136 match serde_json::from_str(text) {
137 Ok(value) => json_args = Some(value),
138 Err(e) => {
139 return multipart_error(
140 StatusCode::BAD_REQUEST,
141 "INVALID_JSON",
142 format!("Invalid JSON in _json field: {}", e),
143 );
144 }
145 }
146 } else {
147 let filename = field
148 .file_name()
149 .map(String::from)
150 .unwrap_or_else(|| name.clone());
151 let content_type = field
152 .content_type()
153 .map(String::from)
154 .unwrap_or_else(|| "application/octet-stream".to_string());
155
156 let mut buffer = BytesMut::new();
157 let mut field = field;
158
159 loop {
160 match field.chunk().await {
161 Ok(Some(chunk)) => {
162 if buffer.len() + chunk.len() > MAX_FILE_SIZE {
163 return multipart_error(
164 StatusCode::PAYLOAD_TOO_LARGE,
165 "FILE_TOO_LARGE",
166 format!(
167 "File '{}' exceeds maximum size of {} bytes",
168 filename, MAX_FILE_SIZE
169 ),
170 );
171 }
172 buffer.extend_from_slice(&chunk);
173 }
174 Ok(None) => break,
175 Err(e) => {
176 return multipart_error(
177 StatusCode::BAD_REQUEST,
178 "READ_ERROR",
179 format!("Failed to read file field: {}", e),
180 );
181 }
182 }
183 }
184
185 let upload = Upload::new(filename, content_type, buffer.freeze());
186 uploads.insert(name, upload);
187 }
188 }
189
190 let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
191
192 if let serde_json::Value::Object(ref mut map) = args {
193 for (name, upload) in uploads {
194 match serde_json::to_value(&upload) {
195 Ok(value) => {
196 map.insert(name, value);
197 }
198 Err(e) => {
199 return multipart_error(
200 StatusCode::INTERNAL_SERVER_ERROR,
201 "SERIALIZE_ERROR",
202 format!("Failed to serialize upload: {}", e),
203 );
204 }
205 }
206 }
207 }
208
209 let request = super::request::RpcRequest::new(function, args);
210 let metadata = forge_core::function::RequestMetadata::new();
211
212 let response = handler.handle(request, auth, metadata).await;
213
214 match serde_json::to_value(&response) {
215 Ok(value) => (StatusCode::OK, axum::Json(value)),
216 Err(e) => multipart_error(
217 StatusCode::INTERNAL_SERVER_ERROR,
218 "SERIALIZE_ERROR",
219 format!("Failed to serialize response: {}", e),
220 ),
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_json_field_name_constant() {
230 assert_eq!(JSON_FIELD_NAME, "_json");
231 }
232}