1use std::fs::File;
2use std::io::{self, BufReader, Cursor, Read};
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures::{future::FutureExt, join, select};
8use hyper::server::conn::Http;
9use tokio::{
10 net::TcpListener,
11 sync::mpsc::{self, Receiver},
12};
13use tokio_rustls::{
14 rustls::{Certificate, PrivateKey, ServerConfig},
15 TlsAcceptor,
16};
17
18use crate::constants::CERTS_WATCH_DELAY_SECS;
19use crate::errors::*;
20use crate::{DoH, LocalExecutor};
21
22pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor>
23where
24 P: AsRef<Path>,
25 P2: AsRef<Path>,
26{
27 let certs: Vec<_> = {
28 let certs_path_str = certs_path.as_ref().display().to_string();
29 let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
30 io::Error::new(
31 e.kind(),
32 format!("Unable to load the certificates [{certs_path_str}]: {e}"),
33 )
34 })?);
35 rustls_pemfile::certs(&mut reader).map_err(|_| {
36 io::Error::new(
37 io::ErrorKind::InvalidInput,
38 "Unable to parse the certificates",
39 )
40 })?
41 }
42 .drain(..)
43 .map(Certificate)
44 .collect();
45 let certs_keys: Vec<_> = {
46 let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
47 let encoded_keys = {
48 let mut encoded_keys = vec![];
49 File::open(certs_keys_path)
50 .map_err(|e| {
51 io::Error::new(
52 e.kind(),
53 format!("Unable to load the certificate keys [{certs_keys_path_str}]: {e}"),
54 )
55 })?
56 .read_to_end(&mut encoded_keys)?;
57 encoded_keys
58 };
59 let mut reader = Cursor::new(encoded_keys);
60 let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
61 io::Error::new(
62 io::ErrorKind::InvalidInput,
63 "Unable to parse the certificates private keys (PKCS8)",
64 )
65 })?;
66 reader.set_position(0);
67 let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
68 io::Error::new(
69 io::ErrorKind::InvalidInput,
70 "Unable to parse the certificates private keys (RSA)",
71 )
72 })?;
73 let mut keys = pkcs8_keys;
74 keys.append(&mut rsa_keys);
75 if keys.is_empty() {
76 return Err(io::Error::new(
77 io::ErrorKind::InvalidInput,
78 "No private keys found - Make sure that they are in PKCS#8/PEM format",
79 ));
80 }
81 keys.drain(..).map(PrivateKey).collect()
82 };
83
84 let mut server_config = certs_keys
85 .into_iter()
86 .find_map(|certs_key| {
87 let server_config_builder = ServerConfig::builder()
88 .with_safe_defaults()
89 .with_no_client_auth();
90 server_config_builder
91 .with_single_cert(certs.clone(), certs_key)
92 .ok()
93 })
94 .ok_or_else(|| {
95 io::Error::new(
96 io::ErrorKind::InvalidInput,
97 "Unable to find a valid certificate and key",
98 )
99 })?;
100 server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
101 Ok(TlsAcceptor::from(Arc::new(server_config)))
102}
103
104impl DoH {
105 async fn start_https_service(
106 self,
107 mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
108 listener: TcpListener,
109 server: Http<LocalExecutor>,
110 ) -> Result<(), DoHError> {
111 let mut tls_acceptor: Option<TlsAcceptor> = None;
112 let listener_service = async {
113 loop {
114 select! {
115 tcp_cnx = listener.accept().fuse() => {
116 if tls_acceptor.is_none() || tcp_cnx.is_err() {
117 continue;
118 }
119 let (raw_stream, client_addr) = tcp_cnx.unwrap();
120 if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
121 let mut doh = self.clone();
122 doh.remote_addr = Some(client_addr);
123 doh.client_serve(stream, server.clone()).await
124 }
125 }
126 new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
127 if new_tls_acceptor.is_none() {
128 break;
129 }
130 tls_acceptor = new_tls_acceptor;
131 }
132 complete => break
133 }
134 }
135 Ok(()) as Result<(), DoHError>
136 };
137 listener_service.await?;
138 Ok(())
139 }
140
141 pub async fn start_with_tls(
142 self,
143 listener: TcpListener,
144 server: Http<LocalExecutor>,
145 ) -> Result<(), DoHError> {
146 let certs_path = self
147 .globals
148 .tls_cert_path
149 .as_ref()
150 .ok_or_else(|| {
151 DoHError::Io(std::io::Error::new(
152 std::io::ErrorKind::NotFound,
153 "TLS certificate path not provided",
154 ))
155 })?
156 .clone();
157 let certs_keys_path = self
158 .globals
159 .tls_cert_key_path
160 .as_ref()
161 .ok_or_else(|| {
162 DoHError::Io(std::io::Error::new(
163 std::io::ErrorKind::NotFound,
164 "TLS certificate key path not provided",
165 ))
166 })?
167 .clone();
168 let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
169 let https_service = self.start_https_service(tls_acceptor_receiver, listener, server);
170 let cert_service = async {
171 loop {
172 match create_tls_acceptor(&certs_path, &certs_keys_path) {
173 Ok(tls_acceptor) => {
174 if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
175 break;
176 }
177 }
178 Err(e) => eprintln!("TLS certificates error: {e}"),
179 }
180 tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
181 }
182 Ok::<_, DoHError>(())
183 };
184 join!(https_service, cert_service).0
185 }
186}