1use crate::cache::{SessionCache, SessionKey};
2use crate::{key_index, HttpsLayerSettings, MaybeHttpsStream};
3use antidote::Mutex;
4use boring::error::ErrorStack;
5use boring::ssl::{
6 ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslRef,
7 SslSessionCacheMode,
8};
9use http1::uri::Scheme;
10use http1::Uri;
11use hyper1::rt::{Read, ReadBufCursor, Write};
12use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
13use hyper_util::rt::TokioIo;
14use std::error::Error;
15use std::fmt;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::{io, net};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tokio::net::TcpStream;
23#[cfg(all(feature = "runtime", feature = "hyper1-runtime"))]
24use tower::util::MapResponse;
25#[cfg(all(feature = "runtime", feature = "hyper1-runtime"))]
26use tower::ServiceExt;
27use tower_layer::Layer;
28use tower_service::Service;
29
30#[derive(Clone)]
32pub struct HttpsConnector<T> {
33 http: T,
34 inner: Inner,
35}
36
37pub type TokioHttpConnector =
40 MapResponse<HttpConnector, fn(TokioIo<TcpStream>) -> TokioIo<TokioIo<TcpStream>>>;
41
42#[cfg(all(feature = "runtime", feature = "hyper1-runtime"))]
43impl HttpsConnector<TokioHttpConnector> {
44 pub fn new() -> Result<Self, ErrorStack> {
51 let mut http = HttpConnector::new();
52 http.enforce_http(false);
53
54 HttpsLayer::new().map(|l| l.layer(http.map_response(TokioIo::new as _)))
55 }
56}
57
58impl<S, T> HttpsConnector<S>
59where
60 S: Service<Uri, Response = T> + Send,
61 S::Error: Into<Box<dyn Error + Send + Sync>>,
62 S::Future: Unpin + Send + 'static,
63 T: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
64{
65 pub fn with_connector(
73 http: S,
74 ssl: SslConnectorBuilder,
75 ) -> Result<HttpsConnector<S>, ErrorStack> {
76 HttpsLayer::with_connector(ssl).map(|l| l.layer(http))
77 }
78
79 pub fn set_callback<F>(&mut self, callback: F)
85 where
86 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
87 {
88 self.inner.callback = Some(Arc::new(callback));
89 }
90
91 pub fn set_ssl_callback<F>(&mut self, callback: F)
93 where
94 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
95 {
96 self.inner.ssl_callback = Some(Arc::new(callback));
97 }
98}
99
100pub struct HttpsLayer {
102 inner: Inner,
103}
104
105#[derive(Clone)]
106struct Inner {
107 ssl: SslConnector,
108 cache: Arc<Mutex<SessionCache>>,
109 callback: Option<Callback>,
110 ssl_callback: Option<SslCallback>,
111}
112
113type Callback =
114 Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
115type SslCallback = Arc<dyn Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
116
117impl HttpsLayer {
118 pub fn new() -> Result<HttpsLayer, ErrorStack> {
122 let mut ssl = SslConnector::builder(SslMethod::tls())?;
123
124 ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
125
126 Self::with_connector(ssl)
127 }
128
129 pub fn with_connector(ssl: SslConnectorBuilder) -> Result<HttpsLayer, ErrorStack> {
133 Self::with_connector_and_settings(ssl, Default::default())
134 }
135
136 pub fn with_connector_and_settings(
138 mut ssl: SslConnectorBuilder,
139 settings: HttpsLayerSettings,
140 ) -> Result<HttpsLayer, ErrorStack> {
141 let cache = Arc::new(Mutex::new(SessionCache::with_capacity(
142 settings.session_cache_capacity,
143 )));
144
145 ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
146
147 ssl.set_new_session_callback({
148 let cache = cache.clone();
149 move |ssl, session| {
150 if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
151 cache.lock().insert(key.clone(), session);
152 }
153 }
154 });
155
156 Ok(HttpsLayer {
157 inner: Inner {
158 ssl: ssl.build(),
159 cache,
160 callback: None,
161 ssl_callback: None,
162 },
163 })
164 }
165
166 pub fn set_callback<F>(&mut self, callback: F)
172 where
173 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
174 {
175 self.inner.callback = Some(Arc::new(callback));
176 }
177
178 pub fn set_ssl_callback<F>(&mut self, callback: F)
180 where
181 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
182 {
183 self.inner.ssl_callback = Some(Arc::new(callback));
184 }
185}
186
187impl<S> Layer<S> for HttpsLayer {
188 type Service = HttpsConnector<S>;
189
190 fn layer(&self, inner: S) -> HttpsConnector<S> {
191 HttpsConnector {
192 http: inner,
193 inner: self.inner.clone(),
194 }
195 }
196}
197
198impl Inner {
199 fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<Ssl, ErrorStack> {
200 let mut conf = self.ssl.configure()?;
201
202 if let Some(ref callback) = self.callback {
203 callback(&mut conf, uri)?;
204 }
205
206 let key = SessionKey {
207 host: host.to_string(),
208 port: uri.port_u16().unwrap_or(443),
209 };
210
211 if let Some(session) = self.cache.lock().get(&key) {
212 unsafe {
213 conf.set_session(&session)?;
214 }
215 }
216
217 let idx = key_index()?;
218 conf.set_ex_data(idx, key);
219
220 let mut ssl = conf.into_ssl(host)?;
221
222 if let Some(ref ssl_callback) = self.ssl_callback {
223 ssl_callback(&mut ssl, uri)?;
224 }
225
226 Ok(ssl)
227 }
228}
229
230impl<T, S> Service<Uri> for HttpsConnector<S>
231where
232 S: Service<Uri, Response = T> + Send,
233 S::Error: Into<Box<dyn Error + Send + Sync>>,
234 S::Future: Unpin + Send + 'static,
235 T: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
236{
237 type Response = MaybeHttpsStream<T>;
238 type Error = Box<dyn Error + Sync + Send>;
239 #[allow(clippy::type_complexity)]
240 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
241
242 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
243 self.http.poll_ready(cx).map_err(Into::into)
244 }
245
246 fn call(&mut self, uri: Uri) -> Self::Future {
247 let is_tls_scheme = uri
248 .scheme()
249 .map(|s| s == &Scheme::HTTPS || s.as_str() == "wss")
250 .unwrap_or(false);
251
252 let tls_setup = if is_tls_scheme {
253 Some((self.inner.clone(), uri.clone()))
254 } else {
255 None
256 };
257
258 let connect = self.http.call(uri);
259
260 let f = async {
261 let conn = connect.await.map_err(Into::into)?;
262
263 let (inner, uri) = match tls_setup {
264 Some((inner, uri)) => (inner, uri),
265 None => return Ok(MaybeHttpsStream::Http(conn)),
266 };
267
268 let mut host = uri.host().ok_or("URI missing host")?;
269
270 if !host.is_empty() {
274 let last = host.len() - 1;
275 let mut chars = host.chars();
276
277 if let (Some('['), Some(']')) = (chars.next(), chars.last()) {
278 if host[1..last].parse::<net::Ipv6Addr>().is_ok() {
279 host = &host[1..last];
280 }
281 }
282 }
283
284 let ssl = inner.setup_ssl(&uri, host)?;
285 let stream = tokio_boring::SslStreamBuilder::new(ssl, conn)
286 .connect()
287 .await?;
288
289 Ok(MaybeHttpsStream::Https(stream))
290 };
291
292 Box::pin(f)
293 }
294}
295
296impl<T> Connection for MaybeHttpsStream<T>
297where
298 T: Connection,
299{
300 fn connected(&self) -> Connected {
301 match self {
302 MaybeHttpsStream::Http(s) => s.connected(),
303 MaybeHttpsStream::Https(s) => {
304 let mut connected = s.get_ref().connected();
305
306 if s.ssl().selected_alpn_protocol() == Some(b"h2") {
307 connected = connected.negotiated_h2();
308 }
309
310 connected
311 }
312 }
313 }
314}
315
316impl<T> Read for MaybeHttpsStream<T>
317where
318 T: AsyncRead + AsyncWrite + Unpin,
319{
320 fn poll_read(
321 mut self: Pin<&mut Self>,
322 cx: &mut Context<'_>,
323 buf: ReadBufCursor<'_>,
324 ) -> Poll<Result<(), std::io::Error>> {
325 match &mut *self {
326 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_read(cx, buf),
327 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_read(cx, buf),
328 }
329 }
330}
331
332impl<T> Write for MaybeHttpsStream<T>
333where
334 T: AsyncRead + AsyncWrite + Unpin,
335{
336 fn poll_write(
337 mut self: Pin<&mut Self>,
338 ctx: &mut Context<'_>,
339 buf: &[u8],
340 ) -> Poll<io::Result<usize>> {
341 match &mut *self {
342 MaybeHttpsStream::Http(inner) => {
343 Pin::new(&mut TokioIo::new(inner)).poll_write(ctx, buf)
344 }
345 MaybeHttpsStream::Https(inner) => {
346 Pin::new(&mut TokioIo::new(inner)).poll_write(ctx, buf)
347 }
348 }
349 }
350
351 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
352 match &mut *self {
353 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_flush(ctx),
354 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_flush(ctx),
355 }
356 }
357
358 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
359 match &mut *self {
360 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_shutdown(ctx),
361 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_shutdown(ctx),
362 }
363 }
364}