1#![doc = include_str!("../README.md")]
2
3use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use axum::{
11 body::Bytes,
12 extract::{FromRequest, Request},
13 http::{
14 header::{HeaderValue, ACCEPT, CONTENT_LENGTH, CONTENT_TYPE},
15 StatusCode,
16 },
17 response::{IntoResponse, Response},
18 Extension,
19};
20use tower::Service;
21
22#[cfg(all(feature = "json", feature = "simd-json"))]
23compile_error!("json and simd-json features are mutually exclusive");
24#[cfg(all(feature = "default-json", feature = "default-cbor"))]
25compile_error!("default-json and default-cbor features are mutually exclusive");
26
27#[cfg(feature = "default-json")]
28static DEFAULT_CONTENT_TYPE_VALUE: &str = "application/json";
30
31#[cfg(feature = "default-cbor")]
32static DEFAULT_CONTENT_TYPE_VALUE: &str = "application/cbor";
34
35#[cfg(not(any(feature = "default-json", feature = "default-cbor")))]
36compile_error!("A default-* feature must be enabled for fallback encoding");
37
38static DEFAULT_CONTENT_TYPE: HeaderValue = HeaderValue::from_static(DEFAULT_CONTENT_TYPE_VALUE);
39
40static MALFORMED_RESPONSE: (StatusCode, &str) = (StatusCode::BAD_REQUEST, "Malformed request body");
41
42#[derive(Debug, Clone)]
69pub struct Negotiate<T>(
70 pub T,
72);
73
74impl<T, S> FromRequest<S> for Negotiate<T>
79where
80 T: serde::de::DeserializeOwned,
81 S: Send + Sync,
82{
83 type Rejection = Response;
84
85 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
86 let content_type = req
87 .headers()
88 .get(CONTENT_TYPE)
89 .and_then(|h| h.to_str().ok())
90 .unwrap_or(DEFAULT_CONTENT_TYPE_VALUE);
91
92 let content_type = content_type
94 .split(';')
95 .next()
96 .map(str::trim)
97 .unwrap_or_default();
98
99 match content_type {
100 #[cfg(feature = "simd-json")]
101 "application/json" => {
102 let mut body = Bytes::from_request(req, state)
103 .await
104 .map_err(|e| {
105 tracing::error!(error = %e, "failed to ready request body as bytes");
106 e.into_response()
107 })?
108 .to_vec();
109
110 let body = simd_json::from_slice(&mut body).map_err(|e| {
111 tracing::error!(error = %e, "failed to deserialize request body as json");
112 MALFORMED_RESPONSE.into_response()
113 })?;
114
115 Ok(Self(body))
116 }
117 #[cfg(feature = "json")]
118 "application/json" => {
119 let body = Bytes::from_request(req, state).await.map_err(|e| {
120 tracing::error!(error = %e, "failed to ready request body as bytes");
121 e.into_response()
122 })?;
123
124 let body = serde_json::from_slice(&body).map_err(|e| {
125 tracing::error!(error = %e, "failed to deserialize request body as json");
126 MALFORMED_RESPONSE.into_response()
127 })?;
128
129 Ok(Self(body))
130 }
131
132 #[cfg(feature = "cbor")]
133 "application/cbor" => {
134 let body = Bytes::from_request(req, state).await.map_err(|e| {
135 tracing::error!(error = %e, "failed to ready request body as bytes");
136 e.into_response()
137 })?;
138
139 let body = cbor4ii::serde::from_slice(&body).map_err(|e| {
140 tracing::error!(error = %e, "failed to deserialize request body as json");
141 MALFORMED_RESPONSE.into_response()
142 })?;
143
144 Ok(Self(body))
145 }
146
147 _ => {
148 tracing::error!("unsupported content-type header: {:?}", content_type);
149 Err((
150 StatusCode::NOT_ACCEPTABLE,
151 "Invalid content type on request",
152 )
153 .into_response())
154 }
155 }
156 }
157}
158
159#[derive(Clone)]
163struct ErasedNegotiate(Arc<Box<dyn erased_serde::Serialize + Send + Sync>>);
164
165impl<T> From<T> for ErasedNegotiate
166where
167 T: serde::Serialize + Send + Sync + 'static,
168{
169 fn from(value: T) -> Self {
170 Self(Arc::new(Box::from(value)))
171 }
172}
173
174impl<T> IntoResponse for Negotiate<T>
178where
179 T: serde::Serialize + Send + Sync + 'static,
180{
181 fn into_response(self) -> Response {
182 let data: ErasedNegotiate = self.0.into();
183 (
184 StatusCode::UNSUPPORTED_MEDIA_TYPE,
185 Extension(data),
186 "Misconfigured service layer",
187 )
188 .into_response()
189 }
190}
191
192#[derive(Clone)]
196pub struct NegotiateLayer;
197
198impl<S> tower::Layer<S> for NegotiateLayer {
199 type Service = NegotiateService<S>;
200
201 fn layer(&self, inner: S) -> Self::Service {
202 NegotiateService(inner)
203 }
204}
205
206trait SupportedEncodingExt {
207 fn supported_encoding(&self) -> Option<&'static str>;
208}
209
210impl SupportedEncodingExt for &[u8] {
211 fn supported_encoding(&self) -> Option<&'static str> {
212 match *self {
213 #[cfg(any(feature = "simd-json", feature = "json"))]
214 b"application/json" => Some("application/json"),
215 #[cfg(feature = "cbor")]
216 b"application/cbor" => Some("application/cbor"),
217 b"*/*" => Some(DEFAULT_CONTENT_TYPE_VALUE),
218 _ => None,
219 }
220 }
221}
222
223trait AcceptExt {
224 fn negotiate(&self) -> Option<&'static str>;
225}
226
227impl AcceptExt for axum::http::HeaderMap {
228 fn negotiate(&self) -> Option<&'static str> {
229 let accept = self.get(ACCEPT).unwrap_or(&DEFAULT_CONTENT_TYPE);
230 let precise_mime = accept.as_bytes().supported_encoding();
231
232 if precise_mime.is_some() {
234 return precise_mime;
235 }
236
237 accept
238 .to_str()
239 .ok()?
240 .split(',')
241 .map(str::trim)
242 .filter_map(|s| {
243 let mut segments = s.split(';').map(str::trim);
244 let mime = segments.next().unwrap_or(s);
245
246 let mime_type = mime.as_bytes().supported_encoding()?;
248
249 let q = segments
251 .find_map(|s| {
252 let value = s.strip_prefix("q=")?;
253 Some(value.parse::<f32>().unwrap_or(0.0))
254 })
255 .unwrap_or(1.0);
256 Some((mime_type, q))
257 })
258 .min_by(|(_, a), (_, b)| b.total_cmp(a))
259 .map(|(mime, _)| mime)
260 }
261}
262
263#[derive(Clone)]
265pub struct NegotiateService<S>(S);
266
267impl<T> Service<Request> for NegotiateService<T>
268where
269 T: Service<Request>,
270 T::Response: IntoResponse,
271 T::Future: Send + 'static,
272{
273 type Response = axum::response::Response;
274 type Error = T::Error;
275 type Future =
276 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
277
278 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279 self.0.poll_ready(cx)
280 }
281
282 fn call(&mut self, request: Request) -> Self::Future {
283 let accept = request.headers().negotiate();
284
285 let Some(encoding) = accept else {
286 return Box::pin(async move {
287 let response: Response = (
288 StatusCode::NOT_ACCEPTABLE,
289 "Invalid content type on request",
290 )
291 .into_response();
292 Ok(response)
293 });
294 };
295
296 let future = self.0.call(request);
297
298 Box::pin(async move {
299 let inner_service = future.await?;
300 let response: Response = inner_service.into_response();
301 let data = response.extensions().get::<ErasedNegotiate>();
302
303 let Some(ErasedNegotiate(payload)) = data else {
304 return Ok(response);
305 };
306
307 let body = match encoding {
308 #[cfg(any(feature = "simd-json", feature = "json"))]
309 "application/json" => {
310 let mut body = Vec::new();
311 {
312 let mut serializer = serde_json::Serializer::new(&mut body);
313 let mut serializer = <dyn erased_serde::Serializer>::erase(&mut serializer);
314 if let Err(e) = payload.erased_serialize(&mut serializer) {
315 tracing::error!(error = %e, "failed to deserialize request body as json");
316
317 let response: Response = (
318 StatusCode::INTERNAL_SERVER_ERROR,
319 "Failed to serialize response",
320 )
321 .into_response();
322 return Ok(response);
323 }
324 }
325 body
326 }
327 #[cfg(feature = "cbor")]
328 "application/cbor" => {
329 let mut body = cbor4ii::core::utils::BufWriter::new(Vec::new());
330 {
331 let mut serializer = cbor4ii::serde::Serializer::new(&mut body);
332 let mut serializer = <dyn erased_serde::Serializer>::erase(&mut serializer);
333 if let Err(e) = payload.erased_serialize(&mut serializer) {
334 tracing::error!(error = %e, "failed to deserialize request body as cbor");
335
336 let response: Response = (
337 StatusCode::INTERNAL_SERVER_ERROR,
338 "Failed to serialize response",
339 )
340 .into_response();
341 return Ok(response);
342 }
343 }
344 body.into_inner()
345 }
346 _ => vec![],
347 };
348
349 let (mut parts, _) = response.into_parts();
350 if parts.status == StatusCode::UNSUPPORTED_MEDIA_TYPE {
351 parts.status = StatusCode::OK;
352 }
353 parts
354 .headers
355 .insert(CONTENT_TYPE, HeaderValue::from_static(encoding));
356 parts.headers.remove(CONTENT_LENGTH);
357
358 Ok(Response::from_parts(parts, body.into()))
359 })
360 }
361}
362
363#[cfg(test)]
364mod test {
365 use crate::Negotiate;
366
367 use axum::{
368 body::Body,
369 http::{
370 header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE},
371 Request, StatusCode,
372 },
373 response::IntoResponse,
374 routing::post,
375 Router,
376 };
377 use http_body_util::BodyExt;
378 use tower::ServiceExt;
379
380 use crate::NegotiateLayer;
381
382 #[derive(Debug, serde::Serialize, serde::Deserialize)]
383 struct Example {
384 message: String,
385 }
386
387 fn content_length(headers: &axum::http::HeaderMap) -> usize {
388 headers
389 .get(CONTENT_LENGTH)
390 .map(|v| v.to_str().unwrap().parse::<usize>().unwrap())
391 .unwrap()
392 }
393
394 mod general {
395 use super::*;
396
397 #[cfg(feature = "cbor")]
398 pub fn expected_cbor_body() -> Vec<u8> {
399 use cbor4ii::core::{enc::Encode, utils::BufWriter, Value};
400
401 let mut writer = BufWriter::new(Vec::new());
402 Value::Map(vec![(
403 Value::Text("message".to_string()),
404 Value::Text("Hello, test!".to_string()),
405 )])
406 .encode(&mut writer)
407 .unwrap();
408 writer.into_inner()
409 }
410
411 mod input {
412 use super::*;
413
414 #[tokio::test]
415 async fn test_does_not_process_handler_if_content_type_is_not_supported() {
416 #[axum::debug_handler]
417 async fn handler(_: Negotiate<Example>) -> impl IntoResponse {
418 unimplemented!("This should not be called");
419 #[allow(unreachable_code)]
420 ()
421 }
422
423 let app = Router::new()
424 .route("/", post(handler))
425 .layer(NegotiateLayer);
426
427 let response = app
428 .oneshot(
429 Request::builder()
430 .uri("/")
431 .header(CONTENT_TYPE, "non-supported")
432 .method("POST")
433 .body(Body::from("really-cool-format"))
434 .unwrap(),
435 )
436 .await
437 .unwrap();
438
439 assert_eq!(response.status(), 406);
440 assert_eq!(
441 response.into_body().collect().await.unwrap().to_bytes(),
442 "Invalid content type on request"
443 );
444 }
445 }
446
447 mod output {
448 use super::*;
449
450 #[tokio::test]
451 async fn test_inform_error_when_misconfigured() {
452 #[axum::debug_handler]
453 async fn handler() -> impl IntoResponse {
454 Negotiate(Example {
455 message: "Hello, test!".to_string(),
456 })
457 }
458
459 let app = Router::new().route("/", post(handler));
460
461 let response = app
462 .oneshot(
463 Request::builder()
464 .uri("/")
465 .method("POST")
466 .body(Body::empty())
467 .unwrap(),
468 )
469 .await
470 .unwrap();
471
472 assert_eq!(response.status(), 415);
473 assert_eq!(
474 response.into_body().collect().await.unwrap().to_bytes(),
475 "Misconfigured service layer"
476 );
477 }
478
479 #[tokio::test]
480 async fn test_does_not_process_handler_if_accept_is_not_supported() {
481 #[axum::debug_handler]
482 async fn handler() -> impl IntoResponse {
483 unimplemented!("This should not be called");
484 #[allow(unreachable_code)]
485 ()
486 }
487
488 let app = Router::new()
489 .route("/", post(handler))
490 .layer(NegotiateLayer);
491
492 let response = app
493 .oneshot(
494 Request::builder()
495 .uri("/")
496 .header(ACCEPT, "non-supported")
497 .method("POST")
498 .body(Body::empty())
499 .unwrap(),
500 )
501 .await
502 .unwrap();
503
504 assert_eq!(response.status(), 406);
505 assert_eq!(
506 response.into_body().collect().await.unwrap().to_bytes(),
507 "Invalid content type on request"
508 );
509 }
510 }
511 }
512
513 #[cfg(any(feature = "simd-json", feature = "json"))]
514 mod json {
515 use serde_json::json;
516
517 use super::*;
518
519 mod input {
520 use super::*;
521
522 #[cfg(feature = "default-json")]
523 #[tokio::test]
524 async fn test_can_read_input_without_content_type_by_default() {
525 #[axum::debug_handler]
526 async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
527 format!("Hello, {}!", input.message)
528 }
529
530 let app = Router::new().route("/", post(handler));
531
532 let response = app
533 .oneshot(
534 Request::builder()
535 .uri("/")
536 .method("POST")
537 .body(json!({ "message": "test" }).to_string())
538 .unwrap(),
539 )
540 .await
541 .unwrap();
542
543 assert_eq!(response.status(), 200);
544 assert_eq!(
545 response.into_body().collect().await.unwrap().to_bytes(),
546 "Hello, test!"
547 );
548 }
549
550 #[tokio::test]
551 async fn test_can_read_input_with_specified_header() {
552 #[axum::debug_handler]
553 async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
554 format!("Hello, {}!", input.message)
555 }
556
557 let app = Router::new().route("/", post(handler));
558
559 let response = app
560 .oneshot(
561 Request::builder()
562 .uri("/")
563 .header(CONTENT_TYPE, "application/json")
564 .method("POST")
565 .body(json!({ "message": "test" }).to_string())
566 .unwrap(),
567 )
568 .await
569 .unwrap();
570
571 assert_eq!(response.status(), 200);
572 assert_eq!(
573 response.into_body().collect().await.unwrap().to_bytes(),
574 "Hello, test!"
575 );
576 }
577
578 #[tokio::test]
579 async fn test_can_read_input_with_charset_in_header() {
580 #[axum::debug_handler]
581 async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
582 format!("Hello, {}!", input.message)
583 }
584
585 let app = Router::new().route("/", post(handler));
586
587 let response = app
588 .oneshot(
589 Request::builder()
590 .uri("/")
591 .header(CONTENT_TYPE, "application/json; charset=utf-8")
592 .method("POST")
593 .body(json!({ "message": "test" }).to_string())
594 .unwrap(),
595 )
596 .await
597 .unwrap();
598
599 assert_eq!(response.status(), 200);
600 assert_eq!(
601 response.into_body().collect().await.unwrap().to_bytes(),
602 "Hello, test!"
603 );
604 }
605
606 #[tokio::test]
607 async fn test_does_not_accept_invalid_inputs() {
608 #[axum::debug_handler]
609 async fn handler(_: Negotiate<Example>) -> impl IntoResponse {
610 unimplemented!("This should not be called");
611 #[allow(unreachable_code)]
612 ()
613 }
614
615 let app = Router::new()
616 .route("/", post(handler))
617 .layer(NegotiateLayer);
618
619 let response = app
620 .oneshot(
621 Request::builder()
622 .uri("/")
623 .method("POST")
624 .header(CONTENT_TYPE, "application/json")
625 .body(json!({ "not": true }).to_string())
626 .unwrap(),
627 )
628 .await
629 .unwrap();
630
631 assert_eq!(response.status(), 400);
632 assert_eq!(
633 response.into_body().collect().await.unwrap().to_bytes(),
634 "Malformed request body"
635 );
636 }
637 }
638
639 mod output {
640 use super::*;
641
642 #[tokio::test]
643 async fn test_encode_as_requested() {
644 #[axum::debug_handler]
645 async fn handler() -> impl IntoResponse {
646 Negotiate(Example {
647 message: "Hello, test!".to_string(),
648 })
649 }
650
651 let app = Router::new()
652 .route("/", post(handler))
653 .layer(NegotiateLayer);
654
655 let response = app
656 .oneshot(
657 Request::builder()
658 .uri("/")
659 .method("POST")
660 .header(ACCEPT, "application/json")
661 .body(Body::empty())
662 .unwrap(),
663 )
664 .await
665 .unwrap();
666
667 let expected_body = json!({ "message": "Hello, test!" }).to_string();
668
669 assert_eq!(response.status(), 200);
670 assert_eq!(
671 response.headers().get(CONTENT_TYPE).unwrap(),
672 "application/json"
673 );
674 assert_eq!(content_length(response.headers()), expected_body.len());
675 assert_eq!(
676 response.into_body().collect().await.unwrap().to_bytes(),
677 expected_body,
678 );
679 }
680
681 #[tokio::test]
682 async fn test_encode_as_requested_multi() {
683 #[axum::debug_handler]
684 async fn handler() -> impl IntoResponse {
685 Negotiate(Example {
686 message: "Hello, test!".to_string(),
687 })
688 }
689
690 let app = Router::new()
691 .route("/", post(handler))
692 .layer(NegotiateLayer);
693
694 let response = app
695 .oneshot(
696 Request::builder()
697 .uri("/")
698 .method("POST")
699 .header(ACCEPT, "not-supported, application/json;q=5,something-else")
700 .body(Body::empty())
701 .unwrap(),
702 )
703 .await
704 .unwrap();
705
706 let expected_body = json!({ "message": "Hello, test!" }).to_string();
707
708 assert_eq!(response.status(), 200);
709 assert_eq!(
710 response.headers().get(CONTENT_TYPE).unwrap(),
711 "application/json"
712 );
713 assert_eq!(content_length(response.headers()), expected_body.len());
714 assert_eq!(
715 response.into_body().collect().await.unwrap().to_bytes(),
716 expected_body,
717 );
718 }
719
720 #[cfg(feature = "cbor")]
721 #[tokio::test]
722 async fn test_encode_as_requested_multi_w_q() {
723 #[axum::debug_handler]
724 async fn handler() -> impl IntoResponse {
725 Negotiate(Example {
726 message: "Hello, test!".to_string(),
727 })
728 }
729
730 let app = Router::new()
731 .route("/", post(handler))
732 .layer(NegotiateLayer);
733
734 let response = app
735 .oneshot(
736 Request::builder()
737 .uri("/")
738 .method("POST")
739 .header(
740 ACCEPT,
741 "application/json;q=0.8;other;stuff,application/cbor;q=0.9",
742 )
743 .body(Body::empty())
744 .unwrap(),
745 )
746 .await
747 .unwrap();
748
749 assert_eq!(response.status(), 200);
750 assert_eq!(
751 response.headers().get(CONTENT_TYPE).unwrap(),
752 "application/cbor"
753 );
754 }
755
756 #[cfg(feature = "cbor")]
757 #[tokio::test]
758 async fn test_encode_as_requested_multi_w_q_same_weights() {
759 #[axum::debug_handler]
760 async fn handler() -> impl IntoResponse {
761 Negotiate(Example {
762 message: "Hello, test!".to_string(),
763 })
764 }
765
766 let app = Router::new()
767 .route("/", post(handler))
768 .layer(NegotiateLayer);
769
770 let response = app
771 .oneshot(
772 Request::builder()
773 .uri("/")
774 .method("POST")
775 .header(
776 ACCEPT,
777 "application/cbor;q=0.9,application/json;q=0.9;other;stuff",
778 )
779 .body(Body::empty())
780 .unwrap(),
781 )
782 .await
783 .unwrap();
784
785 assert_eq!(response.status(), 200);
786 assert_eq!(
787 response.headers().get(CONTENT_TYPE).unwrap(),
788 "application/cbor"
789 );
790 }
791
792 #[cfg(feature = "default-json")]
793 #[tokio::test]
794 async fn test_use_default_encoding_without_headers() {
795 #[axum::debug_handler]
796 async fn handler() -> impl IntoResponse {
797 Negotiate(Example {
798 message: "Hello, test!".to_string(),
799 })
800 }
801
802 let app = Router::new()
803 .route("/", post(handler))
804 .layer(NegotiateLayer);
805
806 let response = app
807 .oneshot(
808 Request::builder()
809 .uri("/")
810 .method("POST")
811 .body(Body::empty())
812 .unwrap(),
813 )
814 .await
815 .unwrap();
816
817 assert_eq!(response.status(), 200);
818 assert_eq!(
819 response.headers().get(CONTENT_TYPE).unwrap(),
820 "application/json"
821 );
822 assert_eq!(
823 response.into_body().collect().await.unwrap().to_bytes(),
824 json!({ "message": "Hello, test!" }).to_string()
825 );
826 }
827
828 #[tokio::test]
829 async fn test_retain_handler_status_code() {
830 #[axum::debug_handler]
831 async fn handler() -> impl IntoResponse {
832 (
833 StatusCode::CREATED,
834 Negotiate(Example {
835 message: "Hello, test!".to_string(),
836 }),
837 )
838 }
839
840 let app = Router::new()
841 .route("/", post(handler))
842 .layer(NegotiateLayer);
843
844 let response = app
845 .oneshot(
846 Request::builder()
847 .uri("/")
848 .method("POST")
849 .body(Body::empty())
850 .unwrap(),
851 )
852 .await
853 .unwrap();
854
855 assert_eq!(response.status(), StatusCode::CREATED);
856 #[cfg(feature = "default-json")]
857 assert_eq!(
858 response.headers().get(CONTENT_TYPE).unwrap(),
859 "application/json"
860 );
861 #[cfg(feature = "default-json")]
862 assert_eq!(
863 response.into_body().collect().await.unwrap().to_bytes(),
864 json!({ "message": "Hello, test!" }).to_string()
865 );
866 #[cfg(feature = "default-cbor")]
867 assert_eq!(
868 response.headers().get(CONTENT_TYPE).unwrap(),
869 "application/cbor"
870 );
871 #[cfg(feature = "default-cbor")]
872 assert_eq!(
873 response.into_body().collect().await.unwrap().to_bytes(),
874 general::expected_cbor_body()
875 );
876 }
877 }
878 }
879
880 #[cfg(feature = "cbor")]
881 mod cbor {
882 use cbor4ii::core::{enc::Encode, utils::BufWriter, Value};
883
884 use super::*;
885
886 mod input {
887 use super::*;
888
889 #[cfg(feature = "default-cbor")]
890 #[tokio::test]
891 async fn test_can_read_input_without_content_type_by_default() {
892 #[axum::debug_handler]
893 async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
894 format!("Hello, {}!", input.message)
895 }
896
897 let app = Router::new().route("/", post(handler));
898 let body = {
899 let mut writer = BufWriter::new(Vec::new());
900 Value::Map(vec![(
901 Value::Text("message".to_string()),
902 Value::Text("test".to_string()),
903 )])
904 .encode(&mut writer)
905 .unwrap();
906 writer.into_inner()
907 };
908
909 let response = app
910 .oneshot(
911 Request::builder()
912 .uri("/")
913 .method("POST")
914 .body(Body::from(body))
915 .unwrap(),
916 )
917 .await
918 .unwrap();
919
920 assert_eq!(response.status(), 200);
921 assert_eq!(
922 response.into_body().collect().await.unwrap().to_bytes(),
923 "Hello, test!"
924 );
925 }
926
927 #[tokio::test]
928 async fn test_can_read_input_with_specified_header() {
929 #[axum::debug_handler]
930 async fn handler(Negotiate(input): Negotiate<Example>) -> impl IntoResponse {
931 format!("Hello, {}!", input.message)
932 }
933
934 let app = Router::new().route("/", post(handler));
935 let body = {
936 let mut writer = BufWriter::new(Vec::new());
937 Value::Map(vec![(
938 Value::Text("message".to_string()),
939 Value::Text("test".to_string()),
940 )])
941 .encode(&mut writer)
942 .unwrap();
943 writer.into_inner()
944 };
945
946 let response = app
947 .oneshot(
948 Request::builder()
949 .uri("/")
950 .header(CONTENT_TYPE, "application/cbor")
951 .method("POST")
952 .body(Body::from(body))
953 .unwrap(),
954 )
955 .await
956 .unwrap();
957
958 assert_eq!(response.status(), 200);
959 assert_eq!(
960 response.into_body().collect().await.unwrap().to_bytes(),
961 "Hello, test!"
962 );
963 }
964 }
965
966 mod output {
967 use super::*;
968
969 #[tokio::test]
970 async fn test_encode_as_requested() {
971 #[axum::debug_handler]
972 async fn handler() -> impl IntoResponse {
973 Negotiate(Example {
974 message: "Hello, test!".to_string(),
975 })
976 }
977
978 let app = Router::new()
979 .route("/", post(handler))
980 .layer(NegotiateLayer);
981
982 let response = app
983 .oneshot(
984 Request::builder()
985 .uri("/")
986 .method("POST")
987 .header(ACCEPT, "application/cbor")
988 .body(Body::empty())
989 .unwrap(),
990 )
991 .await
992 .unwrap();
993
994 let expected_body = general::expected_cbor_body();
995
996 assert_eq!(response.status(), 200);
997 assert_eq!(
998 response.headers().get(CONTENT_TYPE).unwrap(),
999 "application/cbor"
1000 );
1001 assert_eq!(content_length(response.headers()), expected_body.len());
1002 assert_eq!(
1003 response.into_body().collect().await.unwrap().to_bytes(),
1004 expected_body,
1005 );
1006 }
1007
1008 #[tokio::test]
1009 async fn test_encode_as_requested_multi() {
1010 #[axum::debug_handler]
1011 async fn handler() -> impl IntoResponse {
1012 Negotiate(Example {
1013 message: "Hello, test!".to_string(),
1014 })
1015 }
1016
1017 let app = Router::new()
1018 .route("/", post(handler))
1019 .layer(NegotiateLayer);
1020
1021 let response = app
1022 .oneshot(
1023 Request::builder()
1024 .uri("/")
1025 .method("POST")
1026 .header(ACCEPT, "something-else;q=0.5,application/cbor")
1027 .body(Body::empty())
1028 .unwrap(),
1029 )
1030 .await
1031 .unwrap();
1032
1033 let expected_body = general::expected_cbor_body();
1034
1035 assert_eq!(response.status(), 200);
1036 assert_eq!(
1037 response.headers().get(CONTENT_TYPE).unwrap(),
1038 "application/cbor"
1039 );
1040 assert_eq!(content_length(response.headers()), expected_body.len());
1041 assert_eq!(
1042 response.into_body().collect().await.unwrap().to_bytes(),
1043 expected_body,
1044 );
1045 }
1046
1047 #[cfg(feature = "json")]
1048 #[tokio::test]
1049 async fn test_encode_as_requested_multi_without_q_using_default_weight() {
1050 #[axum::debug_handler]
1051 async fn handler() -> impl IntoResponse {
1052 Negotiate(Example {
1053 message: "Hello, test!".to_string(),
1054 })
1055 }
1056
1057 let app = Router::new()
1058 .route("/", post(handler))
1059 .layer(NegotiateLayer);
1060
1061 let response = app
1062 .oneshot(
1063 Request::builder()
1064 .uri("/")
1065 .method("POST")
1066 .header(ACCEPT, "application/cbor;q=0.2,application/json")
1067 .body(Body::empty())
1068 .unwrap(),
1069 )
1070 .await
1071 .unwrap();
1072
1073 assert_eq!(response.status(), 200);
1074 assert_eq!(
1075 response.headers().get(CONTENT_TYPE).unwrap(),
1076 "application/json"
1077 );
1078 }
1079
1080 #[cfg(feature = "json")]
1082 #[tokio::test]
1083 async fn test_encode_as_requested_equal_q() {
1084 #[axum::debug_handler]
1085 async fn handler() -> impl IntoResponse {
1086 Negotiate(Example {
1087 message: "Hello, test!".to_string(),
1088 })
1089 }
1090
1091 let app = Router::new()
1092 .route("/", post(handler))
1093 .layer(NegotiateLayer);
1094
1095 let response = app
1096 .oneshot(
1097 Request::builder()
1098 .uri("/")
1099 .method("POST")
1100 .header(ACCEPT, "application/cbor,application/json")
1101 .body(Body::empty())
1102 .unwrap(),
1103 )
1104 .await
1105 .unwrap();
1106
1107 assert_eq!(response.status(), 200);
1108 assert_eq!(
1109 response.headers().get(CONTENT_TYPE).unwrap(),
1110 "application/cbor"
1111 );
1112 }
1113 #[cfg(feature = "json")]
1115 #[tokio::test]
1116 async fn test_encode_as_requested_equal_q2() {
1117 #[axum::debug_handler]
1118 async fn handler() -> impl IntoResponse {
1119 Negotiate(Example {
1120 message: "Hello, test!".to_string(),
1121 })
1122 }
1123
1124 let app = Router::new()
1125 .route("/", post(handler))
1126 .layer(NegotiateLayer);
1127
1128 let response = app
1129 .oneshot(
1130 Request::builder()
1131 .uri("/")
1132 .method("POST")
1133 .header(ACCEPT, "application/json,application/cbor")
1134 .body(Body::empty())
1135 .unwrap(),
1136 )
1137 .await
1138 .unwrap();
1139
1140 assert_eq!(response.status(), 200);
1141 assert_eq!(
1142 response.headers().get(CONTENT_TYPE).unwrap(),
1143 "application/json"
1144 );
1145 }
1146
1147 #[tokio::test]
1148 async fn test_retain_status_code() {
1149 #[axum::debug_handler]
1150 async fn handler() -> impl IntoResponse {
1151 (
1152 StatusCode::CREATED,
1153 Negotiate(Example {
1154 message: "Hello, test!".to_string(),
1155 }),
1156 )
1157 }
1158
1159 let app = Router::new()
1160 .route("/", post(handler))
1161 .layer(NegotiateLayer);
1162
1163 let response = app
1164 .oneshot(
1165 Request::builder()
1166 .uri("/")
1167 .method("POST")
1168 .header(ACCEPT, "application/cbor")
1169 .body(Body::empty())
1170 .unwrap(),
1171 )
1172 .await
1173 .unwrap();
1174
1175 assert_eq!(response.status(), StatusCode::CREATED);
1176 assert_eq!(
1177 response.headers().get(CONTENT_TYPE).unwrap(),
1178 "application/cbor"
1179 );
1180 assert_eq!(
1181 response.into_body().collect().await.unwrap().to_bytes(),
1182 general::expected_cbor_body()
1183 );
1184 }
1185
1186 #[cfg(feature = "default-cbor")]
1187 #[tokio::test]
1188 async fn test_default_encoding_without_header() {
1189 #[axum::debug_handler]
1190 async fn handler() -> impl IntoResponse {
1191 (
1192 StatusCode::CREATED,
1193 Negotiate(Example {
1194 message: "Hello, test!".to_string(),
1195 }),
1196 )
1197 }
1198
1199 let app = Router::new()
1200 .route("/", post(handler))
1201 .layer(NegotiateLayer);
1202
1203 let response = app
1204 .oneshot(
1205 Request::builder()
1206 .uri("/")
1207 .method("POST")
1208 .body(Body::empty())
1209 .unwrap(),
1210 )
1211 .await
1212 .unwrap();
1213
1214 assert_eq!(response.status(), StatusCode::CREATED);
1215 assert_eq!(
1216 response.headers().get(CONTENT_TYPE).unwrap(),
1217 "application/cbor"
1218 );
1219 assert_eq!(
1220 response.into_body().collect().await.unwrap().to_bytes(),
1221 general::expected_cbor_body()
1222 );
1223 }
1224
1225 #[cfg(feature = "default-cbor")]
1226 #[tokio::test]
1227 async fn test_default_encoding_with_star() {
1228 #[axum::debug_handler]
1229 async fn handler() -> impl IntoResponse {
1230 (
1231 StatusCode::CREATED,
1232 Negotiate(Example {
1233 message: "Hello, test!".to_string(),
1234 }),
1235 )
1236 }
1237
1238 let app = Router::new()
1239 .route("/", post(handler))
1240 .layer(NegotiateLayer);
1241
1242 let response = app
1243 .oneshot(
1244 Request::builder()
1245 .uri("/")
1246 .method("POST")
1247 .header(ACCEPT, "*/*")
1248 .body(Body::empty())
1249 .unwrap(),
1250 )
1251 .await
1252 .unwrap();
1253
1254 assert_eq!(response.status(), StatusCode::CREATED);
1255 assert_eq!(
1256 response.headers().get(CONTENT_TYPE).unwrap(),
1257 "application/cbor"
1258 );
1259 assert_eq!(
1260 response.into_body().collect().await.unwrap().to_bytes(),
1261 general::expected_cbor_body()
1262 );
1263 }
1264 }
1265 }
1266}