1use axum::extract::{FromRequest, FromRequestParts};
21use axum::response::{IntoResponse, Response};
22
23macro_rules! impl_extractor_deref {
24 ($extractor:ident) => {
25 impl<T> std::ops::Deref for $extractor<T> {
26 type Target = T;
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31 }
32
33 impl<T> std::ops::DerefMut for $extractor<T> {
34 fn deref_mut(&mut self) -> &mut Self::Target {
35 &mut self.0
36 }
37 }
38 };
39}
40
41#[derive(Debug, Clone, Copy, Default)]
46pub struct Form<T>(pub T);
47
48impl_extractor_deref!(Form);
49
50impl<S, T> FromRequest<S> for Form<T>
51where
52 S: Send + Sync,
53 axum::extract::Form<T>: FromRequest<S, Rejection = axum::extract::rejection::FormRejection>,
54{
55 type Rejection = crate::AutumnError;
56
57 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
58 axum::extract::Form::from_request(req, state)
59 .await
60 .map(|axum::extract::Form(value)| Self(value))
61 .map_err(|err| rejection_to_error(err.status(), err.body_text()))
62 }
63}
64
65#[derive(Debug, Clone, Copy, Default)]
71pub struct Json<T>(pub T);
72
73impl_extractor_deref!(Json);
74
75impl<S, T> FromRequest<S> for Json<T>
76where
77 S: Send + Sync,
78 axum::extract::Json<T>: FromRequest<S, Rejection = axum::extract::rejection::JsonRejection>,
79{
80 type Rejection = crate::AutumnError;
81
82 async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
83 axum::extract::Json::from_request(req, state)
84 .await
85 .map(|axum::extract::Json(value)| Self(value))
86 .map_err(|err| rejection_to_error(err.status(), err.body_text()))
87 }
88}
89
90impl<T> IntoResponse for Json<T>
91where
92 axum::Json<T>: IntoResponse,
93{
94 fn into_response(self) -> Response {
95 axum::Json(self.0).into_response()
96 }
97}
98
99#[derive(Debug, Clone, Copy, Default)]
104pub struct Path<T>(pub T);
105
106impl_extractor_deref!(Path);
107
108impl<S, T> FromRequestParts<S> for Path<T>
109where
110 S: Send + Sync,
111 axum::extract::Path<T>:
112 FromRequestParts<S, Rejection = axum::extract::rejection::PathRejection>,
113{
114 type Rejection = crate::AutumnError;
115
116 async fn from_request_parts(
117 parts: &mut axum::http::request::Parts,
118 state: &S,
119 ) -> Result<Self, Self::Rejection> {
120 axum::extract::Path::from_request_parts(parts, state)
121 .await
122 .map(|axum::extract::Path(value)| Self(value))
123 .map_err(|err| rejection_to_error(err.status(), err.body_text()))
124 }
125}
126
127#[derive(Debug, Clone, Copy, Default)]
132pub struct Query<T>(pub T);
133
134impl_extractor_deref!(Query);
135
136impl<S, T> FromRequestParts<S> for Query<T>
137where
138 S: Send + Sync,
139 axum::extract::Query<T>:
140 FromRequestParts<S, Rejection = axum::extract::rejection::QueryRejection>,
141{
142 type Rejection = crate::AutumnError;
143
144 async fn from_request_parts(
145 parts: &mut axum::http::request::Parts,
146 state: &S,
147 ) -> Result<Self, Self::Rejection> {
148 axum::extract::Query::from_request_parts(parts, state)
149 .await
150 .map(|axum::extract::Query(value)| Self(value))
151 .map_err(|err| rejection_to_error(err.status(), err.body_text()))
152 }
153}
154
155fn rejection_to_error(status: http::StatusCode, body_text: String) -> crate::AutumnError {
156 crate::AutumnError::bad_request_msg(body_text).with_status(status)
157}
158
159#[cfg(feature = "multipart")]
170pub struct Multipart {
171 inner: axum::extract::Multipart,
172 config: crate::security::config::UploadConfig,
173}
174
175#[cfg(feature = "multipart")]
176impl Multipart {
177 pub async fn next_field(&mut self) -> crate::AutumnResult<Option<MultipartField<'_>>> {
184 let Some(field) = self
185 .inner
186 .next_field()
187 .await
188 .map_err(|err| multipart_error_to_error(&err))?
189 else {
190 return Ok(None);
191 };
192
193 if field.file_name().is_some() && !self.config.allowed_mime_types.is_empty() {
196 let Some(content_type) = field.content_type().map(str::to_owned) else {
197 return Err(crate::AutumnError::bad_request_msg(
198 "missing content type on uploaded file",
199 ));
200 };
201 if !self
202 .config
203 .allowed_mime_types
204 .iter()
205 .any(|allowed| allowed.eq_ignore_ascii_case(&content_type))
206 {
207 return Err(crate::AutumnError::bad_request_msg(format!(
208 "unsupported upload content type: {content_type}"
209 )));
210 }
211 }
212
213 Ok(Some(MultipartField {
214 inner: field,
215 max_file_size_bytes: self.config.max_file_size_bytes,
216 }))
217 }
218}
219
220#[cfg(feature = "multipart")]
221impl<S> axum::extract::FromRequest<S> for Multipart
222where
223 S: Send + Sync,
224 axum::extract::Multipart:
225 axum::extract::FromRequest<S, Rejection = axum::extract::multipart::MultipartRejection>,
226{
227 type Rejection = crate::AutumnError;
228
229 async fn from_request(
230 mut req: axum::extract::Request,
231 state: &S,
232 ) -> Result<Self, Self::Rejection> {
233 let config = req
234 .extensions()
235 .get::<crate::security::config::UploadConfig>()
236 .cloned()
237 .unwrap_or_default();
238 axum::extract::DefaultBodyLimit::max(config.max_request_size_bytes).apply(&mut req);
239 let inner = axum::extract::Multipart::from_request(req, state)
240 .await
241 .map_err(|err| multipart_rejection_to_error(&err))?;
242 Ok(Self { inner, config })
243 }
244}
245
246#[cfg(feature = "multipart")]
248pub struct MultipartField<'a> {
249 inner: axum::extract::multipart::Field<'a>,
250 max_file_size_bytes: usize,
251}
252
253#[cfg(all(feature = "multipart", feature = "storage"))]
254struct MultipartFieldStreamState<'a> {
255 inner: axum::extract::multipart::Field<'a>,
256 total: usize,
257 max: usize,
258 errored: bool,
259}
260
261#[cfg(feature = "multipart")]
262#[allow(clippy::elidable_lifetime_names)]
263impl<'a> MultipartField<'a> {
264 #[must_use]
266 pub fn name(&self) -> Option<&str> {
267 self.inner.name()
268 }
269
270 #[must_use]
272 pub fn file_name(&self) -> Option<&str> {
273 self.inner.file_name()
274 }
275
276 #[must_use]
278 pub fn content_type(&self) -> Option<&str> {
279 self.inner.content_type()
280 }
281
282 #[must_use]
303 pub fn with_max_bytes(mut self, max: usize) -> Self {
304 self.max_file_size_bytes = self.max_file_size_bytes.min(max);
305 self
306 }
307
308 pub async fn bytes_limited(mut self) -> crate::AutumnResult<Vec<u8>> {
315 let mut out = Vec::new();
316 let mut read = 0usize;
317 while let Some(chunk) = self
318 .inner
319 .chunk()
320 .await
321 .map_err(|err| multipart_error_to_error(&err))?
322 {
323 read += chunk.len();
324 if read > self.max_file_size_bytes {
325 return Err(file_too_large_error(self.max_file_size_bytes));
326 }
327 out.extend_from_slice(&chunk);
328 }
329 Ok(out)
330 }
331
332 #[cfg(feature = "storage")]
369 pub async fn save_to_blob_store<'b>(
370 self,
371 store: &'b (dyn crate::storage::BlobStore + '_),
372 key: impl Into<String>,
373 ) -> crate::AutumnResult<crate::storage::Blob>
374 where
375 'a: 'b,
376 {
377 let key = key.into();
378 let content_type = self
379 .inner
380 .content_type()
381 .map_or_else(|| "application/octet-stream".to_owned(), str::to_owned);
382
383 let state = MultipartFieldStreamState {
388 inner: self.inner,
389 total: 0,
390 max: self.max_file_size_bytes,
391 errored: false,
392 };
393
394 let stream = futures::stream::unfold(state, |mut state| async move {
395 if state.errored {
396 return None;
397 }
398 match state.inner.chunk().await {
399 Ok(Some(chunk)) => {
400 state.total = state.total.saturating_add(chunk.len());
401 if state.total > state.max {
402 let err = crate::storage::BlobStoreError::PayloadTooLarge(format!(
403 "uploaded file exceeds limit of {} bytes",
404 state.max,
405 ));
406 state.errored = true;
407 Some((Err(err), state))
408 } else {
409 Some((Ok(chunk), state))
410 }
411 }
412 Ok(None) => None,
413 Err(err) => {
414 state.errored = true;
420 let mapped = blob_error_from_multipart(&err);
421 Some((Err(mapped), state))
422 }
423 }
424 });
425 let stream: crate::storage::ByteStream<'b> = Box::pin(stream);
426
427 store
428 .put_stream(&key, &content_type, stream)
429 .await
430 .map_err(crate::storage::BlobStoreError::into_autumn_error)
431 }
432
433 pub async fn save_to<P: AsRef<std::path::Path>>(
440 mut self,
441 path: P,
442 ) -> crate::AutumnResult<usize> {
443 use tokio::io::AsyncWriteExt as _;
444
445 let path = path.as_ref();
446 let mut file = tokio::fs::File::create(path)
447 .await
448 .map_err(crate::AutumnError::internal_server_error)?;
449
450 let mut written = 0usize;
451 while let Some(chunk) = self
452 .inner
453 .chunk()
454 .await
455 .map_err(|err| multipart_error_to_error(&err))?
456 {
457 written += chunk.len();
458 if written > self.max_file_size_bytes {
459 drop(file);
460 let _ = tokio::fs::remove_file(path).await;
461 return Err(file_too_large_error(self.max_file_size_bytes));
462 }
463 file.write_all(&chunk)
464 .await
465 .map_err(crate::AutumnError::internal_server_error)?;
466 }
467 file.flush()
468 .await
469 .map_err(crate::AutumnError::internal_server_error)?;
470 Ok(written)
471 }
472}
473
474#[cfg(feature = "multipart")]
475fn multipart_rejection_to_error(
476 err: &axum::extract::multipart::MultipartRejection,
477) -> crate::AutumnError {
478 crate::AutumnError::bad_request_msg(err.body_text()).with_status(err.status())
479}
480
481#[cfg(feature = "multipart")]
482#[cfg(all(feature = "multipart", feature = "storage"))]
490fn blob_error_from_multipart(
491 err: &axum::extract::multipart::MultipartError,
492) -> crate::storage::BlobStoreError {
493 let status = err.status();
494 let body = err.body_text();
495 if status == http::StatusCode::PAYLOAD_TOO_LARGE {
496 crate::storage::BlobStoreError::PayloadTooLarge(body)
497 } else if status.is_client_error() {
498 crate::storage::BlobStoreError::InvalidInput(body)
499 } else {
500 crate::storage::BlobStoreError::Io(body)
501 }
502}
503
504#[cfg(feature = "multipart")]
505fn multipart_error_to_error(err: &axum::extract::multipart::MultipartError) -> crate::AutumnError {
506 crate::AutumnError::bad_request_msg(err.body_text()).with_status(err.status())
507}
508
509#[cfg(feature = "multipart")]
510fn file_too_large_error(max_file_size_bytes: usize) -> crate::AutumnError {
511 crate::AutumnError::bad_request_msg(format!(
512 "uploaded file exceeds limit of {max_file_size_bytes} bytes",
513 ))
514 .with_status(http::StatusCode::PAYLOAD_TOO_LARGE)
515}
516
517pub use axum::extract::State;
518
519#[cfg(all(test, feature = "multipart"))]
520mod tests {
521 use super::*;
522 use axum::extract::FromRequest;
523 use axum::http::Request;
524
525 #[tokio::test]
526 async fn test_multipart_field_bytes_limited_success() {
527 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nhello\r\n--boundary--\r\n";
528 let req = Request::builder()
529 .header("content-type", "multipart/form-data; boundary=boundary")
530 .body(axum::body::Body::from(body))
531 .unwrap();
532
533 let mut multipart = axum::extract::Multipart::from_request(req, &())
534 .await
535 .unwrap();
536 let field = multipart.next_field().await.unwrap().unwrap();
537
538 let wrapper = MultipartField {
539 inner: field,
540 max_file_size_bytes: 100,
541 };
542
543 let bytes = wrapper.bytes_limited().await.unwrap();
544 assert_eq!(bytes, b"hello");
545 }
546
547 #[tokio::test]
548 async fn test_multipart_field_bytes_limited_too_large() {
549 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nhello world\r\n--boundary--\r\n";
550 let req = Request::builder()
551 .header("content-type", "multipart/form-data; boundary=boundary")
552 .body(axum::body::Body::from(body))
553 .unwrap();
554
555 let mut multipart = axum::extract::Multipart::from_request(req, &())
556 .await
557 .unwrap();
558 let field = multipart.next_field().await.unwrap().unwrap();
559
560 let wrapper = MultipartField {
561 inner: field,
562 max_file_size_bytes: 5,
563 };
564
565 let err = wrapper.bytes_limited().await.unwrap_err();
566 assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
567 }
568
569 #[tokio::test]
570 async fn test_multipart_field_save_to_success() {
571 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nfile content\r\n--boundary--\r\n";
572 let req = Request::builder()
573 .header("content-type", "multipart/form-data; boundary=boundary")
574 .body(axum::body::Body::from(body))
575 .unwrap();
576
577 let mut multipart = axum::extract::Multipart::from_request(req, &())
578 .await
579 .unwrap();
580 let field = multipart.next_field().await.unwrap().unwrap();
581
582 let wrapper = MultipartField {
583 inner: field,
584 max_file_size_bytes: 100,
585 };
586
587 let dir = tempfile::tempdir().unwrap();
588 let file_path = dir.path().join("out.txt");
589
590 let written = wrapper.save_to(&file_path).await.unwrap();
591 assert_eq!(written, 12);
592
593 let content = std::fs::read_to_string(&file_path).unwrap();
594 assert_eq!(content, "file content");
595 }
596
597 #[tokio::test]
598 async fn test_multipart_field_save_to_too_large() {
599 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\r\nfile content\r\n--boundary--\r\n";
600 let req = Request::builder()
601 .header("content-type", "multipart/form-data; boundary=boundary")
602 .body(axum::body::Body::from(body))
603 .unwrap();
604
605 let mut multipart = axum::extract::Multipart::from_request(req, &())
606 .await
607 .unwrap();
608 let field = multipart.next_field().await.unwrap().unwrap();
609
610 let wrapper = MultipartField {
611 inner: field,
612 max_file_size_bytes: 4,
613 };
614
615 let dir = tempfile::tempdir().unwrap();
616 let file_path = dir.path().join("out_large.txt");
617
618 let err = wrapper.save_to(&file_path).await.unwrap_err();
619 assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
620
621 assert!(!file_path.exists());
622 }
623
624 #[cfg(feature = "storage")]
625 #[tokio::test]
626 async fn test_multipart_field_save_to_blob_store_success() {
627 use crate::storage::{BlobStore, LocalBlobStore, local::SigningKey};
628 use std::time::Duration;
629
630 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nblob content\r\n--boundary--\r\n";
631 let req = Request::builder()
632 .header("content-type", "multipart/form-data; boundary=boundary")
633 .body(axum::body::Body::from(body))
634 .unwrap();
635
636 let mut multipart = axum::extract::Multipart::from_request(req, &())
637 .await
638 .unwrap();
639 let field = multipart.next_field().await.unwrap().unwrap();
640
641 let wrapper = MultipartField {
642 inner: field,
643 max_file_size_bytes: 100,
644 };
645
646 let root = tempfile::tempdir().unwrap();
647 let store = LocalBlobStore::new(
648 "local",
649 root.path(),
650 "/blobs",
651 Duration::from_secs(3600),
652 SigningKey::random(),
653 vec![],
654 )
655 .unwrap();
656
657 let blob = wrapper.save_to_blob_store(&store, "myblob").await.unwrap();
658 assert_eq!(blob.key, "myblob");
659 assert_eq!(blob.content_type, "text/plain");
660
661 let bytes = store.get("myblob").await.unwrap();
662 assert_eq!(&bytes[..], b"blob content");
663 }
664
665 #[cfg(feature = "storage")]
666 #[tokio::test]
667 async fn test_multipart_field_save_to_blob_store_too_large() {
668 use crate::storage::{BlobStore, LocalBlobStore, local::SigningKey};
669 use std::time::Duration;
670
671 let body = "--boundary\r\nContent-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\nContent-Type: text/plain\r\n\r\nblob content\r\n--boundary--\r\n";
672 let req = Request::builder()
673 .header("content-type", "multipart/form-data; boundary=boundary")
674 .body(axum::body::Body::from(body))
675 .unwrap();
676
677 let mut multipart = axum::extract::Multipart::from_request(req, &())
678 .await
679 .unwrap();
680 let field = multipart.next_field().await.unwrap().unwrap();
681
682 let wrapper = MultipartField {
683 inner: field,
684 max_file_size_bytes: 4, };
686
687 let root = tempfile::tempdir().unwrap();
688 let store = LocalBlobStore::new(
689 "local",
690 root.path(),
691 "/blobs",
692 Duration::from_secs(3600),
693 SigningKey::random(),
694 vec![],
695 )
696 .unwrap();
697
698 let err = wrapper
699 .save_to_blob_store(&store, "myblob")
700 .await
701 .unwrap_err();
702 assert_eq!(err.status(), http::StatusCode::PAYLOAD_TOO_LARGE);
703
704 let get_err = store.get("myblob").await.unwrap_err();
706 assert_eq!(get_err.status(), http::StatusCode::NOT_FOUND);
707 }
708
709 #[tokio::test]
710 async fn test_multipart_field_metadata() {
711 let body = "--boundary\r\nContent-Disposition: form-data; name=\"custom_name\"; filename=\"custom_file.png\"\r\nContent-Type: image/png\r\n\r\npng\r\n--boundary--\r\n";
712 let req = Request::builder()
713 .header("content-type", "multipart/form-data; boundary=boundary")
714 .body(axum::body::Body::from(body))
715 .unwrap();
716
717 let mut multipart = axum::extract::Multipart::from_request(req, &())
718 .await
719 .unwrap();
720 let field = multipart.next_field().await.unwrap().unwrap();
721
722 let wrapper = MultipartField {
723 inner: field,
724 max_file_size_bytes: 100,
725 };
726
727 assert_eq!(wrapper.name(), Some("custom_name"));
728 assert_eq!(wrapper.file_name(), Some("custom_file.png"));
729 assert_eq!(wrapper.content_type(), Some("image/png"));
730
731 let tighter = wrapper.with_max_bytes(50);
732 assert_eq!(tighter.max_file_size_bytes, 50);
733
734 let not_tighter = tighter.with_max_bytes(200);
735 assert_eq!(not_tighter.max_file_size_bytes, 50); }
737}
738
739use crate::security::trusted_proxies::ResolvedClientIdentity;
742
743pub struct ClientAddr(pub std::net::IpAddr);
762
763impl ClientAddr {
764 #[must_use]
766 pub const fn ip(&self) -> std::net::IpAddr {
767 self.0
768 }
769}
770
771impl<S> FromRequestParts<S> for ClientAddr
772where
773 S: Send + Sync,
774{
775 type Rejection = (axum::http::StatusCode, &'static str);
776
777 async fn from_request_parts(
778 parts: &mut axum::http::request::Parts,
779 _state: &S,
780 ) -> Result<Self, Self::Rejection> {
781 parts
782 .extensions
783 .get::<ResolvedClientIdentity>()
784 .and_then(|id| id.addr)
785 .map(ClientAddr)
786 .ok_or((
787 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
788 "ClientAddr not resolved. Is the TrustedProxiesLayer installed?",
789 ))
790 }
791}
792
793impl<S> axum::extract::OptionalFromRequestParts<S> for ClientAddr
794where
795 S: Send + Sync,
796{
797 type Rejection = std::convert::Infallible;
798
799 async fn from_request_parts(
800 parts: &mut axum::http::request::Parts,
801 _state: &S,
802 ) -> Result<Option<Self>, Self::Rejection> {
803 Ok(parts
804 .extensions
805 .get::<ResolvedClientIdentity>()
806 .and_then(|id| id.addr)
807 .map(ClientAddr))
808 }
809}
810
811pub struct ClientHost(pub String);
820
821impl ClientHost {
822 #[must_use]
824 pub fn as_str(&self) -> &str {
825 &self.0
826 }
827}
828
829impl<S> FromRequestParts<S> for ClientHost
830where
831 S: Send + Sync,
832{
833 type Rejection = (axum::http::StatusCode, &'static str);
834
835 async fn from_request_parts(
836 parts: &mut axum::http::request::Parts,
837 _state: &S,
838 ) -> Result<Self, Self::Rejection> {
839 parts
840 .extensions
841 .get::<ResolvedClientIdentity>()
842 .and_then(|id| id.host.clone())
843 .map(ClientHost)
844 .ok_or((
845 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
846 "ClientHost not resolved. Is the TrustedProxiesLayer installed?",
847 ))
848 }
849}
850
851impl<S> axum::extract::OptionalFromRequestParts<S> for ClientHost
852where
853 S: Send + Sync,
854{
855 type Rejection = std::convert::Infallible;
856
857 async fn from_request_parts(
858 parts: &mut axum::http::request::Parts,
859 _state: &S,
860 ) -> Result<Option<Self>, Self::Rejection> {
861 Ok(parts
862 .extensions
863 .get::<ResolvedClientIdentity>()
864 .and_then(|id| id.host.clone())
865 .map(ClientHost))
866 }
867}
868
869pub struct ClientScheme(pub String);
878
879impl ClientScheme {
880 #[must_use]
882 pub fn as_str(&self) -> &str {
883 &self.0
884 }
885
886 #[must_use]
888 pub fn is_https(&self) -> bool {
889 self.0.eq_ignore_ascii_case("https")
890 }
891}
892
893impl<S> FromRequestParts<S> for ClientScheme
894where
895 S: Send + Sync,
896{
897 type Rejection = (axum::http::StatusCode, &'static str);
898
899 async fn from_request_parts(
900 parts: &mut axum::http::request::Parts,
901 _state: &S,
902 ) -> Result<Self, Self::Rejection> {
903 parts
904 .extensions
905 .get::<ResolvedClientIdentity>()
906 .map(|id| Self(id.scheme.clone().unwrap_or_else(|| "http".to_owned())))
907 .ok_or((
908 axum::http::StatusCode::INTERNAL_SERVER_ERROR,
909 "ClientScheme not resolved. Is the TrustedProxiesLayer installed?",
910 ))
911 }
912}
913
914impl<S> axum::extract::OptionalFromRequestParts<S> for ClientScheme
915where
916 S: Send + Sync,
917{
918 type Rejection = std::convert::Infallible;
919
920 async fn from_request_parts(
921 parts: &mut axum::http::request::Parts,
922 _state: &S,
923 ) -> Result<Option<Self>, Self::Rejection> {
924 Ok(parts
925 .extensions
926 .get::<ResolvedClientIdentity>()
927 .map(|id| Self(id.scheme.clone().unwrap_or_else(|| "http".to_owned()))))
928 }
929}
930
931#[cfg(test)]
932mod trusted_proxy_extractor_tests {
933 use super::*;
934 use axum::Router;
935 use axum::body::Body;
936 use axum::routing::get;
937 use tower::ServiceExt;
938
939 fn make_identity(addr: &str, host: &str, scheme: &str) -> ResolvedClientIdentity {
940 ResolvedClientIdentity {
941 addr: Some(addr.parse().unwrap()),
942 host: Some(host.to_owned()),
943 scheme: Some(scheme.to_owned()),
944 }
945 }
946
947 #[tokio::test]
948 async fn client_addr_extractor_reads_from_extension() {
949 async fn handler(ClientAddr(ip): ClientAddr) -> String {
950 ip.to_string()
951 }
952
953 let app = Router::new().route("/", get(handler));
954
955 let mut req = axum::http::Request::builder()
956 .uri("/")
957 .body(Body::empty())
958 .unwrap();
959 req.extensions_mut()
960 .insert(make_identity("192.0.2.1", "app.example", "https"));
961
962 let resp = app.oneshot(req).await.unwrap();
963 assert_eq!(resp.status(), axum::http::StatusCode::OK);
964 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
965 assert_eq!(&body[..], b"192.0.2.1");
966 }
967
968 #[tokio::test]
969 async fn client_host_extractor_reads_from_extension() {
970 async fn handler(ClientHost(host): ClientHost) -> String {
971 host
972 }
973
974 let app = Router::new().route("/", get(handler));
975
976 let mut req = axum::http::Request::builder()
977 .uri("/")
978 .body(Body::empty())
979 .unwrap();
980 req.extensions_mut()
981 .insert(make_identity("192.0.2.1", "app.example", "https"));
982
983 let resp = app.oneshot(req).await.unwrap();
984 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
985 assert_eq!(&body[..], b"app.example");
986 }
987
988 #[tokio::test]
989 async fn client_scheme_extractor_reads_from_extension() {
990 async fn handler(ClientScheme(scheme): ClientScheme) -> String {
991 scheme
992 }
993
994 let app = Router::new().route("/", get(handler));
995
996 let mut req = axum::http::Request::builder()
997 .uri("/")
998 .body(Body::empty())
999 .unwrap();
1000 req.extensions_mut()
1001 .insert(make_identity("192.0.2.1", "app.example", "https"));
1002
1003 let resp = app.oneshot(req).await.unwrap();
1004 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
1005 assert_eq!(&body[..], b"https");
1006 }
1007
1008 #[tokio::test]
1009 async fn client_addr_missing_returns_500() {
1010 async fn handler(_: ClientAddr) -> &'static str {
1011 "ok"
1012 }
1013
1014 let app = Router::new().route("/", get(handler));
1015 let req = axum::http::Request::builder()
1016 .uri("/")
1017 .body(Body::empty())
1018 .unwrap();
1019 let resp = app.oneshot(req).await.unwrap();
1020 assert_eq!(resp.status(), axum::http::StatusCode::INTERNAL_SERVER_ERROR);
1021 }
1022
1023 #[tokio::test]
1024 async fn optional_client_addr_returns_none_when_missing() {
1025 async fn handler(addr: Option<ClientAddr>) -> String {
1026 if addr.is_some() {
1027 "some".to_owned()
1028 } else {
1029 "none".to_owned()
1030 }
1031 }
1032
1033 let app = Router::new().route("/", get(handler));
1034 let req = axum::http::Request::builder()
1035 .uri("/")
1036 .body(Body::empty())
1037 .unwrap();
1038 let resp = app.oneshot(req).await.unwrap();
1039 let body = axum::body::to_bytes(resp.into_body(), 64).await.unwrap();
1040 assert_eq!(&body[..], b"none");
1041 }
1042}