use btls::{
error::ErrorStack,
ssl::{self, ErrorCode, Ssl, SslRef, SslStream as SslStreamCore},
};
use compio::buf::{IoBuf, IoBufMut};
use compio::BufResult;
use compio_io::{compat::SyncStream, AsyncRead, AsyncWrite};
use std::error::Error;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::{fmt, io};
fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(e) => match e.code() {
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
_ => Poll::Ready(Err(e)),
},
}
}
#[derive(Debug)]
pub struct SslStream<S>(SslStreamCore<SyncStream<S>>);
impl<S: AsyncRead + AsyncWrite> SslStream<S> {
#[inline]
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
SslStreamCore::new(ssl, SyncStream::new(stream)).map(SslStream)
}
#[inline]
pub fn poll_connect(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), HandshakeError>> {
self.with_context(cx, |s| cvt_ossl(s.connect()))
.map_err(HandshakeError::Ssl)
}
#[inline]
pub async fn connect(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
self.drive_handshake(|s| s.connect()).await
}
#[inline]
pub fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), HandshakeError>> {
self.with_context(cx, |s| cvt_ossl(s.accept()))
.map_err(HandshakeError::Ssl)
}
#[inline]
pub async fn accept(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
self.drive_handshake(|s| s.accept()).await
}
#[inline]
pub fn poll_do_handshake(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), HandshakeError>> {
self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
.map_err(HandshakeError::Ssl)
}
#[inline]
pub async fn do_handshake(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
self.drive_handshake(|s| s.do_handshake()).await
}
async fn drive_handshake<F>(mut self: Pin<&mut Self>, mut f: F) -> Result<(), HandshakeError>
where
F: FnMut(&mut SslStreamCore<SyncStream<S>>) -> Result<(), ssl::Error>,
{
loop {
let res = {
let this = unsafe { self.as_mut().get_unchecked_mut() };
f(&mut this.0)
};
match res {
Ok(()) => {
self.as_mut()
.flush_write_buf()
.await
.map_err(HandshakeError::Io)?;
return Ok(());
}
Err(e) => match e.code() {
ErrorCode::WANT_WRITE => {
self.as_mut()
.flush_write_buf()
.await
.map_err(HandshakeError::Io)?;
}
ErrorCode::WANT_READ => {
self.as_mut()
.flush_write_buf()
.await
.map_err(HandshakeError::Io)?;
self.as_mut()
.fill_read_buf()
.await
.map_err(HandshakeError::Io)?;
}
_ => return Err(HandshakeError::Ssl(e)),
},
}
}
}
}
impl<S: AsyncRead + AsyncWrite> SslStream<S> {
async fn fill_read_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
this.0.get_mut().fill_read_buf().await
}
async fn flush_write_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
let this = unsafe { self.as_mut().get_unchecked_mut() };
this.0.get_mut().flush_write_buf().await
}
}
impl<S> SslStream<S> {
#[inline]
pub fn ssl(&self) -> &SslRef {
self.0.ssl()
}
#[inline]
pub fn get_ref(&self) -> &S {
self.0.get_ref().get_ref()
}
#[inline]
pub fn get_mut(&mut self) -> &mut S {
self.0.get_mut().get_mut()
}
#[inline]
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
unsafe {
let this = self.get_unchecked_mut();
Pin::new_unchecked(this.0.get_mut().get_mut())
}
}
fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut SslStreamCore<SyncStream<S>>) -> R,
{
let this = unsafe { self.get_unchecked_mut() };
this.0.ssl_mut().set_task_waker(Some(ctx.waker().clone()));
let r = f(&mut this.0);
this.0.ssl_mut().set_task_waker(None);
r
}
}
impl<S> AsyncRead for SslStream<S>
where
S: AsyncRead + AsyncWrite,
{
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
let slice = buf.as_uninit();
loop {
match self.0.read_uninit(slice) {
Ok(res) => {
unsafe { buf.advance_to(res) };
return BufResult(Ok(res), buf);
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
match self.0.get_mut().fill_read_buf().await {
Ok(_) => continue,
Err(e) => return BufResult(Err(e), buf),
}
}
res => return BufResult(res, buf),
}
}
}
}
impl<S> AsyncWrite for SslStream<S>
where
S: AsyncRead + AsyncWrite,
{
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let slice = buf.as_init();
loop {
let res = io::Write::write(&mut self.0, slice);
match res {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
Ok(_) => continue,
Err(e) => return BufResult(Err(e), buf),
},
_ => return BufResult(res, buf),
}
}
}
async fn flush(&mut self) -> io::Result<()> {
loop {
match io::Write::flush(&mut self.0) {
Ok(()) => break,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
self.0.get_mut().flush_write_buf().await?;
}
Err(e) => return Err(e),
}
}
self.0.get_mut().flush_write_buf().await?;
Ok(())
}
async fn shutdown(&mut self) -> io::Result<()> {
self.flush().await?;
self.0.get_mut().get_mut().shutdown().await
}
}
pub enum HandshakeError {
Ssl(ssl::Error),
Io(io::Error),
}
impl HandshakeError {
#[must_use]
pub fn code(&self) -> Option<ErrorCode> {
match self {
HandshakeError::Ssl(e) => Some(e.code()),
_ => None,
}
}
#[must_use]
pub fn as_io_error(&self) -> Option<&io::Error> {
match self {
HandshakeError::Ssl(e) => e.io_error(),
HandshakeError::Io(e) => Some(e),
}
}
}
impl fmt::Debug for HandshakeError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeError::Ssl(e) => fmt::Debug::fmt(e, fmt),
HandshakeError::Io(e) => fmt::Debug::fmt(e, fmt),
}
}
}
impl fmt::Display for HandshakeError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HandshakeError::Ssl(e) => fmt::Display::fmt(e, fmt),
HandshakeError::Io(e) => fmt::Display::fmt(e, fmt),
}
}
}
impl Error for HandshakeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
HandshakeError::Ssl(e) => e.source(),
HandshakeError::Io(e) => Some(e),
}
}
}