async_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`futures-io`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![warn(missing_docs)]
7
8use futures_io::{AsyncRead, AsyncWrite};
9use openssl::{
10    error::ErrorStack,
11    ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
12};
13use std::{
14    fmt, future,
15    io::{self, Read, Write},
16    pin::Pin,
17    task::{Context, Poll},
18};
19
20#[cfg(test)]
21mod test;
22
23struct StreamWrapper<S> {
24    stream: S,
25    context: usize,
26}
27
28impl<S> fmt::Debug for StreamWrapper<S>
29where
30    S: fmt::Debug,
31{
32    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33        self.stream.fmt(fmt)
34    }
35}
36
37impl<S> StreamWrapper<S> {
38    /// # Safety
39    ///
40    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
41    /// wrapper must be pinned in memory.
42    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
43        debug_assert_ne!(self.context, 0);
44        let stream = Pin::new_unchecked(&mut self.stream);
45        let context = &mut *(self.context as *mut _);
46        (stream, context)
47    }
48}
49
50impl<S> Read for StreamWrapper<S>
51where
52    S: AsyncRead,
53{
54    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55        let (stream, cx) = unsafe { self.parts() };
56        match stream.poll_read(cx, buf)? {
57            Poll::Ready(nread) => Ok(nread),
58            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
59        }
60    }
61}
62
63impl<S> Write for StreamWrapper<S>
64where
65    S: AsyncWrite,
66{
67    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68        let (stream, cx) = unsafe { self.parts() };
69        match stream.poll_write(cx, buf) {
70            Poll::Ready(r) => r,
71            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
72        }
73    }
74
75    fn flush(&mut self) -> io::Result<()> {
76        let (stream, cx) = unsafe { self.parts() };
77        match stream.poll_flush(cx) {
78            Poll::Ready(r) => r,
79            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
80        }
81    }
82}
83
84fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
85    match r {
86        Ok(v) => Poll::Ready(Ok(v)),
87        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
88        Err(e) => Poll::Ready(Err(e)),
89    }
90}
91
92fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
93    match r {
94        Ok(v) => Poll::Ready(Ok(v)),
95        Err(e) => match e.code() {
96            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
97            _ => Poll::Ready(Err(e)),
98        },
99    }
100}
101
102/// An asynchronous version of [`openssl::ssl::SslStream`].
103#[derive(Debug)]
104pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
105
106impl<S> SslStream<S>
107where
108    S: AsyncRead + AsyncWrite,
109{
110    /// Like [`SslStream::new`](ssl::SslStream::new).
111    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
112        ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
113    }
114
115    /// Like [`SslStream::connect`](ssl::SslStream::connect).
116    pub fn poll_connect(
117        self: Pin<&mut Self>,
118        cx: &mut Context<'_>,
119    ) -> Poll<Result<(), ssl::Error>> {
120        self.with_context(cx, |s| cvt_ossl(s.connect()))
121    }
122
123    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
124    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
125        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
126    }
127
128    /// Like [`SslStream::accept`](ssl::SslStream::accept).
129    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
130        self.with_context(cx, |s| cvt_ossl(s.accept()))
131    }
132
133    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
134    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
135        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
136    }
137
138    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
139    pub fn poll_do_handshake(
140        self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142    ) -> Poll<Result<(), ssl::Error>> {
143        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
144    }
145
146    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
147    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
148        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
149    }
150
151    /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
152    pub fn poll_peek(
153        self: Pin<&mut Self>,
154        cx: &mut Context<'_>,
155        buf: &mut [u8],
156    ) -> Poll<Result<usize, ssl::Error>> {
157        self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
158    }
159
160    /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
161    pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
162        future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
163    }
164
165    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
166    #[cfg(ossl111)]
167    pub fn poll_read_early_data(
168        self: Pin<&mut Self>,
169        cx: &mut Context<'_>,
170        buf: &mut [u8],
171    ) -> Poll<Result<usize, ssl::Error>> {
172        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
173    }
174
175    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
176    #[cfg(ossl111)]
177    pub async fn read_early_data(
178        mut self: Pin<&mut Self>,
179        buf: &mut [u8],
180    ) -> Result<usize, ssl::Error> {
181        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
182    }
183
184    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
185    #[cfg(ossl111)]
186    pub fn poll_write_early_data(
187        self: Pin<&mut Self>,
188        cx: &mut Context<'_>,
189        buf: &[u8],
190    ) -> Poll<Result<usize, ssl::Error>> {
191        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
192    }
193
194    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
195    #[cfg(ossl111)]
196    pub async fn write_early_data(
197        mut self: Pin<&mut Self>,
198        buf: &[u8],
199    ) -> Result<usize, ssl::Error> {
200        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
201    }
202}
203
204impl<S> SslStream<S> {
205    /// Returns a shared reference to the `Ssl` object associated with this stream.
206    pub fn ssl(&self) -> &SslRef {
207        self.0.ssl()
208    }
209
210    /// Returns a shared reference to the underlying stream.
211    pub fn get_ref(&self) -> &S {
212        &self.0.get_ref().stream
213    }
214
215    /// Returns a mutable reference to the underlying stream.
216    pub fn get_mut(&mut self) -> &mut S {
217        &mut self.0.get_mut().stream
218    }
219
220    /// Returns a pinned mutable reference to the underlying stream.
221    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
222        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
223    }
224
225    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
226    where
227        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
228    {
229        let this = unsafe { self.get_unchecked_mut() };
230        this.0.get_mut().context = ctx as *mut _ as usize;
231        let r = f(&mut this.0);
232        this.0.get_mut().context = 0;
233        r
234    }
235}
236
237impl<S> AsyncRead for SslStream<S>
238where
239    S: AsyncRead + AsyncWrite,
240{
241    fn poll_read(
242        self: Pin<&mut Self>,
243        ctx: &mut Context<'_>,
244        buf: &mut [u8],
245    ) -> Poll<io::Result<usize>> {
246        self.with_context(ctx, |s| cvt(s.read(buf)))
247    }
248}
249
250impl<S> AsyncWrite for SslStream<S>
251where
252    S: AsyncRead + AsyncWrite,
253{
254    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
255        self.with_context(ctx, |s| cvt(s.write(buf)))
256    }
257
258    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
259        self.with_context(ctx, |s| cvt(s.flush()))
260    }
261
262    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
263        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
264            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
265            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
266            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
267                return Poll::Pending;
268            }
269            Err(e) => {
270                return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
271            }
272        }
273
274        self.get_pin_mut().poll_close(ctx)
275    }
276}