use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use nexus_async_rt::{AsyncRead, AsyncWrite, TcpStream};
#[cfg(feature = "tls")]
use nexus_net::buf::{ReadBuf, WriteBuf};
#[cfg(feature = "tls")]
const TMP_SIZE: usize = 8192;
#[cfg(feature = "tls")]
const _: () = assert!(
TMP_SIZE <= 16 * 1024,
"TMP_SIZE > 16 KiB requires handshake-piggyback fix (0.7.0)"
);
pub enum MaybeTls {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(Box<TlsInner>),
}
#[cfg(feature = "tls")]
pub struct TlsInner {
pub(crate) stream: TcpStream,
pub(crate) codec: nexus_net::tls::TlsCodec,
pending_read: ReadBuf,
pending_write: WriteBuf,
tmp: Box<[u8; TMP_SIZE]>,
}
#[cfg(feature = "tls")]
impl TlsInner {
pub(crate) const TMP_SIZE: usize = TMP_SIZE;
pub(crate) const DEFAULT_PENDING_WRITE_CAPACITY: usize = 65_536;
#[allow(dead_code)]
pub(crate) fn new(stream: TcpStream, codec: nexus_net::tls::TlsCodec) -> Self {
Self::with_capacities(
stream,
codec,
Self::TMP_SIZE,
Self::DEFAULT_PENDING_WRITE_CAPACITY,
)
}
pub(crate) fn with_capacities(
stream: TcpStream,
codec: nexus_net::tls::TlsCodec,
pending_read_cap: usize,
pending_write_cap: usize,
) -> Self {
assert!(
pending_read_cap >= Self::TMP_SIZE,
"pending_read_cap ({pending_read_cap}) must be >= TMP_SIZE ({})",
Self::TMP_SIZE,
);
Self {
stream,
codec,
pending_read: ReadBuf::with_capacity(pending_read_cap),
pending_write: WriteBuf::new(pending_write_cap, 0),
tmp: Box::new([0u8; TMP_SIZE]),
}
}
}
impl MaybeTls {
pub fn is_tls(&self) -> bool {
match self {
Self::Plain(_) => false,
#[cfg(feature = "tls")]
Self::Tls(_) => true,
}
}
}
impl AsyncRead for MaybeTls {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
loop {
let n = inner.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Poll::Ready(Ok(n));
}
if !inner.pending_read.is_empty() {
let consumed = inner
.codec
.read_tls_step(inner.pending_read.data())
.map_err(tls_to_io)?;
inner.pending_read.advance(consumed);
continue;
}
let n = match Pin::new(&mut inner.stream).poll_read(cx, &mut inner.tmp[..]) {
Poll::Ready(Ok(0)) => return Poll::Ready(Ok(0)), Poll::Ready(Ok(n)) => n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
let consumed = inner
.codec
.read_tls_step(&inner.tmp[..n])
.map_err(tls_to_io)?;
if consumed < n {
let rem_len = n - consumed;
let spare = inner.pending_read.spare();
spare[..rem_len].copy_from_slice(&inner.tmp[consumed..n]);
inner.pending_read.filled(rem_len);
}
}
}
}
}
}
impl AsyncWrite for MaybeTls {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
let consumed = inner.codec.try_encrypt(buf).map_err(tls_to_io)?;
if consumed == 0 {
cx.waker().wake_by_ref();
return Poll::Pending;
}
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
Poll::Ready(Ok(consumed))
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
Pin::new(&mut inner.stream).poll_flush(cx)
}
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
MaybeTls::Plain(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
MaybeTls::Tls(inner) => {
inner.codec.send_close_notify();
drain_codec_to_pending(inner, cx)?;
drain_pending(inner, cx)?;
if !inner.pending_write.is_empty() {
return Poll::Pending;
}
Pin::new(&mut inner.stream).poll_shutdown(cx)
}
}
}
}
#[cfg(feature = "tls")]
fn drain_pending(inner: &mut TlsInner, cx: &mut Context<'_>) -> io::Result<()> {
while !inner.pending_write.is_empty() {
match Pin::new(&mut inner.stream).poll_write(cx, inner.pending_write.data()) {
Poll::Ready(Ok(0)) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"transport write returned 0",
));
}
Poll::Ready(Ok(n)) => {
inner.pending_write.advance(n);
}
Poll::Ready(Err(e)) => return Err(e),
Poll::Pending => return Ok(()), }
}
Ok(())
}
#[cfg(feature = "tls")]
fn drain_codec_to_pending(inner: &mut TlsInner, cx: &mut Context<'_>) -> io::Result<()> {
while inner.codec.wants_write() {
if inner.pending_write.spare().is_empty() {
drain_pending(inner, cx)?;
if inner.pending_write.spare().is_empty() {
return Ok(());
}
}
let n = inner.codec.write_tls_to(&mut inner.pending_write.spare())?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"rustls reported wants_write but produced 0 bytes \
into a non-empty buffer",
));
}
inner.pending_write.filled(n);
drain_pending(inner, cx)?;
}
Ok(())
}
#[cfg(feature = "tls")]
fn tls_to_io(e: nexus_net::tls::TlsError) -> io::Error {
match e {
nexus_net::tls::TlsError::Io(io_err) => io_err,
other => io::Error::other(other),
}
}
#[cfg(all(test, feature = "tls"))]
mod tests {
use std::io::{Cursor, Write};
use std::sync::Arc;
use nexus_net::buf::ReadBuf;
use nexus_net::tls::{TlsCodec, TlsConfig};
fn generate_self_signed() -> (Vec<rustls::pki_types::CertificateDer<'static>>, Vec<u8>) {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.expect("cert generation");
(
vec![rustls::pki_types::CertificateDer::from(
cert.cert.der().to_vec(),
)],
cert.key_pair.serialize_der(),
)
}
fn connected_pair() -> (TlsCodec, rustls::ServerConnection) {
let (cert_chain, key_der) = generate_self_signed();
let key = rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap();
let server_config = Arc::new(
rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, key)
.unwrap(),
);
let mut server = rustls::ServerConnection::new(server_config).unwrap();
let client_config = TlsConfig::builder().danger_no_verify().build().unwrap();
let mut client = TlsCodec::new(&client_config, "localhost").unwrap();
let mut c2s = Vec::new();
let mut s2c = Vec::new();
for _ in 0..64 {
while client.wants_write() {
client.write_tls_to(&mut c2s).unwrap();
}
if !c2s.is_empty() {
server.read_tls(&mut Cursor::new(&c2s)).unwrap();
server.process_new_packets().unwrap();
c2s.clear();
}
while server.wants_write() {
server.write_tls(&mut s2c).unwrap();
}
if !s2c.is_empty() {
client.read_and_process_tls(&s2c).unwrap();
s2c.clear();
}
if !client.is_handshaking() && !server.is_handshaking() {
return (client, server);
}
}
panic!("TLS handshake did not complete");
}
fn encrypt_server_payload(server: &mut rustls::ServerConnection, payload: &[u8]) -> Vec<u8> {
server.writer().write_all(payload).unwrap();
let mut ciphertext = Vec::new();
while server.wants_write() {
server.write_tls(&mut ciphertext).unwrap();
}
ciphertext
}
#[test]
fn pending_read_flow_drains_plaintext_before_more_ciphertext() {
let (mut client, mut server) = connected_pair();
let payload = vec![b'x'; 64 * 1024];
let ciphertext = encrypt_server_payload(&mut server, &payload);
let chunk_size = 32 * 1024;
let mut pending_read = ReadBuf::with_capacity(chunk_size);
let mut plaintext = Vec::with_capacity(payload.len());
let mut offset = 0;
let mut dst = [0u8; 1024];
for _ in 0..1_000_000 {
let n = client.read_plaintext(&mut dst).unwrap();
if n > 0 {
plaintext.extend_from_slice(&dst[..n]);
if plaintext.len() == payload.len() {
break;
}
continue;
}
if !pending_read.is_empty() {
let consumed = client.read_tls_step(pending_read.data()).unwrap();
pending_read.advance(consumed);
continue;
}
if offset < ciphertext.len() {
let end = (offset + chunk_size).min(ciphertext.len());
let chunk = &ciphertext[offset..end];
let consumed = client.read_tls_step(chunk).unwrap();
if consumed < chunk.len() {
let rem = &chunk[consumed..];
let spare = pending_read.spare();
spare[..rem.len()].copy_from_slice(rem);
pending_read.filled(rem.len());
}
offset = end;
continue;
}
break;
}
assert_eq!(plaintext, payload);
assert_eq!(offset, ciphertext.len());
assert!(pending_read.is_empty());
}
}