compio_openssl/
lib.rs

1//! A compio asynchronous stream of OpenSSL stream.
2//!
3//! You can use [`SslStream::new`] to build a stream just like [`openssl:ssl::SslStream`](ssl::SslStream::new)
4//! or setup a stream manually and convert it to [`SslStream`] using [`SslStream::from`].
5
6use std::io::{self, ErrorKind, Write};
7use std::result::Result;
8
9use compio::BufResult;
10use compio::buf::{IoBuf, IoBufMut};
11use compio::io::compat::SyncStream;
12use compio::io::{AsyncRead, AsyncWrite};
13use openssl::error::ErrorStack;
14use openssl::ssl::{self, ErrorCode, ShutdownResult, ShutdownState, Ssl, SslRef};
15
16#[cfg(test)]
17mod test;
18
19/// Compio asynchronous version of [`openssl:ssl::SslStream`](ssl::SslStream).
20#[derive(Debug)]
21pub struct SslStream<S> {
22    stream: ssl::SslStream<SyncStream<S>>,
23}
24
25impl<S: AsyncRead + AsyncWrite> SslStream<S> {
26    /// Create a new `SslStream`.
27    ///
28    /// Reference: [`SslStream::new`](ssl::SslStream::new)
29    pub fn new(ssl: Ssl, stream: S) -> Result<SslStream<S>, ErrorStack> {
30        let stream = ssl::SslStream::new(ssl, SyncStream::new(stream))?;
31        Ok(SslStream { stream })
32    }
33
34    /// Get a mutable reference to the underlying stream.
35    ///
36    /// # Warning
37    ///
38    /// Any read/write operation to the stream would most likely corrupt the SSL session.
39    #[inline(always)]
40    pub fn get_mut(&mut self) -> &mut S {
41        self.stream.get_mut().get_mut()
42    }
43
44    /// Returns a shared reference to the underlying stream.
45    #[inline(always)]
46    pub fn get_ref(&self) -> &S {
47        self.stream.get_ref().get_ref()
48    }
49
50    /// Returns a shared reference to the [`Ssl`] object associated with this stream.
51    #[inline(always)]
52    pub fn ssl(&self) -> &SslRef {
53        self.stream.ssl()
54    }
55
56    /// Initiates a server-side TLS handshake.
57    ///
58    /// Reference: [`SslStream::accept`](ssl::SslStream::accept)
59    pub async fn accept(&mut self) -> io::Result<()> {
60        self.ssl_async_do(|s| s.accept()).await
61    }
62
63    /// Initiates a server-side TLS handshake.
64    ///
65    /// Reference: [`SslStream::connect`](ssl::SslStream::connect)
66    pub async fn connect(&mut self) -> io::Result<()> {
67        self.ssl_async_do(|s| s.connect()).await
68    }
69
70    /// Read application data transmitted by a client before handshake completion.
71    ///
72    /// Useful for reducing latency, but vulnerable to replay attacks.
73    ///
74    /// Returns Ok(0) if all early data has been read.
75    ///
76    /// Reference: [`SslStream::read_early_data`](ssl::SslStream::read_early_data)
77    #[cfg(any(ossl111, libressl340))]
78    pub async fn read_realy_data(&mut self, buf: &mut [u8]) -> io::Result<usize> {
79        self.ssl_async_do(|s| s.read_early_data(buf)).await
80    }
81
82    /// Send data to the server without blocking on handshake completion.
83    ///
84    /// Useful for reducing latency, but vulnerable to replay attacks.
85    ///
86    /// Reference: [`SslStream::write_early_data`](ssl::SslStream::write_early_data)
87    #[cfg(any(ossl111, libressl340))]
88    pub async fn write_realy_data(&mut self, buf: &[u8]) -> io::Result<usize> {
89        self.ssl_async_do(|s| s.write_early_data(buf)).await
90    }
91
92    /// Reads data from the stream, without removing it from the queue.
93    ///
94    /// Reference: [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek)
95    pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
96        self.ssl_async_do(|s| s.ssl_peek(buf)).await
97    }
98
99    /// Returns the session's shutdown state.
100    #[inline(always)]
101    pub fn get_shutdown(&mut self) -> ShutdownState {
102        self.stream.get_shutdown()
103    }
104
105    /// Sets the session's shutdown state.
106    ///
107    /// This can be used to tell OpenSSL that the session should be cached even if a full two-way shutdown was not completed.
108    #[inline(always)]
109    pub fn set_shutdown(&mut self, state: ShutdownState) {
110        self.stream.set_shutdown(state)
111    }
112
113    /// Perform a stateless server-side handshake.
114    ///
115    /// Requires that cookie generation and verification callbacks were
116    /// set on the SSL context.
117    ///
118    /// Returns `Ok(true)` if a complete ClientHello containing a valid cookie
119    /// was read, in which case the handshake should be continued via
120    /// `accept`. If a HelloRetryRequest containing a fresh cookie was
121    /// transmitted, `Ok(false)` is returned instead. If the handshake cannot
122    /// proceed at all, `Err` is returned.
123    #[inline(always)]
124    #[cfg(ossl111)]
125    pub async fn stateless(&mut self) -> Result<bool, ErrorStack> {
126        self.stream.stateless()
127    }
128
129    async fn ssl_async_do<R, F>(&mut self, mut f: F) -> io::Result<R>
130    where
131        F: FnMut(&mut ssl::SslStream<SyncStream<S>>) -> Result<R, ssl::Error>,
132    {
133        loop {
134            match f(&mut self.stream) {
135                Ok(n) => return Ok(n),
136                Err(e) => match e.code() {
137                    ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => {
138                        if self.stream.get_mut().flush_write_buf().await? == 0 {
139                            self.stream.get_mut().fill_read_buf().await?;
140                        }
141                    }
142                    _ => return Err(ssl_err_into_io(e)),
143                },
144            }
145        }
146    }
147}
148
149impl<S> From<ssl::SslStream<SyncStream<S>>> for SslStream<S> {
150    fn from(value: ssl::SslStream<SyncStream<S>>) -> Self {
151        SslStream { stream: value }
152    }
153}
154
155#[inline]
156fn ssl_err_into_io(err: openssl::ssl::Error) -> io::Error {
157    err.into_io_error().unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))
158}
159
160impl<S: AsyncRead> AsyncRead for SslStream<S> {
161    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
162        let read_buf = buf.as_mut_slice();
163        loop {
164            let ret = self.stream.ssl_read_uninit(read_buf);
165            match ret {
166                Ok(n) => {
167                    // SAFETY: the length we just read
168                    unsafe { buf.set_buf_init(n) };
169                    return BufResult(Ok(n), buf);
170                }
171                Err(e) if e.code() == ErrorCode::ZERO_RETURN => {
172                    return BufResult(Ok(0), buf);
173                }
174                Err(e) if e.code() == ErrorCode::WANT_READ => {
175                    match self.stream.get_mut().fill_read_buf().await {
176                        Ok(_) => continue,
177                        Err(e) => return BufResult(Err(e), buf),
178                    }
179                }
180                Err(e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {}
181                Err(e) => return BufResult(Err(ssl_err_into_io(e)), buf),
182            }
183        }
184    }
185
186    // OpenSSL does not support vectored reads
187}
188
189/// `AsyncRead` is needed for shutting down stream.
190impl<S: AsyncWrite + AsyncRead> AsyncWrite for SslStream<S> {
191    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
192        let slice = buf.as_slice();
193        loop {
194            let ret = self.stream.ssl_write(slice);
195            match ret {
196                Ok(n) => {
197                    let ret = self.stream.get_mut().flush_write_buf().await;
198                    return BufResult(ret.map(|_| n), buf);
199                }
200                Err(e) if e.code() == ErrorCode::WANT_WRITE => {
201                    match self.stream.get_mut().flush_write_buf().await {
202                        Ok(_) => continue,
203                        Err(e) => return BufResult(Err(e), buf),
204                    }
205                }
206                Err(e) => return BufResult(Err(ssl_err_into_io(e)), buf),
207            }
208        }
209    }
210
211    // OpenSSL does not support vectored writes
212
213    async fn flush(&mut self) -> io::Result<()> {
214        loop {
215            match self.stream.flush() {
216                Ok(_) => {
217                    self.stream.get_mut().flush_write_buf().await?;
218                    return Ok(());
219                }
220                Err(e) if e.kind() == ErrorKind::WouldBlock => {
221                    self.stream.get_mut().flush_write_buf().await?;
222                }
223                e => return e,
224            }
225        }
226    }
227
228    async fn shutdown(&mut self) -> io::Result<()> {
229        loop {
230            let ret = self.stream.shutdown();
231            match ret {
232                Ok(ShutdownResult::Sent) => {
233                    self.stream.get_mut().flush_write_buf().await?;
234                }
235                Ok(ShutdownResult::Received) => {
236                    break;
237                }
238                Err(e) if e.code() == ErrorCode::WANT_WRITE => {
239                    self.stream.get_mut().flush_write_buf().await?;
240                }
241                Err(e) if e.code() == ErrorCode::WANT_READ => {
242                    self.stream.get_mut().fill_read_buf().await?;
243                }
244                Err(e) if e.code() == ErrorCode::SYSCALL && e.io_error().is_none() => {
245                    break;
246                }
247                Err(e) => return Err(ssl_err_into_io(e)),
248            }
249        }
250        self.stream.get_mut().get_mut().shutdown().await
251    }
252}