use std::str::FromStr;
#[cfg(unix)]
use {std::path::PathBuf, tokio::net::UnixStream};
use tokio::net::TcpStream;
#[cfg(feature = "tls")]
use {
std::{net::SocketAddr, sync::Arc},
tokio_rustls::{
rustls::{
self,
pki_types::{pem::PemObject, CertificateDer, ServerName}
},
TlsConnector
}
};
use super::Stream;
use crate::{err::Error, pspec::ParsedSpec};
#[derive(Clone, Debug)]
pub struct TcpConnInfo {
pub addr: String
}
impl TryFrom<&ParsedSpec<'_>> for TcpConnInfo {
type Error = Error;
fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
Ok(Self {
addr: ps.addr.to_string()
})
}
}
impl FromStr for TcpConnInfo {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ps = ParsedSpec::parse(s)?;
Self::try_from(&ps)
}
}
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct UdsConnInfo {
pub fname: PathBuf
}
#[cfg(unix)]
impl TryFrom<&ParsedSpec<'_>> for UdsConnInfo {
type Error = Error;
fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
Ok(Self {
fname: PathBuf::from(ps.addr)
})
}
}
#[cfg(unix)]
impl FromStr for UdsConnInfo {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
fname: PathBuf::from(s)
})
}
}
#[cfg(feature = "tls")]
#[derive(Clone, Debug)]
pub struct TlsTcpConnInfo {
pub addr: String,
pub host: String,
pub ca_cert_pem: Vec<u8>
}
#[cfg(feature = "tls")]
impl TryFrom<&ParsedSpec<'_>> for TlsTcpConnInfo {
type Error = Error;
fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
let ca_cert_pem = if let Some(fname) = ps.args.get("cafile") {
std::fs::read(fname)?
} else {
return Err(Error::pki("No CA file specified"));
};
let host = ps.args.get("host").map_or_else(
|| {
ps.addr.parse::<SocketAddr>().map_or_else(
|_| {
if let Some((host, _)) = ps.addr.split_once(':') {
host.to_string()
} else {
ps.addr.to_string()
}
},
|sa| {
let ipaddr = sa.ip();
format!("{ipaddr}")
}
)
},
ToString::to_string
);
Ok(Self {
addr: ps.addr.to_string(),
host,
ca_cert_pem
})
}
}
#[cfg(feature = "tls")]
impl FromStr for TlsTcpConnInfo {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ps = ParsedSpec::parse(s)?;
Self::try_from(&ps)
}
}
#[derive(Clone, Debug)]
pub enum Connector {
Tcp(TcpConnInfo),
#[cfg(unix)]
Uds(UdsConnInfo),
#[cfg(feature = "tls")]
TlsTcp(TlsTcpConnInfo)
}
impl Connector {
pub fn tcp(s: &str) -> Result<Self, Error> {
Ok(Self::Tcp(TcpConnInfo::from_str(s)?))
}
#[cfg(unix)]
pub fn uds(s: &str) -> Result<Self, Error> {
Ok(Self::Uds(UdsConnInfo::from_str(s)?))
}
#[cfg(feature = "tls")]
pub fn tls_tcp(s: &str) -> Result<Self, Error> {
Ok(Self::TlsTcp(TlsTcpConnInfo::from_str(s)?))
}
}
impl Connector {
#[must_use]
pub fn display_addr(&self) -> String {
match self {
Self::Tcp(tci) => tci.addr.clone(),
#[cfg(unix)]
Self::Uds(uci) => uci.fname.display().to_string(),
#[cfg(feature = "tls")]
Self::TlsTcp(ttci) => ttci.addr.clone()
}
}
}
impl FromStr for Connector {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let ps = ParsedSpec::parse(s)?;
#[cfg(feature = "tls")]
if ps.is_tls()? {
let info = TlsTcpConnInfo::try_from(&ps)?;
return Ok(Self::from(info));
}
#[cfg(unix)]
if ps.is_uds() {
let info = UdsConnInfo::try_from(&ps)?;
return Ok(Self::from(info));
}
let info = TcpConnInfo::try_from(&ps)?;
Ok(Self::from(info))
}
}
impl From<TcpConnInfo> for Connector {
fn from(info: TcpConnInfo) -> Self {
Self::Tcp(info)
}
}
#[cfg(unix)]
impl From<UdsConnInfo> for Connector {
fn from(info: UdsConnInfo) -> Self {
Self::Uds(info)
}
}
#[cfg(feature = "tls")]
impl From<TlsTcpConnInfo> for Connector {
fn from(info: TlsTcpConnInfo) -> Self {
Self::TlsTcp(info)
}
}
impl Connector {
pub async fn connect(&self) -> Result<Stream, Error> {
match self {
Self::Tcp(info) => {
let strm = TcpStream::connect(&info.addr).await?;
Ok(Stream::Tcp(strm))
}
#[cfg(unix)]
Self::Uds(info) => {
let strm = UnixStream::connect(&info.fname).await?;
Ok(Stream::Uds(strm))
}
#[cfg(feature = "tls")]
Self::TlsTcp(info) => self.connect_tcp_tls(info).await
}
}
#[cfg(feature = "tls")]
async fn connect_tcp_tls(
&self,
info: &TlsTcpConnInfo
) -> Result<Stream, Error> {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in CertificateDer::pem_slice_iter(&info.ca_cert_pem) {
let cert = cert.map_err(|e| {
let msg = format!("Unable to deserialize certificate ({e})");
Error::pki(msg)
})?;
root_cert_store.add(cert).map_err(|e| {
let msg =
format!("Unable to add certificate to certificate store ({e})");
Error::pki(msg)
})?;
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let raw_stream = TcpStream::connect(&info.addr).await?;
let domain = ServerName::try_from(info.host.clone()).map_err(|e| {
let msg = format!("Invalid server name ({e})");
Error::pki(msg)
})?;
let strm = connector.connect(domain, raw_stream).await?;
Ok(Stream::TlsTcp(strm))
}
}