use crate::url::ParsedUrl;
use rustls::crypto::aws_lc_rs;
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned};
use std::io::{self, Read, Write};
use std::net::{IpAddr, TcpStream, ToSocketAddrs};
use std::sync::{Arc, OnceLock};
use std::time::Duration;
pub(crate) type TlsClientConfig = Arc<ClientConfig>;
pub(crate) fn shared_crypto_provider() -> &'static Arc<rustls::crypto::CryptoProvider> {
static PROVIDER: OnceLock<Arc<rustls::crypto::CryptoProvider>> = OnceLock::new();
PROVIDER.get_or_init(|| Arc::new(aws_lc_rs::default_provider()))
}
pub(crate) enum ClientStream {
Plain(TcpStream),
Tls(Box<StreamOwned<ClientConnection, TcpStream>>),
}
impl Read for ClientStream {
fn read(&mut self, buffer: &mut [u8]) -> io::Result<usize> {
match self {
Self::Plain(stream) => stream.read(buffer),
Self::Tls(stream) => stream.read(buffer),
}
}
}
impl Write for ClientStream {
fn write(&mut self, buffer: &[u8]) -> io::Result<usize> {
match self {
Self::Plain(stream) => stream.write(buffer),
Self::Tls(stream) => stream.write(buffer),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Plain(stream) => stream.flush(),
Self::Tls(stream) => stream.flush(),
}
}
}
pub(crate) fn load_tls_client_config() -> io::Result<TlsClientConfig> {
let cert_result = rustls_native_certs::load_native_certs();
let mut roots = RootCertStore::empty();
let (added, _) = roots.add_parsable_certificates(cert_result.certs);
if added == 0 {
return Err(io::Error::other(
"failed to load any native tls root certificates",
));
}
Ok(Arc::new(
ClientConfig::builder_with_provider(Arc::clone(shared_crypto_provider()))
.with_safe_default_protocol_versions()
.map_err(io::Error::other)?
.with_root_certificates(roots)
.with_no_client_auth(),
))
}
pub(crate) fn connect_tcp_stream(
host: &str,
port: u16,
connect_timeout: Duration,
read_timeout: Duration,
write_timeout: Duration,
) -> io::Result<Option<TcpStream>> {
let mut addrs = (host, port).to_socket_addrs()?;
let Some(addr) = addrs.next() else {
return Ok(None);
};
let stream = TcpStream::connect_timeout(&addr, connect_timeout)?;
stream.set_nodelay(true)?;
stream.set_read_timeout(Some(read_timeout))?;
stream.set_write_timeout(Some(write_timeout))?;
Ok(Some(stream))
}
pub(crate) fn wrap_client_stream(
url: &ParsedUrl,
stream: TcpStream,
tls_config: Option<&TlsClientConfig>,
) -> io::Result<ClientStream> {
if !url.uses_tls() {
return Ok(ClientStream::Plain(stream));
}
let server_name = server_name(&url.host)?;
let connection = ClientConnection::new(
Arc::clone(tls_config.ok_or_else(|| io::Error::other("missing tls client config"))?),
server_name,
)
.map_err(io::Error::other)?;
Ok(ClientStream::Tls(Box::new(StreamOwned::new(
connection, stream,
))))
}
fn server_name(host: &str) -> io::Result<ServerName<'static>> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(ServerName::IpAddress(ip.into()));
}
ServerName::try_from(host.to_string()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid tls server name in url host",
)
})
}