use core::{
future::Future,
ops::DerefMut,
pin::Pin,
task::{Context, Poll},
};
use std::io;
pub use rustls_crate::*;
use xitca_io::io::{AsyncIo, Interest, Ready};
pub struct TlsStream<C, Io> {
conn: C,
io: Io,
}
impl<C, S, Io> TlsStream<C, Io>
where
C: DerefMut<Target = ConnectionCommon<S>>,
S: SideData,
Io: io::Read + io::Write,
{
fn process_new_packets(&mut self) -> io::Result<()> {
match self.conn.process_new_packets() {
Ok(_) => Ok(()),
Err(e) => {
let _ = self.write_tls();
Err(io::Error::new(io::ErrorKind::InvalidData, e))
}
}
}
fn write_tls(&mut self) -> io::Result<usize> {
self.conn.write_tls(&mut self.io)
}
fn read_tls(&mut self) -> io::Result<usize> {
self.conn.read_tls(&mut self.io)
}
}
impl<C, S, Io> TlsStream<C, Io>
where
C: DerefMut<Target = ConnectionCommon<S>> + Unpin,
S: SideData,
Io: AsyncIo,
{
pub fn session(&self) -> &C {
&self.conn
}
pub async fn handshake(mut io: Io, mut conn: C) -> io::Result<Self> {
while conn.is_handshaking() {
if let Err(e) = conn.complete_io(&mut io) {
if !matches!(e.kind(), io::ErrorKind::WouldBlock) {
return Err(e);
}
let interest = match (conn.wants_read(), conn.wants_write()) {
(true, true) => Interest::READABLE | Interest::WRITABLE,
(true, false) => Interest::READABLE,
(false, true) => Interest::WRITABLE,
(false, false) => unreachable!(),
};
io.ready(interest).await?;
}
}
Ok(TlsStream { io, conn })
}
}
impl<C, Io, S> AsyncIo for TlsStream<C, Io>
where
C: DerefMut<Target = ConnectionCommon<S>> + Unpin,
S: SideData,
Io: AsyncIo,
{
#[inline]
fn ready(&mut self, interest: Interest) -> impl Future<Output = io::Result<Ready>> + Send {
self.io.ready(interest)
}
#[inline]
fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll<io::Result<Ready>> {
self.io.poll_ready(interest, cx)
}
fn is_vectored_write(&self) -> bool {
self.io.is_vectored_write()
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().io), cx)
}
}
impl<C, Io, S> io::Read for TlsStream<C, Io>
where
C: DerefMut<Target = ConnectionCommon<S>>,
S: SideData,
Io: AsyncIo,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.conn.wants_read() {
let n = self.read_tls()?;
self.process_new_packets()?;
if n == 0 {
break;
}
}
self.conn.reader().read(buf)
}
}
impl<C, Io, S> io::Write for TlsStream<C, Io>
where
C: DerefMut<Target = ConnectionCommon<S>>,
S: SideData,
Io: AsyncIo,
{
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
write_with(self, |writer| writer.write(buf))
}
#[inline]
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
write_with(self, |writer| writer.write_vectored(bufs))
}
fn flush(&mut self) -> io::Result<()> {
while self.conn.wants_write() {
if self.write_tls()? == 0 {
return Err(io::ErrorKind::WriteZero.into());
}
}
Ok(())
}
}
fn write_with<C, Io, S, F>(stream: &mut TlsStream<C, Io>, mut func: F) -> io::Result<usize>
where
Io: AsyncIo,
C: DerefMut<Target = ConnectionCommon<S>>,
S: SideData,
F: for<'r> FnMut(&mut Writer<'r>) -> io::Result<usize>,
{
loop {
match func(&mut stream.conn.writer())? {
0 if stream.conn.wants_write() => io::Write::flush(stream)?,
n => return Ok(n),
}
}
}