1use crate::cache::{SessionCache, SessionKey};
2use crate::{key_index, HttpsLayerSettings, MaybeHttpsStream};
3use antidote::Mutex;
4use boring::error::ErrorStack;
5use boring::ssl::{
6 ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslRef,
7 SslSessionCacheMode,
8};
9use http_old::uri::Scheme;
10use hyper_old::client::connect::{Connected, Connection};
11use hyper_old::client::HttpConnector;
12use hyper_old::service::Service;
13use hyper_old::Uri;
14use std::error::Error;
15use std::future::Future;
16use std::net;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::{fmt, io};
21use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
22use tower_layer::Layer;
23
24#[derive(Clone)]
26pub struct HttpsConnector<T> {
27 http: T,
28 inner: Inner,
29}
30
31#[cfg(feature = "runtime")]
32impl HttpsConnector<HttpConnector> {
33 pub fn new() -> Result<HttpsConnector<HttpConnector>, ErrorStack> {
40 let mut http = HttpConnector::new();
41 http.enforce_http(false);
42
43 HttpsLayer::new().map(|l| l.layer(http))
44 }
45}
46
47impl<S, T> HttpsConnector<S>
48where
49 S: Service<Uri, Response = T> + Send,
50 S::Error: Into<Box<dyn Error + Send + Sync>>,
51 S::Future: Unpin + Send + 'static,
52 T: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
53{
54 pub fn with_connector(
58 http: S,
59 ssl: SslConnectorBuilder,
60 ) -> Result<HttpsConnector<S>, ErrorStack> {
61 HttpsLayer::with_connector(ssl).map(|l| l.layer(http))
62 }
63
64 pub fn set_callback<F>(&mut self, callback: F)
70 where
71 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
72 {
73 self.inner.callback = Some(Arc::new(callback));
74 }
75
76 pub fn set_ssl_callback<F>(&mut self, callback: F)
78 where
79 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
80 {
81 self.inner.ssl_callback = Some(Arc::new(callback));
82 }
83}
84
85pub struct HttpsLayer {
87 inner: Inner,
88}
89
90#[derive(Clone)]
91struct Inner {
92 ssl: SslConnector,
93 cache: Arc<Mutex<SessionCache>>,
94 callback: Option<Callback>,
95 ssl_callback: Option<SslCallback>,
96}
97
98type Callback =
99 Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
100type SslCallback = Arc<dyn Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
101
102impl HttpsLayer {
103 pub fn new() -> Result<HttpsLayer, ErrorStack> {
107 let mut ssl = SslConnector::builder(SslMethod::tls())?;
108
109 ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
110
111 Self::with_connector(ssl)
112 }
113
114 pub fn with_connector(ssl: SslConnectorBuilder) -> Result<HttpsLayer, ErrorStack> {
118 Self::with_connector_and_settings(ssl, Default::default())
119 }
120
121 pub fn with_connector_and_settings(
123 mut ssl: SslConnectorBuilder,
124 settings: HttpsLayerSettings,
125 ) -> Result<HttpsLayer, ErrorStack> {
126 let cache = Arc::new(Mutex::new(SessionCache::with_capacity(
127 settings.session_cache_capacity,
128 )));
129
130 ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
131
132 ssl.set_new_session_callback({
133 let cache = cache.clone();
134 move |ssl, session| {
135 if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
136 cache.lock().insert(key.clone(), session);
137 }
138 }
139 });
140
141 Ok(HttpsLayer {
142 inner: Inner {
143 ssl: ssl.build(),
144 cache,
145 callback: None,
146 ssl_callback: None,
147 },
148 })
149 }
150
151 pub fn set_callback<F>(&mut self, callback: F)
157 where
158 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
159 {
160 self.inner.callback = Some(Arc::new(callback));
161 }
162
163 pub fn set_ssl_callback<F>(&mut self, callback: F)
165 where
166 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
167 {
168 self.inner.ssl_callback = Some(Arc::new(callback));
169 }
170}
171
172impl<S> Layer<S> for HttpsLayer {
173 type Service = HttpsConnector<S>;
174
175 fn layer(&self, inner: S) -> HttpsConnector<S> {
176 HttpsConnector {
177 http: inner,
178 inner: self.inner.clone(),
179 }
180 }
181}
182
183impl Inner {
184 fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<Ssl, ErrorStack> {
185 let mut conf = self.ssl.configure()?;
186
187 if let Some(ref callback) = self.callback {
188 callback(&mut conf, uri)?;
189 }
190
191 let key = SessionKey {
192 host: host.to_string(),
193 port: uri.port_u16().unwrap_or(443),
194 };
195
196 if let Some(session) = self.cache.lock().get(&key) {
197 unsafe {
198 conf.set_session(&session)?;
199 }
200 }
201
202 let idx = key_index()?;
203 conf.set_ex_data(idx, key);
204
205 let mut ssl = conf.into_ssl(host)?;
206
207 if let Some(ref ssl_callback) = self.ssl_callback {
208 ssl_callback(&mut ssl, uri)?;
209 }
210
211 Ok(ssl)
212 }
213}
214
215impl<S> Service<Uri> for HttpsConnector<S>
216where
217 S: Service<Uri> + Send,
218 S::Error: Into<Box<dyn Error + Send + Sync>>,
219 S::Future: Unpin + Send + 'static,
220 S::Response: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
221{
222 type Response = MaybeHttpsStream<S::Response>;
223 type Error = Box<dyn Error + Sync + Send>;
224 #[allow(clippy::type_complexity)]
225 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
226
227 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
228 self.http.poll_ready(cx).map_err(Into::into)
229 }
230
231 fn call(&mut self, uri: Uri) -> Self::Future {
232 let is_tls_scheme = uri
233 .scheme()
234 .map(|s| s == &Scheme::HTTPS || s.as_str() == "wss")
235 .unwrap_or(false);
236
237 let tls_setup = if is_tls_scheme {
238 Some((self.inner.clone(), uri.clone()))
239 } else {
240 None
241 };
242
243 let connect = self.http.call(uri);
244
245 let f = async {
246 let conn = connect.await.map_err(Into::into)?;
247
248 let (inner, uri) = match tls_setup {
249 Some((inner, uri)) => (inner, uri),
250 None => return Ok(MaybeHttpsStream::Http(conn)),
251 };
252
253 let mut host = uri.host().ok_or("URI missing host")?;
254
255 if !host.is_empty() {
259 let last = host.len() - 1;
260 let mut chars = host.chars();
261
262 if (chars.next(), chars.last()) == (Some('['), Some(']'))
263 && host[1..last].parse::<net::Ipv6Addr>().is_ok()
264 {
265 host = &host[1..last];
266 }
267 }
268
269 let ssl = inner.setup_ssl(&uri, host)?;
270 let stream = tokio_boring::SslStreamBuilder::new(ssl, conn)
271 .connect()
272 .await?;
273
274 Ok(MaybeHttpsStream::Https(stream))
275 };
276
277 Box::pin(f)
278 }
279}
280
281impl<T> Connection for MaybeHttpsStream<T>
282where
283 T: Connection,
284{
285 fn connected(&self) -> Connected {
286 match self {
287 MaybeHttpsStream::Http(s) => s.connected(),
288 MaybeHttpsStream::Https(s) => {
289 let mut connected = s.get_ref().connected();
290
291 if s.ssl().selected_alpn_protocol() == Some(b"h2") {
292 connected = connected.negotiated_h2();
293 }
294
295 connected
296 }
297 }
298 }
299}
300
301impl<T> AsyncRead for MaybeHttpsStream<T>
302where
303 T: AsyncRead + AsyncWrite + Unpin,
304{
305 fn poll_read(
306 mut self: Pin<&mut Self>,
307 ctx: &mut Context<'_>,
308 buf: &mut ReadBuf<'_>,
309 ) -> Poll<io::Result<()>> {
310 match &mut *self {
311 MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(ctx, buf),
312 MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(ctx, buf),
313 }
314 }
315}
316
317impl<T> AsyncWrite for MaybeHttpsStream<T>
318where
319 T: AsyncRead + AsyncWrite + Unpin,
320{
321 fn poll_write(
322 mut self: Pin<&mut Self>,
323 ctx: &mut Context<'_>,
324 buf: &[u8],
325 ) -> Poll<io::Result<usize>> {
326 match &mut *self {
327 MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(ctx, buf),
328 MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(ctx, buf),
329 }
330 }
331
332 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
333 match &mut *self {
334 MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(ctx),
335 MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(ctx),
336 }
337 }
338
339 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
340 match &mut *self {
341 MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(ctx),
342 MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx),
343 }
344 }
345}