1use std::io;
4use std::fs::File;
5use std::path::Path;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use log::error;
9use futures::{pin_mut, ready, TryFuture};
10use futures::future::Either;
11use pin_project_lite::pin_project;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio::net::TcpStream;
14use tokio_rustls::{Accept, TlsAcceptor};
15use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
16use tokio_rustls::server::TlsStream;
17use crate::error::ExitError;
18
19pub use tokio_rustls::rustls::ServerConfig;
20
21
22pub fn create_server_config(
29 service: &str, key_path: &Path, cert_path: &Path
30) -> Result<ServerConfig, ExitError> {
31
32 ServerConfig::builder()
33 .with_no_client_auth()
34 .with_single_cert(read_certs(cert_path)?, read_key(key_path)?)
35 .map_err(|err| {
36 error!("Failed to create {service} TLS server config: {err}");
37 ExitError::Generic
38 })
39}
40
41fn read_certs(
43 cert_path: &Path
44) -> Result<Vec<CertificateDer<'static>>, ExitError> {
45 rustls_pemfile::certs(
46 &mut io::BufReader::new(
47 File::open(cert_path).map_err(|err| {
48 error!(
49 "Failed to open TLS certificate file '{}': {}.",
50 cert_path.display(), err
51 );
52 ExitError::Generic
53 })?
54 )
55 ).collect::<Result<_, _>>().map_err(|err| {
56 error!(
57 "Failed to read TLS certificate file '{}': {}.",
58 cert_path.display(), err
59 );
60 ExitError::Generic
61 })
62}
63
64fn read_key(key_path: &Path) -> Result<PrivateKeyDer<'static>, ExitError> {
72 use rustls_pemfile::Item::*;
73
74 let mut key_file = io::BufReader::new(
75 File::open(key_path).map_err(|err| {
76 error!(
77 "Failed to open TLS key file '{}': {}.",
78 key_path.display(), err
79 );
80 ExitError::Generic
81 })?
82 );
83
84 let mut key = None;
85
86 while let Some(item) =
87 rustls_pemfile::read_one(&mut key_file).transpose()
88 {
89 let item = item.map_err(|err| {
90 error!(
91 "Failed to read TLS key file '{}': {}.",
92 key_path.display(), err
93 );
94 ExitError::Generic
95 })?;
96
97 let bits = match item {
98 Pkcs1Key(bits) => bits.into(),
99 Pkcs8Key(bits) => bits.into(),
100 Sec1Key(bits) => bits.into(),
101 _ => continue,
102 };
103 if key.is_some() {
104 error!(
105 "TLS key file '{}' contains multiple keys.",
106 key_path.display()
107 );
108 return Err(ExitError::Generic)
109 }
110 key = Some(bits)
111 }
112
113 match key {
114 Some(key) => Ok(key),
115 None => {
116 error!(
117 "TLS key file '{}' does not contain any usable keys.",
118 key_path.display()
119 );
120 Err(ExitError::Generic)
121 }
122 }
123}
124
125
126pin_project! {
129 #[project = TlsTcpStreamProj]
134 enum TlsTcpStream {
135 Accept { #[pin] fut: Accept<TcpStream> },
137
138 Stream { #[pin] fut: TlsStream<TcpStream> },
140
141 Empty,
148 }
149}
150
151impl TlsTcpStream {
152 fn new(sock: TcpStream, tls: &TlsAcceptor) -> Self {
153 Self::Accept { fut: tls.accept(sock) }
154 }
155
156 fn poll_accept(
157 mut self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 ) -> Poll<Result<Pin<&mut Self>, io::Error>> {
160 match self.as_mut().project() {
161 TlsTcpStreamProj::Accept { fut } => {
162 match ready!(fut.try_poll(cx)) {
163 Ok(fut) => {
164 self.set(Self::Stream { fut });
165 Poll::Ready(Ok(self))
166 }
167 Err(err) => {
168 self.set(Self::Empty);
169 Poll::Ready(Err(err))
170 }
171 }
172 }
173 _ => Poll::Ready(Ok(self)),
174 }
175 }
176}
177
178impl AsyncRead for TlsTcpStream {
179 fn poll_read(
180 self: Pin<&mut Self>,
181 cx: &mut Context<'_>,
182 buf: &mut ReadBuf<'_>
183 ) -> Poll<Result<(), io::Error>> {
184 let mut this = match ready!(self.poll_accept(cx)) {
185 Ok(this) => this,
186 Err(err) => return Poll::Ready(Err(err))
187 };
188 match this.as_mut().project() {
189 TlsTcpStreamProj::Stream { fut } => {
190 fut.poll_read(cx, buf)
191 }
192 TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
193 _ => unreachable!()
194 }
195 }
196}
197
198impl AsyncWrite for TlsTcpStream {
199 fn poll_write(
200 self: Pin<&mut Self>,
201 cx: &mut Context<'_>,
202 buf: &[u8]
203 ) -> Poll<Result<usize, io::Error>> {
204 let mut this = match ready!(self.poll_accept(cx)) {
205 Ok(this) => this,
206 Err(err) => return Poll::Ready(Err(err))
207 };
208 match this.as_mut().project() {
209 TlsTcpStreamProj::Stream { fut } => {
210 fut.poll_write(cx, buf)
211 }
212 TlsTcpStreamProj::Empty => { Poll::Ready(Ok(0)) }
213 _ => unreachable!()
214 }
215 }
216
217 fn poll_flush(
218 self: Pin<&mut Self>,
219 cx: &mut Context<'_>
220 ) -> Poll<Result<(), io::Error>> {
221 let mut this = match ready!(self.poll_accept(cx)) {
222 Ok(this) => this,
223 Err(err) => return Poll::Ready(Err(err))
224 };
225 match this.as_mut().project() {
226 TlsTcpStreamProj::Stream { fut } => {
227 fut.poll_flush(cx)
228 }
229 TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
230 _ => unreachable!()
231 }
232 }
233
234 fn poll_shutdown(
235 self: Pin<&mut Self>,
236 cx: &mut Context<'_>
237 ) -> Poll<Result<(), io::Error>> {
238 let mut this = match ready!(self.poll_accept(cx)) {
239 Ok(this) => this,
240 Err(err) => return Poll::Ready(Err(err))
241 };
242 match this.as_mut().project() {
243 TlsTcpStreamProj::Stream { fut } => {
244 fut.poll_shutdown(cx)
245 }
246 TlsTcpStreamProj::Empty => { Poll::Ready(Ok(())) }
247 _ => unreachable!()
248 }
249 }
250}
251
252
253pub struct MaybeTlsTcpStream {
257 sock: Either<TcpStream, TlsTcpStream>,
258}
259
260impl MaybeTlsTcpStream {
261 pub fn new(sock: TcpStream, tls: Option<&TlsAcceptor>) -> Self {
266 MaybeTlsTcpStream {
267 sock: match tls {
268 Some(tls) => Either::Right(TlsTcpStream::new(sock, tls)),
269 None => Either::Left(sock)
270 }
271 }
272 }
273}
274
275impl AsyncRead for MaybeTlsTcpStream {
276 fn poll_read(
277 mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf
278 ) -> Poll<Result<(), io::Error>> {
279 match self.sock {
280 Either::Left(ref mut sock) => {
281 pin_mut!(sock);
282 sock.poll_read(cx, buf)
283 }
284 Either::Right(ref mut sock) => {
285 pin_mut!(sock);
286 sock.poll_read(cx, buf)
287 }
288 }
289 }
290}
291
292
293impl AsyncWrite for MaybeTlsTcpStream {
294 fn poll_write(
295 mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]
296 ) -> Poll<Result<usize, io::Error>> {
297 match self.sock {
298 Either::Left(ref mut sock) => {
299 pin_mut!(sock);
300 sock.poll_write(cx, buf)
301 }
302 Either::Right(ref mut sock) => {
303 pin_mut!(sock);
304 sock.poll_write(cx, buf)
305 }
306 }
307 }
308
309 fn poll_flush(
310 mut self: Pin<&mut Self>, cx: &mut Context
311 ) -> Poll<Result<(), io::Error>> {
312 match self.sock {
313 Either::Left(ref mut sock) => {
314 pin_mut!(sock);
315 sock.poll_flush(cx)
316 }
317 Either::Right(ref mut sock) => {
318 pin_mut!(sock);
319 sock.poll_flush(cx)
320 }
321 }
322 }
323
324 fn poll_shutdown(
325 mut self: Pin<&mut Self>, cx: &mut Context
326 ) -> Poll<Result<(), io::Error>> {
327 match self.sock {
328 Either::Left(ref mut sock) => {
329 pin_mut!(sock);
330 sock.poll_shutdown(cx)
331 }
332 Either::Right(ref mut sock) => {
333 pin_mut!(sock);
334 sock.poll_shutdown(cx)
335 }
336 }
337 }
338}
339