Skip to main content

axum_bootstrap/util/
extractor.rs

1//! # HTTP 请求提取器模块
2//!
3//! 提供各种自定义的 Axum extractor,用于从 HTTP 请求中提取常用信息
4//!
5//! # 示例
6//!
7//! ```no_run
8//! use axum::{Router, routing::get};
9//! use axum_bootstrap::util::extractor::Host;
10//!
11//! async fn handler(Host(host): Host) -> String {
12//!     format!("Request host: {}", host)
13//! }
14//!
15//! let app = Router::new().route("/", get(handler));
16//! ```
17
18use std::io;
19
20use axum::{extract::FromRequestParts, http::request::Parts};
21
22use crate::error::AppError;
23
24/// Host extractor
25///
26/// 从 HTTP 请求中提取 Host 信息,兼容 HTTP/1.x 和 HTTP/2
27///
28/// # 工作原理
29///
30/// - **HTTP/1.x**: 从 `Host` header 中读取
31/// - **HTTP/2**: 优先从 `:authority` pseudo-header 读取,回退到 `Host` header
32///
33/// # 示例
34///
35/// ```no_run
36/// use axum::{Router, routing::get};
37/// use axum_bootstrap::util::extractor::Host;
38///
39/// async fn show_host(Host(host): Host) -> String {
40///     format!("Your host is: {}", host)
41/// }
42///
43/// let app = Router::new().route("/", get(show_host));
44/// ```
45///
46/// # 错误处理
47///
48/// 如果请求中没有 Host 信息,将返回 500 错误
49///
50/// # 可选 Host 提取
51///
52/// 如果你希望 Host 是可选的(不存在时不报错),可以使用 `Option<Host>`:
53///
54/// ```no_run
55/// use axum::{Router, routing::get};
56/// use axum_bootstrap::util::extractor::Host;
57///
58/// async fn show_host(host: Option<Host>) -> String {
59///     match host {
60///         Some(Host(h)) => format!("Your host is: {}", h),
61///         None => "No host provided".to_string(),
62///     }
63/// }
64///
65/// let app = Router::new().route("/", get(show_host));
66/// ```
67#[derive(Debug, Clone)]
68pub struct Host(pub String);
69
70impl<S> FromRequestParts<S> for Host
71where
72    S: Send + Sync,
73{
74    type Rejection = AppError;
75
76    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
77        // 以Host header为优先
78        // HTTP1 要求必须传递 Host header
79        // HTTP2 中用于特殊情况的需求,例如反向代理指定 Host
80        if let Some(host) = parts.headers.get("host") {
81            if let Ok(host_str) = host.to_str() {
82                return Ok(Host(host_str.to_string()));
83            }
84        }
85
86        // HTTP/2 使用 :authority pseudo-header
87        // 在 Axum/Hyper 中,:authority 会被转换为 URI 的 authority 部分
88        if let Some(authority) = parts.uri.authority() {
89            return Ok(Host(authority.to_string()));
90        }
91
92        // 无法获取 Host 信息
93        Err(AppError::new(io::Error::new(io::ErrorKind::InvalidInput, "Missing Host information in request")))
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use axum::http::Request;
101
102    #[tokio::test]
103    async fn test_host_from_header() {
104        let req = Request::builder().uri("/test").header("host", "example.com:8080").body(()).unwrap();
105
106        let (mut parts, _) = req.into_parts();
107        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
108
109        assert_eq!(host.0, "example.com:8080");
110    }
111
112    #[tokio::test]
113    async fn test_host_from_authority() {
114        let req = Request::builder().uri("http://example.com:8080/test").body(()).unwrap();
115
116        let (mut parts, _) = req.into_parts();
117        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
118
119        assert_eq!(host.0, "example.com:8080");
120    }
121
122    #[tokio::test]
123    async fn test_missing_host() {
124        let req = Request::builder().uri("/test").body(()).unwrap();
125
126        let (mut parts, _) = req.into_parts();
127        let result = Host::from_request_parts(&mut parts, &()).await;
128
129        assert!(result.is_err());
130    }
131
132    #[tokio::test]
133    async fn test_authority_precedence() {
134        // 当同时存在 URI authority 和 Host header 时,应该优先使用 authority
135        let req = Request::builder()
136            .uri("http://authority.com:8080/test")
137            .header("host", "header.com:9090")
138            .body(())
139            .unwrap();
140
141        let (mut parts, _) = req.into_parts();
142        let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
143
144        assert_eq!(host.0, "authority.com:8080");
145    }
146}