#![allow(clippy::await_holding_refcell_ref)]
use core::{
cell::{Ref, RefCell},
fmt,
};
use std::{io, net::Shutdown};
pub use openssl::*;
use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream};
use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut};
use crate::bridge::{self, SyncBridge};
pub struct TlsStream<Io> {
io: Io,
tls: RefCell<SslStream<SyncBridge>>,
}
impl<Io> TlsStream<Io>
where
Io: AsyncBufRead + AsyncBufWrite,
{
pub async fn accept(ssl: Ssl, io: Io) -> Result<Self, Error> {
Self::handshake(ssl, io, |tls| tls.accept()).await
}
pub async fn connect(ssl: Ssl, io: Io) -> Result<Self, Error> {
Self::handshake(ssl, io, |tls| tls.connect()).await
}
async fn handshake<F>(ssl: Ssl, io: Io, mut func: F) -> Result<Self, Error>
where
F: FnMut(&mut SslStream<SyncBridge>) -> Result<(), openssl::ssl::Error>,
{
let bridge = SyncBridge::new();
let mut tls = SslStream::new(ssl, bridge)?;
loop {
match func(&mut tls) {
Ok(_) => {
bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?;
return Ok(TlsStream {
io,
tls: RefCell::new(tls),
});
}
Err(ref e) if e.code() == ErrorCode::WANT_WRITE => {
bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?;
}
Err(ref e) if e.code() == ErrorCode::WANT_READ => {
bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?;
bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?;
}
Err(e) => return Err(Error::Tls(e)),
}
}
}
pub fn session(&self) -> Ref<'_, SslRef> {
let tls = self.tls.borrow();
Ref::map(tls, |tls| tls.ssl())
}
}
impl<Io> AsyncBufRead for TlsStream<Io>
where
Io: AsyncBufRead + AsyncBufWrite,
{
async fn read<B>(&self, mut buf: B) -> (io::Result<usize>, B)
where
B: BoundedBufMut,
{
let spare = unsafe { bridge::spare_capacity_mut(&mut buf) };
let res = self.read_tls(spare).await;
if let Ok(n) = &res {
let init = buf.bytes_init();
unsafe { buf.set_init(init + n) };
}
(res, buf)
}
}
impl<Io> TlsStream<Io>
where
Io: AsyncBufRead + AsyncBufWrite,
{
async fn read_tls(&self, buf: &mut [u8]) -> io::Result<usize> {
let mut tls = self.tls.borrow_mut();
loop {
match io::Read::read(&mut *tls, buf) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
let mut read_buf = tls.get_mut().take_read_buf();
let len = read_buf.len();
read_buf.reserve(4096);
drop(tls);
let (res, buf) = self.io.read(read_buf.slice(len..)).await;
tls = self.tls.borrow_mut();
tls.get_mut().set_read_buf(buf.into_inner());
match res {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(_) => {}
Err(e) => return Err(e),
}
}
Err(e) => return Err(e),
}
}
}
}
impl<Io> AsyncBufWrite for TlsStream<Io>
where
Io: AsyncBufRead + AsyncBufWrite,
{
async fn write<B>(&self, buf: B) -> (io::Result<usize>, B)
where
B: BoundedBuf,
{
let data = buf.chunk();
let res = self.write_tls(data).await;
(res, buf)
}
async fn shutdown(mut self, direction: Shutdown) -> io::Result<()> {
let res = self.tls_shutdown().await;
let shutdown_res = self.io.shutdown(direction).await;
res?;
shutdown_res
}
}
impl<Io> TlsStream<Io>
where
Io: AsyncBufRead + AsyncBufWrite,
{
async fn write_tls(&self, buf: &[u8]) -> io::Result<usize> {
let mut tls = self.tls.borrow_mut();
loop {
match io::Write::write(&mut *tls, buf) {
Ok(n) => {
let buf = tls.get_mut().take_write_buf();
drop(tls);
let (res, buf) = bridge::drain_write(&self.io, buf).await;
tls = self.tls.borrow_mut();
tls.get_mut().set_write_buf(buf);
return res.map(|_| n);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
let buf = tls.get_mut().take_write_buf();
drop(tls);
let (res, buf) = bridge::drain_write(&self.io, buf).await;
tls = self.tls.borrow_mut();
tls.get_mut().set_write_buf(buf);
res?;
}
Err(e) => return Err(e),
}
}
}
async fn tls_shutdown(&mut self) -> io::Result<()> {
self.write_tls(&[]).await?;
let tls = self.tls.get_mut();
loop {
match tls.shutdown() {
Ok(ssl::ShutdownResult::Sent) | Ok(ssl::ShutdownResult::Received) => {
bridge::drain_write_buf(&self.io, tls.get_mut()).await?;
return Ok(());
}
Err(ref e) if e.code() == ErrorCode::WANT_WRITE => {
bridge::drain_write_buf(&self.io, tls.get_mut()).await?;
}
Err(ref e) if e.code() == ErrorCode::WANT_READ => {
bridge::drain_write_buf(&self.io, tls.get_mut()).await?;
bridge::fill_read_buf(&self.io, tls.get_mut()).await?;
}
Err(e) => {
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
}
}
}
}
}
#[derive(Debug)]
pub enum Error {
Io(io::Error),
Tls(openssl::ssl::Error),
}
impl From<openssl::error::ErrorStack> for Error {
fn from(e: openssl::error::ErrorStack) -> Self {
Self::Tls(e.into())
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(e) => fmt::Display::fmt(e, f),
Self::Tls(e) => fmt::Display::fmt(e, f),
}
}
}
impl std::error::Error for Error {}