use actix_web::{dev::Payload, http::ConnectionType, FromRequest, HttpRequest, HttpResponse};
use futures::{Future, StreamExt, TryStreamExt};
use serde::Deserialize;
use serde_aux::prelude::serde_introspect;
use serde_json::{Map, Number, Value};
use std::{
ops::{Deref, DerefMut},
pin::Pin,
};
use thiserror::Error;
use crate::{form::MultipartForm, MultipartConfig};
#[derive(Error, Debug)]
pub enum MultipartError {
#[error("Error while parsing field: {0}")]
ParseError(serde_json::Error),
#[error("File for field ({field}) was too large (max size: {limit} bytes)")]
FileSizeError { field: String, limit: usize },
}
#[derive(Debug, Deserialize)]
pub struct File {
pub content_type: String,
pub name: String,
pub bytes: Vec<u8>,
}
#[derive(Debug)]
pub struct Multipart<T>(T);
impl<T> Deref for Multipart<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for Multipart<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T: serde::de::DeserializeOwned + MultipartForm> FromRequest for Multipart<T> {
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self, Self::Error>>>>;
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let mut multipart = actix_multipart::Multipart::new(req.headers(), payload.take());
let req_owned = req.to_owned();
Box::pin(async move {
let config = req_owned.app_data::<MultipartConfig>();
match multipart_to_json::<T>(serde_introspect::<T>(), &mut multipart).await {
Ok(v) => match serde_json::from_value::<T>(v) {
Ok(parsed) => Ok(Multipart(parsed)),
Err(err) => Err(handle_error(MultipartError::ParseError(err), config)),
},
Err(err) => Err(handle_error(err, config)),
}
})
}
}
fn handle_error(error: MultipartError, config: Option<&MultipartConfig>) -> actix_web::Error {
let mut res = match config {
Some(config) => match &config.error_handler {
Some(error_handler) => error_handler(error),
None => HttpResponse::BadRequest().body(error.to_string()),
},
None => HttpResponse::BadRequest().body(error.to_string()),
};
res.head_mut().set_connection_type(ConnectionType::Close);
actix_web::error::InternalError::from_response("invalid multipart", res).into()
}
async fn multipart_to_json<T: MultipartForm>(
valid_fields: &[&str],
multipart: &mut actix_multipart::Multipart,
) -> Result<Value, MultipartError> {
let mut map = Map::new();
while let Ok(Some(mut field)) = multipart.try_next().await {
let disposition = field.content_disposition().clone();
let field_name = match disposition.get_name() {
Some(v) => v,
None => continue,
};
let field_name_formatted = field_name.replace("[]", "");
if !valid_fields.contains(&field_name) {
continue;
}
if field.content_disposition().get_filename().is_some() {
let mut data: Vec<Value> = Vec::new();
let max_size = T::max_size(field_name);
let mut size = 0;
while let Some(chunk) = field.next().await {
match chunk {
Ok(bytes) => {
size += bytes.len();
if let Some(max_size) = max_size {
if size > max_size {
return Err(MultipartError::FileSizeError {
field: field_name.to_string(),
limit: max_size,
});
}
}
data.reserve_exact(bytes.len());
for byte in bytes {
data.push(Value::Number(Number::from(byte)));
}
}
Err(_) => {
map.insert(field_name_formatted.to_owned(), Value::Null);
continue;
}
}
}
let mut field_map = Map::new();
field_map.insert(
"content_type".to_owned(),
Value::String(field.content_type().to_string()),
);
field_map.insert(
"name".to_owned(),
Value::String(
field
.content_disposition()
.get_filename()
.unwrap()
.to_string(),
),
);
field_map.insert("bytes".to_owned(), Value::Array(data));
params_insert(
&mut map,
field_name,
&field_name_formatted,
Value::Object(field_map),
);
} else if let Some(Ok(value)) = field.next().await {
if let Ok(str) = std::str::from_utf8(&value) {
match str.parse::<isize>() {
Ok(number) => params_insert(
&mut map,
field_name,
&field_name_formatted,
Value::Number(Number::from(number)),
),
Err(_) => match str {
"true" => params_insert(
&mut map,
field_name,
&field_name_formatted,
Value::Bool(true),
),
"false" => params_insert(
&mut map,
field_name,
&field_name_formatted,
Value::Bool(false),
),
_ => params_insert(
&mut map,
field_name,
&field_name_formatted,
Value::String(str.to_owned()),
),
},
}
}
} else {
params_insert(&mut map, field_name, &field_name_formatted, Value::Null)
}
}
Ok(Value::Object(map))
}
fn params_insert(
params: &mut Map<String, Value>,
field_name: &str,
field_name_formatted: &String,
element: Value,
) {
if field_name.ends_with("[]") {
if params.contains_key(field_name_formatted) {
if let Value::Array(val) = params.get_mut(field_name_formatted).unwrap() {
val.push(element);
}
} else {
params.insert(field_name_formatted.to_owned(), Value::Array(vec![element]));
}
} else {
params.insert(field_name.to_owned(), element);
}
}