axum_typed_multipart/
try_from_field.rs1use crate::try_from_chunks::TryFromChunks;
2use crate::{FieldMetadata, TypedMultipartError};
3use async_trait::async_trait;
4use axum::extract::multipart::Field;
5use futures_util::stream::StreamExt;
6use futures_util::TryStreamExt;
7use std::mem;
8
9#[async_trait]
16pub trait TryFromField: Sized {
17 async fn try_from_field(
19 field: Field<'_>,
20 limit_bytes: Option<usize>,
21 ) -> Result<Self, TypedMultipartError>;
22}
23
24#[doc = include_str!("../examples/state.rs")]
30#[async_trait]
32pub trait TryFromFieldWithState<S>: Sized {
33 async fn try_from_field_with_state(
35 field: Field<'_>,
36 limit_bytes: Option<usize>,
37 state: &S,
38 ) -> Result<Self, TypedMultipartError>;
39}
40
41#[async_trait]
42impl<T, S> TryFromFieldWithState<S> for T
43where
44 T: TryFromField,
45{
46 async fn try_from_field_with_state(
47 field: Field<'_>,
48 limit_bytes: Option<usize>,
49 _state: &S,
50 ) -> Result<Self, TypedMultipartError> {
51 T::try_from_field(field, limit_bytes).await
52 }
53}
54
55#[async_trait]
56impl<T> TryFromField for T
57where
58 T: TryFromChunks + Send + Sync,
59{
60 async fn try_from_field(
61 field: Field<'_>,
62 limit_bytes: Option<usize>,
63 ) -> Result<Self, TypedMultipartError> {
64 let metadata = FieldMetadata::from(&field);
65 let mut field_name = metadata.name.clone().unwrap_or(String::new());
66 let mut size_bytes = 0;
67
68 let chunks = field.map_err(TypedMultipartError::from).map(|chunk| {
69 if let Ok(chunk) = chunk.as_ref() {
70 size_bytes += chunk.len();
71
72 if let Some(limit_bytes) = limit_bytes {
73 if size_bytes > limit_bytes {
74 return Err(TypedMultipartError::FieldTooLarge {
75 field_name: mem::take(&mut field_name),
76 limit_bytes,
77 });
78 }
79 }
80 }
81
82 chunk
83 });
84
85 T::try_from_chunks(chunks, metadata).await
86 }
87}
88
89#[cfg(test)]
90#[cfg_attr(all(coverage_nightly, test), coverage(off))]
91mod tests {
92 use super::*;
93 use axum::extract::Multipart;
94 use axum::routing::post;
95 use axum::Router;
96 use axum_test_helper::TestClient;
97 use futures_core::Stream;
98 use reqwest::multipart::Form;
99 use std::borrow::Cow;
100
101 #[derive(Debug)]
102 struct Data(String);
103
104 #[async_trait]
105 impl TryFromChunks for Data {
106 async fn try_from_chunks(
107 chunks: impl Stream<Item = Result<bytes::Bytes, TypedMultipartError>> + Send + Sync + Unpin,
108 metadata: FieldMetadata,
109 ) -> Result<Self, TypedMultipartError> {
110 let data = String::try_from_chunks(chunks, metadata).await?;
111 Ok(Self(data))
112 }
113 }
114
115 async fn test_try_from_field<T, F>(input: T, validator: F)
116 where
117 T: Into<Cow<'static, str>>,
118 F: FnOnce(Result<Data, TypedMultipartError>) + Clone + Send + Sync + 'static,
119 {
120 let handler = |mut multipart: Multipart| async move {
121 let field = multipart.next_field().await.unwrap().unwrap();
122 let res = Data::try_from_field(field, Some(512)).await;
123 validator(res);
124 };
125
126 TestClient::new(Router::new().route("/", post(handler)))
127 .post("/")
128 .multipart(Form::new().text("data", input))
129 .send()
130 .await
131 .unwrap();
132 }
133
134 #[tokio::test]
135 async fn test_try_from_field_valid() {
136 let validator = |res: Result<Data, TypedMultipartError>| {
137 assert_eq!(res.unwrap().0, "Hello, world!");
138 };
139 test_try_from_field("Hello, world!", validator).await;
140 }
141
142 #[tokio::test]
143 async fn test_try_from_too_large() {
144 let validator = |res: Result<Data, TypedMultipartError>| {
145 assert!(matches!(res, Err(TypedMultipartError::FieldTooLarge { .. })));
146 };
147 test_try_from_field("x".repeat(513), validator).await;
148 }
149}
150
151#[cfg(test)]
152#[cfg_attr(all(coverage_nightly, test), coverage(off))]
153mod tests_with_state {
154 use super::*;
155 use axum::extract::Multipart;
156 use axum::routing::post;
157 use axum::Router;
158 use axum_test_helper::TestClient;
159 use reqwest::multipart::Form;
160
161 #[derive(Clone)]
162 struct State(String);
163
164 struct DataWithState(String);
165
166 #[async_trait]
167 impl TryFromFieldWithState<State> for DataWithState {
168 async fn try_from_field_with_state(
169 field: Field<'_>,
170 limit_bytes: Option<usize>,
171 state: &State,
172 ) -> Result<Self, TypedMultipartError> {
173 let data = String::try_from_field(field, limit_bytes).await?;
174 Ok(Self(format!("{}, {}", state.0, data)))
175 }
176 }
177
178 #[tokio::test]
179 async fn test_try_from_field_with_state() {
180 let handler = |mut multipart: Multipart| async move {
181 let field = multipart.next_field().await.unwrap().unwrap();
182 let state = State("Hello".to_string());
183 let res = DataWithState::try_from_field_with_state(field, Some(512), &state).await;
184 assert_eq!(res.unwrap().0, "Hello, world!");
185 };
186 TestClient::new(Router::new().route("/", post(handler)))
187 .post("/")
188 .multipart(Form::new().text("data", "world!"))
189 .send()
190 .await
191 .unwrap();
192 }
193}