use http;
use http::HeaderValue;
use http::Uri;
use http::uri::Port;
use crate::service::http::HttpProtocol;
use super::http::HttpConnectionInfo;
fn is_schema_secure(uri: &Uri) -> bool {
uri.scheme_str()
.map(|scheme_str| matches!(scheme_str, "wss" | "https"))
.unwrap_or_default()
}
fn get_non_default_port(uri: &Uri) -> Option<Port<&str>> {
match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) {
(Some(443), true) => None,
(Some(80), false) => None,
_ => uri.port(),
}
}
fn set_host_header<B>(request: &mut http::Request<B>) {
if request.uri().host().is_none() {
tracing::debug!(uri=%request.uri(), "request uri has no host");
return;
}
let uri = request.uri().clone();
request
.headers_mut()
.entry(http::header::HOST)
.or_insert_with(|| {
let hostname = uri.host().expect("authority implies host");
if let Some(port) = get_non_default_port(&uri) {
let s = format!("{hostname}:{port}");
HeaderValue::from_str(&s)
} else {
HeaderValue::from_str(hostname)
}
.expect("uri host is valid header value")
});
}
#[derive(Debug, Default, Clone)]
pub struct SetHostHeader<S> {
inner: S,
}
impl<S, B> tower::Service<http::Request<B>> for SetHostHeader<S>
where
S: tower::Service<http::Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
if req.version() < http::Version::HTTP_2 {
set_host_header(&mut req);
}
self.inner.call(req)
}
}
impl<S, B, C> tower::Service<(C, http::Request<B>)> for SetHostHeader<S>
where
S: tower::Service<(C, http::Request<B>)>,
C: HttpConnectionInfo<B>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, (conn, mut req): (C, http::Request<B>)) -> Self::Future {
if conn.version() == HttpProtocol::Http1 {
set_host_header(&mut req);
}
self.inner.call((conn, req))
}
}
#[derive(Debug, Default, Clone)]
pub struct SetHostHeaderLayer {
_priv: (),
}
impl SetHostHeaderLayer {
pub fn new() -> Self {
Self { _priv: () }
}
}
impl<S> tower::layer::Layer<S> for SetHostHeaderLayer {
type Service = SetHostHeader<S>;
fn layer(&self, inner: S) -> Self::Service {
SetHostHeader { inner }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_host_header() {
let mut request = http::Request::new(());
*request.uri_mut() = "http://example.com".parse().unwrap();
set_host_header(&mut request);
assert_eq!(
request.headers().get(http::header::HOST).unwrap(),
"example.com"
);
let mut request = http::Request::new(());
*request.uri_mut() = "http://example.com:8080".parse().unwrap();
set_host_header(&mut request);
assert_eq!(
request.headers().get(http::header::HOST).unwrap(),
"example.com:8080"
);
let mut request = http::Request::new(());
*request.uri_mut() = "https://example.com".parse().unwrap();
set_host_header(&mut request);
assert_eq!(
request.headers().get(http::header::HOST).unwrap(),
"example.com"
);
let mut request = http::Request::new(());
*request.uri_mut() = "https://example.com:8443".parse().unwrap();
set_host_header(&mut request);
assert_eq!(
request.headers().get(http::header::HOST).unwrap(),
"example.com:8443"
);
}
#[test]
fn test_is_schema_secure() {
let uri = "http://example.com".parse().unwrap();
assert!(!is_schema_secure(&uri));
let uri = "https://example.com".parse().unwrap();
assert!(is_schema_secure(&uri));
let uri = "ws://example.com".parse().unwrap();
assert!(!is_schema_secure(&uri));
let uri = "wss://example.com".parse().unwrap();
assert!(is_schema_secure(&uri));
}
#[test]
fn test_get_non_default_port() {
let uri = "http://example.com".parse().unwrap();
assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None);
let uri = "http://example.com:8080".parse().unwrap();
assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8080));
let uri = "https://example.com".parse().unwrap();
assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), None);
let uri = "https://example.com:8443".parse().unwrap();
assert_eq!(get_non_default_port(&uri).map(|p| p.as_u16()), Some(8443));
}
}