axum-image 0.2.0

Image extractors for Axum
Documentation
use axum::{
    body::Bytes,
    extract::{FromRequest, Request},
    http::{StatusCode, header::CONTENT_TYPE},
};
use axum_extra::extract::Multipart;
use serde::de::DeserializeOwned;

/// An image extractor accepting:
/// * `multipart/form-data`
/// * `image/png`
/// * `image/jpeg`
/// * `image/avif`
/// * `image/webp`
pub struct Image(pub Bytes, pub String);

impl<S> FromRequest<S> for Image
where
    Bytes: FromRequest<S>,
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let Some(content_type) = req.headers().get(CONTENT_TYPE) else {
            return Err(StatusCode::BAD_REQUEST);
        };

        let content_type = content_type.to_str().unwrap();
        let content_type_string = content_type.to_string();

        let body = if content_type.starts_with("multipart/form-data") {
            let mut multipart = Multipart::from_request(req, state)
                .await
                .map_err(|_| StatusCode::BAD_REQUEST)?;

            let Ok(Some(field)) = multipart.next_field().await else {
                return Err(StatusCode::BAD_REQUEST);
            };

            field.bytes().await.map_err(|_| StatusCode::BAD_REQUEST)?
        } else if (content_type == "image/avif")
            | (content_type == "image/jpeg")
            | (content_type == "image/png")
            | (content_type == "image/webp")
            | (content_type == "image/gif")
        {
            Bytes::from_request(req, state)
                .await
                .map_err(|_| StatusCode::BAD_REQUEST)?
        } else {
            return Err(StatusCode::BAD_REQUEST);
        };

        Ok(Self(body, content_type_string))
    }
}

/// A file extractor accepting:
/// * `multipart/form-data`
///
/// Will also attempt to parse out the **last** field in the multipart upload
/// as the given struct from JSON. Every other field is put into a vector of bytes,
/// as they are seen as raw binary data.
pub struct JsonMultipart<T: DeserializeOwned>(pub Vec<Bytes>, pub T);

impl<S, T> FromRequest<S> for JsonMultipart<T>
where
    Bytes: FromRequest<S>,
    S: Send + Sync,
    T: DeserializeOwned,
{
    type Rejection = (StatusCode, String);

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let Some(content_type) = req.headers().get(CONTENT_TYPE) else {
            return Err((
                StatusCode::BAD_REQUEST,
                "no content type header".to_string(),
            ));
        };

        let content_type = content_type.to_str().unwrap();

        if !content_type.starts_with("multipart/form-data") {
            return Err((
                StatusCode::BAD_REQUEST,
                "expected multipart/form-data".to_string(),
            ));
        }

        let mut multipart = Multipart::from_request(req, state).await.map_err(|_| {
            (
                StatusCode::BAD_REQUEST,
                "could not read multipart".to_string(),
            )
        })?;

        let mut body: Vec<Bytes> = {
            let mut out = Vec::new();

            while let Ok(Some(field)) = multipart.next_field().await {
                out.push(
                    field
                        .bytes()
                        .await
                        .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?,
                );
            }

            out
        };

        let last = match body.pop() {
            Some(b) => b,
            None => {
                return Err((
                    StatusCode::BAD_REQUEST,
                    "could not read json data".to_string(),
                ));
            }
        };

        let json: T = match serde_json::from_str(&match String::from_utf8(last.to_vec()) {
            Ok(s) => s,
            Err(_) => return Err((StatusCode::BAD_REQUEST, "json data isn't utf8".to_string())),
        }) {
            Ok(s) => s,
            Err(e) => {
                return Err((StatusCode::BAD_REQUEST, e.to_string()));
            }
        };

        Ok(Self(body, json))
    }
}