use std::io::{self, Read, Write};
#[cfg(feature = "tokio")]
use crate::buf::{ReadBuf, WriteBuf};
use super::codec::TlsCodec;
#[cfg(feature = "tokio")]
const TMP_SIZE: usize = 8192;
#[cfg(feature = "tokio")]
const _: () = assert!(
TMP_SIZE <= 16 * 1024,
"TMP_SIZE > 16 KiB requires handshake-piggyback fix (0.7.0)"
);
pub struct TlsStream<S> {
stream: S,
codec: TlsCodec,
#[cfg(feature = "tokio")]
pending_read: ReadBuf,
#[cfg(feature = "tokio")]
pending_write: WriteBuf,
#[cfg(feature = "tokio")]
tmp: Box<[u8; TMP_SIZE]>,
}
impl<S> TlsStream<S> {
#[cfg(feature = "tokio")]
pub const TMP_SIZE: usize = TMP_SIZE;
#[cfg(feature = "tokio")]
pub const DEFAULT_PENDING_WRITE_CAPACITY: usize = 65_536;
pub fn new(stream: S, codec: TlsCodec) -> Self {
Self {
stream,
codec,
#[cfg(feature = "tokio")]
pending_read: ReadBuf::with_capacity(Self::TMP_SIZE),
#[cfg(feature = "tokio")]
pending_write: WriteBuf::new(Self::DEFAULT_PENDING_WRITE_CAPACITY, 0),
#[cfg(feature = "tokio")]
tmp: Box::new([0u8; TMP_SIZE]),
}
}
#[cfg(feature = "tokio")]
pub fn with_capacities(
stream: S,
codec: TlsCodec,
pending_read_cap: usize,
pending_write_cap: usize,
) -> Self {
assert!(
pending_read_cap >= Self::TMP_SIZE,
"pending_read_cap ({pending_read_cap}) must be >= TMP_SIZE ({})",
Self::TMP_SIZE,
);
Self {
stream,
codec,
pending_read: ReadBuf::with_capacity(pending_read_cap),
pending_write: WriteBuf::new(pending_write_cap, 0),
tmp: Box::new([0u8; TMP_SIZE]),
}
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn codec(&self) -> &TlsCodec {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut TlsCodec {
&mut self.codec
}
pub fn into_parts(self) -> (S, TlsCodec) {
(self.stream, self.codec)
}
pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
self.codec.set_buffer_limit(limit);
}
}
impl<S: Read + Write> TlsStream<S> {
pub fn handshake(&mut self) -> Result<(), super::TlsError> {
while self.codec.is_handshaking() {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
if self.codec.wants_read() {
self.codec.read_tls_from(&mut self.stream)?;
self.codec.process_new_packets()?;
}
}
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(())
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
loop {
let tls_n = self.codec.read_tls_from(&mut self.stream)?;
if tls_n == 0 {
return Ok(0); }
self.codec.process_new_packets().map_err(tls_to_io)?;
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
}
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
#[allow(deprecated)]
self.codec.encrypt(buf).map_err(tls_to_io)?;
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
self.stream.flush()
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> TlsStream<S> {
pub async fn handshake_async(&mut self) -> Result<(), super::TlsError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut tmp = [0u8; 8192];
while self.codec.is_handshaking() {
while self.codec.wants_write() {
let n = self.codec.write_tls_to(&mut tmp.as_mut_slice())?;
self.stream.write_all(&tmp[..n]).await?;
}
self.stream.flush().await?;
if self.codec.wants_read() {
let n = self.stream.read(&mut tmp).await?;
if n == 0 {
return Err(super::TlsError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed during TLS handshake",
)));
}
self.codec.read_and_process_tls(&tmp[..n])?;
}
}
while self.codec.wants_write() {
let n = self.codec.write_tls_to(&mut tmp.as_mut_slice())?;
self.stream.write_all(&tmp[..n]).await?;
}
self.stream.flush().await?;
Ok(())
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> tokio::io::AsyncRead
for TlsStream<S>
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
let this = self.get_mut();
loop {
let slice = buf.initialize_unfilled();
let n = this.codec.read_plaintext(slice).map_err(tls_to_io)?;
if n > 0 {
buf.advance(n);
return std::task::Poll::Ready(Ok(()));
}
if !this.pending_read.is_empty() {
let consumed = this
.codec
.read_tls_step(this.pending_read.data())
.map_err(tls_to_io)?;
this.pending_read.advance(consumed);
continue;
}
let filled = {
let mut tmp_buf = tokio::io::ReadBuf::new(&mut this.tmp[..]);
match std::pin::Pin::new(&mut this.stream).poll_read(cx, &mut tmp_buf) {
std::task::Poll::Ready(Ok(())) => tmp_buf.filled().len(),
std::task::Poll::Ready(Err(e)) => {
return std::task::Poll::Ready(Err(e));
}
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
if filled == 0 {
return std::task::Poll::Ready(Ok(())); }
let consumed = this
.codec
.read_tls_step(&this.tmp[..filled])
.map_err(tls_to_io)?;
if consumed < filled {
let rem_len = filled - consumed;
let spare = this.pending_read.spare();
spare[..rem_len].copy_from_slice(&this.tmp[consumed..filled]);
this.pending_read.filled(rem_len);
}
}
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite
for TlsStream<S>
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
if let Err(e) = drain_codec_to_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if let Err(e) = drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
let consumed = this.codec.try_encrypt(buf).map_err(tls_to_io)?;
if consumed == 0 {
cx.waker().wake_by_ref();
return std::task::Poll::Pending;
}
if let Err(e) = drain_codec_to_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if let Err(e) = drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
std::task::Poll::Ready(Ok(consumed))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
let this = self.get_mut();
if let Err(e) = drain_codec_to_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if let Err(e) = drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
std::pin::Pin::new(&mut this.stream).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
let this = self.get_mut();
this.codec.send_close_notify();
if let Err(e) = drain_codec_to_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if let Err(e) = drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
std::pin::Pin::new(&mut this.stream).poll_shutdown(cx)
}
}
#[cfg(feature = "tokio")]
fn drain_pending<S: tokio::io::AsyncWrite + Unpin>(
this: &mut TlsStream<S>,
cx: &mut std::task::Context<'_>,
) -> io::Result<()> {
while !this.pending_write.is_empty() {
match std::pin::Pin::new(&mut this.stream).poll_write(cx, this.pending_write.data()) {
std::task::Poll::Ready(Ok(0)) => {
return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
}
std::task::Poll::Ready(Ok(n)) => {
this.pending_write.advance(n);
}
std::task::Poll::Ready(Err(e)) => return Err(e),
std::task::Poll::Pending => return Ok(()), }
}
Ok(())
}
#[cfg(feature = "tokio")]
fn drain_codec_to_pending<S: tokio::io::AsyncWrite + Unpin>(
this: &mut TlsStream<S>,
cx: &mut std::task::Context<'_>,
) -> io::Result<()> {
while this.codec.wants_write() {
if this.pending_write.spare().is_empty() {
drain_pending(this, cx)?;
if this.pending_write.spare().is_empty() {
return Ok(());
}
}
let n = this.codec.write_tls_to(&mut this.pending_write.spare())?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"rustls reported wants_write but produced 0 bytes \
into a non-empty buffer",
));
}
this.pending_write.filled(n);
drain_pending(this, cx)?;
}
Ok(())
}
fn tls_to_io(e: super::TlsError) -> io::Error {
match e {
super::TlsError::Io(io) => io,
other => io::Error::other(other),
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::tls::TlsConfig;
fn make_codec() -> TlsCodec {
let cfg = TlsConfig::builder().danger_no_verify().build().unwrap();
TlsCodec::new(&cfg, "localhost").unwrap()
}
#[test]
fn with_capacities_at_minimum_succeeds() {
let _ = TlsStream::with_capacities(
(),
make_codec(),
TlsStream::<()>::TMP_SIZE,
TlsStream::<()>::DEFAULT_PENDING_WRITE_CAPACITY,
);
}
#[test]
fn new_uses_default_capacities() {
let _ = TlsStream::new((), make_codec());
}
#[test]
#[should_panic(expected = "TMP_SIZE")]
fn with_capacities_panics_on_undersized_pending_read() {
let _ = TlsStream::with_capacities(
(),
make_codec(),
TlsStream::<()>::TMP_SIZE - 1,
TlsStream::<()>::DEFAULT_PENDING_WRITE_CAPACITY,
);
}
}