Skip to main content

compio_btls/
lib.rs

1//! Async TLS streams backed by BoringSSL.
2//!
3//! This crate provides a wrapper around the [`btls`] crate's [`SslStream`](ssl::SslStream) type
4//! that works with [`compio`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5//! blocking [`std::io::Read`] and [`std::io::Write`] traits.
6
7use btls::{
8    error::ErrorStack,
9    ssl::{self, ErrorCode, Ssl, SslRef, SslStream as SslStreamCore},
10};
11use compio::buf::{IoBuf, IoBufMut};
12use compio::BufResult;
13use compio_io::{compat::SyncStream, AsyncRead, AsyncWrite};
14use std::error::Error;
15use std::pin::Pin;
16use std::task::Context;
17use std::task::Poll;
18use std::{fmt, io};
19
20fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
21    match r {
22        Ok(v) => Poll::Ready(Ok(v)),
23        Err(e) => match e.code() {
24            ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
25            _ => Poll::Ready(Err(e)),
26        },
27    }
28}
29
30/// An asynchronous version of [`btls::ssl::SslStream`].
31#[derive(Debug)]
32pub struct SslStream<S>(SslStreamCore<SyncStream<S>>);
33
34impl<S: AsyncRead + AsyncWrite> SslStream<S> {
35    #[inline]
36    /// Like [`SslStream::new`](ssl::SslStream::new).
37    pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
38        SslStreamCore::new(ssl, SyncStream::new(stream)).map(SslStream)
39    }
40
41    #[inline]
42    /// Like [`SslStream::connect`](ssl::SslStream::connect).
43    pub fn poll_connect(
44        self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46    ) -> Poll<Result<(), HandshakeError>> {
47        self.with_context(cx, |s| cvt_ossl(s.connect()))
48            .map_err(HandshakeError::Ssl)
49    }
50
51    #[inline]
52    /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
53    pub async fn connect(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
54        self.drive_handshake(|s| s.connect()).await
55    }
56
57    #[inline]
58    /// Like [`SslStream::accept`](ssl::SslStream::accept).
59    pub fn poll_accept(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62    ) -> Poll<Result<(), HandshakeError>> {
63        self.with_context(cx, |s| cvt_ossl(s.accept()))
64            .map_err(HandshakeError::Ssl)
65    }
66
67    #[inline]
68    /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
69    pub async fn accept(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
70        self.drive_handshake(|s| s.accept()).await
71    }
72
73    #[inline]
74    /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
75    pub fn poll_do_handshake(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78    ) -> Poll<Result<(), HandshakeError>> {
79        self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
80            .map_err(HandshakeError::Ssl)
81    }
82
83    #[inline]
84    /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
85    pub async fn do_handshake(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
86        self.drive_handshake(|s| s.do_handshake()).await
87    }
88
89    async fn drive_handshake<F>(mut self: Pin<&mut Self>, mut f: F) -> Result<(), HandshakeError>
90    where
91        F: FnMut(&mut SslStreamCore<SyncStream<S>>) -> Result<(), ssl::Error>,
92    {
93        loop {
94            let res = {
95                let this = unsafe { self.as_mut().get_unchecked_mut() };
96                f(&mut this.0)
97            };
98
99            match res {
100                Ok(()) => {
101                    // Ensure handshake records are pushed out before returning.
102                    self.as_mut()
103                        .flush_write_buf()
104                        .await
105                        .map_err(HandshakeError::Io)?;
106
107                    return Ok(());
108                }
109                Err(e) => match e.code() {
110                    ErrorCode::WANT_WRITE => {
111                        self.as_mut()
112                            .flush_write_buf()
113                            .await
114                            .map_err(HandshakeError::Io)?;
115                    }
116                    ErrorCode::WANT_READ => {
117                        self.as_mut()
118                            .flush_write_buf()
119                            .await
120                            .map_err(HandshakeError::Io)?;
121
122                        self.as_mut()
123                            .fill_read_buf()
124                            .await
125                            .map_err(HandshakeError::Io)?;
126                    }
127                    _ => return Err(HandshakeError::Ssl(e)),
128                },
129            }
130        }
131    }
132}
133
134impl<S: AsyncRead + AsyncWrite> SslStream<S> {
135    async fn fill_read_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
136        let this = unsafe { self.as_mut().get_unchecked_mut() };
137        this.0.get_mut().fill_read_buf().await
138    }
139
140    async fn flush_write_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
141        let this = unsafe { self.as_mut().get_unchecked_mut() };
142        this.0.get_mut().flush_write_buf().await
143    }
144}
145
146impl<S> SslStream<S> {
147    #[inline]
148    /// Returns a shared reference to the `Ssl` object associated with this stream.
149    pub fn ssl(&self) -> &SslRef {
150        self.0.ssl()
151    }
152
153    #[inline]
154    /// Returns a shared reference to the underlying stream.
155    pub fn get_ref(&self) -> &S {
156        self.0.get_ref().get_ref()
157    }
158
159    #[inline]
160    /// Returns a mutable reference to the underlying stream.
161    pub fn get_mut(&mut self) -> &mut S {
162        self.0.get_mut().get_mut()
163    }
164
165    #[inline]
166    /// Returns a pinned mutable reference to the underlying stream.
167    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
168        unsafe {
169            let this = self.get_unchecked_mut();
170            Pin::new_unchecked(this.0.get_mut().get_mut())
171        }
172    }
173
174    fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
175    where
176        F: FnOnce(&mut SslStreamCore<SyncStream<S>>) -> R,
177    {
178        let this = unsafe { self.get_unchecked_mut() };
179        this.0.ssl_mut().set_task_waker(Some(ctx.waker().clone()));
180        let r = f(&mut this.0);
181        this.0.ssl_mut().set_task_waker(None);
182        r
183    }
184}
185
186impl<S> AsyncRead for SslStream<S>
187where
188    S: AsyncRead + AsyncWrite,
189{
190    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
191        let slice = buf.as_uninit();
192        loop {
193            // SAFETY: read_uninit does not de-initialize the buffer.
194            match self.0.read_uninit(slice) {
195                Ok(res) => {
196                    // SAFETY: read_uninit guarantees that nread bytes have been initialized.
197                    unsafe { buf.advance_to(res) };
198                    return BufResult(Ok(res), buf);
199                }
200                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
201                    match self.0.get_mut().fill_read_buf().await {
202                        Ok(_) => continue,
203                        Err(e) => return BufResult(Err(e), buf),
204                    }
205                }
206                res => return BufResult(res, buf),
207            }
208        }
209    }
210}
211
212impl<S> AsyncWrite for SslStream<S>
213where
214    S: AsyncRead + AsyncWrite,
215{
216    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
217        let slice = buf.as_init();
218        loop {
219            let res = io::Write::write(&mut self.0, slice);
220            match res {
221                Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
222                    Ok(_) => continue,
223                    Err(e) => return BufResult(Err(e), buf),
224                },
225                _ => return BufResult(res, buf),
226            }
227        }
228    }
229
230    async fn flush(&mut self) -> io::Result<()> {
231        loop {
232            match io::Write::flush(&mut self.0) {
233                Ok(()) => break,
234                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
235                    self.0.get_mut().flush_write_buf().await?;
236                }
237                Err(e) => return Err(e),
238            }
239        }
240        self.0.get_mut().flush_write_buf().await?;
241        Ok(())
242    }
243
244    async fn shutdown(&mut self) -> io::Result<()> {
245        self.flush().await?;
246        self.0.get_mut().get_mut().shutdown().await
247    }
248}
249
250/// The error type returned after a failed handshake.
251pub enum HandshakeError {
252    /// An error that occurred during the SSL handshake.
253    Ssl(ssl::Error),
254    /// An I/O error that occurred during the handshake.
255    Io(io::Error),
256}
257
258impl HandshakeError {
259    /// Returns the error code, if any.
260    #[must_use]
261    pub fn code(&self) -> Option<ErrorCode> {
262        match self {
263            HandshakeError::Ssl(e) => Some(e.code()),
264            _ => None,
265        }
266    }
267
268    /// Returns a reference to the inner I/O error, if any.
269    #[must_use]
270    pub fn as_io_error(&self) -> Option<&io::Error> {
271        match self {
272            HandshakeError::Ssl(e) => e.io_error(),
273            HandshakeError::Io(e) => Some(e),
274        }
275    }
276}
277
278impl fmt::Debug for HandshakeError {
279    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
280        match self {
281            HandshakeError::Ssl(e) => fmt::Debug::fmt(e, fmt),
282            HandshakeError::Io(e) => fmt::Debug::fmt(e, fmt),
283        }
284    }
285}
286
287impl fmt::Display for HandshakeError {
288    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
289        match self {
290            HandshakeError::Ssl(e) => fmt::Display::fmt(e, fmt),
291            HandshakeError::Io(e) => fmt::Display::fmt(e, fmt),
292        }
293    }
294}
295
296impl Error for HandshakeError {
297    fn source(&self) -> Option<&(dyn Error + 'static)> {
298        match self {
299            HandshakeError::Ssl(e) => e.source(),
300            HandshakeError::Io(e) => Some(e),
301        }
302    }
303}