Skip to main content

async_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`futures_io`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
7#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
8#![warn(
9    clippy::must_use_candidate,
10    clippy::unwrap_in_result,
11    clippy::panic_in_result_fn
12)]
13
14use futures_io::{AsyncRead, AsyncWrite};
15use openssl::{
16    error::ErrorStack,
17    ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
18};
19use std::{
20    fmt, future,
21    io::{self, Read, Write},
22    pin::Pin,
23    task::{Context, Poll, Waker},
24};
25
26#[cfg(test)]
27mod test;
28
29struct StreamWrapper<S: Unpin> {
30    stream: S,
31    waker: Option<Waker>,
32}
33
34impl<S> fmt::Debug for StreamWrapper<S>
35where
36    S: fmt::Debug + Unpin,
37{
38    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
39        self.stream.fmt(fmt)
40    }
41}
42
43impl<S: Unpin> StreamWrapper<S> {
44    fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
45        let stream = Pin::new(&mut self.stream);
46        let context = Context::from_waker(self.waker.as_ref().unwrap_or(Waker::noop()));
47        (stream, context)
48    }
49}
50
51impl<S> Read for StreamWrapper<S>
52where
53    S: AsyncRead + Unpin,
54{
55    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56        let (stream, mut cx) = self.parts();
57        match stream.poll_read(&mut cx, buf)? {
58            Poll::Ready(nread) => Ok(nread),
59            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
60        }
61    }
62}
63
64impl<S> Write for StreamWrapper<S>
65where
66    S: AsyncWrite + Unpin,
67{
68    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
69        let (stream, mut cx) = self.parts();
70        match stream.poll_write(&mut cx, buf) {
71            Poll::Ready(r) => r,
72            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
73        }
74    }
75
76    fn flush(&mut self) -> io::Result<()> {
77        let (stream, mut cx) = self.parts();
78        match stream.poll_flush(&mut cx) {
79            Poll::Ready(r) => r,
80            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
81        }
82    }
83}
84
85fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
86    match r {
87        Ok(v) => Poll::Ready(Ok(v)),
88        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
89        Err(e) => Poll::Ready(Err(e)),
90    }
91}
92
93fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
94    match r {
95        Ok(v) => Poll::Ready(Ok(v)),
96        Err(e) => match e.code() {
97            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
98            _ => Poll::Ready(Err(e)),
99        },
100    }
101}
102
103/// An asynchronous version of [`openssl::ssl::SslStream`].
104#[derive(Debug)]
105pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
106
107impl<S> SslStream<S>
108where
109    S: AsyncRead + AsyncWrite + Unpin,
110{
111    /// Like [`SslStream::new`](ssl::SslStream::new).
112    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
113        ssl::SslStream::new(
114            ssl,
115            StreamWrapper {
116                stream,
117                waker: None,
118            },
119        )
120        .map(SslStream)
121    }
122
123    /// Like [`SslStream::connect`](ssl::SslStream::connect).
124    pub fn poll_connect(
125        self: Pin<&mut Self>,
126        cx: &mut Context<'_>,
127    ) -> Poll<Result<(), ssl::Error>> {
128        self.with_context(cx, |s| cvt_ossl(s.connect()))
129    }
130
131    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
132    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
133        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
134    }
135
136    /// Like [`SslStream::accept`](ssl::SslStream::accept).
137    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
138        self.with_context(cx, |s| cvt_ossl(s.accept()))
139    }
140
141    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
142    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
143        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
144    }
145
146    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
147    pub fn poll_do_handshake(
148        self: Pin<&mut Self>,
149        cx: &mut Context<'_>,
150    ) -> Poll<Result<(), ssl::Error>> {
151        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
152    }
153
154    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
155    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
156        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
157    }
158
159    /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
160    pub fn poll_peek(
161        self: Pin<&mut Self>,
162        cx: &mut Context<'_>,
163        buf: &mut [u8],
164    ) -> Poll<Result<usize, ssl::Error>> {
165        self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
166    }
167
168    /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
169    pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
170        future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
171    }
172
173    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
174    #[cfg(ossl111)]
175    pub fn poll_read_early_data(
176        self: Pin<&mut Self>,
177        cx: &mut Context<'_>,
178        buf: &mut [u8],
179    ) -> Poll<Result<usize, ssl::Error>> {
180        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
181    }
182
183    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
184    #[cfg(ossl111)]
185    pub async fn read_early_data(
186        mut self: Pin<&mut Self>,
187        buf: &mut [u8],
188    ) -> Result<usize, ssl::Error> {
189        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
190    }
191
192    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
193    #[cfg(ossl111)]
194    pub fn poll_write_early_data(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        buf: &[u8],
198    ) -> Poll<Result<usize, ssl::Error>> {
199        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
200    }
201
202    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
203    #[cfg(ossl111)]
204    pub async fn write_early_data(
205        mut self: Pin<&mut Self>,
206        buf: &[u8],
207    ) -> Result<usize, ssl::Error> {
208        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
209    }
210}
211
212impl<S: Unpin> SslStream<S> {
213    /// Returns a shared reference to the `Ssl` object associated with this stream.
214    pub fn ssl(&self) -> &SslRef {
215        self.0.ssl()
216    }
217
218    /// Returns a shared reference to the underlying stream.
219    pub fn get_ref(&self) -> &S {
220        &self.0.get_ref().stream
221    }
222
223    /// Returns a mutable reference to the underlying stream.
224    pub fn get_mut(&mut self) -> &mut S {
225        &mut self.0.get_mut().stream
226    }
227
228    /// Returns a pinned mutable reference to the underlying stream.
229    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
230        Pin::new(&mut self.get_mut().0.get_mut().stream)
231    }
232
233    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
234    where
235        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
236    {
237        let this = self.get_mut();
238        this.0.get_mut().waker = Some(ctx.waker().clone());
239        f(&mut this.0)
240    }
241}
242
243impl<S> AsyncRead for SslStream<S>
244where
245    S: AsyncRead + AsyncWrite + Unpin,
246{
247    fn poll_read(
248        self: Pin<&mut Self>,
249        ctx: &mut Context<'_>,
250        buf: &mut [u8],
251    ) -> Poll<io::Result<usize>> {
252        self.with_context(ctx, |s| cvt(s.read(buf)))
253    }
254}
255
256impl<S> AsyncWrite for SslStream<S>
257where
258    S: AsyncRead + AsyncWrite + Unpin,
259{
260    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
261        self.with_context(ctx, |s| cvt(s.write(buf)))
262    }
263
264    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
265        self.with_context(ctx, |s| cvt(s.flush()))
266    }
267
268    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
269        // We send close_notify but do not wait for the peer's reply before closing the
270        // underlying stream. This is permitted by RFC 8446 ยง6.1 and avoids a half-close
271        // deadlock, but it means any in-flight data from the peer is silently discarded.
272        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
273            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
274            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
275            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
276                return Poll::Pending;
277            }
278            Err(e) => {
279                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
280            }
281        }
282
283        self.get_pin_mut().poll_close(ctx)
284    }
285}