generic_tls/
server.rs

1use crate::stream::Stream;
2
3use rustls_pemfile::Item;
4use std::future::{Future, Ready};
5use std::path::Path;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::{future, io, iter, mem};
10use thiserror::Error;
11use tokio::fs;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig};
14use tokio_rustls::server::TlsStream;
15use tokio_rustls::{Accept, TlsAcceptor};
16
17pub trait Acceptor: Clone + Sync + Send + Unpin + 'static {
18    type Stream<T: Stream>: Stream;
19    type Accept<T: Stream>: Future<Output = Result<Self::Stream<T>, io::Error>> + Unpin;
20
21    fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T>;
22}
23
24impl Acceptor for TlsAcceptor {
25    type Stream<T: Stream> = TlsStream<T>;
26    type Accept<T: Stream> = Accept<T>;
27
28    #[inline(always)]
29    fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T> {
30        TlsAcceptor::accept(self, stream)
31    }
32}
33
34#[derive(Clone, Copy, Debug)]
35pub struct RawAcceptor;
36
37impl Acceptor for RawAcceptor {
38    type Stream<T: Stream> = T;
39    type Accept<T: Stream> = Ready<Result<T, io::Error>>;
40
41    #[inline(always)]
42    fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T> {
43        future::ready(Ok(stream))
44    }
45}
46
47#[derive(Clone, Copy, Debug)]
48pub struct MaybeAcceptor<T>(pub Option<T>);
49
50impl<T: Acceptor> Acceptor for MaybeAcceptor<T> {
51    type Stream<U: Stream> = MaybeStream<T::Stream<U>, U>;
52    type Accept<U: Stream> = MaybeAccept<T, U>;
53
54    #[inline(always)]
55    fn accept<U: Stream>(&self, stream: U) -> Self::Accept<U> {
56        match &self.0 {
57            Some(acceptor) => MaybeAccept::Tls(acceptor.accept(stream)),
58            None => MaybeAccept::Raw(stream),
59        }
60    }
61}
62
63pub enum MaybeStream<T, U> {
64    Tls(T),
65    Raw(U),
66}
67
68impl<T: AsyncRead + Unpin, U: AsyncRead + Unpin> AsyncRead for MaybeStream<T, U> {
69    #[inline(always)]
70    fn poll_read(
71        self: Pin<&mut Self>,
72        context: &mut Context<'_>,
73        buffer: &mut ReadBuf<'_>,
74    ) -> Poll<Result<(), io::Error>> {
75        match self.get_mut() {
76            Self::Tls(tls) => Pin::new(tls).poll_read(context, buffer),
77            Self::Raw(raw) => Pin::new(raw).poll_read(context, buffer),
78        }
79    }
80}
81
82impl<T: AsyncWrite + Unpin, U: AsyncWrite + Unpin> AsyncWrite for MaybeStream<T, U> {
83    #[inline(always)]
84    fn poll_write(
85        self: Pin<&mut Self>,
86        context: &mut Context<'_>,
87        data: &[u8],
88    ) -> Poll<Result<usize, io::Error>> {
89        match self.get_mut() {
90            Self::Tls(tls) => Pin::new(tls).poll_write(context, data),
91            Self::Raw(raw) => Pin::new(raw).poll_write(context, data),
92        }
93    }
94
95    #[inline(always)]
96    fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
97        match self.get_mut() {
98            Self::Tls(tls) => Pin::new(tls).poll_flush(context),
99            Self::Raw(raw) => Pin::new(raw).poll_flush(context),
100        }
101    }
102
103    #[inline(always)]
104    fn poll_shutdown(
105        self: Pin<&mut Self>,
106        context: &mut Context<'_>,
107    ) -> Poll<Result<(), io::Error>> {
108        match self.get_mut() {
109            Self::Tls(tls) => Pin::new(tls).poll_shutdown(context),
110            Self::Raw(raw) => Pin::new(raw).poll_shutdown(context),
111        }
112    }
113}
114
115pub enum MaybeAccept<T: Acceptor, U: Stream> {
116    Tls(T::Accept<U>),
117    Raw(U),
118    Done,
119}
120
121impl<T: Acceptor, U: Stream> Future for MaybeAccept<T, U> {
122    type Output = Result<MaybeStream<T::Stream<U>, U>, io::Error>;
123
124    fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
125        let this = self.get_mut();
126        match mem::replace(this, Self::Done) {
127            Self::Tls(mut tls) => match Pin::new(&mut tls).poll(context) {
128                Poll::Ready(tls) => Poll::Ready(tls.map(MaybeStream::Tls)),
129                Poll::Pending => {
130                    *this = Self::Tls(tls);
131                    Poll::Pending
132                }
133            },
134            Self::Raw(raw) => Poll::Ready(Ok(MaybeStream::Raw(raw))),
135            Self::Done => panic!("MaybeAccept already polled to completion"),
136        }
137    }
138}
139
140#[derive(Error, Debug)]
141pub enum Error {
142    #[error(transparent)]
143    Rustls(#[from] rustls::Error),
144    #[error(transparent)]
145    Io(#[from] io::Error),
146    #[error("Multiple private keys provided")]
147    MultipleKeys,
148    #[error("No suitable private keys provided")]
149    NoKeys,
150}
151
152pub async fn configure(certificate: &Path, key: &Path) -> Result<TlsAcceptor, Error> {
153    enum LoadedItem {
154        Certificate(Vec<u8>),
155        Key(Vec<u8>),
156    }
157
158    let certificate = fs::read_to_string(certificate).await?;
159    let key = fs::read_to_string(key).await?;
160
161    let certificates_iter = iter::from_fn({
162        let mut buffer = certificate.as_bytes();
163
164        move || rustls_pemfile::read_one(&mut buffer).transpose()
165    })
166    .filter_map(|item| match item {
167        Ok(Item::X509Certificate(data)) => Some(Ok(LoadedItem::Certificate(data))),
168        Err(err) => Some(Err(err)),
169        _ => None,
170    });
171
172    let keys_iter = iter::from_fn({
173        let mut buffer = key.as_bytes();
174
175        move || rustls_pemfile::read_one(&mut buffer).transpose()
176    })
177    .filter_map(|item| match item {
178        Ok(Item::RSAKey(data)) | Ok(Item::PKCS8Key(data)) | Ok(Item::ECKey(data)) => {
179            Some(Ok(LoadedItem::Key(data)))
180        }
181        Err(err) => Some(Err(err)),
182        _ => None,
183    });
184
185    let mut certificates = Vec::new();
186    let mut key = None;
187
188    for item in certificates_iter.chain(keys_iter) {
189        let item = item?;
190
191        match item {
192            LoadedItem::Certificate(data) => certificates.push(Certificate(data)),
193            LoadedItem::Key(data) => {
194                if key.is_some() {
195                    return Err(Error::MultipleKeys);
196                }
197
198                key = Some(PrivateKey(data));
199            }
200        }
201    }
202
203    let key = key.ok_or(Error::NoKeys)?;
204
205    ServerConfig::builder()
206        .with_safe_defaults()
207        .with_no_client_auth()
208        .with_single_cert(certificates, key)
209        .map(Arc::new)
210        .map(Into::into)
211        .map_err(Into::into)
212}