opentls 0.2.1

TLS connections with OpenSSL.
Documentation
use crate::{
    async_io::{
        runtime::{AsyncRead, AsyncWrite},
        std_adapter::StdAdapter,
        TlsStream,
    },
    sync_io, Error, HandshakeError,
};
use openssl::ssl::MidHandshakeSslStream;
use std::{
    future::Future,
    io::{Read, Write},
    marker::Unpin,
    pin::Pin,
    ptr::null_mut,
    task::{Context, Poll},
};

pub(crate) async fn handshake<F, S>(f: F, stream: S) -> Result<TlsStream<S>, Error>
where
    F: FnOnce(StdAdapter<S>) -> Result<sync_io::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>> + Unpin,
    S: AsyncRead + AsyncWrite + Unpin,
{
    let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));

    match start.await {
        Err(e) => Err(e),
        Ok(StartedHandshake::Done(s)) => Ok(s),
        Ok(StartedHandshake::Mid(s)) => MidHandshake(Some(s)).await,
    }
}

struct MidHandshake<S>(Option<MidHandshakeSslStream<StdAdapter<S>>>);

enum StartedHandshake<S> {
    Done(TlsStream<S>),
    Mid(MidHandshakeSslStream<StdAdapter<S>>),
}

struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
struct StartedHandshakeFutureInner<F, S> {
    f: F,
    stream: S,
}

impl<F, S> Future for StartedHandshakeFuture<F, S>
where
    F: FnOnce(StdAdapter<S>) -> Result<sync_io::TlsStream<StdAdapter<S>>, HandshakeError<StdAdapter<S>>> + Unpin,
    S: Unpin,
    StdAdapter<S>: Read + Write,
{
    type Output = Result<StartedHandshake<S>, Error>;

    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<StartedHandshake<S>, Error>> {
        let inner = self.0.take().expect("future polled after completion");
        let stream = StdAdapter {
            inner: inner.stream,
            context: ctx as *mut _ as *mut (),
        };

        match (inner.f)(stream) {
            Ok(mut s) => {
                s.get_mut().context = null_mut();
                Poll::Ready(Ok(StartedHandshake::Done(TlsStream::new(s))))
            }
            Err(HandshakeError::WouldBlock(mut s)) => {
                s.get_mut().context = null_mut();
                Poll::Ready(Ok(StartedHandshake::Mid(s)))
            }
            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
        }
    }
}

impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshake<S> {
    type Output = Result<TlsStream<S>, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let mut_self = self.get_mut();
        let mut s = mut_self.0.take().expect("future polled after completion");

        s.get_mut().context = cx as *mut _ as *mut ();
        match s.handshake().map_err(HandshakeError::from) {
            Ok(stream) => Poll::Ready(Ok(TlsStream::new(sync_io::TlsStream(stream)))),
            Err(HandshakeError::Failure(e)) => Poll::Ready(Err(e)),
            Err(HandshakeError::WouldBlock(mut s)) => {
                s.get_mut().context = null_mut();
                mut_self.0 = Some(s);
                Poll::Pending
            }
        }
    }
}