use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
str::FromStr,
sync::Arc,
};
use crate::logger::tracing::info;
use http::Uri;
use hyper::client::connect::dns::Name;
use rustls_native_certs::load_native_certs;
use tokio::net::TcpStream;
use tokio_rustls::{
client::TlsStream,
rustls::{self},
TlsConnector as TlsConnectorTokio,
};
use tower_service::Service;
use crate::triple::transport::resolver::{dns::DnsResolver, Resolve};
#[derive(Clone, Default)]
pub struct HttpsConnector<R = DnsResolver> {
resolver: R,
}
impl HttpsConnector {
pub fn new() -> Self {
Self {
resolver: DnsResolver::default(),
}
}
}
impl<R> HttpsConnector<R> {
pub fn new_with_resolver(resolver: R) -> HttpsConnector<R> {
Self { resolver }
}
}
impl<R> Service<Uri> for HttpsConnector<R>
where
R: Resolve + Clone + Send + Sync + 'static,
R::Future: Send,
{
type Response = TlsStream<TcpStream>;
type Error = crate::Error;
type Future = crate::BoxFuture<Self::Response, Self::Error>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.resolver.poll_ready(cx).map_err(|err| err.into())
}
fn call(&mut self, uri: Uri) -> Self::Future {
let mut inner = self.clone();
Box::pin(async move { inner.call_async(uri).await })
}
}
impl<R> HttpsConnector<R>
where
R: Resolve + Send + Sync + 'static,
{
async fn call_async(&mut self, uri: Uri) -> Result<TlsStream<TcpStream>, crate::Error> {
let host = uri.host().unwrap();
let port = uri.port_u16().unwrap();
let addr = if let Ok(addr) = host.parse::<Ipv4Addr>() {
info!("host is ip address: {:?}", host);
SocketAddr::V4(SocketAddrV4::new(addr, port))
} else {
info!("host is dns: {:?}", host);
let addrs = self
.resolver
.resolve(Name::from_str(host).unwrap())
.await
.map_err(|err| err.into())?;
let addrs: Vec<SocketAddr> = addrs
.map(|mut addr| {
addr.set_port(port);
addr
})
.collect();
addrs[0]
};
let mut root_store = rustls::RootCertStore::empty();
for cert in load_native_certs()? {
root_store.add(&rustls::Certificate(cert.0))?;
}
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
let connector = TlsConnectorTokio::from(Arc::new(config));
let stream = TcpStream::connect(&addr).await?;
let domain = rustls::ServerName::try_from(host).map_err(|err| {
crate::status::Status::new(crate::status::Code::Internal, err.to_string())
})?;
let stream = connector.connect(domain, stream).await?;
Ok(stream)
}
}