Skip to main content

async_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`futures_io`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
7#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
8#![warn(clippy::must_use_candidate, clippy::unwrap_in_result, clippy::panic_in_result_fn)]
9
10use futures_io::{AsyncRead, AsyncWrite};
11use openssl::{
12    error::ErrorStack,
13    ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
14};
15use std::{
16    fmt, future,
17    io::{self, Read, Write},
18    pin::Pin,
19    task::{Context, Poll, Waker},
20};
21
22#[cfg(test)]
23mod test;
24
25struct StreamWrapper<S: Unpin> {
26    stream: S,
27    waker: Waker,
28}
29
30impl<S> fmt::Debug for StreamWrapper<S>
31where
32    S: fmt::Debug + Unpin,
33{
34    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
35        self.stream.fmt(fmt)
36    }
37}
38
39impl<S: Unpin> StreamWrapper<S> {
40    fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
41        let stream = Pin::new(&mut self.stream);
42        let context = Context::from_waker(&self.waker);
43        (stream, context)
44    }
45}
46
47impl<S> Read for StreamWrapper<S>
48where
49    S: AsyncRead + Unpin,
50{
51    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52        let (stream, mut cx) = self.parts();
53        match stream.poll_read(&mut cx, buf)? {
54            Poll::Ready(nread) => Ok(nread),
55            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
56        }
57    }
58}
59
60impl<S> Write for StreamWrapper<S>
61where
62    S: AsyncWrite + Unpin,
63{
64    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
65        let (stream, mut cx) = self.parts();
66        match stream.poll_write(&mut cx, buf) {
67            Poll::Ready(r) => r,
68            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
69        }
70    }
71
72    fn flush(&mut self) -> io::Result<()> {
73        let (stream, mut cx) = self.parts();
74        match stream.poll_flush(&mut cx) {
75            Poll::Ready(r) => r,
76            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
77        }
78    }
79}
80
81fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
82    match r {
83        Ok(v) => Poll::Ready(Ok(v)),
84        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
85        Err(e) => Poll::Ready(Err(e)),
86    }
87}
88
89fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
90    match r {
91        Ok(v) => Poll::Ready(Ok(v)),
92        Err(e) => match e.code() {
93            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
94            _ => Poll::Ready(Err(e)),
95        },
96    }
97}
98
99/// An asynchronous version of [`openssl::ssl::SslStream`].
100#[derive(Debug)]
101pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
102
103impl<S> SslStream<S>
104where
105    S: AsyncRead + AsyncWrite + Unpin,
106{
107    /// Like [`SslStream::new`](ssl::SslStream::new).
108    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
109        ssl::SslStream::new(
110            ssl,
111            StreamWrapper {
112                stream,
113                waker: Waker::noop().clone(),
114            },
115        )
116        .map(SslStream)
117    }
118
119    /// Like [`SslStream::connect`](ssl::SslStream::connect).
120    pub fn poll_connect(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Result<(), ssl::Error>> {
124        self.with_context(cx, |s| cvt_ossl(s.connect()))
125    }
126
127    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
128    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
129        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
130    }
131
132    /// Like [`SslStream::accept`](ssl::SslStream::accept).
133    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
134        self.with_context(cx, |s| cvt_ossl(s.accept()))
135    }
136
137    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
138    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
139        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
140    }
141
142    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
143    pub fn poll_do_handshake(
144        self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146    ) -> Poll<Result<(), ssl::Error>> {
147        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
148    }
149
150    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
151    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
152        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
153    }
154
155    /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
156    pub fn poll_peek(
157        self: Pin<&mut Self>,
158        cx: &mut Context<'_>,
159        buf: &mut [u8],
160    ) -> Poll<Result<usize, ssl::Error>> {
161        self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
162    }
163
164    /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
165    pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
166        future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
167    }
168
169    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
170    #[cfg(ossl111)]
171    pub fn poll_read_early_data(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &mut [u8],
175    ) -> Poll<Result<usize, ssl::Error>> {
176        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
177    }
178
179    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
180    #[cfg(ossl111)]
181    pub async fn read_early_data(
182        mut self: Pin<&mut Self>,
183        buf: &mut [u8],
184    ) -> Result<usize, ssl::Error> {
185        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
186    }
187
188    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
189    #[cfg(ossl111)]
190    pub fn poll_write_early_data(
191        self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193        buf: &[u8],
194    ) -> Poll<Result<usize, ssl::Error>> {
195        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
196    }
197
198    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
199    #[cfg(ossl111)]
200    pub async fn write_early_data(
201        mut self: Pin<&mut Self>,
202        buf: &[u8],
203    ) -> Result<usize, ssl::Error> {
204        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
205    }
206}
207
208impl<S: Unpin> SslStream<S> {
209    /// Returns a shared reference to the `Ssl` object associated with this stream.
210    pub fn ssl(&self) -> &SslRef {
211        self.0.ssl()
212    }
213
214    /// Returns a shared reference to the underlying stream.
215    pub fn get_ref(&self) -> &S {
216        &self.0.get_ref().stream
217    }
218
219    /// Returns a mutable reference to the underlying stream.
220    pub fn get_mut(&mut self) -> &mut S {
221        &mut self.0.get_mut().stream
222    }
223
224    /// Returns a pinned mutable reference to the underlying stream.
225    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
226        Pin::new(&mut self.get_mut().0.get_mut().stream)
227    }
228
229    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
230    where
231        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
232    {
233        let this = self.get_mut();
234        this.0.get_mut().waker = ctx.waker().clone();
235        f(&mut this.0)
236    }
237}
238
239impl<S> AsyncRead for SslStream<S>
240where
241    S: AsyncRead + AsyncWrite + Unpin,
242{
243    fn poll_read(
244        self: Pin<&mut Self>,
245        ctx: &mut Context<'_>,
246        buf: &mut [u8],
247    ) -> Poll<io::Result<usize>> {
248        self.with_context(ctx, |s| cvt(s.read(buf)))
249    }
250}
251
252impl<S> AsyncWrite for SslStream<S>
253where
254    S: AsyncRead + AsyncWrite + Unpin,
255{
256    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
257        self.with_context(ctx, |s| cvt(s.write(buf)))
258    }
259
260    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
261        self.with_context(ctx, |s| cvt(s.flush()))
262    }
263
264    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
265        // We send close_notify but do not wait for the peer's reply before closing the
266        // underlying stream. This is permitted by RFC 8446 ยง6.1 and avoids a half-close
267        // deadlock, but it means any in-flight data from the peer is silently discarded.
268        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
269            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
270            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
271            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
272                return Poll::Pending;
273            }
274            Err(e) => {
275                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
276            }
277        }
278
279        self.get_pin_mut().poll_close(ctx)
280    }
281}