use boring::ssl::*;
pub use boring::*;
mod bridge;
mod callbacks;
use futures::{AsyncRead, AsyncWrite, Stream};
use std::error::Error;
use std::fmt;
use std::future::Future;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use bridge::*;
pub use callbacks::SslContextBuilderExt;
pub use boring::ssl::{
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
BoxSelectCertFuture, ExDataFuture,
};
pub async fn connect<S>(
config: ConnectConfiguration,
domain: &str,
stream: S,
) -> Result<SslStream<S>, HandshakeError<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mid_handshake = config
.setup_connect(domain, AsyncStreamBridge::new(stream))
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
HandshakeFuture(Some(mid_handshake)).await
}
pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mid_handshake = acceptor
.setup_accept(AsyncStreamBridge::new(stream))
.map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
HandshakeFuture(Some(mid_handshake)).await
}
fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
match r {
Ok(v) => Poll::Ready(Ok(v)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
pub struct SslStreamBuilder<S> {
inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
}
impl<S> SslStreamBuilder<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
Self {
inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
}
}
pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
let mid_handshake = self.inner.setup_accept();
HandshakeFuture(Some(mid_handshake)).await
}
pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
let mid_handshake = self.inner.setup_connect();
HandshakeFuture(Some(mid_handshake)).await
}
}
impl<S> SslStreamBuilder<S> {
pub fn ssl(&self) -> &SslRef {
self.inner.ssl()
}
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.inner.ssl_mut()
}
}
#[derive(Debug)]
pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
impl<S> SslStream<S> {
pub fn ssl(&self) -> &SslRef {
self.0.ssl()
}
pub fn ssl_mut(&mut self) -> &mut SslRef {
self.0.ssl_mut()
}
pub fn get_ref(&self) -> &S {
&self.0.get_ref().stream
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.0.get_mut().stream
}
fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
where
F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
{
self.0.get_mut().set_waker(Some(ctx));
let result = f(&mut self.0);
self.0.get_mut().set_waker(None);
result
}
}
impl<S> AsyncRead for SslStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.run_in_context(ctx, |s| cvt(s.read(buf)))
}
}
impl<S> AsyncWrite for SslStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
ctx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.run_in_context(ctx, |s| cvt(s.write(buf)))
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
self.run_in_context(ctx, |s| cvt(s.flush()))
}
fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
match self.run_in_context(ctx, |s| s.shutdown()) {
Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
return Poll::Pending;
}
Err(e) => {
return Poll::Ready(Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
}
}
Pin::new(&mut self.0.get_mut().stream).poll_close(ctx)
}
}
pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
impl<S> HandshakeError<S> {
pub fn ssl(&self) -> Option<&SslRef> {
match &self.0 {
ssl::HandshakeError::Failure(s) => Some(s.ssl()),
_ => None,
}
}
pub fn into_source_stream(self) -> Option<S> {
match self.0 {
ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
_ => None,
}
}
pub fn as_source_stream(&self) -> Option<&S> {
match &self.0 {
ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
_ => None,
}
}
pub fn code(&self) -> Option<ErrorCode> {
match &self.0 {
ssl::HandshakeError::Failure(s) => Some(s.error().code()),
_ => None,
}
}
pub fn as_io_error(&self) -> Option<&io::Error> {
match &self.0 {
ssl::HandshakeError::Failure(s) => s.error().io_error(),
_ => None,
}
}
}
impl<S> fmt::Debug for HandshakeError<S>
where
S: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl<S> fmt::Display for HandshakeError<S> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, fmt)
}
}
impl<S> std::error::Error for HandshakeError<S>
where
S: fmt::Debug,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0.source()
}
}
pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
impl<S> Future for HandshakeFuture<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<SslStream<S>, HandshakeError<S>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut mid_handshake = self.0.take().expect("future polled after completion");
mid_handshake.get_mut().set_waker(Some(ctx));
mid_handshake
.ssl_mut()
.set_task_waker(Some(ctx.waker().clone()));
match mid_handshake.handshake() {
Ok(mut stream) => {
stream.get_mut().set_waker(None);
stream.ssl_mut().set_task_waker(None);
Poll::Ready(Ok(SslStream(stream)))
}
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
mid_handshake.get_mut().set_waker(None);
mid_handshake.ssl_mut().set_task_waker(None);
self.0 = Some(mid_handshake);
Poll::Pending
}
Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
mid_handshake.get_mut().set_waker(None);
Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
mid_handshake,
))))
}
Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
Poll::Ready(Err(HandshakeError(err)))
}
}
}
}
pub struct SslListener<S> {
incoming: S,
acceptor: SslAcceptor,
}
impl<S> SslListener<S> {
pub fn on(incoming: S, acceptor: SslAcceptor) -> Self {
Self { incoming, acceptor }
}
pub async fn accept<I, E>(&mut self) -> std::io::Result<SslStream<I>>
where
S: Stream<Item = Result<I, E>> + Unpin,
I: AsyncRead + AsyncWrite + Unpin,
E: Error,
{
use futures::TryStreamExt;
while let Some(stream) = self
.incoming
.try_next()
.await
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?
{
let stream = accept(&self.acceptor, stream).await.map_err(|err| {
let err = std::io::Error::new(std::io::ErrorKind::Other, err.to_string());
log::error!("{}", err);
err
})?;
return Ok(stream);
}
Err(std::io::Error::new(
io::ErrorKind::BrokenPipe,
"Ssl listener inner stream broken.",
))
}
pub fn into_incoming<I, E>(self) -> impl Stream<Item = io::Result<SslStream<I>>> + Unpin
where
S: Stream<Item = Result<I, E>> + Unpin,
I: AsyncRead + AsyncWrite + Unpin,
E: Error,
{
Box::pin(futures::stream::unfold(self, |mut listener| async move {
let res = listener.accept().await;
Some((res, listener))
}))
}
}