axum_typed_multipart/
try_from_field.rs

1use crate::try_from_chunks::TryFromChunks;
2use crate::{FieldMetadata, TypedMultipartError};
3use async_trait::async_trait;
4use axum::extract::multipart::Field;
5use futures_util::stream::StreamExt;
6use futures_util::TryStreamExt;
7use std::mem;
8
9/// Types that can be created from a multipart field.
10///
11/// Required for all fields in structs deriving [TryFromMultipart](crate::TryFromMultipart).
12///
13/// **Note:** Prefer implementing [TryFromChunks] instead, which automatically provides
14/// this implementation with proper size limit handling.
15#[async_trait]
16pub trait TryFromField: Sized {
17    /// Creates an instance from a multipart field with optional size limit.
18    async fn try_from_field(
19        field: Field<'_>,
20        limit_bytes: Option<usize>,
21    ) -> Result<Self, TypedMultipartError>;
22}
23
24/// Stateful variant of [TryFromField] that provides access to application state during parsing.
25///
26/// ## Example
27///
28/// ```rust,no_run
29#[doc = include_str!("../examples/state.rs")]
30/// ```
31#[async_trait]
32pub trait TryFromFieldWithState<S>: Sized {
33    /// Creates an instance from a field with access to application state.
34    async fn try_from_field_with_state(
35        field: Field<'_>,
36        limit_bytes: Option<usize>,
37        state: &S,
38    ) -> Result<Self, TypedMultipartError>;
39}
40
41#[async_trait]
42impl<T, S> TryFromFieldWithState<S> for T
43where
44    T: TryFromField,
45{
46    async fn try_from_field_with_state(
47        field: Field<'_>,
48        limit_bytes: Option<usize>,
49        _state: &S,
50    ) -> Result<Self, TypedMultipartError> {
51        T::try_from_field(field, limit_bytes).await
52    }
53}
54
55#[async_trait]
56impl<T> TryFromField for T
57where
58    T: TryFromChunks + Send + Sync,
59{
60    async fn try_from_field(
61        field: Field<'_>,
62        limit_bytes: Option<usize>,
63    ) -> Result<Self, TypedMultipartError> {
64        let metadata = FieldMetadata::from(&field);
65        let mut field_name = metadata.name.clone().unwrap_or(String::new());
66        let mut size_bytes = 0;
67
68        let chunks = field.map_err(TypedMultipartError::from).map(|chunk| {
69            if let Ok(chunk) = chunk.as_ref() {
70                size_bytes += chunk.len();
71
72                if let Some(limit_bytes) = limit_bytes {
73                    if size_bytes > limit_bytes {
74                        return Err(TypedMultipartError::FieldTooLarge {
75                            field_name: mem::take(&mut field_name),
76                            limit_bytes,
77                        });
78                    }
79                }
80            }
81
82            chunk
83        });
84
85        T::try_from_chunks(chunks, metadata).await
86    }
87}
88
89#[cfg(test)]
90#[cfg_attr(all(coverage_nightly, test), coverage(off))]
91mod tests {
92    use super::*;
93    use axum::extract::Multipart;
94    use axum::routing::post;
95    use axum::Router;
96    use axum_test_helper::TestClient;
97    use futures_core::Stream;
98    use reqwest::multipart::Form;
99    use std::borrow::Cow;
100
101    #[derive(Debug)]
102    struct Data(String);
103
104    #[async_trait]
105    impl TryFromChunks for Data {
106        async fn try_from_chunks(
107            chunks: impl Stream<Item = Result<bytes::Bytes, TypedMultipartError>> + Send + Sync + Unpin,
108            metadata: FieldMetadata,
109        ) -> Result<Self, TypedMultipartError> {
110            let data = String::try_from_chunks(chunks, metadata).await?;
111            Ok(Self(data))
112        }
113    }
114
115    async fn test_try_from_field<T, F>(input: T, validator: F)
116    where
117        T: Into<Cow<'static, str>>,
118        F: FnOnce(Result<Data, TypedMultipartError>) + Clone + Send + Sync + 'static,
119    {
120        let handler = |mut multipart: Multipart| async move {
121            let field = multipart.next_field().await.unwrap().unwrap();
122            let res = Data::try_from_field(field, Some(512)).await;
123            validator(res);
124        };
125
126        TestClient::new(Router::new().route("/", post(handler)))
127            .post("/")
128            .multipart(Form::new().text("data", input))
129            .send()
130            .await
131            .unwrap();
132    }
133
134    #[tokio::test]
135    async fn test_try_from_field_valid() {
136        let validator = |res: Result<Data, TypedMultipartError>| {
137            assert_eq!(res.unwrap().0, "Hello, world!");
138        };
139        test_try_from_field("Hello, world!", validator).await;
140    }
141
142    #[tokio::test]
143    async fn test_try_from_too_large() {
144        let validator = |res: Result<Data, TypedMultipartError>| {
145            assert!(matches!(res, Err(TypedMultipartError::FieldTooLarge { .. })));
146        };
147        test_try_from_field("x".repeat(513), validator).await;
148    }
149}
150
151#[cfg(test)]
152#[cfg_attr(all(coverage_nightly, test), coverage(off))]
153mod tests_with_state {
154    use super::*;
155    use axum::extract::Multipart;
156    use axum::routing::post;
157    use axum::Router;
158    use axum_test_helper::TestClient;
159    use reqwest::multipart::Form;
160
161    #[derive(Clone)]
162    struct State(String);
163
164    struct DataWithState(String);
165
166    #[async_trait]
167    impl TryFromFieldWithState<State> for DataWithState {
168        async fn try_from_field_with_state(
169            field: Field<'_>,
170            limit_bytes: Option<usize>,
171            state: &State,
172        ) -> Result<Self, TypedMultipartError> {
173            let data = String::try_from_field(field, limit_bytes).await?;
174            Ok(Self(format!("{}, {}", state.0, data)))
175        }
176    }
177
178    #[tokio::test]
179    async fn test_try_from_field_with_state() {
180        let handler = |mut multipart: Multipart| async move {
181            let field = multipart.next_field().await.unwrap().unwrap();
182            let state = State("Hello".to_string());
183            let res = DataWithState::try_from_field_with_state(field, Some(512), &state).await;
184            assert_eq!(res.unwrap().0, "Hello, world!");
185        };
186        TestClient::new(Router::new().route("/", post(handler)))
187            .post("/")
188            .multipart(Form::new().text("data", "world!"))
189            .send()
190            .await
191            .unwrap();
192    }
193}