async_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams using [Rustls](https://github.com/ctz/rustls).
2
3#![doc(
4    html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
5)]
6#![doc(
7    html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
8)]
9#![deprecated(
10    since = "0.4.2",
11    note = "This crate is now deprecated in favor of [futures-rustls](https://crates.io/crates/futures-rustls)."
12)]
13
14macro_rules! ready {
15    ( $e:expr ) => {
16        match $e {
17            std::task::Poll::Ready(t) => t,
18            std::task::Poll::Pending => return std::task::Poll::Pending,
19        }
20    };
21}
22
23pub mod client;
24mod common;
25pub mod server;
26
27use common::{MidHandshake, Stream, TlsState};
28use futures_io::{AsyncRead, AsyncWrite};
29use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
30use std::future::Future;
31use std::io;
32#[cfg(unix)]
33use std::os::unix::io::{AsRawFd, RawFd};
34#[cfg(windows)]
35use std::os::windows::io::{AsRawSocket, RawSocket};
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{Context, Poll};
39
40pub use rustls;
41
42/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
43#[derive(Clone)]
44pub struct TlsConnector {
45    inner: Arc<ClientConfig>,
46    #[cfg(feature = "early-data")]
47    early_data: bool,
48}
49
50/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
51#[derive(Clone)]
52pub struct TlsAcceptor {
53    inner: Arc<ServerConfig>,
54}
55
56impl From<Arc<ClientConfig>> for TlsConnector {
57    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
58        TlsConnector {
59            inner,
60            #[cfg(feature = "early-data")]
61            early_data: false,
62        }
63    }
64}
65
66impl From<Arc<ServerConfig>> for TlsAcceptor {
67    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
68        TlsAcceptor { inner }
69    }
70}
71
72impl TlsConnector {
73    /// Enable 0-RTT.
74    ///
75    /// If you want to use 0-RTT,
76    /// You must also set `ClientConfig.enable_early_data` to `true`.
77    #[cfg(feature = "early-data")]
78    pub fn early_data(mut self, flag: bool) -> TlsConnector {
79        self.early_data = flag;
80        self
81    }
82
83    #[inline]
84    pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
85    where
86        IO: AsyncRead + AsyncWrite + Unpin,
87    {
88        self.connect_with(domain, stream, |_| ())
89    }
90
91    pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
92    where
93        IO: AsyncRead + AsyncWrite + Unpin,
94        F: FnOnce(&mut ClientConnection),
95    {
96        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
97            Ok(session) => session,
98            Err(error) => {
99                return Connect(MidHandshake::Error {
100                    io: stream,
101                    // TODO(eliza): should this really return an `io::Error`?
102                    // Probably not...
103                    error: io::Error::new(io::ErrorKind::Other, error),
104                });
105            }
106        };
107        f(&mut session);
108
109        Connect(MidHandshake::Handshaking(client::TlsStream {
110            io: stream,
111
112            #[cfg(not(feature = "early-data"))]
113            state: TlsState::Stream,
114
115            #[cfg(feature = "early-data")]
116            state: if self.early_data && session.early_data().is_some() {
117                TlsState::EarlyData(0, Vec::new())
118            } else {
119                TlsState::Stream
120            },
121
122            #[cfg(feature = "early-data")]
123            early_waker: None,
124
125            session,
126        }))
127    }
128}
129
130impl TlsAcceptor {
131    #[inline]
132    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
133    where
134        IO: AsyncRead + AsyncWrite + Unpin,
135    {
136        self.accept_with(stream, |_| ())
137    }
138
139    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
140    where
141        IO: AsyncRead + AsyncWrite + Unpin,
142        F: FnOnce(&mut ServerConnection),
143    {
144        let mut session = match ServerConnection::new(self.inner.clone()) {
145            Ok(session) => session,
146            Err(error) => {
147                return Accept(MidHandshake::Error {
148                    io: stream,
149                    // TODO(eliza): should this really return an `io::Error`?
150                    // Probably not...
151                    error: io::Error::new(io::ErrorKind::Other, error),
152                });
153            }
154        };
155        f(&mut session);
156
157        Accept(MidHandshake::Handshaking(server::TlsStream {
158            session,
159            io: stream,
160            state: TlsState::Stream,
161        }))
162    }
163}
164
165pub struct LazyConfigAcceptor<IO> {
166    acceptor: rustls::server::Acceptor,
167    io: Option<IO>,
168}
169
170impl<IO> LazyConfigAcceptor<IO>
171where
172    IO: AsyncRead + AsyncWrite + Unpin,
173{
174    #[inline]
175    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
176        Self {
177            acceptor,
178            io: Some(io),
179        }
180    }
181}
182
183impl<IO> Future for LazyConfigAcceptor<IO>
184where
185    IO: AsyncRead + AsyncWrite + Unpin,
186{
187    type Output = Result<StartHandshake<IO>, io::Error>;
188
189    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190        let this = self.get_mut();
191        loop {
192            let io = match this.io.as_mut() {
193                Some(io) => io,
194                None => {
195                    panic!("Acceptor cannot be polled after acceptance.");
196                }
197            };
198
199            let mut reader = common::SyncReadAdapter { io, cx };
200            match this.acceptor.read_tls(&mut reader) {
201                Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
202                Ok(_) => {}
203                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
204                Err(e) => return Poll::Ready(Err(e)),
205            }
206
207            match this.acceptor.accept() {
208                Ok(Some(accepted)) => {
209                    let io = this.io.take().unwrap();
210                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
211                }
212                Ok(None) => continue,
213                Err(err) => {
214                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
215                }
216            }
217        }
218    }
219}
220
221pub struct StartHandshake<IO> {
222    accepted: rustls::server::Accepted,
223    io: IO,
224}
225
226impl<IO> StartHandshake<IO>
227where
228    IO: AsyncRead + AsyncWrite + Unpin,
229{
230    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
231        self.accepted.client_hello()
232    }
233
234    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
235        self.into_stream_with(config, |_| ())
236    }
237
238    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
239    where
240        F: FnOnce(&mut ServerConnection),
241    {
242        let mut conn = match self.accepted.into_connection(config) {
243            Ok(conn) => conn,
244            Err(error) => {
245                return Accept(MidHandshake::Error {
246                    io: self.io,
247                    // TODO(eliza): should this really return an `io::Error`?
248                    // Probably not...
249                    error: io::Error::new(io::ErrorKind::Other, error),
250                });
251            }
252        };
253        f(&mut conn);
254
255        Accept(MidHandshake::Handshaking(server::TlsStream {
256            session: conn,
257            io: self.io,
258            state: TlsState::Stream,
259        }))
260    }
261}
262
263/// Future returned from `TlsConnector::connect` which will resolve
264/// once the connection handshake has finished.
265pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
266
267/// Future returned from `TlsAcceptor::accept` which will resolve
268/// once the accept handshake has finished.
269pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
270
271/// Like [Connect], but returns `IO` on failure.
272pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
273
274/// Like [Accept], but returns `IO` on failure.
275pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
276
277impl<IO> Connect<IO> {
278    #[inline]
279    pub fn into_fallible(self) -> FallibleConnect<IO> {
280        FallibleConnect(self.0)
281    }
282
283    pub fn get_ref(&self) -> Option<&IO> {
284        match &self.0 {
285            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
286            MidHandshake::Error { io, .. } => Some(io),
287            MidHandshake::End => None,
288        }
289    }
290
291    pub fn get_mut(&mut self) -> Option<&mut IO> {
292        match &mut self.0 {
293            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
294            MidHandshake::Error { io, .. } => Some(io),
295            MidHandshake::End => None,
296        }
297    }
298}
299
300impl<IO> Accept<IO> {
301    #[inline]
302    pub fn into_fallible(self) -> FallibleAccept<IO> {
303        FallibleAccept(self.0)
304    }
305
306    pub fn get_ref(&self) -> Option<&IO> {
307        match &self.0 {
308            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
309            MidHandshake::Error { io, .. } => Some(io),
310            MidHandshake::End => None,
311        }
312    }
313
314    pub fn get_mut(&mut self) -> Option<&mut IO> {
315        match &mut self.0 {
316            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
317            MidHandshake::Error { io, .. } => Some(io),
318            MidHandshake::End => None,
319        }
320    }
321}
322
323impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
324    type Output = io::Result<client::TlsStream<IO>>;
325
326    #[inline]
327    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
328        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
329    }
330}
331
332impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
333    type Output = io::Result<server::TlsStream<IO>>;
334
335    #[inline]
336    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
337        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
338    }
339}
340
341impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
342    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
343
344    #[inline]
345    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346        Pin::new(&mut self.0).poll(cx)
347    }
348}
349
350impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
351    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
352
353    #[inline]
354    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
355        Pin::new(&mut self.0).poll(cx)
356    }
357}
358
359/// Unified TLS stream type
360///
361/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
362/// a single type to keep both client- and server-initiated TLS-encrypted connections.
363#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
364#[derive(Debug)]
365pub enum TlsStream<T> {
366    Client(client::TlsStream<T>),
367    Server(server::TlsStream<T>),
368}
369
370impl<T> TlsStream<T> {
371    pub fn get_ref(&self) -> (&T, &CommonState) {
372        use TlsStream::*;
373        match self {
374            Client(io) => {
375                let (io, session) = io.get_ref();
376                (io, session)
377            }
378            Server(io) => {
379                let (io, session) = io.get_ref();
380                (io, session)
381            }
382        }
383    }
384
385    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
386        use TlsStream::*;
387        match self {
388            Client(io) => {
389                let (io, session) = io.get_mut();
390                (io, &mut *session)
391            }
392            Server(io) => {
393                let (io, session) = io.get_mut();
394                (io, &mut *session)
395            }
396        }
397    }
398}
399
400impl<T> From<client::TlsStream<T>> for TlsStream<T> {
401    fn from(s: client::TlsStream<T>) -> Self {
402        Self::Client(s)
403    }
404}
405
406impl<T> From<server::TlsStream<T>> for TlsStream<T> {
407    fn from(s: server::TlsStream<T>) -> Self {
408        Self::Server(s)
409    }
410}
411
412#[cfg(unix)]
413impl<S> AsRawFd for TlsStream<S>
414where
415    S: AsRawFd,
416{
417    #[inline]
418    fn as_raw_fd(&self) -> RawFd {
419        self.get_ref().0.as_raw_fd()
420    }
421}
422
423#[cfg(windows)]
424impl<S> AsRawSocket for TlsStream<S>
425where
426    S: AsRawSocket,
427{
428    #[inline]
429    fn as_raw_socket(&self) -> RawSocket {
430        self.get_ref().0.as_raw_socket()
431    }
432}
433
434impl<T> AsyncRead for TlsStream<T>
435where
436    T: AsyncRead + AsyncWrite + Unpin,
437{
438    #[inline]
439    fn poll_read(
440        self: Pin<&mut Self>,
441        cx: &mut Context<'_>,
442        buf: &mut [u8],
443    ) -> Poll<io::Result<usize>> {
444        match self.get_mut() {
445            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
446            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
447        }
448    }
449}
450
451impl<T> AsyncWrite for TlsStream<T>
452where
453    T: AsyncRead + AsyncWrite + Unpin,
454{
455    #[inline]
456    fn poll_write(
457        self: Pin<&mut Self>,
458        cx: &mut Context<'_>,
459        buf: &[u8],
460    ) -> Poll<io::Result<usize>> {
461        match self.get_mut() {
462            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
463            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
464        }
465    }
466
467    #[inline]
468    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
469        match self.get_mut() {
470            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
471            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
472        }
473    }
474
475    #[inline]
476    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
477        match self.get_mut() {
478            TlsStream::Client(x) => Pin::new(x).poll_close(cx),
479            TlsStream::Server(x) => Pin::new(x).poll_close(cx),
480        }
481    }
482}