protwrap 0.4.3

Thin protocol wrapper for network applications.
Documentation
//! Helpers for working on the end-points initiating connection requests.

pub mod connector;

use std::{
  future::Future,
  pin::Pin,
  task::{Context, Poll}
};

use tokio::{
  io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, Result},
  net::TcpStream
};

use futures::future::BoxFuture;

#[cfg(unix)]
use tokio::net::UnixStream;

#[cfg(feature = "tls")]
use tokio_rustls::client::TlsStream;

pub use connector::{Connector, TcpConnInfo};

#[cfg(unix)]
pub use connector::UdsConnInfo;

#[cfg(feature = "tls")]
pub use connector::TlsTcpConnInfo;


/// Representation of a stream acting as a client end-point (actively
/// established connection).
#[allow(clippy::large_enum_variant)]
pub enum Stream {
  /// TCP-based client stream.
  Tcp(TcpStream),

  /// Unix local domain client stream.
  #[cfg(unix)]
  Uds(UnixStream),

  /// TLS, based on TCP, client stream.
  #[cfg(feature = "tls")]
  TlsTcp(TlsStream<TcpStream>)
}

impl Stream {
  /// # Errors
  /// Returns `self` if variant isn't `tcp`.
  #[allow(clippy::result_large_err)]
  pub fn try_into_tcp(self) -> std::result::Result<TcpStream, Self> {
    if let Self::Tcp(strm) = self {
      Ok(strm)
    } else {
      Err(self)
    }
  }

  /// # Errors
  /// Returns `self` if variant isn't `uds`.
  #[cfg(unix)]
  #[allow(clippy::result_large_err)]
  pub fn try_into_uds(self) -> std::result::Result<UnixStream, Self> {
    if let Self::Uds(strm) = self {
      Ok(strm)
    } else {
      Err(self)
    }
  }

  /// # Errors
  /// Returns `self` if variant isn't `tlstcp`.
  #[cfg(unix)]
  #[allow(clippy::result_large_err)]
  pub fn try_into_tlstcp(
    self
  ) -> std::result::Result<TlsStream<TcpStream>, Self> {
    if let Self::TlsTcp(strm) = self {
      Ok(strm)
    } else {
      Err(self)
    }
  }
}

impl Stream {
  #[inline]
  pub const fn reqflush(&self) -> bool {
    match self {
      Self::Tcp(_) => false,
      #[cfg(unix)]
      Self::Uds(_) => false,
      #[cfg(feature = "tls")]
      Self::TlsTcp(_) => true
    }
  }

  pub fn ciphersuite(&self) -> Option<String> {
    match self {
      #[cfg(feature = "tls")]
      Self::TlsTcp(strm) => {
        let (_, conn) = strm.get_ref();
        let ciphersuite = conn.negotiated_cipher_suite()?;
        Some(format!("{:?}", ciphersuite.suite()))
      }
      _ => None
    }
  }
}

macro_rules! delegate_call {
  ($self:ident.$method:ident($($args:ident),+)) => {
    unsafe {
      match $self.get_unchecked_mut() {
        Self::Tcp(s) => Pin::new_unchecked(s).$method($($args),+),
        #[cfg(unix)]
        Self::Uds(s) => Pin::new_unchecked(s).$method($($args),+),
        #[cfg(feature = "tls")]
        Self::TlsTcp(s) => Pin::new_unchecked(s).$method($($args),+),
      }
    }
  }
}

impl AsyncRead for Stream {
  fn poll_read(
    self: Pin<&mut Self>,
    cx: &mut Context<'_>,
    buf: &mut ReadBuf<'_>
  ) -> Poll<Result<()>> {
    delegate_call!(self.poll_read(cx, buf))
  }
}

impl AsyncWrite for Stream {
  fn poll_write(
    self: Pin<&mut Self>,
    cx: &mut Context<'_>,
    buf: &[u8]
  ) -> Poll<Result<usize>> {
    delegate_call!(self.poll_write(cx, buf))
  }

  fn poll_flush(
    self: Pin<&mut Self>,
    cx: &mut Context<'_>
  ) -> Poll<tokio::io::Result<()>> {
    delegate_call!(self.poll_flush(cx))
  }

  fn poll_shutdown(
    self: Pin<&mut Self>,
    cx: &mut Context<'_>
  ) -> Poll<tokio::io::Result<()>> {
    delegate_call!(self.poll_shutdown(cx))
  }
}


/// Wrapper which forces shutdown.
///
/// rustls is picky about wanting `close_notify` being sent before closing the
/// write side of the connection.  Because we don't have `AsyncDrop` (yet?),
/// this wrapper can be used instead to make it a little more difficult to
/// forget to perform the cleanup.
///
/// # Errors
pub async fn with_conn<F, T>(mut strm: Stream, f: F) -> Result<T>
where
  F: for<'a> FnOnce(&'a mut Stream) -> BoxFuture<'a, Result<T>>,
  T: Send
{
  // Run application closure
  let res = f(&mut strm).await;

  // Flush connection
  strm.flush().await?;

  // Explicit shutdown.  This should trigger `close_notify`.
  strm.shutdown().await?;

  res
}

/// Wrapper which forces shutdown.
///
/// This serves the same role as [`with_conn()`], but passes ownership of the
/// connection to the application's closure.  The closure must return the
/// connection ownership when done.
///
/// # Errors
pub async fn with_conn_owned<F, Fut, T>(strm: Stream, f: F) -> Result<T>
where
  F: FnOnce(Stream) -> Fut,
  Fut: Future<Output = (Stream, Result<T>)>,
  T: Send
{
  // Run application closure
  let (mut strm, res) = f(strm).await;

  // Flush connection
  strm.flush().await?;

  // Explicit shutdown.  This should trigger `close_notify`.
  strm.shutdown().await?;

  res
}

// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 :