axum_bootstrap/util/
extractor.rs1use std::io;
19
20use axum::{extract::FromRequestParts, http::request::Parts};
21
22use crate::error::AppError;
23
24#[derive(Debug, Clone)]
68pub struct Host(pub String);
69
70impl<S> FromRequestParts<S> for Host
71where
72 S: Send + Sync,
73{
74 type Rejection = AppError;
75
76 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
77 if let Some(host) = parts.headers.get("host") {
81 if let Ok(host_str) = host.to_str() {
82 return Ok(Host(host_str.to_string()));
83 }
84 }
85
86 if let Some(authority) = parts.uri.authority() {
89 return Ok(Host(authority.to_string()));
90 }
91
92 Err(AppError::new(io::Error::new(io::ErrorKind::InvalidInput, "Missing Host information in request")))
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use axum::http::Request;
101
102 #[tokio::test]
103 async fn test_host_from_header() {
104 let req = Request::builder().uri("/test").header("host", "example.com:8080").body(()).unwrap();
105
106 let (mut parts, _) = req.into_parts();
107 let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
108
109 assert_eq!(host.0, "example.com:8080");
110 }
111
112 #[tokio::test]
113 async fn test_host_from_authority() {
114 let req = Request::builder().uri("http://example.com:8080/test").body(()).unwrap();
115
116 let (mut parts, _) = req.into_parts();
117 let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
118
119 assert_eq!(host.0, "example.com:8080");
120 }
121
122 #[tokio::test]
123 async fn test_missing_host() {
124 let req = Request::builder().uri("/test").body(()).unwrap();
125
126 let (mut parts, _) = req.into_parts();
127 let result = Host::from_request_parts(&mut parts, &()).await;
128
129 assert!(result.is_err());
130 }
131
132 #[tokio::test]
133 async fn test_authority_precedence() {
134 let req = Request::builder()
136 .uri("http://authority.com:8080/test")
137 .header("host", "header.com:9090")
138 .body(())
139 .unwrap();
140
141 let (mut parts, _) = req.into_parts();
142 let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
143
144 assert_eq!(host.0, "authority.com:8080");
145 }
146}