hala_tls/
lib.rs

1use boring::ssl;
2pub use boring::ssl::*;
3
4mod bridge;
5mod callbacks;
6
7use futures::{AsyncRead, AsyncWrite};
8
9use std::fmt;
10use std::future::Future;
11use std::io::{self, Read, Write};
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use bridge::*;
16
17pub use callbacks::SslContextBuilderExt;
18
19pub use boring::ssl::{
20    AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
21    BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
22    BoxSelectCertFuture, ExDataFuture,
23};
24
25/// Asynchronously performs a client-side TLS handshake over the provided stream.
26///
27/// This function automatically sets the task waker on the `Ssl` from `config` to
28/// allow to make use of async callbacks provided by the boring crate.
29pub async fn connect<S>(
30    config: ConnectConfiguration,
31    domain: &str,
32    stream: S,
33) -> Result<SslStream<S>, HandshakeError<S>>
34where
35    S: AsyncRead + AsyncWrite + Unpin,
36{
37    let mid_handshake = config
38        .setup_connect(domain, AsyncStreamBridge::new(stream))
39        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
40
41    HandshakeFuture(Some(mid_handshake)).await
42}
43
44/// Asynchronously performs a server-side TLS handshake over the provided stream.
45///
46/// This function automatically sets the task waker on the `Ssl` from `config` to
47/// allow to make use of async callbacks provided by the boring crate.
48pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
49where
50    S: AsyncRead + AsyncWrite + Unpin,
51{
52    let mid_handshake = acceptor
53        .setup_accept(AsyncStreamBridge::new(stream))
54        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
55
56    HandshakeFuture(Some(mid_handshake)).await
57}
58
59fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
60    match r {
61        Ok(v) => Poll::Ready(Ok(v)),
62        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
63        Err(e) => Poll::Ready(Err(e)),
64    }
65}
66
67/// A partially constructed `SslStream`, useful for unusual handshakes.
68pub struct SslStreamBuilder<S> {
69    inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
70}
71
72impl<S> SslStreamBuilder<S>
73where
74    S: AsyncRead + AsyncWrite + Unpin,
75{
76    /// Begins creating an `SslStream` atop `stream`.
77    pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
78        Self {
79            inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
80        }
81    }
82
83    /// Initiates a client-side TLS handshake.
84    pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
85        let mid_handshake = self.inner.setup_accept();
86
87        HandshakeFuture(Some(mid_handshake)).await
88    }
89
90    /// Initiates a server-side TLS handshake.
91    pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
92        let mid_handshake = self.inner.setup_connect();
93
94        HandshakeFuture(Some(mid_handshake)).await
95    }
96}
97
98impl<S> SslStreamBuilder<S> {
99    /// Returns a shared reference to the `Ssl` object associated with this builder.
100    pub fn ssl(&self) -> &SslRef {
101        self.inner.ssl()
102    }
103
104    /// Returns a mutable reference to the `Ssl` object associated with this builder.
105    pub fn ssl_mut(&mut self) -> &mut SslRef {
106        self.inner.ssl_mut()
107    }
108}
109
110/// A wrapper around an underlying raw stream which implements the SSL
111/// protocol.
112///
113/// A `SslStream<S>` represents a handshake that has been completed successfully
114/// and both the server and the client are ready for receiving and sending
115/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
116/// to a `SslStream` are encrypted when passing through to `S`.
117#[derive(Debug)]
118pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
119
120impl<S> SslStream<S> {
121    /// Returns a shared reference to the `Ssl` object associated with this stream.
122    pub fn ssl(&self) -> &SslRef {
123        self.0.ssl()
124    }
125
126    /// Returns a mutable reference to the `Ssl` object associated with this stream.
127    pub fn ssl_mut(&mut self) -> &mut SslRef {
128        self.0.ssl_mut()
129    }
130
131    /// Returns a shared reference to the underlying stream.
132    pub fn get_ref(&self) -> &S {
133        &self.0.get_ref().stream
134    }
135
136    /// Returns a mutable reference to the underlying stream.
137    pub fn get_mut(&mut self) -> &mut S {
138        &mut self.0.get_mut().stream
139    }
140
141    fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
142    where
143        F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
144    {
145        self.0.get_mut().set_waker(Some(ctx));
146
147        let result = f(&mut self.0);
148
149        // NOTE(nox): This should also be executed when `f` panics,
150        // but it's not that important as boring segfaults on panics
151        // and we always set the context prior to doing anything with
152        // the inner async stream.
153        self.0.get_mut().set_waker(None);
154
155        result
156    }
157}
158
159impl<S> AsyncRead for SslStream<S>
160where
161    S: AsyncRead + AsyncWrite + Unpin,
162{
163    fn poll_read(
164        mut self: Pin<&mut Self>,
165        ctx: &mut Context<'_>,
166        buf: &mut [u8],
167    ) -> Poll<io::Result<usize>> {
168        self.run_in_context(ctx, |s| cvt(s.read(buf)))
169    }
170}
171
172impl<S> AsyncWrite for SslStream<S>
173where
174    S: AsyncRead + AsyncWrite + Unpin,
175{
176    fn poll_write(
177        mut self: Pin<&mut Self>,
178        ctx: &mut Context,
179        buf: &[u8],
180    ) -> Poll<io::Result<usize>> {
181        self.run_in_context(ctx, |s| cvt(s.write(buf)))
182    }
183
184    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
185        self.run_in_context(ctx, |s| cvt(s.flush()))
186    }
187
188    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
189        match self.run_in_context(ctx, |s| s.shutdown()) {
190            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
191            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
192            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
193                return Poll::Pending;
194            }
195            Err(e) => {
196                return Poll::Ready(Err(e
197                    .into_io_error()
198                    .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
199            }
200        }
201
202        Pin::new(&mut self.0.get_mut().stream).poll_close(ctx)
203    }
204}
205
206/// The error type returned after a failed handshake.
207pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
208
209impl<S> HandshakeError<S> {
210    /// Returns a shared reference to the `Ssl` object associated with this error.
211    pub fn ssl(&self) -> Option<&SslRef> {
212        match &self.0 {
213            ssl::HandshakeError::Failure(s) => Some(s.ssl()),
214            _ => None,
215        }
216    }
217
218    /// Converts error to the source data stream that was used for the handshake.
219    pub fn into_source_stream(self) -> Option<S> {
220        match self.0 {
221            ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
222            _ => None,
223        }
224    }
225
226    /// Returns a reference to the source data stream.
227    pub fn as_source_stream(&self) -> Option<&S> {
228        match &self.0 {
229            ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
230            _ => None,
231        }
232    }
233
234    /// Returns the error code, if any.
235    pub fn code(&self) -> Option<ErrorCode> {
236        match &self.0 {
237            ssl::HandshakeError::Failure(s) => Some(s.error().code()),
238            _ => None,
239        }
240    }
241
242    /// Returns a reference to the inner I/O error, if any.
243    pub fn as_io_error(&self) -> Option<&io::Error> {
244        match &self.0 {
245            ssl::HandshakeError::Failure(s) => s.error().io_error(),
246            _ => None,
247        }
248    }
249}
250
251impl<S> fmt::Debug for HandshakeError<S>
252where
253    S: fmt::Debug,
254{
255    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
256        fmt::Debug::fmt(&self.0, fmt)
257    }
258}
259
260impl<S> fmt::Display for HandshakeError<S> {
261    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
262        fmt::Display::fmt(&self.0, fmt)
263    }
264}
265
266impl<S> std::error::Error for HandshakeError<S>
267where
268    S: fmt::Debug,
269{
270    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
271        self.0.source()
272    }
273}
274
275/// Future for an ongoing TLS handshake.
276///
277/// See [`connect`] and [`accept`].
278pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
279
280impl<S> Future for HandshakeFuture<S>
281where
282    S: AsyncRead + AsyncWrite + Unpin,
283{
284    type Output = Result<SslStream<S>, HandshakeError<S>>;
285
286    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
287        let mut mid_handshake = self.0.take().expect("future polled after completion");
288
289        mid_handshake.get_mut().set_waker(Some(ctx));
290        mid_handshake
291            .ssl_mut()
292            .set_task_waker(Some(ctx.waker().clone()));
293
294        match mid_handshake.handshake() {
295            Ok(mut stream) => {
296                stream.get_mut().set_waker(None);
297                stream.ssl_mut().set_task_waker(None);
298
299                Poll::Ready(Ok(SslStream(stream)))
300            }
301            Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
302                mid_handshake.get_mut().set_waker(None);
303                mid_handshake.ssl_mut().set_task_waker(None);
304
305                self.0 = Some(mid_handshake);
306
307                Poll::Pending
308            }
309            Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
310                mid_handshake.get_mut().set_waker(None);
311
312                Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
313                    mid_handshake,
314                ))))
315            }
316            Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
317                Poll::Ready(Err(HandshakeError(err)))
318            }
319        }
320    }
321}