1use crate::stream::Stream;
2
3use rustls_pemfile::Item;
4use std::future::{Future, Ready};
5use std::path::Path;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9use std::{future, io, iter, mem};
10use thiserror::Error;
11use tokio::fs;
12use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13use tokio_rustls::rustls::{self, Certificate, PrivateKey, ServerConfig};
14use tokio_rustls::server::TlsStream;
15use tokio_rustls::{Accept, TlsAcceptor};
16
17pub trait Acceptor: Clone + Sync + Send + Unpin + 'static {
18 type Stream<T: Stream>: Stream;
19 type Accept<T: Stream>: Future<Output = Result<Self::Stream<T>, io::Error>> + Unpin;
20
21 fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T>;
22}
23
24impl Acceptor for TlsAcceptor {
25 type Stream<T: Stream> = TlsStream<T>;
26 type Accept<T: Stream> = Accept<T>;
27
28 #[inline(always)]
29 fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T> {
30 TlsAcceptor::accept(self, stream)
31 }
32}
33
34#[derive(Clone, Copy, Debug)]
35pub struct RawAcceptor;
36
37impl Acceptor for RawAcceptor {
38 type Stream<T: Stream> = T;
39 type Accept<T: Stream> = Ready<Result<T, io::Error>>;
40
41 #[inline(always)]
42 fn accept<T: Stream>(&self, stream: T) -> Self::Accept<T> {
43 future::ready(Ok(stream))
44 }
45}
46
47#[derive(Clone, Copy, Debug)]
48pub struct MaybeAcceptor<T>(pub Option<T>);
49
50impl<T: Acceptor> Acceptor for MaybeAcceptor<T> {
51 type Stream<U: Stream> = MaybeStream<T::Stream<U>, U>;
52 type Accept<U: Stream> = MaybeAccept<T, U>;
53
54 #[inline(always)]
55 fn accept<U: Stream>(&self, stream: U) -> Self::Accept<U> {
56 match &self.0 {
57 Some(acceptor) => MaybeAccept::Tls(acceptor.accept(stream)),
58 None => MaybeAccept::Raw(stream),
59 }
60 }
61}
62
63pub enum MaybeStream<T, U> {
64 Tls(T),
65 Raw(U),
66}
67
68impl<T: AsyncRead + Unpin, U: AsyncRead + Unpin> AsyncRead for MaybeStream<T, U> {
69 #[inline(always)]
70 fn poll_read(
71 self: Pin<&mut Self>,
72 context: &mut Context<'_>,
73 buffer: &mut ReadBuf<'_>,
74 ) -> Poll<Result<(), io::Error>> {
75 match self.get_mut() {
76 Self::Tls(tls) => Pin::new(tls).poll_read(context, buffer),
77 Self::Raw(raw) => Pin::new(raw).poll_read(context, buffer),
78 }
79 }
80}
81
82impl<T: AsyncWrite + Unpin, U: AsyncWrite + Unpin> AsyncWrite for MaybeStream<T, U> {
83 #[inline(always)]
84 fn poll_write(
85 self: Pin<&mut Self>,
86 context: &mut Context<'_>,
87 data: &[u8],
88 ) -> Poll<Result<usize, io::Error>> {
89 match self.get_mut() {
90 Self::Tls(tls) => Pin::new(tls).poll_write(context, data),
91 Self::Raw(raw) => Pin::new(raw).poll_write(context, data),
92 }
93 }
94
95 #[inline(always)]
96 fn poll_flush(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
97 match self.get_mut() {
98 Self::Tls(tls) => Pin::new(tls).poll_flush(context),
99 Self::Raw(raw) => Pin::new(raw).poll_flush(context),
100 }
101 }
102
103 #[inline(always)]
104 fn poll_shutdown(
105 self: Pin<&mut Self>,
106 context: &mut Context<'_>,
107 ) -> Poll<Result<(), io::Error>> {
108 match self.get_mut() {
109 Self::Tls(tls) => Pin::new(tls).poll_shutdown(context),
110 Self::Raw(raw) => Pin::new(raw).poll_shutdown(context),
111 }
112 }
113}
114
115pub enum MaybeAccept<T: Acceptor, U: Stream> {
116 Tls(T::Accept<U>),
117 Raw(U),
118 Done,
119}
120
121impl<T: Acceptor, U: Stream> Future for MaybeAccept<T, U> {
122 type Output = Result<MaybeStream<T::Stream<U>, U>, io::Error>;
123
124 fn poll(self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
125 let this = self.get_mut();
126 match mem::replace(this, Self::Done) {
127 Self::Tls(mut tls) => match Pin::new(&mut tls).poll(context) {
128 Poll::Ready(tls) => Poll::Ready(tls.map(MaybeStream::Tls)),
129 Poll::Pending => {
130 *this = Self::Tls(tls);
131 Poll::Pending
132 }
133 },
134 Self::Raw(raw) => Poll::Ready(Ok(MaybeStream::Raw(raw))),
135 Self::Done => panic!("MaybeAccept already polled to completion"),
136 }
137 }
138}
139
140#[derive(Error, Debug)]
141pub enum Error {
142 #[error(transparent)]
143 Rustls(#[from] rustls::Error),
144 #[error(transparent)]
145 Io(#[from] io::Error),
146 #[error("Multiple private keys provided")]
147 MultipleKeys,
148 #[error("No suitable private keys provided")]
149 NoKeys,
150}
151
152pub async fn configure(certificate: &Path, key: &Path) -> Result<TlsAcceptor, Error> {
153 enum LoadedItem {
154 Certificate(Vec<u8>),
155 Key(Vec<u8>),
156 }
157
158 let certificate = fs::read_to_string(certificate).await?;
159 let key = fs::read_to_string(key).await?;
160
161 let certificates_iter = iter::from_fn({
162 let mut buffer = certificate.as_bytes();
163
164 move || rustls_pemfile::read_one(&mut buffer).transpose()
165 })
166 .filter_map(|item| match item {
167 Ok(Item::X509Certificate(data)) => Some(Ok(LoadedItem::Certificate(data))),
168 Err(err) => Some(Err(err)),
169 _ => None,
170 });
171
172 let keys_iter = iter::from_fn({
173 let mut buffer = key.as_bytes();
174
175 move || rustls_pemfile::read_one(&mut buffer).transpose()
176 })
177 .filter_map(|item| match item {
178 Ok(Item::RSAKey(data)) | Ok(Item::PKCS8Key(data)) | Ok(Item::ECKey(data)) => {
179 Some(Ok(LoadedItem::Key(data)))
180 }
181 Err(err) => Some(Err(err)),
182 _ => None,
183 });
184
185 let mut certificates = Vec::new();
186 let mut key = None;
187
188 for item in certificates_iter.chain(keys_iter) {
189 let item = item?;
190
191 match item {
192 LoadedItem::Certificate(data) => certificates.push(Certificate(data)),
193 LoadedItem::Key(data) => {
194 if key.is_some() {
195 return Err(Error::MultipleKeys);
196 }
197
198 key = Some(PrivateKey(data));
199 }
200 }
201 }
202
203 let key = key.ok_or(Error::NoKeys)?;
204
205 ServerConfig::builder()
206 .with_safe_defaults()
207 .with_no_client_auth()
208 .with_single_cert(certificates, key)
209 .map(Arc::new)
210 .map(Into::into)
211 .map_err(Into::into)
212}