protwrap 0.4.3

Thin protocol wrapper for network applications.
Documentation
//! Utility functions for establishing connections for common stream types.

use std::str::FromStr;

#[cfg(unix)]
use {std::path::PathBuf, tokio::net::UnixStream};

use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use {
  std::{net::SocketAddr, sync::Arc},
  tokio_rustls::{
    rustls::{
      self,
      pki_types::{pem::PemObject, CertificateDer, ServerName}
    },
    TlsConnector
  }
};


use super::Stream;

use crate::{err::Error, pspec::ParsedSpec};

/// Context used to establish plain TCP connections.
#[derive(Clone, Debug)]
pub struct TcpConnInfo {
  /// Socket address.
  pub addr: String
}

impl TryFrom<&ParsedSpec<'_>> for TcpConnInfo {
  type Error = Error;

  fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
    Ok(Self {
      addr: ps.addr.to_string()
    })
  }
}

impl FromStr for TcpConnInfo {
  type Err = Error;

  fn from_str(s: &str) -> Result<Self, Self::Err> {
    let ps = ParsedSpec::parse(s)?;
    Self::try_from(&ps)
  }
}


/// Context used to establish unix local domain connections.
#[cfg(unix)]
#[derive(Clone, Debug)]
pub struct UdsConnInfo {
  /// Socket address pathname.
  pub fname: PathBuf
}

#[cfg(unix)]
impl TryFrom<&ParsedSpec<'_>> for UdsConnInfo {
  type Error = Error;

  fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
    Ok(Self {
      fname: PathBuf::from(ps.addr)
    })
  }
}

#[cfg(unix)]
impl FromStr for UdsConnInfo {
  type Err = Error;

  fn from_str(s: &str) -> Result<Self, Self::Err> {
    Ok(Self {
      fname: PathBuf::from(s)
    })
  }
}


/// Context used to establish TLS (based on TCP) connections.
#[cfg(feature = "tls")]
#[derive(Clone, Debug)]
pub struct TlsTcpConnInfo {
  /// Socket address.
  pub addr: String,
  pub host: String,
  pub ca_cert_pem: Vec<u8>
}

#[cfg(feature = "tls")]
impl TryFrom<&ParsedSpec<'_>> for TlsTcpConnInfo {
  type Error = Error;

  fn try_from(ps: &ParsedSpec<'_>) -> Result<Self, Self::Error> {
    let ca_cert_pem = if let Some(fname) = ps.args.get("cafile") {
      std::fs::read(fname)?
    } else {
      return Err(Error::pki("No CA file specified"));
    };

    // If a host name has been explicitly set, then use it.
    // Otherwise, use target address.  First attempt to parse it as an socket
    // address and remove its port.  If it's not an IP address, then assume
    // it's a host name
    let host = ps.args.get("host").map_or_else(
      || {
        ps.addr.parse::<SocketAddr>().map_or_else(
          |_| {
            if let Some((host, _)) = ps.addr.split_once(':') {
              host.to_string()
            } else {
              ps.addr.to_string()
            }
          },
          |sa| {
            let ipaddr = sa.ip();
            format!("{ipaddr}")
          }
        )
      },
      ToString::to_string
    );

    Ok(Self {
      addr: ps.addr.to_string(),
      host,
      ca_cert_pem
    })
  }
}

#[cfg(feature = "tls")]
impl FromStr for TlsTcpConnInfo {
  type Err = Error;

  // "{[::1]:8443,key=/tmp/key.pem,cert=/tmp/cert.pem,cacert=/tmp/cacert.pem}"
  fn from_str(s: &str) -> Result<Self, Self::Err> {
    let ps = ParsedSpec::parse(s)?;
    Self::try_from(&ps)
  }
}


/// Protocol-specific connector helper.
#[derive(Clone, Debug)]
pub enum Connector {
  Tcp(TcpConnInfo),

  #[cfg(unix)]
  Uds(UdsConnInfo),

  #[cfg(feature = "tls")]
  TlsTcp(TlsTcpConnInfo)
}

impl Connector {
  /// Create a TCP connector from a string.
  ///
  /// # Errors
  /// This function will fail if the target address specification could not be
  /// parsed.
  pub fn tcp(s: &str) -> Result<Self, Error> {
    Ok(Self::Tcp(TcpConnInfo::from_str(s)?))
  }

