1use actix_web::{dev::Payload, http::ConnectionType, FromRequest, HttpRequest, HttpResponse};
2use futures::{Future, StreamExt, TryStreamExt};
3use serde::Deserialize;
4use serde_aux::prelude::serde_introspect;
5use serde_json::{Map, Number, Value};
6use std::{
7 ops::{Deref, DerefMut},
8 pin::Pin,
9};
10use thiserror::Error;
11
12use crate::{form::MultipartForm, MultipartConfig};
13
14#[derive(Error, Debug)]
16pub enum MultipartError {
17 #[error("Error while parsing field: {0}")]
18 ParseError(serde_json::Error),
19 #[error("File for field ({field}) was too large (max size: {limit} bytes)")]
20 FileSizeError { field: String, limit: usize },
21}
22
23#[derive(Debug, Deserialize)]
25pub struct File {
26 pub content_type: String,
27 pub name: String,
28 pub bytes: Vec<u8>,
29}
30
31#[derive(Debug)]
33pub struct Multipart<T>(T);
34
35impl<T> Deref for Multipart<T> {
36 type Target = T;
37
38 fn deref(&self) -> &T {
39 &self.0
40 }
41}
42
43impl<T> DerefMut for Multipart<T> {
44 fn deref_mut(&mut self) -> &mut T {
45 &mut self.0
46 }
47}
48
49impl<T: serde::de::DeserializeOwned + MultipartForm> FromRequest for Multipart<T> {
50 type Error = actix_web::Error;
51 type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
52
53 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
54 let mut multipart = actix_multipart::Multipart::new(req.headers(), payload.take());
55 let req_owned = req.to_owned();
56
57 Box::pin(async move {
58 let config = req_owned.app_data::<MultipartConfig>();
59
60 match multipart_to_json::<T>(serde_introspect::<T>(), &mut multipart).await {
61 Ok(v) => match serde_json::from_value::<T>(v) {
62 Ok(parsed) => Ok(Multipart(parsed)),
63 Err(err) => Err(handle_error(MultipartError::ParseError(err), config)),
64 },
65 Err(err) => Err(handle_error(err, config)),
66 }
67 })
68 }
69}
70
71fn handle_error(error: MultipartError, config: Option<&MultipartConfig>) -> actix_web::Error {
72 let mut res = match config {
73 Some(config) => match &config.error_handler {
74 Some(error_handler) => error_handler(error),
75 None => HttpResponse::BadRequest().body(error.to_string()),
76 },
77 None => HttpResponse::BadRequest().body(error.to_string()),
78 };
79
80 res.head_mut().set_connection_type(ConnectionType::Close);
84
85 actix_web::error::InternalError::from_response("invalid multipart", res).into()
86}
87
88async fn multipart_to_json<T: MultipartForm>(
92 valid_fields: &[&str],
93 multipart: &mut actix_multipart::Multipart,
94) -> Result<Value, MultipartError> {
95 let mut map = Map::new();
96
97 while let Ok(Some(mut field)) = multipart.try_next().await {
98 let disposition = field.content_disposition().clone();
99
100 let field_name = match disposition.get_name() {
101 Some(v) => v,
102 None => continue,
103 };
104
105 let field_name_formatted = field_name.replace("[]", "");
106
107 if !valid_fields.contains(&field_name) {
109 continue;
110 }
111
112 if field.content_disposition().get_filename().is_some() {
113 let mut data: Vec<Value> = Vec::new();
115
116 let max_size = T::max_size(field_name);
117 let mut size = 0;
118
119 while let Some(chunk) = field.next().await {
120 match chunk {
121 Ok(bytes) => {
122 size += bytes.len();
123 if let Some(max_size) = max_size {
124 if size > max_size {
125 return Err(MultipartError::FileSizeError {
126 field: field_name.to_string(),
127 limit: max_size,
128 });
129 }
130 }
131
132 data.reserve_exact(bytes.len());
133 for byte in bytes {
134 data.push(Value::Number(Number::from(byte)));
135 }
136 }
137 Err(_) => {
138 map.insert(field_name_formatted.to_owned(), Value::Null);
139 continue;
140 }
141 }
142 }
143
144 let mut field_map = Map::new();
145 field_map.insert(
146 "content_type".to_owned(),
147 Value::String(field.content_type().to_string()),
148 );
149
150 field_map.insert(
151 "name".to_owned(),
152 Value::String(
153 field
154 .content_disposition()
155 .get_filename()
156 .unwrap()
157 .to_string(),
158 ),
159 );
160
161 field_map.insert("bytes".to_owned(), Value::Array(data));
162
163 params_insert(
164 &mut map,
165 field_name,
166 &field_name_formatted,
167 Value::Object(field_map),
168 );
169 } else if let Some(Ok(value)) = field.next().await {
170 if let Ok(str) = std::str::from_utf8(&value) {
172 match str.parse::<isize>() {
174 Ok(number) => params_insert(
175 &mut map,
176 field_name,
177 &field_name_formatted,
178 Value::Number(Number::from(number)),
179 ),
180 Err(_) => match str {
181 "true" => params_insert(
182 &mut map,
183 field_name,
184 &field_name_formatted,
185 Value::Bool(true),
186 ),
187 "false" => params_insert(
188 &mut map,
189 field_name,
190 &field_name_formatted,
191 Value::Bool(false),
192 ),
193 _ => params_insert(
194 &mut map,
195 field_name,
196 &field_name_formatted,
197 Value::String(str.to_owned()),
198 ),
199 },
200 }
201 }
202 } else {
203 params_insert(&mut map, field_name, &field_name_formatted, Value::Null)
205 }
206 }
207
208 Ok(Value::Object(map))
209}
210
211fn params_insert(
213 params: &mut Map<String, Value>,
214 field_name: &str,
215 field_name_formatted: &String,
216 element: Value,
217) {
218 if field_name.ends_with("[]") {
219 if params.contains_key(field_name_formatted) {
220 if let Value::Array(val) = params.get_mut(field_name_formatted).unwrap() {
221 val.push(element);
222 }
223 } else {
224 params.insert(field_name_formatted.to_owned(), Value::Array(vec![element]));
225 }
226 } else {
227 params.insert(field_name.to_owned(), element);
228 }
229}