axum_content_negotiation/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{
4    future::Future,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use axum::{
11    body::Bytes,
12    extract::{FromRequest, Request},
13    http::{
14        header::{HeaderValue, ACCEPT, CONTENT_LENGTH, CONTENT_TYPE},
15        StatusCode,
16    },
17    response::{IntoResponse, Response},
18    Extension,
19};
20use tower::Service;
21
22#[cfg(all(feature = "json", feature = "simd-json"))]
23compile_error!("json and simd-json features are mutually exclusive");
24#[cfg(all(feature = "default-json", feature = "default-cbor"))]
25compile_error!("default-json and default-cbor features are mutually exclusive");
26
27#[cfg(feature = "default-json")]
28/// Default to application/json if the request does not have any accept header or uses */* when json is enabled
29static DEFAULT_CONTENT_TYPE_VALUE: &str = "application/json";
30
31#[cfg(feature = "default-cbor")]
32/// Default to application/cbor if the request does not have any accept header or uses */* when json is not enabled
33static DEFAULT_CONTENT_TYPE_VALUE: &str = "application/cbor";
34
35#[cfg(not(any(feature = "default-json", feature = "default-cbor")))]
36compile_error!("A default-* feature must be enabled for fallback encoding");
37
38static DEFAULT_CONTENT_TYPE: HeaderValue = HeaderValue::from_static(DEFAULT_CONTENT_TYPE_VALUE);
39
40static MALFORMED_RESPONSE: (StatusCode, &str) = (StatusCode::BAD_REQUEST, "Malformed request body");
41
42/// Used either as an [Extract](axum::extract::FromRequest) or [Response](axum::response::IntoResponse) to negotiate the serialization format used.
43///
44/// When used as an [Extract](axum::extract::FromRequest), it will attempt to deserialize the request body into the target type based on the `Content-Type` header.
45/// When used as a [Response](axum::response::IntoResponse), it will attempt to serialize the target type into the response body based on the `Accept` header.
46///
47/// For the [Response](axum::response::IntoResponse) case, the [`NegotiateLayer`] must be used to wrap the service in order to acctually perform the serialization.
48/// If the [Layer](tower::Layer) is not used, the response will be an 415 Unsupported Media Type error.
49///
50/// ## Example
51///
52/// ```rust
53/// use axum_content_negotiation::Negotiate;
54///
55/// #[derive(serde::Serialize, serde::Deserialize)]
56/// struct Example {
57///    message: String,
58/// }
59///
60/// async fn handler(
61///    Negotiate(input): Negotiate<Example>
62/// ) -> impl axum::response::IntoResponse {
63///   Negotiate(Example {
64///     message: format!("Hello, {}!", input.message)
65///   })
66/// }
67/// ```
68#[derive(Debug, Clone)]
69pub struct Negotiate<T>(
70    /// The stored content to be serialized/deserialized
71    pub T,
72);
73
74/// [Negotiate] implements [`FromRequest`] if the target type is deserializable.
75///
76/// It will attempt to deserialize the request body based on the `Content-Type` header.
77/// If the `Content-Type` header is not supported, it will return a 415 Unsupported Media Type response without running the handler.
78impl<T, S> FromRequest<S> for Negotiate<T>
79where
80    T: serde::de::DeserializeOwned,
81    S: Send + Sync,
82{
83    type Rejection = Response;
84
85    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
86        let content_type = req
87            .headers()
88            .get(CONTENT_TYPE)
89            .and_then(|h| h.to_str().ok())
90            .unwrap_or(DEFAULT_CONTENT_TYPE_VALUE);
91
92        // The header may include a charset or other info after `;`, if so, ignore it
93        let content_type = content_type
94            .split(';')
95            .next()
96            .map(str::trim)
97            .unwrap_or_default();
98
99        match content_type {
100            #[cfg(feature = "simd-json")]
101            "application/json" => {
102                let mut body = Bytes::from_request(req, state)
103                    .await
104                    .map_err(|e| {
105                        tracing::error!(error = %e, "failed to ready request body as bytes");
106                        e.into_response()
107                    })?
108                    .to_vec();
109
110                let body = simd_json::from_slice(&mut body).map_err(|e| {
111                    tracing::error!(error = %e, "failed to deserialize request body as json");
112                    MALFORMED_RESPONSE.into_response()
113                })?;
114
115                Ok(Self(body))
116            }
117            #[cfg(feature = "json")]
118            "application/json" => {
119                let body = Bytes::from_request(req, state).await.map_err(|e| {
120                    tracing::error!(error = %e, "failed to ready request body as bytes");
121                    e.into_response()
122                })?;
123
124                let body = serde_json::from_slice(&body).map_err(|e| {
125                    tracing::error!(error = %e, "failed to deserialize request body as json");
126                    MALFORMED_RESPONSE.into_response()
127                })?;
128
129                Ok(Self(body))
130            }
131
132            #[cfg(feature = "cbor")]
133            "application/cbor" => {
134                let body = Bytes::from_request(req, state).await.map_err(|e| {
135                    tracing::error!(error = %e, "failed to ready request body as bytes");
136                    e.into_response()
137                })?;
138
139                let body = cbor4ii::serde::from_slice(&body).map_err(|e| {
140                    tracing::error!(error = %e, "failed to deserialize request body as json");
141                    MALFORMED_RESPONSE.into_response()
142                })?;
143
144                Ok(Self(body))
145            }
146
147            _ => {
148                tracing::error!("unsupported content-type header: {:?}", content_type);
149                Err((
150                    StatusCode::NOT_ACCEPTABLE,
151                    "Invalid content type on request",
152                )
153                    .into_response())
154            }
155        }
156    }
157}
158
159/// Internal Negotiate object without the type parameter explicitly, in order to be able retrieve it as an extension on the [Layer](tower::Layer) response processing.
160///
161/// Considering [Extension]s are type safe, and we don't know ahead of time the type of the stored content, we must store it erased to dynamically dispatch for serialization latter.
162#[derive(Clone)]
163struct ErasedNegotiate(Arc<Box<dyn erased_serde::Serialize + Send + Sync>>);
164
165impl<T> From<T> for ErasedNegotiate
166where
167    T: serde::Serialize + Send + Sync + 'static,
168{
169    fn from(value: T) -> Self {
170        Self(Arc::new(Box::from(value)))
171    }
172}
173
174/// [Negotiate] implements [`IntoResponse`] if the internal content is serialiazable.
175///
176/// It will return convert it to a 415 Unsupported Media Type by default, which will be converted to the right response status on the [`NegotiateLayer`].
177impl<T> IntoResponse for Negotiate<T>
178where
179    T: serde::Serialize + Send + Sync + 'static,
180{
181    fn into_response(self) -> Response {
182        let data: ErasedNegotiate = self.0.into();
183        (
184            StatusCode::UNSUPPORTED_MEDIA_TYPE,
185            Extension(data),
186            "Misconfigured service layer",
187        )
188            .into_response()
189    }
190}
191
192/// Layer responsible to convert a [Negotiate] response into the right serialization format based on the `Accept` header.
193///
194/// If the `Accept` header is not supported, it will return a 406 Not Acceptable response without running the handler.
195#[derive(Clone)]
196pub struct NegotiateLayer;
197
198impl<S> tower::Layer<S> for NegotiateLayer {
199    type Service = NegotiateService<S>;
200
201    fn layer(&self, inner: S) -> Self::Service {
202        NegotiateService(inner)
203    }
204}
205
206trait SupportedEncodingExt {
207    fn supported_encoding(&self) -> Option<&'static str>;
208}
209
210impl SupportedEncodingExt for &[u8] {
211    fn supported_encoding(&self) -> Option<&'static str> {
212        match *self {
213            #[cfg(any(feature = "simd-json", feature = "json"))]
214            b"application/json" => Some("application/json"),
215            #[cfg(feature = "cbor")]
216            b"application/cbor" => Some("application/cbor"),
217            b"*/*" => Some(DEFAULT_CONTENT_TYPE_VALUE),
218            _ => None,
219        }
220    }
221}
222
223trait AcceptExt {
224    fn negotiate(&self) -> Option<&'static str>;
225}
226
227impl AcceptExt for axum::http::HeaderMap {
228    fn negotiate(&self) -> Option<&'static str> {
229        let accept = self.get(ACCEPT).unwrap_or(&DEFAULT_CONTENT_TYPE);
230        let precise_mime = accept.as_bytes().supported_encoding();
231
232        // Avoid iterations and splits if it's an exact match
233        if precise_mime.is_some() {
234            return precise_mime;
235        }
236
237        accept
238            .to_str()
239            .ok()?
240            .split(',')
241            .map(str::trim)
242            .filter_map(|s| {
243                let mut segments = s.split(';').map(str::trim);
244                let mime = segments.next().unwrap_or(s);
245
246                // See if it's a type we support
247                let mime_type = mime.as_bytes().supported_encoding()?;
248
249                // If we support it, parse or default the q value
250                let q = segments
251                    .find_map(|s| {
252                        let value = s.strip_prefix("q=")?;
253                        Some(value.parse::<f32>().unwrap_or(0.0))
254                    })
255                    .unwrap_or(1.0);
256                Some((mime_type, q))
257            })
258            .min_by(|(_, a), (_, b)| b.total_cmp(a))
259            .map(|(mime, _)| mime)
260    }
261}
262
263/// Serialize the stored [Extension] struct defined by a [Negotiate] into the right serialization format based on the `Accept` header.
264#[derive(Clone)]
265pub struct NegotiateService<S>(S);
266
267impl<T> Service<Request> for NegotiateService<T>
268where
269    T: Service<Request>,
270    T::Response: IntoResponse,
271    T::Future: Send + 'static,
272{
273    type Response = axum::response::Response;
274    type Error = T::Error;
275    type Future =
276        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
277
278    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279        self.0.poll_ready(cx)
280    }
281
282    fn call(&mut self, request: Request) -> Self::Future {
283        let accept = request.headers().negotiate();
284
285        let Some(encoding) = accept else {
286            return Box::pin(async move {
287                let response: Response = (
288                    StatusCode::NOT_ACCEPTABLE,
289                    "Invalid content type on request",
290                )
291                    .into_response();
292                Ok(response)
293            });
294        };
295
296        let future = self.0.call(request);
297
298        Box::pin(async move {
299            let inner_service = future.await?;
300            let response: Response = inner_service.into_response();
301            let data = response.extensions().get::<ErasedNegotiate>();
302
303            let Some(ErasedNegotiate(payload)) = data else {
304                return Ok(response);
305            };
306
307            let body = match encoding {
308                #[cfg(any(feature = "simd-json", feature = "json"))]
309                "application/json" => {
310                    let mut body = Vec::new();
311                    {
312                        let mut serializer = serde_json::Serializer::new(&mut body);
313                        let mut serializer = <dyn erased_serde::Serializer>::erase(&mut serializer);
314                        if let Err(e) = payload.erased_serialize(&mut serializer) {
315                            tracing::error!(error = %e, "failed to deserialize request body as json");
316
317                            let response: Response = (
318                                StatusCode::INTERNAL_SERVER_ERROR,
319                                "Failed to serialize response",
320                            )
321                                .into_response();
322                            return Ok(response);
323                        }
324                    }
325                    body
326                }
327                #[cfg(feature = "cbor")]
328                "application/cbor" => {
329                    let mut body = cbor4ii::core::utils::BufWriter::new(Vec::new());
330                    {
331                        let mut serializer = cbor4ii::serde::Serializer::new(&mut body);
332                        let mut serializer = <dyn erased_serde::Serializer>::erase(&mut serializer);
333                        if let Err(e) = payload.erased_serialize(&mut serializer) {
334                            tracing::error!(error = %e, "failed to deserialize request body as cbor");
335
336                            let response: Response = (
337                                StatusCode::INTERNAL_SERVER_ERROR,
338                                "Failed to serialize response",
339                            )
340                                .into_response();
341                            return Ok(response);
342                        }
343                    }
344                    body.into_inner()
345                }
346                _ => vec![],
347            };
348
349            let (mut parts, _) = response.into_parts();
350            if parts.status == StatusCode::UNSUPPORTED_MEDIA_TYPE {
351                parts.status = StatusCode::OK;
352            }
353            parts
354                .headers
355                .insert(CONTENT_TYPE, HeaderValue::from_static(encoding));
356            parts.headers.remove(CONTENT_LENGTH);
357
358            Ok(Response::from_parts(parts, body.into()))
359        })
360    }
361}
362
363#[cfg(test)]
364mod test {
365    use crate::Negotiate;
366
367    use axum::{
368        body::Body,
369        http::{
370            header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE},
371            Request, StatusCode,
372        },
373        response::IntoResponse,
374        routing::post,
375        Router,
376    };
377    use http_body_util::BodyExt;
378    use tower::ServiceExt;
379
380    use crate::NegotiateLayer;
381
382    #[derive(Debug, serde::Serialize, serde::Deserialize)]
383    struct Example {
384        message: String,
385    }
386
387    fn content_length(headers: &axum::http::HeaderMap) -> usize {
388        headers
389            .get(CONTENT_LENGTH)
390            .map(|v| v.to_str().unwrap().parse::<usize>().unwrap())
391            .unwrap()
392    }
393
394    mod general {
395        use super::*;
396
397        #[cfg(feature = "cbor")]
398        pub fn expected_cbor_body() -> Vec<u8> {
399            use cbor4ii::core::{enc::Encode, utils::BufWriter, Value};
400
401            let mut writer = BufWriter::new(Vec::new());
402            Value::Map(vec![(
403                Value::Text("message".to_string()),
404                Value::Text("Hello, test!".to_string()),
405            )])
406            .encode(&mut writer)
407            .unwrap();
408            writer.into_inner()
409        }
410
411        mod input {
412            use super::*;
413
414            #[tokio::test]
415            async fn test_does_not_process_handler_if_content_type_is_not_supported() {
416                #[axum::debug_handler]
417                async fn handler(_: Negotiate<Example>) -> impl IntoResponse {
418                    unimplemented!("This should not be called");
419                    #[allow(unreachable_code)]
420                    ()
421                }
422
423                let app = Router::new()
424                    .route("/", post(handler))
425                    .layer(NegotiateLayer);
426
427                let response = app
428                    .oneshot(
429                        Request::builder()
430                            .uri("/")
431                            .header(CONTENT_TYPE, "non-supported")
432                            .method("POST")
433                            .body(Body::from("really-cool-format"))
434                            .unwrap(),
435                    )
436                    .await
437                    .unwrap();
438
439                assert_eq!(response.status(), 406);
440                assert_eq!(
441                    response.into_body().collect().await.unwrap().to_bytes(),
442                    "Invalid content type on request"
443                );
444            }
445        }
446
447        mod output {
448            use super::*;
449
450            #[tokio::test]
451            async fn test_inform_error_when_misconfigured() {
452                #[axum::debug_handler]
453                async fn handler() -> impl IntoResponse {
454                    Negotiate(Example {
455                        message: "Hello, test!".to_string(),
456                    })
457                }
458
459                let app = Router::new().route("/", post(handler));
460
461                let response = app
462                    .oneshot(
463                        Request::builder()
464                            .uri("/")
465                            .method("POST")
466                            .body(Body::empty())
467                            .unwrap(),
468                    )
469                    .await
470                    .unwrap();
471
472                assert_eq!(response.status(), 415);
473                assert_eq!(
474                    response.into_body().collect().await.unwrap().to_bytes(),
475                    "Misconfigured service layer"
476                );
477            }
478
479            #[tokio::test]
480            async fn test_does_not_process_handler_if_accept_is_not_supported() {
481                #[axum::debug_handler]
482                async fn handler() -> impl IntoResponse {
483                    unimplemented!("This should not be called");
484                    #[allow(unreachable_code)]
485                    ()
486                }
487
488                let app = Router::new()
489                    .route("/", post(handler))
490                    .layer(NegotiateLayer);
491
492                let response = app
493                    .oneshot(
494                        Request::builder()
495                            .uri("/")
496                            .header(ACCEPT, "non-supported")
497                            .method("POST")
498                            .body(Body::empty())
499                            .unwrap(),
500                    )
501                    .await
502                    .unwrap();
503
504                assert_eq!(response.status(), 406);
505                assert_eq!(
506                    response.into_body().collect().await.unwrap().to_bytes(),
507                    "Invalid content type on request"
508                );
509            }
510        }
511    }
512
513    #[cfg(any(feature = "simd-json", feature = "json"))]
514    mod json {
515        use serde_json::json;
516
517        use super::*;
518
519        mod input {
520            use super::*;
521
522            #[cfg(feature = "default-json")]
523            #[tokio::test]
524            async fn test_can_read_input_without_content_type_by_default() {
525                #[axum::debug_handler]
526                async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
527                    format!("Hello, {}!", input.message)
528                }
529
530                let app = Router::new().route("/", post(handler));
531
532                let response = app
533                    .oneshot(
534                        Request::builder()
535                            .uri("/")
536                            .method("POST")
537                            .body(json!({ "message": "test" }).to_string())
538                            .unwrap(),
539                    )
540                    .await
541                    .unwrap();
542
543                assert_eq!(response.status(), 200);
544                assert_eq!(
545                    response.into_body().collect().await.unwrap().to_bytes(),
546                    "Hello, test!"
547                );
548            }
549
550            #[tokio::test]
551            async fn test_can_read_input_with_specified_header() {
552                #[axum::debug_handler]
553                async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
554                    format!("Hello, {}!", input.message)
555                }
556
557                let app = Router::new().route("/", post(handler));
558
559                let response = app
560                    .oneshot(
561                        Request::builder()
562                            .uri("/")
563                            .header(CONTENT_TYPE, "application/json")
564                            .method("POST")
565                            .body(json!({ "message": "test" }).to_string())
566                            .unwrap(),
567                    )
568                    .await
569                    .unwrap();
570
571                assert_eq!(response.status(), 200);
572                assert_eq!(
573                    response.into_body().collect().await.unwrap().to_bytes(),
574                    "Hello, test!"
575                );
576            }
577
578            #[tokio::test]
579            async fn test_can_read_input_with_charset_in_header() {
580                #[axum::debug_handler]
581                async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
582                    format!("Hello, {}!", input.message)
583                }
584
585                let app = Router::new().route("/", post(handler));
586
587                let response = app
588                    .oneshot(
589                        Request::builder()
590                            .uri("/")
591                            .header(CONTENT_TYPE, "application/json;    charset=utf-8")
592                            .method("POST")
593                            .body(json!({ "message": "test" }).to_string())
594                            .unwrap(),
595                    )
596                    .await
597                    .unwrap();
598
599                assert_eq!(response.status(), 200);
600                assert_eq!(
601                    response.into_body().collect().await.unwrap().to_bytes(),
602                    "Hello, test!"
603                );
604            }
605
606            #[tokio::test]
607            async fn test_does_not_accept_invalid_inputs() {
608                #[axum::debug_handler]
609                async fn handler(_: Negotiate<Example>) -> impl IntoResponse {
610                    unimplemented!("This should not be called");
611                    #[allow(unreachable_code)]
612                    ()
613                }
614
615                let app = Router::new()
616                    .route("/", post(handler))
617                    .layer(NegotiateLayer);
618
619                let response = app
620                    .oneshot(
621                        Request::builder()
622                            .uri("/")
623                            .method("POST")
624                            .header(CONTENT_TYPE, "application/json")
625                            .body(json!({ "not": true }).to_string())
626                            .unwrap(),
627                    )
628                    .await
629                    .unwrap();
630
631                assert_eq!(response.status(), 400);
632                assert_eq!(
633                    response.into_body().collect().await.unwrap().to_bytes(),
634                    "Malformed request body"
635                );
636            }
637        }
638
639        mod output {
640            use super::*;
641
642            #[tokio::test]
643            async fn test_encode_as_requested() {
644                #[axum::debug_handler]
645                async fn handler() -> impl IntoResponse {
646                    Negotiate(Example {
647                        message: "Hello, test!".to_string(),
648                    })
649                }
650
651                let app = Router::new()
652                    .route("/", post(handler))
653                    .layer(NegotiateLayer);
654
655                let response = app
656                    .oneshot(
657                        Request::builder()
658                            .uri("/")
659                            .method("POST")
660                            .header(ACCEPT, "application/json")
661                            .body(Body::empty())
662                            .unwrap(),
663                    )
664                    .await
665                    .unwrap();
666
667                let expected_body = json!({ "message": "Hello, test!" }).to_string();
668
669                assert_eq!(response.status(), 200);
670                assert_eq!(
671                    response.headers().get(CONTENT_TYPE).unwrap(),
672                    "application/json"
673                );
674                assert_eq!(content_length(response.headers()), expected_body.len());
675                assert_eq!(
676                    response.into_body().collect().await.unwrap().to_bytes(),
677                    expected_body,
678                );
679            }
680
681            #[tokio::test]
682            async fn test_encode_as_requested_multi() {
683                #[axum::debug_handler]
684                async fn handler() -> impl IntoResponse {
685                    Negotiate(Example {
686                        message: "Hello, test!".to_string(),
687                    })
688                }
689
690                let app = Router::new()
691                    .route("/", post(handler))
692                    .layer(NegotiateLayer);
693
694                let response = app
695                    .oneshot(
696                        Request::builder()
697                            .uri("/")
698                            .method("POST")
699                            .header(ACCEPT, "not-supported, application/json;q=5,something-else")
700                            .body(Body::empty())
701                            .unwrap(),
702                    )
703                    .await
704                    .unwrap();
705
706                let expected_body = json!({ "message": "Hello, test!" }).to_string();
707
708                assert_eq!(response.status(), 200);
709                assert_eq!(
710                    response.headers().get(CONTENT_TYPE).unwrap(),
711                    "application/json"
712                );
713                assert_eq!(content_length(response.headers()), expected_body.len());
714                assert_eq!(
715                    response.into_body().collect().await.unwrap().to_bytes(),
716                    expected_body,
717                );
718            }
719
720            #[cfg(feature = "cbor")]
721            #[tokio::test]
722            async fn test_encode_as_requested_multi_w_q() {
723                #[axum::debug_handler]
724                async fn handler() -> impl IntoResponse {
725                    Negotiate(Example {
726                        message: "Hello, test!".to_string(),
727                    })
728                }
729
730                let app = Router::new()
731                    .route("/", post(handler))
732                    .layer(NegotiateLayer);
733
734                let response = app
735                    .oneshot(
736                        Request::builder()
737                            .uri("/")
738                            .method("POST")
739                            .header(
740                                ACCEPT,
741                                "application/json;q=0.8;other;stuff,application/cbor;q=0.9",
742                            )
743                            .body(Body::empty())
744                            .unwrap(),
745                    )
746                    .await
747                    .unwrap();
748
749                assert_eq!(response.status(), 200);
750                assert_eq!(
751                    response.headers().get(CONTENT_TYPE).unwrap(),
752                    "application/cbor"
753                );
754            }
755
756            #[cfg(feature = "cbor")]
757            #[tokio::test]
758            async fn test_encode_as_requested_multi_w_q_same_weights() {
759                #[axum::debug_handler]
760                async fn handler() -> impl IntoResponse {
761                    Negotiate(Example {
762                        message: "Hello, test!".to_string(),
763                    })
764                }
765
766                let app = Router::new()
767                    .route("/", post(handler))
768                    .layer(NegotiateLayer);
769
770                let response = app
771                    .oneshot(
772                        Request::builder()
773                            .uri("/")
774                            .method("POST")
775                            .header(
776                                ACCEPT,
777                                "application/cbor;q=0.9,application/json;q=0.9;other;stuff",
778                            )
779                            .body(Body::empty())
780                            .unwrap(),
781                    )
782                    .await
783                    .unwrap();
784
785                assert_eq!(response.status(), 200);
786                assert_eq!(
787                    response.headers().get(CONTENT_TYPE).unwrap(),
788                    "application/cbor"
789                );
790            }
791
792            #[cfg(feature = "default-json")]
793            #[tokio::test]
794            async fn test_use_default_encoding_without_headers() {
795                #[axum::debug_handler]
796                async fn handler() -> impl IntoResponse {
797                    Negotiate(Example {
798                        message: "Hello, test!".to_string(),
799                    })
800                }
801
802                let app = Router::new()
803                    .route("/", post(handler))
804                    .layer(NegotiateLayer);
805
806                let response = app
807                    .oneshot(
808                        Request::builder()
809                            .uri("/")
810                            .method("POST")
811                            .body(Body::empty())
812                            .unwrap(),
813                    )
814                    .await
815                    .unwrap();
816
817                assert_eq!(response.status(), 200);
818                assert_eq!(
819                    response.headers().get(CONTENT_TYPE).unwrap(),
820                    "application/json"
821                );
822                assert_eq!(
823                    response.into_body().collect().await.unwrap().to_bytes(),
824                    json!({ "message": "Hello, test!" }).to_string()
825                );
826            }
827
828            #[tokio::test]
829            async fn test_retain_handler_status_code() {
830                #[axum::debug_handler]
831                async fn handler() -> impl IntoResponse {
832                    (
833                        StatusCode::CREATED,
834                        Negotiate(Example {
835                            message: "Hello, test!".to_string(),
836                        }),
837                    )
838                }
839
840                let app = Router::new()
841                    .route("/", post(handler))
842                    .layer(NegotiateLayer);
843
844                let response = app
845                    .oneshot(
846                        Request::builder()
847                            .uri("/")
848                            .method("POST")
849                            .body(Body::empty())
850                            .unwrap(),
851                    )
852                    .await
853                    .unwrap();
854
855                assert_eq!(response.status(), StatusCode::CREATED);
856                #[cfg(feature = "default-json")]
857                assert_eq!(
858                    response.headers().get(CONTENT_TYPE).unwrap(),
859                    "application/json"
860                );
861                #[cfg(feature = "default-json")]
862                assert_eq!(
863                    response.into_body().collect().await.unwrap().to_bytes(),
864                    json!({ "message": "Hello, test!" }).to_string()
865                );
866                #[cfg(feature = "default-cbor")]
867                assert_eq!(
868                    response.headers().get(CONTENT_TYPE).unwrap(),
869                    "application/cbor"
870                );
871                #[cfg(feature = "default-cbor")]
872                assert_eq!(
873                    response.into_body().collect().await.unwrap().to_bytes(),
874                    general::expected_cbor_body()
875                );
876            }
877        }
878    }
879
880    #[cfg(feature = "cbor")]
881    mod cbor {
882        use cbor4ii::core::{enc::Encode, utils::BufWriter, Value};
883
884        use super::*;
885
886        mod input {
887            use super::*;
888
889            #[cfg(feature = "default-cbor")]
890            #[tokio::test]
891            async fn test_can_read_input_without_content_type_by_default() {
892                #[axum::debug_handler]
893                async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
894                    format!("Hello, {}!", input.message)
895                }
896
897                let app = Router::new().route("/", post(handler));
898                let body = {
899                    let mut writer = BufWriter::new(Vec::new());
900                    Value::Map(vec![(
901                        Value::Text("message".to_string()),
902                        Value::Text("test".to_string()),
903                    )])
904                    .encode(&mut writer)
905                    .unwrap();
906                    writer.into_inner()
907                };
908
909                let response = app
910                    .oneshot(
911                        Request::builder()
912                            .uri("/")
913                            .method("POST")
914                            .body(Body::from(body))
915                            .unwrap(),
916                    )
917                    .await
918                    .unwrap();
919
920                assert_eq!(response.status(), 200);
921                assert_eq!(
922                    response.into_body().collect().await.unwrap().to_bytes(),
923                    "Hello, test!"
924                );
925            }
926
927            #[tokio::test]
928            async fn test_can_read_input_with_specified_header() {
929                #[axum::debug_handler]
930                async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
931                    format!("Hello, {}!", input.message)
932                }
933
934                let app = Router::new().route("/", post(handler));
935                let body = {
936                    let mut writer = BufWriter::new(Vec::new());
937                    Value::Map(vec![(
938                        Value::Text("message".to_string()),
939                        Value::Text("test".to_string()),
940                    )])
941                    .encode(&mut writer)
942                    .unwrap();
943                    writer.into_inner()
944                };
945
946                let response = app
947                    .oneshot(
948                        Request::builder()
949                            .uri("/")
950                            .header(CONTENT_TYPE, "application/cbor")
951                            .method("POST")
952                            .body(Body::from(body))
953                            .unwrap(),
954                    )
955                    .await
956                    .unwrap();
957
958                assert_eq!(response.status(), 200);
959                assert_eq!(
960                    response.into_body().collect().await.unwrap().to_bytes(),
961                    "Hello, test!"
962                );
963            }
964        }
965
966        mod output {
967            use super::*;
968
969            #[tokio::test]
970            async fn test_encode_as_requested() {
971                #[axum::debug_handler]
972                async fn handler() -> impl IntoResponse {
973                    Negotiate(Example {
974                        message: "Hello, test!".to_string(),
975                    })
976                }
977
978                let app = Router::new()
979                    .route("/", post(handler))
980                    .layer(NegotiateLayer);
981
982                let response = app
983                    .oneshot(
984                        Request::builder()
985                            .uri("/")
986                            .method("POST")
987                            .header(ACCEPT, "application/cbor")
988                            .body(Body::empty())
989                            .unwrap(),
990                    )
991                    .await
992                    .unwrap();
993
994                let expected_body = general::expected_cbor_body();
995
996                assert_eq!(response.status(), 200);
997                assert_eq!(
998                    response.headers().get(CONTENT_TYPE).unwrap(),
999                    "application/cbor"
1000                );
1001                assert_eq!(content_length(response.headers()), expected_body.len());
1002                assert_eq!(
1003                    response.into_body().collect().await.unwrap().to_bytes(),
1004                    expected_body,
1005                );
1006            }
1007
1008            #[tokio::test]
1009            async fn test_encode_as_requested_multi() {
1010                #[axum::debug_handler]
1011                async fn handler() -> impl IntoResponse {
1012                    Negotiate(Example {
1013                        message: "Hello, test!".to_string(),
1014                    })
1015                }
1016
1017                let app = Router::new()
1018                    .route("/", post(handler))
1019                    .layer(NegotiateLayer);
1020
1021                let response = app
1022                    .oneshot(
1023                        Request::builder()
1024                            .uri("/")
1025                            .method("POST")
1026                            .header(ACCEPT, "something-else;q=0.5,application/cbor")
1027                            .body(Body::empty())
1028                            .unwrap(),
1029                    )
1030                    .await
1031                    .unwrap();
1032
1033                let expected_body = general::expected_cbor_body();
1034
1035                assert_eq!(response.status(), 200);
1036                assert_eq!(
1037                    response.headers().get(CONTENT_TYPE).unwrap(),
1038                    "application/cbor"
1039                );
1040                assert_eq!(content_length(response.headers()), expected_body.len());
1041                assert_eq!(
1042                    response.into_body().collect().await.unwrap().to_bytes(),
1043                    expected_body,
1044                );
1045            }
1046
1047            #[cfg(feature = "json")]
1048            #[tokio::test]
1049            async fn test_encode_as_requested_multi_without_q_using_default_weight() {
1050                #[axum::debug_handler]
1051                async fn handler() -> impl IntoResponse {
1052                    Negotiate(Example {
1053                        message: "Hello, test!".to_string(),
1054                    })
1055                }
1056
1057                let app = Router::new()
1058                    .route("/", post(handler))
1059                    .layer(NegotiateLayer);
1060
1061                let response = app
1062                    .oneshot(
1063                        Request::builder()
1064                            .uri("/")
1065                            .method("POST")
1066                            .header(ACCEPT, "application/cbor;q=0.2,application/json")
1067                            .body(Body::empty())
1068                            .unwrap(),
1069                    )
1070                    .await
1071                    .unwrap();
1072
1073                assert_eq!(response.status(), 200);
1074                assert_eq!(
1075                    response.headers().get(CONTENT_TYPE).unwrap(),
1076                    "application/json"
1077                );
1078            }
1079
1080            // Given equal q values, the first mime type should be selected
1081            #[cfg(feature = "json")]
1082            #[tokio::test]
1083            async fn test_encode_as_requested_equal_q() {
1084                #[axum::debug_handler]
1085                async fn handler() -> impl IntoResponse {
1086                    Negotiate(Example {
1087                        message: "Hello, test!".to_string(),
1088                    })
1089                }
1090
1091                let app = Router::new()
1092                    .route("/", post(handler))
1093                    .layer(NegotiateLayer);
1094
1095                let response = app
1096                    .oneshot(
1097                        Request::builder()
1098                            .uri("/")
1099                            .method("POST")
1100                            .header(ACCEPT, "application/cbor,application/json")
1101                            .body(Body::empty())
1102                            .unwrap(),
1103                    )
1104                    .await
1105                    .unwrap();
1106
1107                assert_eq!(response.status(), 200);
1108                assert_eq!(
1109                    response.headers().get(CONTENT_TYPE).unwrap(),
1110                    "application/cbor"
1111                );
1112            }
1113            // Given equal q values, the first mime type should be selected
1114            #[cfg(feature = "json")]
1115            #[tokio::test]
1116            async fn test_encode_as_requested_equal_q2() {
1117                #[axum::debug_handler]
1118                async fn handler() -> impl IntoResponse {
1119                    Negotiate(Example {
1120                        message: "Hello, test!".to_string(),
1121                    })
1122                }
1123
1124                let app = Router::new()
1125                    .route("/", post(handler))
1126                    .layer(NegotiateLayer);
1127
1128                let response = app
1129                    .oneshot(
1130                        Request::builder()
1131                            .uri("/")
1132                            .method("POST")
1133                            .header(ACCEPT, "application/json,application/cbor")
1134                            .body(Body::empty())
1135                            .unwrap(),
1136                    )
1137                    .await
1138                    .unwrap();
1139
1140                assert_eq!(response.status(), 200);
1141                assert_eq!(
1142                    response.headers().get(CONTENT_TYPE).unwrap(),
1143                    "application/json"
1144                );
1145            }
1146
1147            #[tokio::test]
1148            async fn test_retain_status_code() {
1149                #[axum::debug_handler]
1150                async fn handler() -> impl IntoResponse {
1151                    (
1152                        StatusCode::CREATED,
1153                        Negotiate(Example {
1154                            message: "Hello, test!".to_string(),
1155                        }),
1156                    )
1157                }
1158
1159                let app = Router::new()
1160                    .route("/", post(handler))
1161                    .layer(NegotiateLayer);
1162
1163                let response = app
1164                    .oneshot(
1165                        Request::builder()
1166                            .uri("/")
1167                            .method("POST")
1168                            .header(ACCEPT, "application/cbor")
1169                            .body(Body::empty())
1170                            .unwrap(),
1171                    )
1172                    .await
1173                    .unwrap();
1174
1175                assert_eq!(response.status(), StatusCode::CREATED);
1176                assert_eq!(
1177                    response.headers().get(CONTENT_TYPE).unwrap(),
1178                    "application/cbor"
1179                );
1180                assert_eq!(
1181                    response.into_body().collect().await.unwrap().to_bytes(),
1182                    general::expected_cbor_body()
1183                );
1184            }
1185
1186            #[cfg(feature = "default-cbor")]
1187            #[tokio::test]
1188            async fn test_default_encoding_without_header() {
1189                #[axum::debug_handler]
1190                async fn handler() -> impl IntoResponse {
1191                    (
1192                        StatusCode::CREATED,
1193                        Negotiate(Example {
1194                            message: "Hello, test!".to_string(),
1195                        }),
1196                    )
1197                }
1198
1199                let app = Router::new()
1200                    .route("/", post(handler))
1201                    .layer(NegotiateLayer);
1202
1203                let response = app
1204                    .oneshot(
1205                        Request::builder()
1206                            .uri("/")
1207                            .method("POST")
1208                            .body(Body::empty())
1209                            .unwrap(),
1210                    )
1211                    .await
1212                    .unwrap();
1213
1214                assert_eq!(response.status(), StatusCode::CREATED);
1215                assert_eq!(
1216                    response.headers().get(CONTENT_TYPE).unwrap(),
1217                    "application/cbor"
1218                );
1219                assert_eq!(
1220                    response.into_body().collect().await.unwrap().to_bytes(),
1221                    general::expected_cbor_body()
1222                );
1223            }
1224
1225            #[cfg(feature = "default-cbor")]
1226            #[tokio::test]
1227            async fn test_default_encoding_with_star() {
1228                #[axum::debug_handler]
1229                async fn handler() -> impl IntoResponse {
1230                    (
1231                        StatusCode::CREATED,
1232                        Negotiate(Example {
1233                            message: "Hello, test!".to_string(),
1234                        }),
1235                    )
1236                }
1237
1238                let app = Router::new()
1239                    .route("/", post(handler))
1240                    .layer(NegotiateLayer);
1241
1242                let response = app
1243                    .oneshot(
1244                        Request::builder()
1245                            .uri("/")
1246                            .method("POST")
1247                            .header(ACCEPT, "*/*")
1248                            .body(Body::empty())
1249                            .unwrap(),
1250                    )
1251                    .await
1252                    .unwrap();
1253
1254                assert_eq!(response.status(), StatusCode::CREATED);
1255                assert_eq!(
1256                    response.headers().get(CONTENT_TYPE).unwrap(),
1257                    "application/cbor"
1258                );
1259                assert_eq!(
1260                    response.into_body().collect().await.unwrap().to_bytes(),
1261                    general::expected_cbor_body()
1262                );
1263            }
1264        }
1265    }
1266}