  /// Create an unix domain socket connector from a string.
  ///
  /// # Errors
  /// This function will fail if the target address specification could not be
  /// parsed.
  #[cfg(unix)]
  pub fn uds(s: &str) -> Result<Self, Error> {
    Ok(Self::Uds(UdsConnInfo::from_str(s)?))
  }

  /// Create an TCP/TLS socket connector from a string.
  ///
  /// # Errors
  /// This function will fail if the target address specification could not be
  /// parsed.
  #[cfg(feature = "tls")]
  pub fn tls_tcp(s: &str) -> Result<Self, Error> {
    Ok(Self::TlsTcp(TlsTcpConnInfo::from_str(s)?))
  }
}

impl Connector {
  /// Return a displayable string representation of the `Connector`'s target
  /// address.
  ///
  /// The returned string is not intended to be reversible.
  #[must_use]
  pub fn display_addr(&self) -> String {
    match self {
      Self::Tcp(tci) => tci.addr.clone(),
      #[cfg(unix)]
      Self::Uds(uci) => uci.fname.display().to_string(),
      #[cfg(feature = "tls")]
      Self::TlsTcp(ttci) => ttci.addr.clone()
    }
  }
}

impl FromStr for Connector {
  type Err = Error;

  fn from_str(s: &str) -> Result<Self, Self::Err> {
    let ps = ParsedSpec::parse(s)?;

    #[cfg(feature = "tls")]
    if ps.is_tls()? {
      let info = TlsTcpConnInfo::try_from(&ps)?;
      return Ok(Self::from(info));
    }

    #[cfg(unix)]
    if ps.is_uds() {
      let info = UdsConnInfo::try_from(&ps)?;
      return Ok(Self::from(info));
    }

    // Assume plain TCP
    let info = TcpConnInfo::try_from(&ps)?;
    Ok(Self::from(info))
  }
}

impl From<TcpConnInfo> for Connector {
  fn from(info: TcpConnInfo) -> Self {
    Self::Tcp(info)
  }
}

#[cfg(unix)]
impl From<UdsConnInfo> for Connector {
  fn from(info: UdsConnInfo) -> Self {
    Self::Uds(info)
  }
}

#[cfg(feature = "tls")]
impl From<TlsTcpConnInfo> for Connector {
  fn from(info: TlsTcpConnInfo) -> Self {
    Self::TlsTcp(info)
  }
}

impl Connector {
  /// # Errors
  /// [`Error::IO`] indicates failure to establish connections.
  ///
  /// # Panics
  /// For now, this function will panic if:
  /// - An invalid cipher-suite configuration has been chosen.
  /// - rustls is unable to look up `localhost` name
  ///
  /// This will change in the future.
  pub async fn connect(&self) -> Result<Stream, Error> {
    match self {
      Self::Tcp(info) => {
        let strm = TcpStream::connect(&info.addr).await?;
        Ok(Stream::Tcp(strm))
      }

      #[cfg(unix)]
      Self::Uds(info) => {
        let strm = UnixStream::connect(&info.fname).await?;
        Ok(Stream::Uds(strm))
      }

      #[cfg(feature = "tls")]
      Self::TlsTcp(info) => self.connect_tcp_tls(info).await
    }
  }

  #[cfg(feature = "tls")]
  async fn connect_tcp_tls(
    &self,
    info: &TlsTcpConnInfo
  ) -> Result<Stream, Error> {
    let mut root_cert_store = rustls::RootCertStore::empty();
    for cert in CertificateDer::pem_slice_iter(&info.ca_cert_pem) {
      let cert = cert.map_err(|e| {
        let msg = format!("Unable to deserialize certificate ({e})");
        Error::pki(msg)
      })?;
      root_cert_store.add(cert).map_err(|e| {
        let msg =
          format!("Unable to add certificate to certificate store ({e})");
        Error::pki(msg)
      })?;
    }

    let config = rustls::ClientConfig::builder()
      .with_root_certificates(root_cert_store)
      .with_no_client_auth();

    let connector = TlsConnector::from(Arc::new(config));

    let raw_stream = TcpStream::connect(&info.addr).await?;

    let domain = ServerName::try_from(info.host.clone()).map_err(|e| {
      let msg = format!("Invalid server name ({e})");
      Error::pki(msg)
    })?;

    let strm = connector.connect(domain, raw_stream).await?;

    Ok(Stream::TlsTcp(strm))
  }
}

// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 :