libdoh/
tls.rs

1use std::fs::File;
2use std::io::{self, BufReader, Cursor, Read};
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures::{future::FutureExt, join, select};
8use hyper::server::conn::Http;
9use tokio::{
10    net::TcpListener,
11    sync::mpsc::{self, Receiver},
12};
13use tokio_rustls::{
14    rustls::{Certificate, PrivateKey, ServerConfig},
15    TlsAcceptor,
16};
17
18use crate::constants::CERTS_WATCH_DELAY_SECS;
19use crate::errors::*;
20use crate::{DoH, LocalExecutor};
21
22pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor>
23where
24    P: AsRef<Path>,
25    P2: AsRef<Path>,
26{
27    let certs: Vec<_> = {
28        let certs_path_str = certs_path.as_ref().display().to_string();
29        let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
30            io::Error::new(
31                e.kind(),
32                format!("Unable to load the certificates [{certs_path_str}]: {e}"),
33            )
34        })?);
35        rustls_pemfile::certs(&mut reader).map_err(|_| {
36            io::Error::new(
37                io::ErrorKind::InvalidInput,
38                "Unable to parse the certificates",
39            )
40        })?
41    }
42    .drain(..)
43    .map(Certificate)
44    .collect();
45    let certs_keys: Vec<_> = {
46        let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
47        let encoded_keys = {
48            let mut encoded_keys = vec![];
49            File::open(certs_keys_path)
50                .map_err(|e| {
51                    io::Error::new(
52                        e.kind(),
53                        format!("Unable to load the certificate keys [{certs_keys_path_str}]: {e}"),
54                    )
55                })?
56                .read_to_end(&mut encoded_keys)?;
57            encoded_keys
58        };
59        let mut reader = Cursor::new(encoded_keys);
60        let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
61            io::Error::new(
62                io::ErrorKind::InvalidInput,
63                "Unable to parse the certificates private keys (PKCS8)",
64            )
65        })?;
66        reader.set_position(0);
67        let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
68            io::Error::new(
69                io::ErrorKind::InvalidInput,
70                "Unable to parse the certificates private keys (RSA)",
71            )
72        })?;
73        let mut keys = pkcs8_keys;
74        keys.append(&mut rsa_keys);
75        if keys.is_empty() {
76            return Err(io::Error::new(
77                io::ErrorKind::InvalidInput,
78                "No private keys found - Make sure that they are in PKCS#8/PEM format",
79            ));
80        }
81        keys.drain(..).map(PrivateKey).collect()
82    };
83
84    let mut server_config = certs_keys
85        .into_iter()
86        .find_map(|certs_key| {
87            let server_config_builder = ServerConfig::builder()
88                .with_safe_defaults()
89                .with_no_client_auth();
90            server_config_builder
91                .with_single_cert(certs.clone(), certs_key)
92                .ok()
93        })
94        .ok_or_else(|| {
95            io::Error::new(
96                io::ErrorKind::InvalidInput,
97                "Unable to find a valid certificate and key",
98            )
99        })?;
100    server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
101    Ok(TlsAcceptor::from(Arc::new(server_config)))
102}
103
104impl DoH {
105    async fn start_https_service(
106        self,
107        mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
108        listener: TcpListener,
109        server: Http<LocalExecutor>,
110    ) -> Result<(), DoHError> {
111        let mut tls_acceptor: Option<TlsAcceptor> = None;
112        let listener_service = async {
113            loop {
114                select! {
115                    tcp_cnx = listener.accept().fuse() => {
116                        if tls_acceptor.is_none() || tcp_cnx.is_err() {
117                            continue;
118                        }
119                        let (raw_stream, client_addr) = tcp_cnx.unwrap();
120                        if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
121                            let mut doh = self.clone();
122                            doh.remote_addr = Some(client_addr);
123                            doh.client_serve(stream, server.clone()).await
124                        }
125                    }
126                    new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
127                        if new_tls_acceptor.is_none() {
128                            break;
129                        }
130                        tls_acceptor = new_tls_acceptor;
131                    }
132                    complete => break
133                }
134            }
135            Ok(()) as Result<(), DoHError>
136        };
137        listener_service.await?;
138        Ok(())
139    }
140
141    pub async fn start_with_tls(
142        self,
143        listener: TcpListener,
144        server: Http<LocalExecutor>,
145    ) -> Result<(), DoHError> {
146        let certs_path = self
147            .globals
148            .tls_cert_path
149            .as_ref()
150            .ok_or_else(|| {
151                DoHError::Io(std::io::Error::new(
152                    std::io::ErrorKind::NotFound,
153                    "TLS certificate path not provided",
154                ))
155            })?
156            .clone();
157        let certs_keys_path = self
158            .globals
159            .tls_cert_key_path
160            .as_ref()
161            .ok_or_else(|| {
162                DoHError::Io(std::io::Error::new(
163                    std::io::ErrorKind::NotFound,
164                    "TLS certificate key path not provided",
165                ))
166            })?
167            .clone();
168        let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
169        let https_service = self.start_https_service(tls_acceptor_receiver, listener, server);
170        let cert_service = async {
171            loop {
172                match create_tls_acceptor(&certs_path, &certs_keys_path) {
173                    Ok(tls_acceptor) => {
174                        if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
175                            break;
176                        }
177                    }
178                    Err(e) => eprintln!("TLS certificates error: {e}"),
179                }
180                tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
181            }
182            Ok::<_, DoHError>(())
183        };
184        join!(https_service, cert_service).0
185    }
186}