1use crate::common::tls_state::TlsState;
2
3use crate::client;
4
5use futures_io::{AsyncRead, AsyncWrite};
6use rustls::{ClientConfig, ClientConnection, OwnedTrustAnchor, RootCertStore, ServerName};
7use std::convert::TryFrom;
8use std::future::Future;
9use std::io;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::task::{Context, Poll};
13
14#[derive(Clone)]
39pub struct TlsConnector {
40 inner: Arc<ClientConfig>,
41 #[cfg(feature = "early-data")]
42 early_data: bool,
43}
44
45impl From<Arc<ClientConfig>> for TlsConnector {
46 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
47 TlsConnector {
48 inner,
49 #[cfg(feature = "early-data")]
50 early_data: false,
51 }
52 }
53}
54
55impl From<ClientConfig> for TlsConnector {
56 fn from(inner: ClientConfig) -> TlsConnector {
57 TlsConnector {
58 inner: Arc::new(inner),
59 #[cfg(feature = "early-data")]
60 early_data: false,
61 }
62 }
63}
64
65impl Default for TlsConnector {
66 fn default() -> Self {
67 let mut root_certs = RootCertStore::empty();
68 root_certs.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
69 OwnedTrustAnchor::from_subject_spki_name_constraints(
70 ta.subject,
71 ta.spki,
72 ta.name_constraints,
73 )
74 }));
75 let config = ClientConfig::builder()
76 .with_safe_defaults()
77 .with_root_certificates(root_certs)
78 .with_no_client_auth();
79 Arc::new(config).into()
80 }
81}
82
83impl TlsConnector {
84 pub fn new() -> Self {
88 Default::default()
89 }
90
91 #[cfg(feature = "early-data")]
95 pub fn early_data(mut self, flag: bool) -> TlsConnector {
96 self.early_data = flag;
97 self
98 }
99
100 #[inline]
106 pub fn connect<'a, IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
107 where
108 IO: AsyncRead + AsyncWrite + Unpin,
109 {
110 self.connect_with(domain, stream, |_| ())
111 }
112
113 fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
116 where
117 IO: AsyncRead + AsyncWrite + Unpin,
118 F: FnOnce(&mut ClientConnection),
119 {
120 let domain = match ServerName::try_from(domain.as_ref()) {
121 Ok(domain) => domain,
122 Err(_) => {
123 return Connect(ConnectInner::Error(Some(io::Error::new(
124 io::ErrorKind::InvalidInput,
125 "invalid domain",
126 ))))
127 }
128 };
129
130 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
131 Ok(session) => session,
132 Err(_) => {
133 return Connect(ConnectInner::Error(Some(io::Error::new(
134 io::ErrorKind::Other,
135 "invalid connection",
136 ))))
137 }
138 };
139
140 f(&mut session);
141
142 #[cfg(not(feature = "early-data"))]
143 {
144 Connect(ConnectInner::Handshake(client::MidHandshake::Handshaking(
145 client::TlsStream {
146 session,
147 io: stream,
148 state: TlsState::Stream,
149 },
150 )))
151 }
152
153 #[cfg(feature = "early-data")]
154 {
155 Connect(ConnectInner::Handshake(if self.early_data {
156 client::MidHandshake::EarlyData(client::TlsStream {
157 session,
158 io: stream,
159 state: TlsState::EarlyData,
160 early_data: (0, Vec::new()),
161 })
162 } else {
163 client::MidHandshake::Handshaking(client::TlsStream {
164 session,
165 io: stream,
166 state: TlsState::Stream,
167 early_data: (0, Vec::new()),
168 })
169 }))
170 }
171 }
172}
173
174pub struct Connect<IO>(ConnectInner<IO>);
177
178enum ConnectInner<IO> {
179 Error(Option<io::Error>),
180 Handshake(client::MidHandshake<IO>),
181}
182
183impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
184 type Output = io::Result<client::TlsStream<IO>>;
185
186 #[inline]
187 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
188 match self.0 {
189 ConnectInner::Error(ref mut err) => {
190 Poll::Ready(Err(err.take().expect("Polled twice after being Ready")))
191 }
192 ConnectInner::Handshake(ref mut handshake) => Pin::new(handshake).poll(cx),
193 }
194 }
195}