futures_boring/
lib.rs

1use boring::ssl::*;
2pub use boring::*;
3
4mod bridge;
5mod callbacks;
6
7use futures::{AsyncRead, AsyncWrite, Stream};
8
9use std::error::Error;
10use std::fmt;
11use std::future::Future;
12use std::io::{self, Read, Write};
13use std::pin::Pin;
14
15use std::task::{Context, Poll};
16
17use bridge::*;
18
19pub use callbacks::SslContextBuilderExt;
20
21pub use boring::ssl::{
22    AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
23    BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
24    BoxSelectCertFuture, ExDataFuture,
25};
26
27/// Asynchronously performs a client-side TLS handshake over the provided stream.
28///
29/// This function automatically sets the task waker on the `Ssl` from `config` to
30/// allow to make use of async callbacks provided by the boring crate.
31pub async fn connect<S>(
32    config: ConnectConfiguration,
33    domain: &str,
34    stream: S,
35) -> Result<SslStream<S>, HandshakeError<S>>
36where
37    S: AsyncRead + AsyncWrite + Unpin,
38{
39    let mid_handshake = config
40        .setup_connect(domain, AsyncStreamBridge::new(stream))
41        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
42
43    HandshakeFuture(Some(mid_handshake)).await
44}
45
46/// Asynchronously performs a server-side TLS handshake over the provided stream.
47///
48/// This function automatically sets the task waker on the `Ssl` from `config` to
49/// allow to make use of async callbacks provided by the boring crate.
50pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
51where
52    S: AsyncRead + AsyncWrite + Unpin,
53{
54    let mid_handshake = acceptor
55        .setup_accept(AsyncStreamBridge::new(stream))
56        .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
57
58    HandshakeFuture(Some(mid_handshake)).await
59}
60
61fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
62    match r {
63        Ok(v) => Poll::Ready(Ok(v)),
64        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
65        Err(e) => Poll::Ready(Err(e)),
66    }
67}
68
69/// A partially constructed `SslStream`, useful for unusual handshakes.
70pub struct SslStreamBuilder<S> {
71    inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
72}
73
74impl<S> SslStreamBuilder<S>
75where
76    S: AsyncRead + AsyncWrite + Unpin,
77{
78    /// Begins creating an `SslStream` atop `stream`.
79    pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
80        Self {
81            inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
82        }
83    }
84
85    /// Initiates a client-side TLS handshake.
86    pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
87        let mid_handshake = self.inner.setup_accept();
88
89        HandshakeFuture(Some(mid_handshake)).await
90    }
91
92    /// Initiates a server-side TLS handshake.
93    pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
94        let mid_handshake = self.inner.setup_connect();
95
96        HandshakeFuture(Some(mid_handshake)).await
97    }
98}
99
100impl<S> SslStreamBuilder<S> {
101    /// Returns a shared reference to the `Ssl` object associated with this builder.
102    pub fn ssl(&self) -> &SslRef {
103        self.inner.ssl()
104    }
105
106    /// Returns a mutable reference to the `Ssl` object associated with this builder.
107    pub fn ssl_mut(&mut self) -> &mut SslRef {
108        self.inner.ssl_mut()
109    }
110}
111
112/// A wrapper around an underlying raw stream which implements the SSL
113/// protocol.
114///
115/// A `SslStream<S>` represents a handshake that has been completed successfully
116/// and both the server and the client are ready for receiving and sending
117/// data. Bytes read from a `SslStream` are decrypted from `S` and bytes written
118/// to a `SslStream` are encrypted when passing through to `S`.
119#[derive(Debug)]
120pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
121
122impl<S> SslStream<S> {
123    /// Returns a shared reference to the `Ssl` object associated with this stream.
124    pub fn ssl(&self) -> &SslRef {
125        self.0.ssl()
126    }
127
128    /// Returns a mutable reference to the `Ssl` object associated with this stream.
129    pub fn ssl_mut(&mut self) -> &mut SslRef {
130        self.0.ssl_mut()
131    }
132
133    /// Returns a shared reference to the underlying stream.
134    pub fn get_ref(&self) -> &S {
135        &self.0.get_ref().stream
136    }
137
138    /// Returns a mutable reference to the underlying stream.
139    pub fn get_mut(&mut self) -> &mut S {
140        &mut self.0.get_mut().stream
141    }
142
143    fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
144    where
145        F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
146    {
147        self.0.get_mut().set_waker(Some(ctx));
148
149        let result = f(&mut self.0);
150
151        // NOTE(nox): This should also be executed when `f` panics,
152        // but it's not that important as boring segfaults on panics
153        // and we always set the context prior to doing anything with
154        // the inner async stream.
155        self.0.get_mut().set_waker(None);
156
157        result
158    }
159}
160
161impl<S> AsyncRead for SslStream<S>
162where
163    S: AsyncRead + AsyncWrite + Unpin,
164{
165    fn poll_read(
166        mut self: Pin<&mut Self>,
167        ctx: &mut Context<'_>,
168        buf: &mut [u8],
169    ) -> Poll<io::Result<usize>> {
170        self.run_in_context(ctx, |s| cvt(s.read(buf)))
171    }
172}
173
174impl<S> AsyncWrite for SslStream<S>
175where
176    S: AsyncRead + AsyncWrite + Unpin,
177{
178    fn poll_write(
179        mut self: Pin<&mut Self>,
180        ctx: &mut Context,
181        buf: &[u8],
182    ) -> Poll<io::Result<usize>> {
183        self.run_in_context(ctx, |s| cvt(s.write(buf)))
184    }
185
186    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
187        self.run_in_context(ctx, |s| cvt(s.flush()))
188    }
189
190    fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
191        match self.run_in_context(ctx, |s| s.shutdown()) {
192            Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
193            Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
194            Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
195                return Poll::Pending;
196            }
197            Err(e) => {
198                return Poll::Ready(Err(e
199                    .into_io_error()
200                    .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
201            }
202        }
203
204        Pin::new(&mut self.0.get_mut().stream).poll_close(ctx)
205    }
206}
207
208/// The error type returned after a failed handshake.
209pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
210
211impl<S> HandshakeError<S> {
212    /// Returns a shared reference to the `Ssl` object associated with this error.
213    pub fn ssl(&self) -> Option<&SslRef> {
214        match &self.0 {
215            ssl::HandshakeError::Failure(s) => Some(s.ssl()),
216            _ => None,
217        }
218    }
219
220    /// Converts error to the source data stream that was used for the handshake.
221    pub fn into_source_stream(self) -> Option<S> {
222        match self.0 {
223            ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
224            _ => None,
225        }
226    }
227
228    /// Returns a reference to the source data stream.
229    pub fn as_source_stream(&self) -> Option<&S> {
230        match &self.0 {
231            ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
232            _ => None,
233        }
234    }
235
236    /// Returns the error code, if any.
237    pub fn code(&self) -> Option<ErrorCode> {
238        match &self.0 {
239            ssl::HandshakeError::Failure(s) => Some(s.error().code()),
240            _ => None,
241        }
242    }
243
244    /// Returns a reference to the inner I/O error, if any.
245    pub fn as_io_error(&self) -> Option<&io::Error> {
246        match &self.0 {
247            ssl::HandshakeError::Failure(s) => s.error().io_error(),
248            _ => None,
249        }
250    }
251}
252
253impl<S> fmt::Debug for HandshakeError<S>
254where
255    S: fmt::Debug,
256{
257    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
258        fmt::Debug::fmt(&self.0, fmt)
259    }
260}
261
262impl<S> fmt::Display for HandshakeError<S> {
263    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
264        fmt::Display::fmt(&self.0, fmt)
265    }
266}
267
268impl<S> std::error::Error for HandshakeError<S>
269where
270    S: fmt::Debug,
271{
272    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
273        self.0.source()
274    }
275}
276
277/// Future for an ongoing TLS handshake.
278///
279/// See [`connect`] and [`accept`].
280pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
281
282impl<S> Future for HandshakeFuture<S>
283where
284    S: AsyncRead + AsyncWrite + Unpin,
285{
286    type Output = Result<SslStream<S>, HandshakeError<S>>;
287
288    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
289        let mut mid_handshake = self.0.take().expect("future polled after completion");
290
291        mid_handshake.get_mut().set_waker(Some(ctx));
292        mid_handshake
293            .ssl_mut()
294            .set_task_waker(Some(ctx.waker().clone()));
295
296        match mid_handshake.handshake() {
297            Ok(mut stream) => {
298                stream.get_mut().set_waker(None);
299                stream.ssl_mut().set_task_waker(None);
300
301                Poll::Ready(Ok(SslStream(stream)))
302            }
303            Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
304                mid_handshake.get_mut().set_waker(None);
305                mid_handshake.ssl_mut().set_task_waker(None);
306
307                self.0 = Some(mid_handshake);
308
309                Poll::Pending
310            }
311            Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
312                mid_handshake.get_mut().set_waker(None);
313
314                Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
315                    mid_handshake,
316                ))))
317            }
318            Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
319                Poll::Ready(Err(HandshakeError(err)))
320            }
321        }
322    }
323}
324
325pub struct SslListener<S> {
326    incoming: S,
327    acceptor: SslAcceptor,
328}
329
330impl<S> SslListener<S> {
331    /// Create new SslListener with incoming stream and `SslAcceptor`
332    pub fn on(incoming: S, acceptor: SslAcceptor) -> Self {
333        Self { incoming, acceptor }
334    }
335
336    pub async fn accept<I, E>(&mut self) -> std::io::Result<SslStream<I>>
337    where
338        S: Stream<Item = Result<I, E>> + Unpin,
339        I: AsyncRead + AsyncWrite + Unpin,
340        E: Error,
341    {
342        use futures::TryStreamExt;
343
344        while let Some(stream) = self
345            .incoming
346            .try_next()
347            .await
348            .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?
349        {
350            let stream = accept(&self.acceptor, stream).await.map_err(|err| {
351                let err = std::io::Error::new(std::io::ErrorKind::Other, err.to_string());
352
353                log::error!("{}", err);
354
355                err
356            })?;
357
358            return Ok(stream);
359        }
360
361        Err(std::io::Error::new(
362            io::ErrorKind::BrokenPipe,
363            "Ssl listener inner stream broken.",
364        ))
365    }
366
367    pub fn into_incoming<I, E>(self) -> impl Stream<Item = io::Result<SslStream<I>>> + Unpin
368    where
369        S: Stream<Item = Result<I, E>> + Unpin,
370        I: AsyncRead + AsyncWrite + Unpin,
371        E: Error,
372    {
373        Box::pin(futures::stream::unfold(self, |mut listener| async move {
374            let res = listener.accept().await;
375            Some((res, listener))
376        }))
377    }
378}