actix_multipart_extract/
extractor.rs

1use actix_web::{dev::Payload, http::ConnectionType, FromRequest, HttpRequest, HttpResponse};
2use futures::{Future, StreamExt, TryStreamExt};
3use serde::Deserialize;
4use serde_aux::prelude::serde_introspect;
5use serde_json::{Map, Number, Value};
6use std::{
7    ops::{Deref, DerefMut},
8    pin::Pin,
9};
10use thiserror::Error;
11
12use crate::{form::MultipartForm, MultipartConfig};
13
14/// Error type for multipart forms.
15#[derive(Error, Debug)]
16pub enum MultipartError {
17    #[error("Error while parsing field: {0}")]
18    ParseError(serde_json::Error),
19    #[error("File for field ({field}) was too large (max size: {limit} bytes)")]
20    FileSizeError { field: String, limit: usize },
21}
22
23/// Representing a file in a multipart form.
24#[derive(Debug, Deserialize)]
25pub struct File {
26    pub content_type: String,
27    pub name: String,
28    pub bytes: Vec<u8>,
29}
30
31/// Extractor to extract multipart forms from the request
32#[derive(Debug)]
33pub struct Multipart<T>(T);
34
35impl<T> Deref for Multipart<T> {
36    type Target = T;
37
38    fn deref(&self) -> &T {
39        &self.0
40    }
41}
42
43impl<T> DerefMut for Multipart<T> {
44    fn deref_mut(&mut self) -> &mut T {
45        &mut self.0
46    }
47}
48
49impl<T: serde::de::DeserializeOwned + MultipartForm> FromRequest for Multipart<T> {
50    type Error = actix_web::Error;
51    type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
52
53    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
54        let mut multipart = actix_multipart::Multipart::new(req.headers(), payload.take());
55        let req_owned = req.to_owned();
56
57        Box::pin(async move {
58            let config = req_owned.app_data::<MultipartConfig>();
59
60            match multipart_to_json::<T>(serde_introspect::<T>(), &mut multipart).await {
61                Ok(v) => match serde_json::from_value::<T>(v) {
62                    Ok(parsed) => Ok(Multipart(parsed)),
63                    Err(err) => Err(handle_error(MultipartError::ParseError(err), config)),
64                },
65                Err(err) => Err(handle_error(err, config)),
66            }
67        })
68    }
69}
70
71fn handle_error(error: MultipartError, config: Option<&MultipartConfig>) -> actix_web::Error {
72    let mut res = match config {
73        Some(config) => match &config.error_handler {
74            Some(error_handler) => error_handler(error),
75            None => HttpResponse::BadRequest().body(error.to_string()),
76        },
77        None => HttpResponse::BadRequest().body(error.to_string()),
78    };
79
80    // We must do this manually because of a bug in actix_http
81    // Ideally we would have all errors be a `actix_web::Error` by default
82    // SEE: https://github.com/actix/actix-web/pull/2779
83    res.head_mut().set_connection_type(ConnectionType::Close);
84
85    actix_web::error::InternalError::from_response("invalid multipart", res).into()
86}
87
88/// Convert a [`actix_multipart::Multipart`] form to a [`Value::Object`].
89///
90/// This checks for valid fields and file size limits on the [`MultipartForm`].
91async fn multipart_to_json<T: MultipartForm>(
92    valid_fields: &[&str],
93    multipart: &mut actix_multipart::Multipart,
94) -> Result<Value, MultipartError> {
95    let mut map = Map::new();
96
97    while let Ok(Some(mut field)) = multipart.try_next().await {
98        let disposition = field.content_disposition().clone();
99
100        let field_name = match disposition.get_name() {
101            Some(v) => v,
102            None => continue,
103        };
104
105        let field_name_formatted = field_name.replace("[]", "");
106
107        // Make sure the field actually exists on the form
108        if !valid_fields.contains(&field_name) {
109            continue;
110        }
111
112        if field.content_disposition().get_filename().is_some() {
113            // Is a file
114            let mut data: Vec<Value> = Vec::new();
115
116            let max_size = T::max_size(field_name);
117            let mut size = 0;
118
119            while let Some(chunk) = field.next().await {
120                match chunk {
121                    Ok(bytes) => {
122                        size += bytes.len();
123                        if let Some(max_size) = max_size {
124                            if size > max_size {
125                                return Err(MultipartError::FileSizeError {
126                                    field: field_name.to_string(),
127                                    limit: max_size,
128                                });
129                            }
130                        }
131
132                        data.reserve_exact(bytes.len());
133                        for byte in bytes {
134                            data.push(Value::Number(Number::from(byte)));
135                        }
136                    }
137                    Err(_) => {
138                        map.insert(field_name_formatted.to_owned(), Value::Null);
139                        continue;
140                    }
141                }
142            }
143
144            let mut field_map = Map::new();
145            field_map.insert(
146                "content_type".to_owned(),
147                Value::String(field.content_type().to_string()),
148            );
149
150            field_map.insert(
151                "name".to_owned(),
152                Value::String(
153                    field
154                        .content_disposition()
155                        .get_filename()
156                        .unwrap()
157                        .to_string(),
158                ),
159            );
160
161            field_map.insert("bytes".to_owned(), Value::Array(data));
162
163            params_insert(
164                &mut map,
165                field_name,
166                &field_name_formatted,
167                Value::Object(field_map),
168            );
169        } else if let Some(Ok(value)) = field.next().await {
170            // Not a file, parse as other JSON types
171            if let Ok(str) = std::str::from_utf8(&value) {
172                // Attempt to convert into a number
173                match str.parse::<isize>() {
174                    Ok(number) => params_insert(
175                        &mut map,
176                        field_name,
177                        &field_name_formatted,
178                        Value::Number(Number::from(number)),
179                    ),
180                    Err(_) => match str {
181                        "true" => params_insert(
182                            &mut map,
183                            field_name,
184                            &field_name_formatted,
185                            Value::Bool(true),
186                        ),
187                        "false" => params_insert(
188                            &mut map,
189                            field_name,
190                            &field_name_formatted,
191                            Value::Bool(false),
192                        ),
193                        _ => params_insert(
194                            &mut map,
195                            field_name,
196                            &field_name_formatted,
197                            Value::String(str.to_owned()),
198                        ),
199                    },
200                }
201            }
202        } else {
203            // Nothing
204            params_insert(&mut map, field_name, &field_name_formatted, Value::Null)
205        }
206    }
207
208    Ok(Value::Object(map))
209}
210
211/// Insert params to the map. This works with individual fields and arrays.
212fn params_insert(
213    params: &mut Map<String, Value>,
214    field_name: &str,
215    field_name_formatted: &String,
216    element: Value,
217) {
218    if field_name.ends_with("[]") {
219        if params.contains_key(field_name_formatted) {
220            if let Value::Array(val) = params.get_mut(field_name_formatted).unwrap() {
221                val.push(element);
222            }
223        } else {
224            params.insert(field_name_formatted.to_owned(), Value::Array(vec![element]));
225        }
226    } else {
227        params.insert(field_name.to_owned(), element);
228    }
229}