1use core::{
2 fmt,
3 ops::{Deref, DerefMut},
4};
5
6use axum::{
7 extract::{FromRequest, FromRequestParts, Request},
8 http::header,
9 response::{IntoResponse, Response},
10};
11use bytes::{Bytes, BytesMut};
12
13use crate::{Accept, CodecDecode, CodecEncode, CodecRejection, ContentType, IntoCodecResponse};
14
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
50pub struct Codec<T>(pub T);
51
52impl<T> Codec<T> {
53 pub fn into_inner(self) -> T {
55 self.0
56 }
57}
58
59impl<T> Codec<T>
60where
61 T: CodecEncode,
62{
63 pub fn to_response<C: Into<ContentType>>(&self, content_type: C) -> Response {
68 let content_type = content_type.into();
69 let bytes = match self.to_bytes(content_type) {
70 Ok(bytes) => bytes,
71 Err(rejection) => return rejection.into_response(),
72 };
73
74 ([(header::CONTENT_TYPE, content_type.into_header())], bytes).into_response()
75 }
76}
77
78impl<T> Deref for Codec<T> {
79 type Target = T;
80
81 fn deref(&self) -> &Self::Target {
82 &self.0
83 }
84}
85
86impl<T> DerefMut for Codec<T> {
87 fn deref_mut(&mut self) -> &mut Self::Target {
88 &mut self.0
89 }
90}
91
92impl<T: fmt::Display> fmt::Display for Codec<T> {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 self.0.fmt(f)
95 }
96}
97
98impl<T, S> FromRequest<S> for Codec<T>
99where
100 T: for<'de> CodecDecode<'de>,
101 S: Send + Sync + 'static,
102{
103 type Rejection = Response;
104
105 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
106 let (mut parts, body) = req.into_parts();
107 let accept = Accept::from_request_parts(&mut parts, state).await.unwrap();
108
109 let req = Request::from_parts(parts, body);
110
111 let content_type = req
112 .headers()
113 .get(header::CONTENT_TYPE)
114 .and_then(ContentType::from_header)
115 .unwrap_or_default();
116
117 let data = match () {
118 #[cfg(feature = "form")]
119 () if content_type == ContentType::Form && req.method() == axum::http::Method::GET => {
120 let query = req.uri().query().unwrap_or("");
121
122 Codec::from_form(query.as_bytes()).map_err(CodecRejection::from)
123 }
124 () => {
125 let bytes = Bytes::from_request(req, state)
126 .await
127 .map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?;
128
129 Codec::from_bytes(&bytes, content_type)
130 }
131 }
132 .map_err(|e| e.into_codec_response(accept.into()))?;
133
134 Ok(data)
135 }
136}
137
138#[cfg(feature = "aide")]
139impl<T> aide::operation::OperationInput for Codec<T>
140where
141 T: schemars::JsonSchema,
142{
143 fn operation_input(
144 ctx: &mut aide::generate::GenContext,
145 operation: &mut aide::openapi::Operation,
146 ) {
147 axum::Json::<T>::operation_input(ctx, operation);
148 }
149
150 fn inferred_early_responses(
151 ctx: &mut aide::generate::GenContext,
152 operation: &mut aide::openapi::Operation,
153 ) -> Vec<(Option<u16>, aide::openapi::Response)> {
154 axum::Json::<T>::inferred_early_responses(ctx, operation)
155 }
156}
157
158#[cfg(feature = "aide")]
159impl<T> aide::operation::OperationOutput for Codec<T>
160where
161 T: schemars::JsonSchema,
162{
163 type Inner = T;
164
165 fn operation_response(
166 ctx: &mut aide::generate::GenContext,
167 operation: &mut aide::openapi::Operation,
168 ) -> Option<aide::openapi::Response> {
169 axum::Json::<T>::operation_response(ctx, operation)
170 }
171
172 fn inferred_responses(
173 ctx: &mut aide::generate::GenContext,
174 operation: &mut aide::openapi::Operation,
175 ) -> Vec<(Option<u16>, aide::openapi::Response)> {
176 axum::Json::<T>::inferred_responses(ctx, operation)
177 }
178}
179
180#[cfg(feature = "validator")]
181impl<T> validator::Validate for Codec<T>
182where
183 T: validator::Validate,
184{
185 fn validate(&self) -> Result<(), validator::ValidationErrors> {
186 self.0.validate()
187 }
188}
189
190pub struct BorrowCodec<T> {
225 data: T,
226 #[allow(dead_code)]
227 #[doc(hidden)]
228 bytes: BytesMut,
229}
230
231impl<T> BorrowCodec<T> {
232 pub unsafe fn as_mut_unchecked(&mut self) -> &mut T {
240 &mut self.data
241 }
242}
243
244impl<T> AsRef<T> for BorrowCodec<T> {
245 fn as_ref(&self) -> &T {
246 self
247 }
248}
249
250impl<T> Deref for BorrowCodec<T> {
251 type Target = T;
252
253 fn deref(&self) -> &Self::Target {
254 &self.data
255 }
256}
257
258impl<T> fmt::Debug for BorrowCodec<T>
259where
260 T: fmt::Debug,
261{
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 f.debug_struct("BorrowCodec")
264 .field("data", &self.data)
265 .finish_non_exhaustive()
266 }
267}
268
269impl<T> PartialEq for BorrowCodec<T>
270where
271 T: PartialEq,
272{
273 fn eq(&self, other: &Self) -> bool {
274 self.data == other.data
275 }
276}
277
278impl<T> Eq for BorrowCodec<T> where T: Eq {}
279
280impl<T> PartialOrd for BorrowCodec<T>
281where
282 T: PartialOrd,
283{
284 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
285 self.data.partial_cmp(&other.data)
286 }
287}
288
289impl<T> Ord for BorrowCodec<T>
290where
291 T: Ord,
292{
293 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
294 self.data.cmp(&other.data)
295 }
296}
297
298impl<T> std::hash::Hash for BorrowCodec<T>
299where
300 T: std::hash::Hash,
301{
302 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
303 self.data.hash(state);
304 }
305}
306
307impl<'de, T> BorrowCodec<T>
308where
309 T: CodecDecode<'de>,
310{
311 pub fn from_bytes(bytes: BytesMut, content_type: ContentType) -> Result<Self, CodecRejection> {
317 let data = Codec::<T>::from_bytes(
318 unsafe { std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) },
323 content_type,
324 )?
325 .into_inner();
326
327 Ok(Self { data, bytes })
328 }
329}
330
331impl<T, S> FromRequest<S> for BorrowCodec<T>
332where
333 T: CodecDecode<'static>,
334 S: Send + Sync + 'static,
335{
336 type Rejection = Response;
337
338 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
339 let (mut parts, body) = req.into_parts();
340 let accept = Accept::from_request_parts(&mut parts, state).await.unwrap();
341
342 let req = Request::from_parts(parts, body);
343
344 let content_type = req
345 .headers()
346 .get(header::CONTENT_TYPE)
347 .and_then(ContentType::from_header)
348 .unwrap_or_default();
349
350 let bytes = match () {
351 #[cfg(feature = "form")]
352 () if content_type == ContentType::Form && req.method() == axum::http::Method::GET => {
353 req.uri().query().map_or_else(BytesMut::new, BytesMut::from)
354 }
355 () => BytesMut::from_request(req, state)
356 .await
357 .map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?,
358 };
359
360 let data =
361 Self::from_bytes(bytes, content_type).map_err(|e| e.into_codec_response(accept.into()))?;
362
363 #[cfg(feature = "validator")]
364 data
365 .as_ref()
366 .validate()
367 .map_err(|e| CodecRejection::from(e).into_codec_response(accept.into()))?;
368
369 Ok(data)
370 }
371}
372
373#[cfg(feature = "aide")]
374impl<T> aide::operation::OperationInput for BorrowCodec<T>
375where
376 T: schemars::JsonSchema,
377{
378 fn operation_input(
379 ctx: &mut aide::generate::GenContext,
380 operation: &mut aide::openapi::Operation,
381 ) {
382 axum::Json::<T>::operation_input(ctx, operation);
383 }
384
385 fn inferred_early_responses(
386 ctx: &mut aide::generate::GenContext,
387 operation: &mut aide::openapi::Operation,
388 ) -> Vec<(Option<u16>, aide::openapi::Response)> {
389 axum::Json::<T>::inferred_early_responses(ctx, operation)
390 }
391}
392
393#[cfg(test)]
394mod test {
395 use super::{Codec, ContentType};
396
397 #[crate::apply(decode)]
398 #[derive(Debug, PartialEq, Eq)]
399 struct Data {
400 hello: String,
401 }
402
403 #[test]
404 fn test_json_codec() {
405 let bytes = b"{\"hello\": \"world\"}";
406
407 let Codec(data) = Codec::<Data>::from_bytes(bytes, ContentType::Json).unwrap();
408
409 assert_eq!(data, Data {
410 hello: "world".into()
411 });
412 }
413
414 #[test]
415 fn test_msgpack_codec() {
416 let bytes = b"\x81\xa5hello\xa5world";
417
418 let Codec(data) = Codec::<Data>::from_bytes(bytes, ContentType::MsgPack).unwrap();
419
420 assert_eq!(data, Data {
421 hello: "world".into()
422 });
423 }
424}
425
426#[cfg(any(test, miri))]
427mod miri {
428 use std::borrow::Cow;
429
430 use bytes::Bytes;
431
432 use super::*;
433
434 #[crate::apply(decode, crate = "crate")]
435 #[derive(Debug, PartialEq, Eq)]
436 struct BorrowData<'a> {
437 #[serde(borrow)]
438 hello: Cow<'a, str>,
439 }
440
441 #[test]
442 fn test_zero_copy() {
443 let bytes = b"{\"hello\": \"world\"}".to_vec();
444 let data =
445 BorrowCodec::<BorrowData>::from_bytes(BytesMut::from(Bytes::from(bytes)), ContentType::Json)
446 .unwrap();
447
448 assert_eq!(data.hello, Cow::Borrowed("world"));
449 }
450}