use std::io;
use axum::{extract::FromRequestParts, http::request::Parts};
use crate::error::AppError;
#[derive(Debug, Clone)]
pub struct Host(pub String);
impl<S> FromRequestParts<S> for Host
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parts.headers.get("host") {
if let Ok(host_str) = host.to_str() {
return Ok(Host(host_str.to_string()));
}
}
if let Some(authority) = parts.uri.authority() {
return Ok(Host(authority.to_string()));
}
Err(AppError::new(io::Error::new(io::ErrorKind::InvalidInput, "Missing Host information in request")))
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::Request;
#[tokio::test]
async fn test_host_from_header() {
let req = Request::builder().uri("/test").header("host", "example.com:8080").body(()).unwrap();
let (mut parts, _) = req.into_parts();
let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
assert_eq!(host.0, "example.com:8080");
}
#[tokio::test]
async fn test_host_from_authority() {
let req = Request::builder().uri("http://example.com:8080/test").body(()).unwrap();
let (mut parts, _) = req.into_parts();
let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
assert_eq!(host.0, "example.com:8080");
}
#[tokio::test]
async fn test_missing_host() {
let req = Request::builder().uri("/test").body(()).unwrap();
let (mut parts, _) = req.into_parts();
let result = Host::from_request_parts(&mut parts, &()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_authority_precedence() {
let req = Request::builder()
.uri("http://authority.com:8080/test")
.header("host", "header.com:9090")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let host = Host::from_request_parts(&mut parts, &()).await.unwrap();
assert_eq!(host.0, "authority.com:8080");
}
}