routinator/utils/
tls.rs

1//! Utility functions for dealing with TLS.
2
3use std::io;
4use std::fs::File;
5use std::path::Path;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use log::error;
9use futures::{pin_mut, ready, TryFuture};
10use futures::future::Either;
11use pin_project_lite::pin_project;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_rustls::{Accept, TlsAcceptor};
15use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
16use tokio_rustls::server::TlsStream;
17use crate::error::ExitError;
18
19pub use tokio_rustls::rustls::ServerConfig;
20
21
22//------------ create_server_config -----------------------------------------
23
24/// Creates the TLS server config.
25///
26/// The service this config is for should be given through `service`. This is
27/// used for logging.
28pub fn create_server_config(
29    service: &str, key_path: &Path, cert_path: &Path
30) -> Result<ServerConfig, ExitError> {
31
32    ServerConfig::builder()
33        .with_no_client_auth()
34        .with_single_cert(read_certs(cert_path)?, read_key(key_path)?)
35        .map_err(|err| {
36            error!("Failed to create {service} TLS server config: {err}");
37            ExitError::Generic
38        })
39}
40
41/// Reads the certificates from the given PEM file.
42fn read_certs(
43    cert_path: &Path
44) -> Result<Vec<CertificateDer<'static>>, ExitError> {
45    rustls_pemfile::certs(
46        &mut io::BufReader::new(
47            File::open(cert_path).map_err(|err| {
48                error!(
49                    "Failed to open TLS certificate file '{}': {}.",
50                    cert_path.display(), err
51                );
52                ExitError::Generic
53            })?
54        )
55    ).collect::<Result<_, _>>().map_err(|err| {
56        error!(
57            "Failed to read TLS certificate file '{}': {}.",
58            cert_path.display(), err
59        );
60        ExitError::Generic
61    })
62}
63
64/// Reads the first private key from the given PEM file.
65///
66/// The key may be a PKCS#1 RSA private key, a PKCS#8 private key, or a
67/// SEC1 encoded EC private key. All other PEM items are ignored.
68///
69/// Errors out if opening or reading the file fails or if there isn’t exactly
70/// one private key in the file.
71fn read_key(key_path: &Path) -> Result<PrivateKeyDer<'static>, ExitError> {
72    use rustls_pemfile::Item::*;
73
74    let mut key_file = io::BufReader::new(
75        File::open(key_path).map_err(|err| {
76            error!(
77                "Failed to open TLS key file '{}': {}.",
78                key_path.display(), err
79            );
80            ExitError::Generic
81        })?
82    );
83
84    let mut key = None;
85
86    while let Some(item) =
87        rustls_pemfile::read_one(&mut key_file).transpose()
88    {
89        let item = item.map_err(|err| {
90            error!(
91                "Failed to read TLS key file '{}': {}.",
92                key_path.display(), err
93            );
94            ExitError::Generic
95        })?;
96
97        let bits = match item {
98            Pkcs1Key(bits) => bits.into(),
99            Pkcs8Key(bits) => bits.into(),
100            Sec1Key(bits) => bits.into(),
101            _ => continue,
102        };
103        if key.is_some() {
104            error!(
105                "TLS key file '{}' contains multiple keys.",
106                key_path.display()
107            );
108            return Err(ExitError::Generic)
109        }
110        key = Some(bits)
111    }
112
113    match key {
114        Some(key) => Ok(key),
115        None => {
116             error!(
117                "TLS key file '{}' does not contain any usable keys.",
118                key_path.display()
119            );
120            Err(ExitError::Generic)
121       }
122    }
123}
124
125
126//------------ TlsTcpStream --------------------------------------------------
127
128pin_project! {
129    /// A TLS stream that behaves like a regular TCP stream.
130    ///
131    /// Specifically, `AsyncRead` and `AsyncWrite` will return `Poll::NotReady`
132    /// until the TLS accept machinery has concluded.
133    #[project = TlsTcpStreamProj]
134    enum TlsTcpStream {
135        /// The TLS handshake is going on.
136        Accept { #[pin] fut: Accept<TcpStream> },
137
138        /// We have a working TLS stream.
139        Stream { #[pin] fut: TlsStream<TcpStream> },
140
141        /// TLS handshake has failed.
142        ///
143        /// Because hyper still wants to do a clean flush and shutdown, we
144        /// need to still work in this state. For read and write, we just
145        /// keep returning the clean shutdown indiciation of zero length
146        /// operations.
147        Empty,
148    }
149}
150
151impl TlsTcpStream {
152    fn new(sock: TcpStream, tls: &TlsAcceptor) -> Self {
153        Self::Accept { fut: tls.accept(sock) }
154    }
155
156    fn poll_accept(
157        mut self: Pin<&mut Self>,
158        cx: &mut Context<'_>,
159    ) -> Poll<Result<Pin<&mut Self>, io::Error>> {
160        match self.as_mut().project() {
161            TlsTcpStreamProj::Accept { fut } => {
162                match ready!(fut.try_poll(cx)) {
163                    Ok(fut) => {
164                        self.set(Self::Stream { fut });
165                        Poll::Ready(Ok(self))
166                    }
167                    Err(err) => {
168                        self.set(Self::Empty);
169                        Poll::Ready(Err(err))
170                    }
171                }
172            }
173            _ => Poll::Ready(Ok(self)),
174        }
175    }
176}
177
178impl AsyncRead for TlsTcpStream {
179    fn poll_read(
180        self: Pin<&mut Self>,
181        cx: &mut Context<'_>,
182        buf: &mut ReadBuf<'_>
183    ) -> Poll<Result<(), io::Error>> {
184        let mut this = match ready!(self.poll_accept(cx)) {
185            Ok(this) => this,
186            Err(err) => return Poll::Ready(Err(err))
187        };
188        match this.as_mut().project() {
189            TlsTcpStreamProj::Stream { fut } => {
190                fut.poll_read(cx, buf)
191            }
192            TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
193            _ => unreachable!()
194        }
195    }
196}
197
198impl AsyncWrite for TlsTcpStream {
199    fn poll_write(
200        self: Pin<&mut Self>,
201        cx: &mut Context<'_>,
202        buf: &[u8]
203    ) -> Poll<Result<usize, io::Error>> {
204        let mut this = match ready!(self.poll_accept(cx)) {
205            Ok(this) => this,
206            Err(err) => return Poll::Ready(Err(err))
207        };
208        match this.as_mut().project() {
209            TlsTcpStreamProj::Stream { fut } => {
210                fut.poll_write(cx, buf)
211            }
212            TlsTcpStreamProj::Empty => { Poll::Ready(Ok(0)) }
213            _ => unreachable!()
214        }
215    }
216
217    fn poll_flush(
218        self: Pin<&mut Self>,
219        cx: &mut Context<'_>
220    ) -> Poll<Result<(), io::Error>> {
221        let mut this = match ready!(self.poll_accept(cx)) {
222            Ok(this) => this,
223            Err(err) => return Poll::Ready(Err(err))
224        };
225        match this.as_mut().project() {
226            TlsTcpStreamProj::Stream { fut } => {
227                fut.poll_flush(cx)
228            }
229            TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
230            _ => unreachable!()
231        }
232    }
233
234    fn poll_shutdown(
235        self: Pin<&mut Self>,
236        cx: &mut Context<'_>
237    ) -> Poll<Result<(), io::Error>> {
238        let mut this = match ready!(self.poll_accept(cx)) {
239            Ok(this) => this,
240            Err(err) => return Poll::Ready(Err(err))
241        };
242        match this.as_mut().project() {
243            TlsTcpStreamProj::Stream { fut } => {
244                fut.poll_shutdown(cx)
245            }
246            TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
247            _ => unreachable!()
248        }
249    }
250}
251
252
253//------------ MaybeTlsTcpStream ---------------------------------------------
254
255/// A TCP stream that may or may not use TLS.
256pub struct MaybeTlsTcpStream {
257    sock: Either<TcpStream, TlsTcpStream>,
258}
259
260impl MaybeTlsTcpStream {
261    /// Creates a new stream.
262    ///
263    /// If `tls` is some, the stream will be a TLS stream, otherwise it
264    /// will be a plain TCP stream.
265    pub fn new(sock: TcpStream, tls: Option<&TlsAcceptor>) -> Self {
266        MaybeTlsTcpStream {
267            sock: match tls {
268                Some(tls) => Either::Right(TlsTcpStream::new(sock, tls)),
269                None => Either::Left(sock)
270            }
271        }
272    }
273}
274
275impl AsyncRead for MaybeTlsTcpStream {
276    fn poll_read(
277        mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf
278    ) -> Poll<Result<(), io::Error>> {
279        match self.sock {
280            Either::Left(ref mut sock) => {
281                pin_mut!(sock);
282                sock.poll_read(cx, buf)
283            }
284            Either::Right(ref mut sock) => {
285                pin_mut!(sock);
286                sock.poll_read(cx, buf)
287            }
288        }
289    }
290}
291
292
293impl AsyncWrite for MaybeTlsTcpStream {
294    fn poll_write(
295        mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]
296    ) -> Poll<Result<usize, io::Error>> {
297        match self.sock {
298            Either::Left(ref mut sock) => {
299                pin_mut!(sock);
300                sock.poll_write(cx, buf)
301            }
302            Either::Right(ref mut sock) => {
303                pin_mut!(sock);
304                sock.poll_write(cx, buf)
305            }
306        }
307    }
308
309    fn poll_flush(
310        mut self: Pin<&mut Self>, cx: &mut Context
311    ) -> Poll<Result<(), io::Error>> {
312        match self.sock {
313            Either::Left(ref mut sock) => {
314                pin_mut!(sock);
315                sock.poll_flush(cx)
316            }
317            Either::Right(ref mut sock) => {
318                pin_mut!(sock);
319                sock.poll_flush(cx)
320            }
321        }
322    }
323
324    fn poll_shutdown(
325        mut self: Pin<&mut Self>, cx: &mut Context
326    ) -> Poll<Result<(), io::Error>> {
327        match self.sock {
328            Either::Left(ref mut sock) => {
329                pin_mut!(sock);
330                sock.poll_shutdown(cx)
331            }
332            Either::Right(ref mut sock) => {
333                pin_mut!(sock);
334                sock.poll_shutdown(cx)
335            }
336        }
337    }
338}
339