use std::io::{self, Read, Write};
use super::codec::TlsCodec;
pub struct TlsStream<S> {
stream: S,
codec: TlsCodec,
#[cfg_attr(not(feature = "tokio"), allow(dead_code))]
pending_write: Vec<u8>,
}
impl<S> TlsStream<S> {
pub fn new(stream: S, codec: TlsCodec) -> Self {
Self {
stream,
codec,
pending_write: Vec::new(),
}
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn codec(&self) -> &TlsCodec {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut TlsCodec {
&mut self.codec
}
pub fn into_parts(self) -> (S, TlsCodec) {
(self.stream, self.codec)
}
}
impl<S: Read + Write> TlsStream<S> {
pub fn handshake(&mut self) -> Result<(), super::TlsError> {
while self.codec.is_handshaking() {
if self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
if self.codec.wants_read() {
self.codec.read_tls_from(&mut self.stream)?;
self.codec.process_new_packets()?;
}
}
if self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(())
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
loop {
let tls_n = self.codec.read_tls_from(&mut self.stream)?;
if tls_n == 0 {
return Ok(0); }
self.codec.process_new_packets().map_err(tls_to_io)?;
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
}
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.codec.encrypt(buf).map_err(tls_to_io)?;
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
self.stream.flush()
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> TlsStream<S> {
pub async fn handshake_async(&mut self) -> Result<(), super::TlsError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut tmp = [0u8; 8192];
while self.codec.is_handshaking() {
if self.codec.wants_write() {
let n = self.codec.write_tls_to(&mut tmp.as_mut_slice())?;
self.stream.write_all(&tmp[..n]).await?;
self.stream.flush().await?;
}
if self.codec.wants_read() {
let n = self.stream.read(&mut tmp).await?;
if n == 0 {
return Err(super::TlsError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed during TLS handshake",
)));
}
self.codec.read_tls(&tmp[..n])?;
self.codec.process_new_packets()?;
}
}
if self.codec.wants_write() {
let n = self.codec.write_tls_to(&mut tmp.as_mut_slice())?;
self.stream.write_all(&tmp[..n]).await?;
self.stream.flush().await?;
}
Ok(())
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> tokio::io::AsyncRead
for TlsStream<S>
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
let this = self.get_mut();
let slice = buf.initialize_unfilled();
let n = this.codec.read_plaintext(slice).map_err(tls_to_io)?;
if n > 0 {
buf.advance(n);
return std::task::Poll::Ready(Ok(()));
}
let mut tmp = [0u8; 8192];
let mut tmp_buf = tokio::io::ReadBuf::new(&mut tmp);
match std::pin::Pin::new(&mut this.stream).poll_read(cx, &mut tmp_buf) {
std::task::Poll::Ready(Ok(())) => {
let filled = tmp_buf.filled().len();
if filled == 0 {
return std::task::Poll::Ready(Ok(())); }
this.codec.read_tls(&tmp[..filled]).map_err(tls_to_io)?;
this.codec.process_new_packets().map_err(tls_to_io)?;
let slice = buf.initialize_unfilled();
let pn = this.codec.read_plaintext(slice).map_err(tls_to_io)?;
if pn > 0 {
buf.advance(pn);
std::task::Poll::Ready(Ok(()))
} else {
cx.waker().wake_by_ref();
std::task::Poll::Pending
}
}
std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)),
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
#[cfg(feature = "tokio")]
impl<S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite
for TlsStream<S>
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = tokio_drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
this.codec.encrypt(buf).map_err(tls_to_io)?;
let mut tmp = [0u8; 8192];
while this.codec.wants_write() {
let n = this.codec.write_tls_to(&mut tmp.as_mut_slice())?;
this.pending_write.extend_from_slice(&tmp[..n]);
}
if let Err(e) = tokio_drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
let this = self.get_mut();
let mut tmp = [0u8; 8192];
while this.codec.wants_write() {
let n = this.codec.write_tls_to(&mut tmp.as_mut_slice())?;
this.pending_write.extend_from_slice(&tmp[..n]);
}
if let Err(e) = tokio_drain_pending(this, cx) {
return std::task::Poll::Ready(Err(e));
}
if !this.pending_write.is_empty() {
return std::task::Poll::Pending;
}
std::pin::Pin::new(&mut this.stream).poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
std::pin::Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
}
}
#[cfg(feature = "tokio")]
fn tokio_drain_pending<S: tokio::io::AsyncWrite + Unpin>(
this: &mut TlsStream<S>,
cx: &mut std::task::Context<'_>,
) -> io::Result<()> {
while !this.pending_write.is_empty() {
match std::pin::Pin::new(&mut this.stream).poll_write(cx, &this.pending_write) {
std::task::Poll::Ready(Ok(0)) => {
return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
}
std::task::Poll::Ready(Ok(n)) => {
this.pending_write.drain(..n);
}
std::task::Poll::Ready(Err(e)) => return Err(e),
std::task::Poll::Pending => return Ok(()), }
}
Ok(())
}
fn tls_to_io(e: super::TlsError) -> io::Error {
match e {
super::TlsError::Io(io) => io,
other => io::Error::other(other),
}
}