use std::error::Error as StdError;
use std::fmt;
use std::pin::Pin;
use openssl::error::ErrorStack;
use openssl::ssl;
use openssl::x509::X509VerifyResult;
use futures::io::{AsyncRead, AsyncWrite};
use futures::prelude::*;
use futures::task::{Context, Poll};
use async_stdio::*;
use crate::SslStream;
pub struct MidHandshakeSslStream<S>(HandshakeInner<S>);
impl<S: fmt::Debug> fmt::Debug for MidHandshakeSslStream<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("MidHandshakeSslStream")
.field(&self.0)
.finish()
}
}
impl<S> MidHandshakeSslStream<S> {
pub(crate) fn new(inner: ssl::MidHandshakeSslStream<AsStdIo<S>>, ctrl: WakerCtrl) -> Self {
MidHandshakeSslStream(HandshakeInner {
inner: Some((inner, ctrl)),
})
}
}
struct HandshakeInner<S> {
inner: Option<(ssl::MidHandshakeSslStream<AsStdIo<S>>, WakerCtrl)>,
}
impl<S: fmt::Debug> fmt::Debug for HandshakeInner<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SslStream")
.field("inner", &self.inner)
.finish()
}
}
pub enum HandshakeError<S> {
SetupFailure(ErrorStack),
Failure(MidHandshakeSslStream<S>),
}
impl<S: fmt::Debug> fmt::Debug for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
HandshakeError::SetupFailure(ref e) => f.debug_tuple("SetupFailure").field(e).finish(),
HandshakeError::Failure(ref s) => f.debug_tuple("Failure").field(s).finish(),
}
}
}
impl<S: fmt::Debug> StdError for HandshakeError<S> {
fn description(&self) -> &str {
match *self {
HandshakeError::SetupFailure(_) => "stream setup failed",
HandshakeError::Failure(_) => "the handshake failed",
}
}
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match *self {
HandshakeError::SetupFailure(ref e) => Some(e),
HandshakeError::Failure(ref s) => {
s.0.inner.as_ref().map(|s| s.0.error() as &dyn StdError)
}
}
}
}
impl<S: fmt::Debug> fmt::Display for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(StdError::description(self))?;
match *self {
HandshakeError::SetupFailure(ref e) => write!(f, ": {}", e)?,
HandshakeError::Failure(ref s) => {
if let Some(s) = s.0.inner.as_ref().map(|s| &s.0) {
write!(f, ": {}", s.error())?;
let verify = s.ssl().verify_result();
if verify != X509VerifyResult::OK {
write!(f, ": {}", verify)?;
}
}
}
}
Ok(())
}
}
impl<S> From<ErrorStack> for HandshakeError<S> {
fn from(other: ErrorStack) -> Self {
HandshakeError::SetupFailure(other)
}
}
impl<S> HandshakeError<S> {
pub(crate) fn from_ssl(err: ssl::HandshakeError<AsStdIo<S>>, ctrl: WakerCtrl) -> Option<Self> {
Some(match err {
ssl::HandshakeError::SetupFailure(e) => HandshakeError::SetupFailure(e),
ssl::HandshakeError::Failure(inner) => {
HandshakeError::Failure(MidHandshakeSslStream::new(inner, ctrl))
}
_ => return None,
})
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Future for MidHandshakeSslStream<S> {
type Output = Result<SslStream<S>, HandshakeError<S>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::get_mut(self);
let mut handshake = &mut this.0;
let (inner, ctrl) = handshake
.inner
.take()
.expect("handshake polled after completion");
ctrl.register(cx.waker());
match inner.handshake() {
Ok(inner) => Poll::Ready(Ok(SslStream { inner, ctrl })),
Err(ssl::HandshakeError::WouldBlock(inner)) => {
handshake.inner = Some((inner, ctrl));
Poll::Pending
}
Err(err) => Poll::Ready(Err(HandshakeError::from_ssl(err, ctrl).unwrap())),
}
}
}