use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use nexus_async_rt::{AsyncRead, AsyncWrite, TcpStream};
pub enum MaybeTls {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<TlsInner>),
}
#[cfg(feature = "tls")]
pub struct TlsInner {
pub(crate) stream: TcpStream,
pub(crate) codec: nexus_net::tls::TlsCodec,
pending_write: Vec<u8>,
}
#[cfg(feature = "tls")]
impl TlsInner {
pub(crate) fn new(stream: TcpStream, codec: nexus_net::tls::TlsCodec) -> Self {
Self {
stream,
codec,
pending_write: Vec::with_capacity(16_384),
}
}
}
impl MaybeTls {
pub fn is_tls(&self) -> bool {
match self {
Self::Plain(_) => false,
#[cfg(feature = "tls")]
Self::Tls(_) => true,
}
}
}
impl AsyncRead for MaybeTls {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
let n = inner.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Poll::Ready(Ok(n));
}
let mut tmp = [0u8; 8192];
match Pin::new(&mut inner.stream).poll_read(cx, &mut tmp) {
Poll::Ready(Ok(0)) => Poll::Ready(Ok(0)), Poll::Ready(Ok(n)) => {
inner.codec.read_tls(&tmp[..n]).map_err(tls_to_io)?;
inner.codec.process_new_packets().map_err(tls_to_io)?;
let pn = inner.codec.read_plaintext(buf).map_err(tls_to_io)?;
if pn > 0 {
Poll::Ready(Ok(pn))
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
}
}
impl AsyncWrite for MaybeTls {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
inner.codec.encrypt(buf).map_err(tls_to_io)?;
inner
.codec
.write_tls_to(&mut inner.pending_write)
.map_err(io::Error::other)?;
drain_pending(inner, cx)?;
Poll::Ready(Ok(buf.len()))
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
if inner.codec.wants_write() {
inner
.codec
.write_tls_to(&mut inner.pending_write)
.map_err(io::Error::other)?;
}
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
Pin::new(&mut inner.stream).poll_flush(cx)
}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => Pin::new(&mut inner.stream).poll_shutdown(cx),
}
}
}
#[cfg(feature = "tls")]
fn drain_pending(inner: &mut TlsInner, cx: &mut Context<'_>) -> io::Result<()> {
while !inner.pending_write.is_empty() {
match Pin::new(&mut inner.stream).poll_write(cx, &inner.pending_write) {
Poll::Ready(Ok(0)) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"transport write returned 0",
));
}
Poll::Ready(Ok(n)) => {
inner.pending_write.drain(..n);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Pending => return Ok(()), }
}
Ok(())
}
#[cfg(feature = "tls")]
fn tls_to_io(e: nexus_net::tls::TlsError) -> io::Error {
match e {
nexus_net::tls::TlsError::Io(io_err) => io_err,
other => io::Error::other(other),
}
}