Skip to main content

forge_runtime/gateway/
multipart.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::{Extension, Multipart, Path, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use bytes::BytesMut;
8
9use forge_core::function::AuthContext;
10use forge_core::types::Upload;
11
12use super::rpc::RpcHandler;
13
14const MAX_FILE_SIZE: usize = 10 * 1024 * 1024;
15const MAX_UPLOAD_FIELDS: usize = 20;
16const MAX_FIELD_NAME_LENGTH: usize = 255;
17const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024;
18const JSON_FIELD_NAME: &str = "_json";
19
20/// Create a multipart error response.
21fn multipart_error(
22    status: StatusCode,
23    code: &str,
24    message: impl Into<String>,
25) -> (StatusCode, axum::Json<serde_json::Value>) {
26    (
27        status,
28        axum::Json(serde_json::json!({
29            "success": false,
30            "error": {
31                "code": code,
32                "message": message.into()
33            }
34        })),
35    )
36}
37
38/// Handle multipart form data for RPC calls with file uploads.
39pub async fn rpc_multipart_handler(
40    State(handler): State<Arc<RpcHandler>>,
41    Extension(auth): Extension<AuthContext>,
42    Path(function): Path<String>,
43    mut multipart: Multipart,
44) -> impl IntoResponse {
45    let mut json_args: Option<serde_json::Value> = None;
46    let mut uploads: HashMap<String, Upload> = HashMap::new();
47
48    // Parse multipart fields
49    loop {
50        let field = match multipart.next_field().await {
51            Ok(Some(f)) => f,
52            Ok(None) => break,
53            Err(e) => {
54                return multipart_error(StatusCode::BAD_REQUEST, "MULTIPART_ERROR", e.to_string());
55            }
56        };
57
58        let name = match field.name().map(String::from).filter(|n| !n.is_empty()) {
59            Some(n) => n,
60            None => {
61                return multipart_error(
62                    StatusCode::BAD_REQUEST,
63                    "INVALID_FIELD",
64                    "Field name is required",
65                );
66            }
67        };
68
69        // Validate field name length
70        if name.len() > MAX_FIELD_NAME_LENGTH {
71            return multipart_error(
72                StatusCode::BAD_REQUEST,
73                "INVALID_FIELD",
74                format!("Field name too long (max {} chars)", MAX_FIELD_NAME_LENGTH),
75            );
76        }
77
78        if name.contains("..") || name.contains('/') || name.contains('\\') {
79            return multipart_error(
80                StatusCode::BAD_REQUEST,
81                "INVALID_FIELD",
82                "Field name contains invalid characters",
83            );
84        }
85
86        // Check upload count before processing to prevent bypass via _json field ordering
87        if name != JSON_FIELD_NAME && uploads.len() >= MAX_UPLOAD_FIELDS {
88            return multipart_error(
89                StatusCode::BAD_REQUEST,
90                "TOO_MANY_FIELDS",
91                format!("Maximum {} upload fields allowed", MAX_UPLOAD_FIELDS),
92            );
93        }
94
95        if name == JSON_FIELD_NAME {
96            let mut buffer = BytesMut::new();
97            let mut json_field = field;
98
99            loop {
100                match json_field.chunk().await {
101                    Ok(Some(chunk)) => {
102                        if buffer.len() + chunk.len() > MAX_JSON_FIELD_SIZE {
103                            return multipart_error(
104                                StatusCode::PAYLOAD_TOO_LARGE,
105                                "JSON_TOO_LARGE",
106                                format!(
107                                    "_json field exceeds maximum size of {} bytes",
108                                    MAX_JSON_FIELD_SIZE
109                                ),
110                            );
111                        }
112                        buffer.extend_from_slice(&chunk);
113                    }
114                    Ok(None) => break,
115                    Err(e) => {
116                        return multipart_error(
117                            StatusCode::BAD_REQUEST,
118                            "READ_ERROR",
119                            format!("Failed to read _json field: {}", e),
120                        );
121                    }
122                }
123            }
124
125            let text = match std::str::from_utf8(&buffer) {
126                Ok(s) => s,
127                Err(_) => {
128                    return multipart_error(
129                        StatusCode::BAD_REQUEST,
130                        "INVALID_JSON",
131                        "Invalid UTF-8 in _json field",
132                    );
133                }
134            };
135
136            match serde_json::from_str(text) {
137                Ok(value) => json_args = Some(value),
138                Err(e) => {
139                    return multipart_error(
140                        StatusCode::BAD_REQUEST,
141                        "INVALID_JSON",
142                        format!("Invalid JSON in _json field: {}", e),
143                    );
144                }
145            }
146        } else {
147            let filename = field
148                .file_name()
149                .map(String::from)
150                .unwrap_or_else(|| name.clone());
151            let content_type = field
152                .content_type()
153                .map(String::from)
154                .unwrap_or_else(|| "application/octet-stream".to_string());
155
156            let mut buffer = BytesMut::new();
157            let mut field = field;
158
159            loop {
160                match field.chunk().await {
161                    Ok(Some(chunk)) => {
162                        if buffer.len() + chunk.len() > MAX_FILE_SIZE {
163                            return multipart_error(
164                                StatusCode::PAYLOAD_TOO_LARGE,
165                                "FILE_TOO_LARGE",
166                                format!(
167                                    "File '{}' exceeds maximum size of {} bytes",
168                                    filename, MAX_FILE_SIZE
169                                ),
170                            );
171                        }
172                        buffer.extend_from_slice(&chunk);
173                    }
174                    Ok(None) => break,
175                    Err(e) => {
176                        return multipart_error(
177                            StatusCode::BAD_REQUEST,
178                            "READ_ERROR",
179                            format!("Failed to read file field: {}", e),
180                        );
181                    }
182                }
183            }
184
185            let upload = Upload::new(filename, content_type, buffer.freeze());
186            uploads.insert(name, upload);
187        }
188    }
189
190    let mut args = json_args.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
191
192    if let serde_json::Value::Object(ref mut map) = args {
193        for (name, upload) in uploads {
194            match serde_json::to_value(&upload) {
195                Ok(value) => {
196                    map.insert(name, value);
197                }
198                Err(e) => {
199                    return multipart_error(
200                        StatusCode::INTERNAL_SERVER_ERROR,
201                        "SERIALIZE_ERROR",
202                        format!("Failed to serialize upload: {}", e),
203                    );
204                }
205            }
206        }
207    }
208
209    let request = super::request::RpcRequest::new(function, args);
210    let metadata = forge_core::function::RequestMetadata::new();
211
212    let response = handler.handle(request, auth, metadata).await;
213
214    match serde_json::to_value(&response) {
215        Ok(value) => (StatusCode::OK, axum::Json(value)),
216        Err(e) => multipart_error(
217            StatusCode::INTERNAL_SERVER_ERROR,
218            "SERIALIZE_ERROR",
219            format!("Failed to serialize response: {}", e),
220        ),
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_json_field_name_constant() {
230        assert_eq!(JSON_FIELD_NAME, "_json");
231    }
232}