1use crate::cache::{SessionCache, SessionKey};
2use crate::{key_index, HttpsLayerSettings, MaybeHttpsStream};
3use antidote::Mutex;
4use boring::error::ErrorStack;
5use boring::ssl::{
6 ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslRef,
7 SslSessionCacheMode,
8};
9use http::uri::Scheme;
10use http::Uri;
11use hyper::rt::{Read, ReadBufCursor, Write};
12use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector};
13use hyper_util::rt::TokioIo;
14use std::error::Error;
15use std::fmt;
16use std::future::Future;
17use std::pin::Pin;
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::{io, net};
21use tokio::io::{AsyncRead, AsyncWrite};
22use tower_layer::Layer;
23use tower_service::Service;
24
25#[derive(Clone)]
27pub struct HttpsConnector<T> {
28 http: T,
29 inner: Inner,
30}
31
32impl HttpsConnector<HttpConnector> {
33 pub fn new() -> Result<HttpsConnector<HttpConnector>, ErrorStack> {
38 let mut http = HttpConnector::new();
39 http.enforce_http(false);
40
41 HttpsLayer::new().map(|l| l.layer(http))
42 }
43}
44
45impl<S, T> HttpsConnector<S>
46where
47 S: Service<Uri, Response = TokioIo<T>> + Send,
48 S::Error: Into<Box<dyn Error + Send + Sync>>,
49 S::Future: Unpin + Send + 'static,
50 T: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
51{
52 pub fn with_connector(
56 http: S,
57 ssl: SslConnectorBuilder,
58 ) -> Result<HttpsConnector<S>, ErrorStack> {
59 HttpsLayer::with_connector(ssl).map(|l| l.layer(http))
60 }
61
62 pub fn set_callback<F>(&mut self, callback: F)
68 where
69 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
70 {
71 self.inner.callback = Some(Arc::new(callback));
72 }
73
74 pub fn set_ssl_callback<F>(&mut self, callback: F)
76 where
77 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
78 {
79 self.inner.ssl_callback = Some(Arc::new(callback));
80 }
81}
82
83pub struct HttpsLayer {
85 inner: Inner,
86}
87
88#[derive(Clone)]
89struct Inner {
90 ssl: SslConnector,
91 cache: Arc<Mutex<SessionCache>>,
92 callback: Option<Callback>,
93 ssl_callback: Option<SslCallback>,
94}
95
96type Callback =
97 Arc<dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
98type SslCallback = Arc<dyn Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + Sync + Send>;
99
100impl HttpsLayer {
101 pub fn new() -> Result<HttpsLayer, ErrorStack> {
105 let mut ssl = SslConnector::builder(SslMethod::tls())?;
106
107 ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
108
109 Self::with_connector(ssl)
110 }
111
112 pub fn with_connector(ssl: SslConnectorBuilder) -> Result<HttpsLayer, ErrorStack> {
116 Self::with_connector_and_settings(ssl, HttpsLayerSettings::default())
117 }
118
119 pub fn with_connector_and_settings(
121 mut ssl: SslConnectorBuilder,
122 settings: HttpsLayerSettings,
123 ) -> Result<HttpsLayer, ErrorStack> {
124 let cache = Arc::new(Mutex::new(SessionCache::with_capacity(
125 settings.session_cache_capacity,
126 )));
127
128 ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
129
130 ssl.set_new_session_callback({
131 let cache = cache.clone();
132 move |ssl, session| {
133 if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
134 cache.lock().insert(key.clone(), session);
135 }
136 }
137 });
138
139 Ok(HttpsLayer {
140 inner: Inner {
141 ssl: ssl.build(),
142 cache,
143 callback: None,
144 ssl_callback: None,
145 },
146 })
147 }
148
149 pub fn set_callback<F>(&mut self, callback: F)
155 where
156 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
157 {
158 self.inner.callback = Some(Arc::new(callback));
159 }
160
161 pub fn set_ssl_callback<F>(&mut self, callback: F)
163 where
164 F: Fn(&mut SslRef, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
165 {
166 self.inner.ssl_callback = Some(Arc::new(callback));
167 }
168}
169
170impl<S> Layer<S> for HttpsLayer {
171 type Service = HttpsConnector<S>;
172
173 fn layer(&self, inner: S) -> HttpsConnector<S> {
174 HttpsConnector {
175 http: inner,
176 inner: self.inner.clone(),
177 }
178 }
179}
180
181impl Inner {
182 fn setup_ssl(&self, uri: &Uri, host: &str) -> Result<Ssl, ErrorStack> {
183 let mut conf = self.ssl.configure()?;
184
185 if let Some(ref callback) = self.callback {
186 callback(&mut conf, uri)?;
187 }
188
189 let key = SessionKey {
190 host: host.to_string(),
191 port: uri.port_u16().unwrap_or(443),
192 };
193
194 if let Some(session) = self.cache.lock().get(&key) {
195 unsafe {
196 conf.set_session(&session)?;
197 }
198 }
199
200 let idx = key_index()?;
201 conf.set_ex_data(idx, key);
202
203 let mut ssl = conf.into_ssl(host)?;
204
205 if let Some(ref ssl_callback) = self.ssl_callback {
206 ssl_callback(&mut ssl, uri)?;
207 }
208
209 Ok(ssl)
210 }
211}
212
213impl<T, S> Service<Uri> for HttpsConnector<S>
214where
215 S: Service<Uri, Response = TokioIo<T>> + Send,
216 S::Error: Into<Box<dyn Error + Send + Sync>>,
217 S::Future: Unpin + Send + 'static,
218 T: AsyncRead + AsyncWrite + Connection + Unpin + fmt::Debug + Sync + Send + 'static,
219{
220 type Response = MaybeHttpsStream<T>;
221 type Error = Box<dyn Error + Sync + Send>;
222 #[allow(clippy::type_complexity)]
223 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
224
225 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 self.http.poll_ready(cx).map_err(Into::into)
227 }
228
229 fn call(&mut self, uri: Uri) -> Self::Future {
230 let is_tls_scheme = uri
231 .scheme()
232 .map(|s| s == &Scheme::HTTPS || s.as_str() == "wss")
233 .unwrap_or(false);
234
235 let tls_setup = if is_tls_scheme {
236 Some((self.inner.clone(), uri.clone()))
237 } else {
238 None
239 };
240
241 let connect = self.http.call(uri);
242
243 let f = async {
244 let conn = connect.await.map_err(Into::into)?.into_inner();
245
246 let Some((inner, uri)) = tls_setup else {
247 return Ok(MaybeHttpsStream::Http(conn));
248 };
249
250 let mut host = uri.host().ok_or("URI missing host")?;
251
252 if let Some(ipv6) = host
256 .strip_prefix('[')
257 .and_then(|h| h.strip_suffix(']'))
258 .filter(|h| h.parse::<net::Ipv6Addr>().is_ok())
259 {
260 host = ipv6;
261 }
262
263 let ssl = inner.setup_ssl(&uri, host)?;
264 let stream = tokio_boring::SslStreamBuilder::new(ssl, conn)
265 .connect()
266 .await?;
267
268 Ok(MaybeHttpsStream::Https(stream))
269 };
270
271 Box::pin(f)
272 }
273}
274
275impl<T> Connection for MaybeHttpsStream<T>
276where
277 T: Connection,
278{
279 fn connected(&self) -> Connected {
280 match self {
281 MaybeHttpsStream::Http(s) => s.connected(),
282 MaybeHttpsStream::Https(s) => {
283 let mut connected = s.get_ref().connected();
284
285 if s.ssl().selected_alpn_protocol() == Some(b"h2") {
286 connected = connected.negotiated_h2();
287 }
288
289 connected
290 }
291 }
292 }
293}
294
295impl<T> Read for MaybeHttpsStream<T>
296where
297 T: AsyncRead + AsyncWrite + Unpin,
298{
299 fn poll_read(
300 mut self: Pin<&mut Self>,
301 cx: &mut Context<'_>,
302 buf: ReadBufCursor<'_>,
303 ) -> Poll<Result<(), std::io::Error>> {
304 match &mut *self {
305 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_read(cx, buf),
306 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_read(cx, buf),
307 }
308 }
309}
310
311impl<T> Write for MaybeHttpsStream<T>
312where
313 T: AsyncRead + AsyncWrite + Unpin,
314{
315 fn poll_write(
316 mut self: Pin<&mut Self>,
317 ctx: &mut Context<'_>,
318 buf: &[u8],
319 ) -> Poll<io::Result<usize>> {
320 match &mut *self {
321 MaybeHttpsStream::Http(inner) => {
322 Pin::new(&mut TokioIo::new(inner)).poll_write(ctx, buf)
323 }
324 MaybeHttpsStream::Https(inner) => {
325 Pin::new(&mut TokioIo::new(inner)).poll_write(ctx, buf)
326 }
327 }
328 }
329
330 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
331 match &mut *self {
332 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_flush(ctx),
333 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_flush(ctx),
334 }
335 }
336
337 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
338 match &mut *self {
339 MaybeHttpsStream::Http(inner) => Pin::new(&mut TokioIo::new(inner)).poll_shutdown(ctx),
340 MaybeHttpsStream::Https(inner) => Pin::new(&mut TokioIo::new(inner)).poll_shutdown(ctx),
341 }
342 }
343}