use std::{error::Error, future::Future};
use http::header::{self, HeaderName, HeaderValue};
use motore::{layer::Layer, service::Service};
use crate::{
client::{Target, target::RemoteHost, utils::is_default_port},
context::ClientContext,
error::client::{Result, builder_error},
request::Request,
};
#[derive(Clone, Debug)]
pub struct Header {
key: HeaderName,
val: HeaderValue,
}
impl Header {
pub fn new(key: HeaderName, val: HeaderValue) -> Self {
Self { key, val }
}
pub fn try_new<K, V>(key: K, val: V) -> Result<Self>
where
K: TryInto<HeaderName>,
K::Error: Error + Send + Sync + 'static,
V: TryInto<HeaderValue>,
V::Error: Error + Send + Sync + 'static,
{
let key = key.try_into().map_err(builder_error)?;
let val = val.try_into().map_err(builder_error)?;
Ok(Self::new(key, val))
}
}
impl<S> Layer<S> for Header {
type Service = HeaderService<S>;
fn layer(self, inner: S) -> Self::Service {
HeaderService {
inner,
key: self.key,
val: self.val,
}
}
}
pub struct HeaderService<S> {
inner: S,
key: HeaderName,
val: HeaderValue,
}
impl<Cx, B, S> Service<Cx, Request<B>> for HeaderService<S>
where
S: Service<Cx, Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
fn call(
&self,
cx: &mut Cx,
mut req: Request<B>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
req.headers_mut().insert(self.key.clone(), self.val.clone());
self.inner.call(cx, req)
}
}
#[derive(Clone, Debug, Default)]
pub enum Host {
None,
#[default]
Auto,
Force(HeaderValue),
Fallback(HeaderValue),
}
impl<S> Layer<S> for Host {
type Service = HostService<S>;
fn layer(self, inner: S) -> Self::Service {
HostService {
inner,
config: self,
}
}
}
pub struct HostService<S> {
inner: S,
config: Host,
}
#[cfg(target_family = "unix")]
const UDS_HOST: HeaderValue = HeaderValue::from_static("unix-domain-socket");
pub(super) fn gen_host(target: &Target) -> Option<HeaderValue> {
let rt = match target {
Target::None => return None,
Target::Remote(rt) => rt,
#[cfg(target_family = "unix")]
Target::Local(_) => return Some(UDS_HOST.clone()),
};
let default_port = is_default_port(&rt.scheme, rt.port);
match &rt.host {
RemoteHost::Ip(ip) => {
let s = if default_port {
if ip.is_ipv4() {
format!("{ip}")
} else {
format!("[{ip}]")
}
} else {
let port = rt.port;
if ip.is_ipv4() {
format!("{ip}:{port}")
} else {
format!("[{ip}]:{port}")
}
};
HeaderValue::from_str(&s).ok()
}
RemoteHost::Name(name) => {
let port = rt.port;
if default_port {
HeaderValue::from_str(name).ok()
} else {
HeaderValue::from_str(&format!("{name}:{port}")).ok()
}
}
}
}
impl<B, S> Service<ClientContext, Request<B>> for HostService<S>
where
S: Service<ClientContext, Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
fn call(
&self,
cx: &mut ClientContext,
mut req: Request<B>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
match &self.config {
Host::None => {}
Host::Auto => {
if !req.headers().contains_key(header::HOST) {
if let Some(val) = gen_host(cx.target()) {
req.headers_mut().insert(header::HOST, val);
}
}
}
Host::Force(val) => {
req.headers_mut().insert(header::HOST, val.clone());
}
Host::Fallback(val) => {
if !req.headers().contains_key(header::HOST) {
req.headers_mut().insert(header::HOST, val.clone());
}
}
}
self.inner.call(cx, req)
}
}
const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"));
pub struct UserAgent {
val: HeaderValue,
}
impl UserAgent {
pub fn new(val: HeaderValue) -> Self {
Self { val }
}
pub fn auto() -> Self {
Self {
val: HeaderValue::from_static(PKG_NAME_WITH_VER),
}
}
}
impl<S> Layer<S> for UserAgent {
type Service = UserAgentService<S>;
fn layer(self, inner: S) -> Self::Service {
UserAgentService {
inner,
val: self.val,
}
}
}
pub struct UserAgentService<S> {
inner: S,
val: HeaderValue,
}
impl<Cx, B, S> Service<Cx, Request<B>> for UserAgentService<S>
where
S: Service<Cx, Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
fn call(
&self,
cx: &mut Cx,
mut req: Request<B>,
) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send {
if !req.headers().contains_key(header::USER_AGENT) {
req.headers_mut()
.insert(header::USER_AGENT, self.val.clone());
}
self.inner.call(cx, req)
}
}
#[cfg(test)]
mod layer_header_tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use faststr::FastStr;
use http::uri::Scheme;
use crate::client::{
Target,
layer::header::gen_host,
target::{RemoteHost, RemoteTarget},
};
const IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
const fn host_target(scheme: Scheme, host: &'static str, port: u16) -> Target {
Target::Remote(RemoteTarget {
scheme,
host: RemoteHost::Name(FastStr::from_static_str(host)),
port,
})
}
const fn ip_target(scheme: Scheme, ip: IpAddr, port: u16) -> Target {
Target::Remote(RemoteTarget {
scheme,
host: RemoteHost::Ip(ip),
port,
})
}
#[test]
fn gen_host_test() {
assert_eq!(gen_host(&Target::None), None);
assert_eq!(
gen_host(&host_target(Scheme::HTTP, "github.com", 80)).unwrap(),
"github.com",
);
assert_eq!(
gen_host(&host_target(Scheme::HTTP, "github.com", 8000)).unwrap(),
"github.com:8000",
);
assert_eq!(
gen_host(&host_target(Scheme::HTTP, "github.com", 443)).unwrap(),
"github.com:443",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV4, 80)).unwrap(),
"127.0.0.1",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV4, 8000)).unwrap(),
"127.0.0.1:8000",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV4, 443)).unwrap(),
"127.0.0.1:443",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV6, 80)).unwrap(),
"[::1]",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV6, 8000)).unwrap(),
"[::1]:8000",
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTP, IPV6, 443)).unwrap(),
"[::1]:443",
);
}
#[cfg(feature = "__tls")]
#[test]
fn gen_host_with_tls_test() {
assert_eq!(
gen_host(&host_target(Scheme::HTTPS, "github.com", 443)).unwrap(),
"github.com",
);
assert_eq!(
gen_host(&host_target(Scheme::HTTPS, "github.com", 4430)).unwrap(),
"github.com:4430"
);
assert_eq!(
gen_host(&host_target(Scheme::HTTPS, "github.com", 80)).unwrap(),
"github.com:80"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV4, 443)).unwrap(),
"127.0.0.1"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV4, 4430)).unwrap(),
"127.0.0.1:4430"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV4, 80)).unwrap(),
"127.0.0.1:80"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV6, 443)).unwrap(),
"[::1]"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV6, 4430)).unwrap(),
"[::1]:4430"
);
assert_eq!(
gen_host(&ip_target(Scheme::HTTPS, IPV6, 80)).unwrap(),
"[::1]:80"
);
}
}