async_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`futures-io`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![warn(missing_docs)]
7
8use futures_io::{AsyncRead, AsyncWrite};
9use openssl::{
10    error::ErrorStack,
11    ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
12};
13use std::{
14    fmt, future,
15    io::{self, Read, Write},
16    pin::Pin,
17    task::{Context, Poll, Waker},
18};
19
20#[cfg(test)]
21mod test;
22
23struct StreamWrapper<S: Unpin> {
24    stream: S,
25    waker: Waker,
26}
27
28impl<S> fmt::Debug for StreamWrapper<S>
29where
30    S: fmt::Debug + Unpin,
31{
32    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33        self.stream.fmt(fmt)
34    }
35}
36
37impl<S: Unpin> StreamWrapper<S> {
38    fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
39        let stream = Pin::new(&mut self.stream);
40        let context = Context::from_waker(&self.waker);
41        (stream, context)
42    }
43}
44
45impl<S> Read for StreamWrapper<S>
46where
47    S: AsyncRead + Unpin,
48{
49    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
50        let (stream, mut cx) = self.parts();
51        match stream.poll_read(&mut cx, buf)? {
52            Poll::Ready(nread) => Ok(nread),
53            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
54        }
55    }
56}
57
58impl<S> Write for StreamWrapper<S>
59where
60    S: AsyncWrite + Unpin,
61{
62    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
63        let (stream, mut cx) = self.parts();
64        match stream.poll_write(&mut cx, buf) {
65            Poll::Ready(r) => r,
66            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
67        }
68    }
69
70    fn flush(&mut self) -> io::Result<()> {
71        let (stream, mut cx) = self.parts();
72        match stream.poll_flush(&mut cx) {
73            Poll::Ready(r) => r,
74            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
75        }
76    }
77}
78
79fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
80    match r {
81        Ok(v) => Poll::Ready(Ok(v)),
82        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
83        Err(e) => Poll::Ready(Err(e)),
84    }
85}
86
87fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
88    match r {
89        Ok(v) => Poll::Ready(Ok(v)),
90        Err(e) => match e.code() {
91            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
92            _ => Poll::Ready(Err(e)),
93        },
94    }
95}
96
97/// An asynchronous version of [`openssl::ssl::SslStream`].
98#[derive(Debug)]
99pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
100
101impl<S> SslStream<S>
102where
103    S: AsyncRead + AsyncWrite + Unpin,
104{
105    /// Like [`SslStream::new`](ssl::SslStream::new).
106    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
107        ssl::SslStream::new(
108            ssl,
109            StreamWrapper {
110                stream,
111                waker: Waker::noop().clone(),
112            },
113        )
114        .map(SslStream)
115    }
116
117    /// Like [`SslStream::connect`](ssl::SslStream::connect).
118    pub fn poll_connect(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121    ) -> Poll<Result<(), ssl::Error>> {
122        self.with_context(cx, |s| cvt_ossl(s.connect()))
123    }
124
125    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
126    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
127        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
128    }
129
130    /// Like [`SslStream::accept`](ssl::SslStream::accept).
131    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
132        self.with_context(cx, |s| cvt_ossl(s.accept()))
133    }
134
135    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
136    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
137        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
138    }
139
140    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
141    pub fn poll_do_handshake(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144    ) -> Poll<Result<(), ssl::Error>> {
145        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
146    }
147
148    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
149    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
150        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
151    }
152
153    /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
154    pub fn poll_peek(
155        self: Pin<&mut Self>,
156        cx: &mut Context<'_>,
157        buf: &mut [u8],
158    ) -> Poll<Result<usize, ssl::Error>> {
159        self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
160    }
161
162    /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
163    pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
164        future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
165    }
166
167    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
168    #[cfg(ossl111)]
169    pub fn poll_read_early_data(
170        self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172        buf: &mut [u8],
173    ) -> Poll<Result<usize, ssl::Error>> {
174        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
175    }
176
177    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
178    #[cfg(ossl111)]
179    pub async fn read_early_data(
180        mut self: Pin<&mut Self>,
181        buf: &mut [u8],
182    ) -> Result<usize, ssl::Error> {
183        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
184    }
185
186    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
187    #[cfg(ossl111)]
188    pub fn poll_write_early_data(
189        self: Pin<&mut Self>,
190        cx: &mut Context<'_>,
191        buf: &[u8],
192    ) -> Poll<Result<usize, ssl::Error>> {
193        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
194    }
195
196    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
197    #[cfg(ossl111)]
198    pub async fn write_early_data(
199        mut self: Pin<&mut Self>,
200        buf: &[u8],
201    ) -> Result<usize, ssl::Error> {
202        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
203    }
204}
205
206impl<S: Unpin> SslStream<S> {
207    /// Returns a shared reference to the `Ssl` object associated with this stream.
208    pub fn ssl(&self) -> &SslRef {
209        self.0.ssl()
210    }
211
212    /// Returns a shared reference to the underlying stream.
213    pub fn get_ref(&self) -> &S {
214        &self.0.get_ref().stream
215    }
216
217    /// Returns a mutable reference to the underlying stream.
218    pub fn get_mut(&mut self) -> &mut S {
219        &mut self.0.get_mut().stream
220    }
221
222    /// Returns a pinned mutable reference to the underlying stream.
223    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
224        Pin::new(&mut self.get_mut().0.get_mut().stream)
225    }
226
227    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
228    where
229        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
230    {
231        let this = self.get_mut();
232        this.0.get_mut().waker = ctx.waker().clone();
233        f(&mut this.0)
234    }
235}
236
237impl<S> AsyncRead for SslStream<S>
238where
239    S: AsyncRead + AsyncWrite + Unpin,
240{
241    fn poll_read(
242        self: Pin<&mut Self>,
243        ctx: &mut Context<'_>,
244        buf: &mut [u8],
245    ) -> Poll<io::Result<usize>> {
246        self.with_context(ctx, |s| cvt(s.read(buf)))
247    }
248}
249
250impl<S> AsyncWrite for SslStream<S>
251where
252    S: AsyncRead + AsyncWrite + Unpin,
253{
254    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
255        self.with_context(ctx, |s| cvt(s.write(buf)))
256    }
257
258    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
259        self.with_context(ctx, |s| cvt(s.flush()))
260    }
261
262    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
263        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
264            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
265            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
266            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
267                return Poll::Pending;
268            }
269            Err(e) => {
270                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
271            }
272        }
273
274        self.get_pin_mut().poll_close(ctx)
275    }
276}