async_std_openssl/
lib.rs

1//! Async TLS streams backed by OpenSSL.
2//!
3//! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with with [`async-std`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`Read`] and [`Write`] traits.
6#![warn(missing_docs)]
7
8use async_std::io::{Read as AsyncRead, Write as AsyncWrite};
9use futures_util::future;
10use openssl::error::ErrorStack;
11use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef};
12use std::fmt;
13use std::io::{self, Read, Write};
14use std::pin::Pin;
15// use std::slice;
16use std::task::{Context, Poll};
17
18pub mod tls_stream_wrapper;
19pub use crate::tls_stream_wrapper::SslStreamWrapper;
20
21#[cfg(test)]
22mod test;
23
24struct StreamWrapper<S> {
25    stream: S,
26    context: usize,
27}
28
29impl<S> fmt::Debug for StreamWrapper<S>
30where
31    S: fmt::Debug,
32{
33    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
34        fmt::Debug::fmt(&self.stream, fmt)
35    }
36}
37
38impl<S> StreamWrapper<S> {
39    /// # Safety
40    ///
41    /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
42    /// wrapper must be pinned in memory.
43    unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
44        debug_assert_ne!(self.context, 0);
45        let stream = Pin::new_unchecked(&mut self.stream);
46        let context = &mut *(self.context as *mut _);
47        (stream, context)
48    }
49}
50
51impl<S> Read for StreamWrapper<S>
52where
53    S: AsyncRead,
54{
55    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56        let (stream, cx) = unsafe { self.parts() };
57        // let mut buf = ReadBuf::new(buf);
58        match stream.poll_read(cx, buf)? {
59            Poll::Ready(num_bytes_read) => Ok(num_bytes_read),
60            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
61        }
62    }
63}
64
65impl<S> Write for StreamWrapper<S>
66where
67    S: AsyncWrite,
68{
69    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
70        let (stream, cx) = unsafe { self.parts() };
71        match stream.poll_write(cx, buf) {
72            Poll::Ready(r) => r,
73            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
74        }
75    }
76
77    fn flush(&mut self) -> io::Result<()> {
78        let (stream, cx) = unsafe { self.parts() };
79        match stream.poll_flush(cx) {
80            Poll::Ready(r) => r,
81            Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
82        }
83    }
84}
85
86fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
87    match r {
88        Ok(v) => Poll::Ready(Ok(v)),
89        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
90        Err(e) => Poll::Ready(Err(e)),
91    }
92}
93
94fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
95    match r {
96        Ok(v) => Poll::Ready(Ok(v)),
97        Err(e) => match e.code() {
98            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
99            _ => Poll::Ready(Err(e)),
100        },
101    }
102}
103
104/// An asynchronous version of [`openssl::ssl::SslStream`].
105#[derive(Debug)]
106pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
107
108impl<S> SslStream<S>
109where
110    S: AsyncRead + AsyncWrite,
111{
112    /// Like [`SslStream::new`](ssl::SslStream::new).
113    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
114        ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
115    }
116
117    /// Like [`SslStream::connect`](ssl::SslStream::connect).
118    pub fn poll_connect(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121    ) -> Poll<Result<(), ssl::Error>> {
122        self.with_context(cx, |s| cvt_ossl(s.connect()))
123    }
124
125    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
126    pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
127        future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
128    }
129
130    /// Like [`SslStream::accept`](ssl::SslStream::accept).
131    pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
132        self.with_context(cx, |s| cvt_ossl(s.accept()))
133    }
134
135    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
136    pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
137        future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
138    }
139
140    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
141    pub fn poll_do_handshake(
142        self: Pin<&mut Self>,
143        cx: &mut Context<'_>,
144    ) -> Poll<Result<(), ssl::Error>> {
145        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
146    }
147
148    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
149    pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
150        future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
151    }
152
153    /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
154    #[cfg(ossl111)]
155    pub fn poll_read_early_data(
156        self: Pin<&mut Self>,
157        cx: &mut Context<'_>,
158        buf: &mut [u8],
159    ) -> Poll<Result<usize, ssl::Error>> {
160        self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
161    }
162
163    /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
164    #[cfg(ossl111)]
165    pub async fn read_early_data(
166        mut self: Pin<&mut Self>,
167        buf: &mut [u8],
168    ) -> Result<usize, ssl::Error> {
169        future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
170    }
171
172    /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
173    #[cfg(ossl111)]
174    pub fn poll_write_early_data(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<Result<usize, ssl::Error>> {
179        self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
180    }
181
182    /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
183    #[cfg(ossl111)]
184    pub async fn write_early_data(
185        mut self: Pin<&mut Self>,
186        buf: &[u8],
187    ) -> Result<usize, ssl::Error> {
188        future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
189    }
190}
191
192impl<S> SslStream<S> {
193    /// Returns a shared reference to the `Ssl` object associated with this stream.
194    pub fn ssl(&self) -> &SslRef {
195        self.0.ssl()
196    }
197
198    /// Returns a shared reference to the underlying stream.
199    pub fn get_ref(&self) -> &S {
200        &self.0.get_ref().stream
201    }
202
203    /// Returns a mutable reference to the underlying stream.
204    pub fn get_mut(&mut self) -> &mut S {
205        &mut self.0.get_mut().stream
206    }
207
208    /// Returns a pinned mutable reference to the underlying stream.
209    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
210        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
211    }
212
213    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
214    where
215        F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
216    {
217        let this = unsafe { self.get_unchecked_mut() };
218        this.0.get_mut().context = ctx as *mut _ as usize;
219        let r = f(&mut this.0);
220        this.0.get_mut().context = 0;
221        r
222    }
223}
224
225impl<S> AsyncRead for SslStream<S>
226where
227    S: AsyncRead + AsyncWrite,
228{
229    fn poll_read(
230        self: Pin<&mut Self>,
231        ctx: &mut Context<'_>,
232        buf: &mut [u8],
233    ) -> Poll<io::Result<usize>> {
234        self.with_context(ctx, |s| {
235            // This isn't really "proper", but rust-openssl doesn't currently expose a suitable interface even though
236            // OpenSSL itself doesn't require the buffer to be initialized. So this is good enough for now.
237            // let slice = unsafe {
238            // let buf = buf.unfilled_mut();
239            // slice::from_raw_parts_mut(buf.as_mut_ptr().cast::<u8>(), buf.len())
240            // };
241
242            // Read into slice from OpenSSL
243            match cvt(s.read(buf))? {
244                Poll::Ready(nread) => {
245                    // unsafe {
246                    //     buf.assume_init(nread);
247                    // }
248                    // buf.advance(nread);
249                    Poll::Ready(Ok(nread))
250                }
251                Poll::Pending => Poll::Pending,
252            }
253        })
254    }
255}
256
257impl<S> AsyncWrite for SslStream<S>
258where
259    S: AsyncRead + AsyncWrite,
260{
261    fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
262        self.with_context(ctx, |s| cvt(s.write(buf)))
263    }
264
265    fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
266        self.with_context(ctx, |s| cvt(s.flush()))
267    }
268
269    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
270        match self.as_mut().with_context(ctx, |s| s.shutdown()) {
271            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
272            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
273            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
274                return Poll::Pending;
275            }
276            Err(e) => {
277                return Poll::Ready(Err(e
278                    .into_io_error()
279                    .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
280            }
281        }
282
283        self.get_pin_mut().poll_close(ctx)
284    }
285}