use alloc::vec::Vec;
use core::pin::Pin;
use core::task::{Context, Poll, ready};
use std::io;
use ::tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use super::{Connection, Error, Step};
fn ioerr(e: Error) -> io::Error {
io::Error::other(e)
}
async fn yield_once() {
struct YieldOnce(bool);
impl core::future::Future for YieldOnce {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
YieldOnce(false).await;
}
pub struct TlsStream<S> {
conn: Connection,
sock: S,
rbuf: Vec<u8>,
rpos: usize,
wbuf: Vec<u8>,
wpos: usize,
}
impl<S: AsyncRead + AsyncWrite + Unpin> TlsStream<S> {
pub async fn handshake(mut conn: Connection, mut sock: S) -> io::Result<Self> {
let mut rd = [0u8; 16 * 1024];
loop {
match conn.drive().map_err(ioerr)? {
Step::WantWrite => {
let out = conn.pop().map_err(ioerr)?;
if !out.is_empty() {
sock.write_all(&out).await?;
sock.flush().await?;
}
}
Step::WantRead => {
let n = sock.read(&mut rd).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"peer closed during handshake",
));
}
let mut fed = 0;
while fed < n {
fed += conn.feed(&rd[fed..n]).map_err(ioerr)?;
}
}
Step::WantSigner(readiness) => {
#[cfg(unix)]
if let Some(r) = readiness {
use ::tokio::io::Interest;
use ::tokio::io::unix::AsyncFd;
let afd = AsyncFd::with_interest(r, Interest::READABLE)?;
let mut guard = afd.readable().await?;
guard.clear_ready();
continue;
}
let _ = &readiness;
yield_once().await;
}
Step::Complete => break,
}
}
Ok(TlsStream {
conn,
sock,
rbuf: Vec::new(),
rpos: 0,
wbuf: Vec::new(),
wpos: 0,
})
}
pub fn negotiated_version(&self) -> Option<super::ProtocolVersion> {
self.conn.negotiated_version()
}
pub fn into_inner(self) -> (Connection, S) {
(self.conn, self.sock)
}
fn flush_wbuf(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
while self.wpos < self.wbuf.len() {
match Pin::new(&mut self.sock).poll_write(cx, &self.wbuf[self.wpos..]) {
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
Poll::Ready(Ok(n)) => self.wpos += n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
self.wbuf.clear();
self.wpos = 0;
Poll::Ready(Ok(()))
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
loop {
if this.rpos < this.rbuf.len() {
let n = (this.rbuf.len() - this.rpos).min(buf.remaining());
buf.put_slice(&this.rbuf[this.rpos..this.rpos + n]);
this.rpos += n;
if this.rpos == this.rbuf.len() {
this.rbuf.clear();
this.rpos = 0;
}
return Poll::Ready(Ok(()));
}
let pt = this.conn.recv().map_err(ioerr)?;
if !pt.is_empty() {
this.rbuf = pt;
this.rpos = 0;
continue;
}
let mut tmp = [0u8; 16 * 1024];
let mut rb = ReadBuf::new(&mut tmp);
match Pin::new(&mut this.sock).poll_read(cx, &mut rb) {
Poll::Ready(Ok(())) => {
let filled = rb.filled();
if filled.is_empty() {
return Poll::Ready(Ok(()));
}
let mut fed = 0;
while fed < filled.len() {
fed += this.conn.feed(&filled[fed..]).map_err(ioerr)?;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
ready!(this.flush_wbuf(cx))?;
this.conn.send(buf).map_err(ioerr)?;
let out = this.conn.pop().map_err(ioerr)?;
this.wbuf.extend_from_slice(&out);
if let Poll::Ready(Err(e)) = this.flush_wbuf(cx) {
return Poll::Ready(Err(e));
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
ready!(this.flush_wbuf(cx))?;
Pin::new(&mut this.sock).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
ready!(this.flush_wbuf(cx))?;
this.conn.close().map_err(ioerr)?;
let out = this.conn.pop().map_err(ioerr)?;
this.wbuf.extend_from_slice(&out);
ready!(this.flush_wbuf(cx))?;
Pin::new(&mut this.sock).poll_shutdown(cx)
}
}