1#![allow(clippy::needless_doctest_main)]
20
21#[macro_use]
22extern crate log;
23
24use hyper::client::connect::{Connected, Connection};
25use hyper::{service::Service, Uri};
26use rustls::client::WantsTransparencyPolicyOrClientCert;
27use rustls::{self, ConfigBuilder, OwnedTrustAnchor, ServerName, WantsCipherSuites};
28use std::convert::TryFrom;
29use std::{
30 fmt,
31 future::Future,
32 io,
33 net::{self, ToSocketAddrs},
34 pin::Pin,
35 sync::Arc,
36 task::{Context, Poll},
37};
38use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
39use tokio::net::TcpStream;
40use tokio_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector};
41
42#[derive(Clone)]
45pub struct AlpnConnector {
46 config: Option<Arc<ClientConfig>>,
47 config_builder: ConfigBuilder<ClientConfig, WantsTransparencyPolicyOrClientCert>,
48}
49
50impl AlpnConnector {
51 fn build_config(&mut self) {
53 if self.config.is_some() {
54 return;
55 }
56
57 let mut config = self.config_builder.clone().with_no_client_auth();
58 config.alpn_protocols.push("h2".as_bytes().to_vec());
59 self.config = Some(Arc::new(config));
60 }
61
62 fn build_config_with_certificate(
64 &mut self,
65 cert_chain: Vec<rustls::Certificate>,
66 key_der: Vec<u8>,
67 ) -> Result<(), rustls::Error> {
68 if self.config.is_some() {
69 return Ok(());
70 }
71
72 let config = self
73 .config_builder
74 .clone()
75 .with_single_cert(cert_chain, rustls::PrivateKey(key_der));
76 match config {
77 Ok(mut c) => {
78 c.alpn_protocols.push("h2".as_bytes().to_vec());
79 self.config = Some(Arc::new(c));
80 Ok(())
81 }
82 Err(e) => Err(e),
83 }
84 }
85}
86
87impl Default for AlpnConnector {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93#[derive(Debug)]
94pub struct AlpnStream(TlsStream<TcpStream>);
95
96impl AsyncRead for AlpnStream {
97 #[inline]
98 fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf) -> Poll<Result<(), io::Error>> {
99 Pin::new(&mut Pin::get_mut(self).0).poll_read(cx, buf)
100 }
101}
102
103impl AsyncWrite for AlpnStream {
104 #[inline]
105 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
106 Pin::new(&mut Pin::get_mut(self).0).poll_write(cx, buf)
107 }
108
109 #[inline]
110 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
111 Pin::new(&mut Pin::get_mut(self).0).poll_flush(cx)
112 }
113
114 #[inline]
115 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
116 Pin::new(&mut Pin::get_mut(self).0).poll_shutdown(cx)
117 }
118}
119
120impl Connection for AlpnStream {
121 fn connected(&self) -> Connected {
122 Connected::new()
123 }
124}
125
126impl AlpnConnector {
127 pub fn new() -> Self {
129 Self::with_client_config(ClientConfig::builder())
130 }
131
132 pub fn with_client_cert(cert_pem: &[u8], key_pem: &[u8]) -> Result<Self, io::Error> {
168 let parsed_keys = rustls_pemfile::pkcs8_private_keys(&mut io::BufReader::new(key_pem)).or({
169 trace!("AlpnConnector::with_client_cert error reading private key");
170 Err(io::Error::new(io::ErrorKind::InvalidData, "private key"))
171 })?;
172
173 if let Some(key) = parsed_keys.first() {
174 let parsed_cert = rustls_pemfile::certs(&mut io::BufReader::new(cert_pem))
175 .or({
176 trace!("AlpnConnector::with_client_cert error reading private key");
177 Err(io::Error::new(io::ErrorKind::InvalidData, "private key"))
178 })?
179 .into_iter()
180 .map(rustls::Certificate)
181 .collect::<Vec<rustls::Certificate>>();
182
183 let mut c = Self::with_client_config(ClientConfig::builder());
184 c.build_config_with_certificate(parsed_cert, key.clone()).or({
185 trace!("AlpnConnector::build_config_with_certificate invalid key");
186 Err(io::Error::new(io::ErrorKind::InvalidData, "key"))
187 })?;
188
189 Ok(c)
190 } else {
191 trace!("AlpnConnector::with_client_cert no private keys found from the given PEM");
192 Err(io::Error::new(io::ErrorKind::InvalidData, "private key"))
193 }
194 }
195
196 fn with_client_config(config: ConfigBuilder<ClientConfig, WantsCipherSuites>) -> Self {
197 let mut root_cert_store = rustls::RootCertStore::empty();
198
199 root_cert_store.add_server_trust_anchors(
200 webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
201 OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)
202 }),
203 );
204
205 let config = config.with_safe_defaults().with_root_certificates(root_cert_store);
206
207 AlpnConnector {
208 config: None,
209 config_builder: config,
210 }
211 }
212
213 async fn resolve(dst: Uri) -> std::io::Result<net::SocketAddr> {
214 let port = dst.port_u16().unwrap_or(443);
215 let host = dst.host().unwrap_or("localhost").to_string();
216
217 let mut addrs = tokio::task::spawn_blocking(move || (host.as_str(), port).to_socket_addrs())
218 .await
219 .unwrap()
220 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("Couldn't resolve host: {:?}", e)))?;
221
222 addrs.next().ok_or_else(|| {
223 io::Error::new(
224 io::ErrorKind::InvalidInput,
225 "Could not resolve host: no address(es) returned".to_string(),
226 )
227 })
228 }
229}
230
231impl fmt::Debug for AlpnConnector {
232 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
233 f.debug_struct("AlpnConnector").finish()
234 }
235}
236
237impl Service<Uri> for AlpnConnector {
238 type Response = AlpnStream;
239 type Error = io::Error;
240 type Future = AlpnConnecting;
241
242 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
243 Poll::Ready(Ok(()))
244 }
245
246 fn call(&mut self, dst: Uri) -> Self::Future {
247 trace!("AlpnConnector::call ({:?})", dst);
248
249 let host = dst.host().unwrap_or("localhost");
250 let host = match ServerName::try_from(host) {
251 Ok(host) => host,
252 Err(err) => {
253 let err = io::Error::new(io::ErrorKind::InvalidInput, format!("invalid url: {:?}", err));
254
255 return AlpnConnecting(Box::pin(async { Err(err) }));
256 }
257 };
258
259 if self.config.is_none() {
261 self.build_config()
262 }
263
264 let config = self.config.clone().unwrap();
265
266 let fut = async move {
267 let socket = Self::resolve(dst).await?;
268 let tcp = TcpStream::connect(&socket).await?;
269
270 trace!("AlpnConnector::call got TCP, trying TLS");
271
272 let connector = TlsConnector::from(config);
273
274 match connector.connect(host, tcp).await {
275 Ok(tls) => Ok(AlpnStream(tls)),
276 Err(e) => {
277 trace!("AlpnConnector::call got error forming a TLS connection.");
278 Err(io::Error::new(io::ErrorKind::Other, e))
279 }
280 }
281 };
282
283 AlpnConnecting(Box::pin(fut))
284 }
285}
286
287type BoxedFut = Pin<Box<dyn Future<Output = io::Result<AlpnStream>> + Send>>;
288
289pub struct AlpnConnecting(BoxedFut);
290
291impl Future for AlpnConnecting {
292 type Output = Result<AlpnStream, io::Error>;
293
294 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295 Pin::new(&mut self.0).poll(cx)
296 }
297}
298
299impl fmt::Debug for AlpnConnecting {
300 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
301 f.pad("AlpnConnecting")
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::AlpnConnector;
308 use hyper::Uri;
309 use std::net::SocketAddr;
310
311 #[tokio::test]
312 async fn test_resolving() {
313 let dst: Uri = "http://theinstituteforendoticresearch.org:80".parse().unwrap();
314 let expected: SocketAddr = "162.213.255.73:80".parse().unwrap();
315
316 assert_eq!(expected, AlpnConnector::resolve(dst).await.unwrap(),)
317 }
318}