use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use nexus_async_rt::{AsyncRead, AsyncWrite, TcpStream};
#[cfg(feature = "tls")]
use nexus_net::buf::{ReadBuf, WriteBuf};
#[cfg(feature = "tls")]
use nexus_net::tls::{TlsBufferCapacities, TlsCodec, TlsError};
pub enum MaybeTls {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<TlsInner>),
}
#[cfg(feature = "tls")]
pub struct TlsInner {
pub(crate) stream: TcpStream,
pub(crate) codec: TlsCodec,
pending_read: ReadBuf,
pending_write: WriteBuf,
}
#[cfg(feature = "tls")]
impl TlsInner {
#[allow(clippy::future_not_send)]
pub async fn connect(
stream: TcpStream,
mut codec: TlsCodec,
capacities: TlsBufferCapacities,
) -> Result<Self, TlsError> {
if let Some(limit) = capacities.rustls_plaintext_limit() {
codec.set_buffer_limit(Some(limit));
}
let mut inner = Self {
stream,
codec,
pending_read: ReadBuf::with_capacity(capacities.read_chunk()),
pending_write: WriteBuf::new(capacities.pending_write(), 0),
};
inner.drive_handshake().await?;
Ok(inner)
}
#[allow(clippy::future_not_send)]
async fn drive_handshake(&mut self) -> Result<(), TlsError> {
while self.codec.is_handshaking() {
while self.codec.wants_write() {
if self.pending_write.spare().is_empty() {
handshake_drain_pending(self).await?;
if self.pending_write.spare().is_empty() {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"pending_write full and socket cannot accept \
during handshake",
)));
}
}
let n = self.codec.write_tls_to(&mut self.pending_write.spare())?;
if n == 0 {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"rustls reported wants_write but produced 0 bytes \
into a non-empty buffer during handshake",
)));
}
self.pending_write.filled(n);
handshake_drain_pending(self).await?;
}
handshake_drain_pending(self).await?;
if !self.codec.is_handshaking() {
break;
}
if self.pending_read.spare().is_empty() {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::InvalidData,
"pending_read full mid-handshake but rustls cannot \
decode a record",
)));
}
let n = handshake_read_into_spare(self).await?;
if n == 0 {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed during TLS handshake",
)));
}
while !self.pending_read.is_empty() && self.codec.is_handshaking() {
let consumed = self.codec.read_tls(self.pending_read.data())?;
if consumed == 0 {
break;
}
self.pending_read.advance(consumed);
}
}
while self.codec.wants_write() {
if self.pending_write.spare().is_empty() {
handshake_drain_pending(self).await?;
if self.pending_write.spare().is_empty() {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"pending_write full and socket cannot accept \
during handshake",
)));
}
}
let n = self.codec.write_tls_to(&mut self.pending_write.spare())?;
if n == 0 {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"rustls reported wants_write but produced 0 bytes \
into a non-empty buffer during handshake flush",
)));
}
self.pending_write.filled(n);
}
handshake_drain_pending(self).await?;
Ok(())
}
}
impl MaybeTls {
pub fn is_tls(&self) -> bool {
match self {
Self::Plain(_) => false,
#[cfg(feature = "tls")]
Self::Tls(_) => true,
}
}
}
impl AsyncRead for MaybeTls {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
loop {
let n = inner.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Poll::Ready(Ok(n));
}
if !inner.pending_read.is_empty() {
let consumed = inner
.codec
.read_tls(inner.pending_read.data())
.map_err(tls_to_io)?;
if consumed == 0 {
} else {
inner.pending_read.advance(consumed);
continue;
}
}
if inner.pending_read.spare().is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"pending_read full but rustls cannot decode \
a record",
)));
}
match Pin::new(&mut inner.stream).poll_read(cx, inner.pending_read.spare()) {
Poll::Ready(Ok(0)) => return Poll::Ready(Ok(0)), Poll::Ready(Ok(filled)) => {
inner.pending_read.filled(filled);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
}
}
}
}
impl AsyncWrite for MaybeTls {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
let consumed = inner.codec.encrypt(buf).map_err(tls_to_io)?;
if consumed == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"rustls plaintext queue limit smaller than \
remaining input — raise via \
TlsBufferCapacities::rustls_plaintext_limit \
or chunk the write into smaller pieces",
)));
}
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
Poll::Ready(Ok(consumed))
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
Pin::new(&mut inner.stream).poll_flush(cx)
}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
inner.codec.send_close_notify();
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
Pin::new(&mut inner.stream).poll_shutdown(cx)
}
}
}
}
impl nexus_net::WireStream for MaybeTls {
fn poll_fill_into<P: nexus_net::ParserSink>(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
sink: &mut P,
max: usize,
) -> Poll<io::Result<usize>> {
if max == 0 || sink.spare().is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"poll_fill_into called with no buffer space \
(max == 0 or sink.spare() is empty)",
)));
}
match self.get_mut() {
MaybeTls::Plain(s) => fill_via_nexus_async_read(Pin::new(s), cx, sink, max),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
loop {
let mut limited = LimitedSink::new(sink, max);
let n = inner
.codec
.drain_plaintext_into(&mut limited)
.map_err(tls_to_io)?;
if n > 0 {
return Poll::Ready(Ok(n));
}
if !inner.pending_read.is_empty() {
let consumed = inner
.codec
.read_tls(inner.pending_read.data())
.map_err(tls_to_io)?;
if consumed == 0 {
} else {
inner.pending_read.advance(consumed);
continue;
}
}
if inner.pending_read.spare().is_empty() {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"pending_read full but rustls cannot decode \
a record",
)));
}
match Pin::new(&mut inner.stream).poll_read(cx, inner.pending_read.spare()) {
Poll::Ready(Ok(0)) => return Poll::Ready(Ok(0)), Poll::Ready(Ok(filled)) => {
inner.pending_read.filled(filled);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
}
}
}
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
<Self as AsyncWrite>::poll_write(self, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as AsyncWrite>::poll_flush(self, cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
<Self as AsyncWrite>::poll_shutdown(self, cx)
}
}
#[cfg(feature = "tls")]
struct LimitedSink<'a, P: nexus_net::ParserSink> {
inner: &'a mut P,
remaining: usize,
}
#[cfg(feature = "tls")]
impl<'a, P: nexus_net::ParserSink> LimitedSink<'a, P> {
fn new(inner: &'a mut P, max: usize) -> Self {
Self {
inner,
remaining: max,
}
}
}
#[cfg(feature = "tls")]
impl<P: nexus_net::ParserSink> nexus_net::ParserSink for LimitedSink<'_, P> {
fn spare(&mut self) -> &mut [u8] {
let s = self.inner.spare();
let n = s.len().min(self.remaining);
&mut s[..n]
}
fn filled(&mut self, n: usize) {
self.inner.filled(n);
self.remaining = self.remaining.saturating_sub(n);
}
}
fn fill_via_nexus_async_read<S, P>(
stream: Pin<&mut S>,
cx: &mut Context<'_>,
sink: &mut P,
max: usize,
) -> Poll<io::Result<usize>>
where
S: AsyncRead + ?Sized,
P: nexus_net::ParserSink,
{
let spare = sink.spare();
let cap = spare.len().min(max);
match stream.poll_read(cx, &mut spare[..cap]) {
Poll::Ready(Ok(n)) => {
if n > 0 {
sink.filled(n);
}
Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
#[cfg(feature = "tls")]
fn drain_pending(inner: &mut TlsInner, cx: &mut Context<'_>) -> io::Result<()> {
while !inner.pending_write.is_empty() {
match Pin::new(&mut inner.stream).poll_write(cx, inner.pending_write.data()) {
Poll::Ready(Ok(0)) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"transport write returned 0",
));
}
Poll::Ready(Ok(n)) => {
inner.pending_write.advance(n);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Pending => return Ok(()),
}
}
Ok(())
}
#[cfg(feature = "tls")]
fn drain_codec_to_pending(inner: &mut TlsInner, cx: &mut Context<'_>) -> io::Result<()> {
while inner.codec.wants_write() {
if inner.pending_write.spare().is_empty() {
drain_pending(inner, cx)?;
if inner.pending_write.spare().is_empty() {
return Ok(());
}
}
let n = inner.codec.write_tls_to(&mut inner.pending_write.spare())?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"rustls reported wants_write but produced 0 bytes \
into a non-empty buffer",
));
}
inner.pending_write.filled(n);
drain_pending(inner, cx)?;
}
Ok(())
}
#[cfg(feature = "tls")]
fn tls_to_io(e: TlsError) -> io::Error {
match e {
TlsError::Io(io_err) => io_err,
other => io::Error::other(other),
}
}
#[cfg(feature = "tls")]
#[allow(clippy::future_not_send)] async fn handshake_drain_pending(inner: &mut TlsInner) -> Result<(), TlsError> {
use std::future::poll_fn;
while !inner.pending_write.is_empty() {
let n =
poll_fn(|cx| Pin::new(&mut inner.stream).poll_write(cx, inner.pending_write.data()))
.await
.map_err(TlsError::Io)?;
if n == 0 {
return Err(TlsError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"transport write returned 0 during TLS handshake",
)));
}
inner.pending_write.advance(n);
}
poll_fn(|cx| Pin::new(&mut inner.stream).poll_flush(cx))
.await
.map_err(TlsError::Io)?;
Ok(())
}
#[cfg(feature = "tls")]
#[allow(clippy::future_not_send)] async fn handshake_read_into_spare(inner: &mut TlsInner) -> Result<usize, TlsError> {
use std::future::poll_fn;
let n = poll_fn(|cx| {
let spare = inner.pending_read.spare();
Pin::new(&mut inner.stream).poll_read(cx, spare)
})
.await
.map_err(TlsError::Io)?;
inner.pending_read.filled(n);
Ok(n)
}