Skip to main content

static_web_server/
tls.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! The module handles requests over TLS via [Rustls](tokio_rustls::rustls).
7//!
8
9// Most of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/tls.rs
10
11use futures_util::ready;
12use hyper::server::accept::Accept;
13use hyper::server::conn::{AddrIncoming, AddrStream};
14use rustls_pki_types::pem::PemObject;
15use rustls_pki_types::{CertificateDer, PrivateKeyDer};
16use std::fs::File;
17use std::future::Future;
18use std::io::{self, BufReader, Cursor, Read};
19use std::net::SocketAddr;
20use std::path::{Path, PathBuf};
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
25use tokio_rustls::rustls::{Error as TlsError, ServerConfig};
26
27use crate::transport::Transport;
28
29/// Represents errors that can occur building the TlsConfig
30#[derive(Debug)]
31pub enum TlsConfigError {
32    /// Error type for I/O operations
33    Io(io::Error),
34    /// An Error parsing the Certificate
35    CertParseError,
36    /// Identity PEM is invalid
37    InvalidIdentityPem,
38    /// An error from an empty key
39    EmptyKey,
40    /// Unknown private key format
41    UnknownPrivateKeyFormat,
42    /// An error from an invalid key
43    InvalidKey(TlsError),
44    /// Illegal section start in PEM
45    IllegalSectionStart(Vec<u8>),
46    /// Illegal section end in PEM
47    IllegalSectionEnd(Vec<u8>),
48}
49
50impl std::fmt::Display for TlsConfigError {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        match self {
53            TlsConfigError::Io(err) => err.fmt(f),
54            TlsConfigError::CertParseError => write!(f, "failed to parse certificate"),
55            TlsConfigError::InvalidIdentityPem => write!(f, "the identity PEM provided is invalid"),
56            TlsConfigError::UnknownPrivateKeyFormat => {
57                write!(f, "the private key format is unknown")
58            }
59            TlsConfigError::EmptyKey => write!(f, "the key provided is probably missing or empty"),
60            TlsConfigError::InvalidKey(err) => write!(f, "the key provided is invalid, {err}"),
61            TlsConfigError::IllegalSectionStart(line) => {
62                let line = String::from_utf8(line.clone()).unwrap_or_default();
63                write!(f, "illegal section start in PEM at '{line}'")
64            }
65            TlsConfigError::IllegalSectionEnd(end_marker) => {
66                let end_marker = String::from_utf8(end_marker.clone()).unwrap_or_default();
67                write!(f, "illegal section end in PEM at '{end_marker}'")
68            }
69        }
70    }
71}
72
73impl std::error::Error for TlsConfigError {}
74
75/// Builder to set the configuration for the Tls server.
76pub struct TlsConfigBuilder {
77    cert: Box<dyn Read + Send + Sync>,
78    key: Box<dyn Read + Send + Sync>,
79}
80
81impl std::fmt::Debug for TlsConfigBuilder {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
83        f.debug_struct("TlsConfigBuilder").finish()
84    }
85}
86
87impl TlsConfigBuilder {
88    /// Create a new TlsConfigBuilder
89    pub fn new() -> TlsConfigBuilder {
90        TlsConfigBuilder {
91            key: Box::new(io::empty()),
92            cert: Box::new(io::empty()),
93        }
94    }
95
96    /// sets the Tls key via File Path, returns `TlsConfigError::IoError` if the file cannot be open
97    pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
98        self.key = Box::new(LazyFile {
99            path: path.as_ref().into(),
100            file: None,
101        });
102        self
103    }
104
105    /// sets the Tls key via bytes slice
106    pub fn key(mut self, key: &[u8]) -> Self {
107        self.key = Box::new(Cursor::new(Vec::from(key)));
108        self
109    }
110
111    /// Specify the file path for the TLS certificate to use.
112    pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
113        self.cert = Box::new(LazyFile {
114            path: path.as_ref().into(),
115            file: None,
116        });
117        self
118    }
119
120    /// sets the Tls certificate via bytes slice
121    pub fn cert(mut self, cert: &[u8]) -> Self {
122        self.cert = Box::new(Cursor::new(Vec::from(cert)));
123        self
124    }
125
126    /// Builds TLS configuration.
127    pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
128        let mut cert_rdr = BufReader::new(self.cert);
129        let cert = CertificateDer::pem_reader_iter(&mut cert_rdr)
130            .collect::<Result<Vec<_>, _>>()
131            .map_err(|_e| TlsConfigError::CertParseError)?;
132
133        // convert it to Vec<u8> to allow reading it again if key is RSA
134        let mut key_buf = Vec::new();
135        self.key
136            .read_to_end(&mut key_buf)
137            .map_err(TlsConfigError::Io)?;
138
139        if key_buf.is_empty() {
140            return Err(TlsConfigError::EmptyKey);
141        }
142
143        let reader = Cursor::new(key_buf);
144        let key = PrivateKeyDer::from_pem_reader(reader).map_err(|err| match err {
145            rustls_pki_types::pem::Error::Base64Decode(_) => TlsConfigError::InvalidIdentityPem,
146            rustls_pki_types::pem::Error::NoItemsFound => TlsConfigError::EmptyKey,
147            rustls_pki_types::pem::Error::IllegalSectionStart { line } => {
148                TlsConfigError::IllegalSectionStart(line)
149            }
150            rustls_pki_types::pem::Error::MissingSectionEnd { end_marker } => {
151                TlsConfigError::IllegalSectionEnd(end_marker)
152            }
153            rustls_pki_types::pem::Error::Io(err) => {
154                TlsConfigError::Io(io::Error::new(io::ErrorKind::InvalidData, err))
155            }
156            _ => TlsConfigError::InvalidIdentityPem,
157        })?;
158
159        let mut config = ServerConfig::builder()
160            .with_no_client_auth()
161            .with_single_cert(cert, key)
162            .map_err(TlsConfigError::InvalidKey)?;
163        config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
164        Ok(config)
165    }
166}
167
168impl Default for TlsConfigBuilder {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174struct LazyFile {
175    path: PathBuf,
176    file: Option<File>,
177}
178
179impl LazyFile {
180    fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
181        if self.file.is_none() {
182            self.file = Some(File::open(&self.path)?);
183        }
184
185        self.file.as_mut().unwrap().read(buf)
186    }
187}
188
189impl Read for LazyFile {
190    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
191        self.lazy_read(buf).map_err(|err| {
192            let kind = err.kind();
193            io::Error::new(
194                kind,
195                format!("error reading file ({:?}): {}", self.path.display(), err),
196            )
197        })
198    }
199}
200
201impl Transport for TlsStream {
202    fn remote_addr(&self) -> Option<SocketAddr> {
203        Some(self.remote_addr)
204    }
205}
206
207enum State {
208    Handshaking(tokio_rustls::Accept<AddrStream>),
209    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
210}
211
212/// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first.
213///
214/// tokio_rustls::server::TlsStream doesn't expose constructor methods,
215/// so we have to TlsAcceptor::accept and handshake to have access to it.
216pub struct TlsStream {
217    state: State,
218    remote_addr: SocketAddr,
219}
220
221impl TlsStream {
222    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
223        let remote_addr = stream.remote_addr();
224        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
225        TlsStream {
226            state: State::Handshaking(accept),
227            remote_addr,
228        }
229    }
230}
231
232impl AsyncRead for TlsStream {
233    fn poll_read(
234        self: Pin<&mut Self>,
235        cx: &mut Context<'_>,
236        buf: &mut ReadBuf<'_>,
237    ) -> Poll<io::Result<()>> {
238        let pin = self.get_mut();
239        match pin.state {
240            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
241                Ok(mut stream) => {
242                    let result = Pin::new(&mut stream).poll_read(cx, buf);
243                    pin.state = State::Streaming(stream);
244                    result
245                }
246                Err(err) => Poll::Ready(Err(err)),
247            },
248            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
249        }
250    }
251}
252
253impl AsyncWrite for TlsStream {
254    fn poll_write(
255        self: Pin<&mut Self>,
256        cx: &mut Context<'_>,
257        buf: &[u8],
258    ) -> Poll<io::Result<usize>> {
259        let pin = self.get_mut();
260        match pin.state {
261            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
262                Ok(mut stream) => {
263                    let result = Pin::new(&mut stream).poll_write(cx, buf);
264                    pin.state = State::Streaming(stream);
265                    result
266                }
267                Err(err) => Poll::Ready(Err(err)),
268            },
269            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
270        }
271    }
272
273    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
274        match self.state {
275            State::Handshaking(_) => Poll::Ready(Ok(())),
276            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
277        }
278    }
279
280    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281        match self.state {
282            State::Handshaking(_) => Poll::Ready(Ok(())),
283            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
284        }
285    }
286}
287
288/// Type to intercept Tls incoming connections.
289pub struct TlsAcceptor {
290    config: Arc<ServerConfig>,
291    incoming: AddrIncoming,
292}
293
294impl TlsAcceptor {
295    /// Creates a new Tls interceptor.
296    pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
297        TlsAcceptor {
298            config: Arc::new(config),
299            incoming,
300        }
301    }
302}
303
304impl Accept for TlsAcceptor {
305    type Conn = TlsStream;
306    type Error = io::Error;
307
308    fn poll_accept(
309        self: Pin<&mut Self>,
310        cx: &mut Context<'_>,
311    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
312        let pin = self.get_mut();
313        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
314            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
315            Some(Err(e)) => Poll::Ready(Some(Err(e))),
316            None => Poll::Ready(None),
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn file_cert_key_rsa_pkcs1() {
327        TlsConfigBuilder::new()
328            .cert_path("tests/tls/local.dev_cert.rsa_pkcs1.pem")
329            .key_path("tests/tls/local.dev_key.rsa_pkcs1.pem")
330            .build()
331            .unwrap();
332    }
333
334    #[test]
335    fn bytes_cert_key_rsa_pkcs1() {
336        let cert = include_str!("../tests/tls/local.dev_cert.rsa_pkcs1.pem");
337        let key = include_str!("../tests/tls/local.dev_key.rsa_pkcs1.pem");
338
339        TlsConfigBuilder::new()
340            .key(key.as_bytes())
341            .cert(cert.as_bytes())
342            .build()
343            .unwrap();
344    }
345
346    #[test]
347    fn file_cert_key_pkcs8() {
348        TlsConfigBuilder::new()
349            .cert_path("tests/tls/local.dev_cert.pkcs8.pem")
350            .key_path("tests/tls/local.dev_key.pkcs8.pem")
351            .build()
352            .unwrap();
353    }
354
355    #[test]
356    fn bytes_cert_key_pkcs8() {
357        let cert = include_str!("../tests/tls/local.dev_cert.pkcs8.pem");
358        let key = include_str!("../tests/tls/local.dev_key.pkcs8.pem");
359
360        TlsConfigBuilder::new()
361            .key(key.as_bytes())
362            .cert(cert.as_bytes())
363            .build()
364            .unwrap();
365    }
366
367    #[test]
368    fn file_cert_key_sec1_ec() {
369        TlsConfigBuilder::new()
370            .cert_path("tests/tls/local.dev_cert.sec1_ec.pem")
371            .key_path("tests/tls/local.dev_key.sec1_ec.pem")
372            .build()
373            .unwrap();
374    }
375
376    #[test]
377    fn bytes_cert_key_sec1_ec() {
378        let cert = include_str!("../tests/tls/local.dev_cert.sec1_ec.pem");
379        let key = include_str!("../tests/tls/local.dev_key.sec1_ec.pem");
380
381        TlsConfigBuilder::new()
382            .key(key.as_bytes())
383            .cert(cert.as_bytes())
384            .build()
385            .unwrap();
386    }
387}