1use super::TlsConnectorData;
2use crate::dep::tokio_rustls::{TlsConnector as RustlsConnector, client::TlsStream};
3use crate::types::TlsTunnel;
4use crate::{RamaInto, RamaTryFrom};
5use pin_project_lite::pin_project;
6use private::{ConnectorKindAuto, ConnectorKindSecure, ConnectorKindTunnel};
7use rama_core::error::ErrorContext;
8use rama_core::error::{BoxError, ErrorExt, OpaqueError};
9use rama_core::{Context, Layer, Service};
10use rama_net::address::Host;
11use rama_net::client::{ConnectorService, EstablishedClientConnection};
12use rama_net::stream::Stream;
13use rama_net::tls::ApplicationProtocol;
14use rama_net::tls::client::NegotiatedTlsParameters;
15use rama_net::transport::TryRefIntoTransportContext;
16use std::fmt;
17use tokio::io::{AsyncRead, AsyncWrite};
18
19pub struct TlsConnectorLayer<K = ConnectorKindAuto> {
23 connector_data: Option<TlsConnectorData>,
24 kind: K,
25}
26
27impl<K: fmt::Debug> std::fmt::Debug for TlsConnectorLayer<K> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("TlsConnectorLayer")
30 .field("connector_data", &self.connector_data)
31 .field("kind", &self.kind)
32 .finish()
33 }
34}
35
36impl<K: Clone> Clone for TlsConnectorLayer<K> {
37 fn clone(&self) -> Self {
38 Self {
39 connector_data: self.connector_data.clone(),
40 kind: self.kind.clone(),
41 }
42 }
43}
44
45impl<K> TlsConnectorLayer<K> {
46 pub fn with_connector_data(mut self, connector_data: TlsConnectorData) -> Self {
49 self.connector_data = Some(connector_data);
50 self
51 }
52
53 pub fn maybe_with_connector_data(mut self, connector_data: Option<TlsConnectorData>) -> Self {
56 self.connector_data = connector_data;
57 self
58 }
59
60 pub fn set_connector_data(&mut self, connector_data: TlsConnectorData) -> &mut Self {
63 self.connector_data = Some(connector_data);
64 self
65 }
66}
67
68impl TlsConnectorLayer<ConnectorKindAuto> {
69 pub fn auto() -> Self {
73 Self {
74 connector_data: None,
75 kind: ConnectorKindAuto,
76 }
77 }
78}
79
80impl TlsConnectorLayer<ConnectorKindSecure> {
81 pub fn secure() -> Self {
84 Self {
85 connector_data: None,
86 kind: ConnectorKindSecure,
87 }
88 }
89}
90
91impl TlsConnectorLayer<ConnectorKindTunnel> {
92 pub fn tunnel(host: Option<Host>) -> Self {
95 Self {
96 connector_data: None,
97 kind: ConnectorKindTunnel { host },
98 }
99 }
100}
101
102impl<K: Clone, S> Layer<S> for TlsConnectorLayer<K> {
103 type Service = TlsConnector<S, K>;
104
105 fn layer(&self, inner: S) -> Self::Service {
106 TlsConnector {
107 inner,
108 connector_data: self.connector_data.clone(),
109 kind: self.kind.clone(),
110 }
111 }
112
113 fn into_layer(self, inner: S) -> Self::Service {
114 TlsConnector {
115 inner,
116 connector_data: self.connector_data,
117 kind: self.kind,
118 }
119 }
120}
121
122impl Default for TlsConnectorLayer<ConnectorKindAuto> {
123 fn default() -> Self {
124 Self::auto()
125 }
126}
127
128pub struct TlsConnector<S, K = ConnectorKindAuto> {
136 inner: S,
137 connector_data: Option<TlsConnectorData>,
138 kind: K,
139}
140
141impl<S: fmt::Debug, K: fmt::Debug> fmt::Debug for TlsConnector<S, K> {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 f.debug_struct("TlsConnector")
144 .field("inner", &self.inner)
145 .field("connector_data", &self.connector_data)
146 .field("kind", &self.kind)
147 .finish()
148 }
149}
150
151impl<S: Clone, K: Clone> Clone for TlsConnector<S, K> {
152 fn clone(&self) -> Self {
153 Self {
154 inner: self.inner.clone(),
155 connector_data: self.connector_data.clone(),
156 kind: self.kind.clone(),
157 }
158 }
159}
160
161impl<S, K> TlsConnector<S, K> {
162 pub const fn new(inner: S, kind: K) -> Self {
164 Self {
165 inner,
166 connector_data: None,
167 kind,
168 }
169 }
170
171 pub fn with_connector_data(mut self, connector_data: TlsConnectorData) -> Self {
180 self.connector_data = Some(connector_data);
181 self
182 }
183
184 pub fn maybe_with_connector_data(mut self, connector_data: Option<TlsConnectorData>) -> Self {
187 self.connector_data = connector_data;
188 self
189 }
190
191 pub fn set_connector_data(&mut self, connector_data: TlsConnectorData) -> &mut Self {
194 self.connector_data = Some(connector_data);
195 self
196 }
197}
198
199impl<S> TlsConnector<S, ConnectorKindAuto> {
200 pub fn auto(inner: S) -> Self {
204 Self::new(inner, ConnectorKindAuto)
205 }
206}
207
208impl<S> TlsConnector<S, ConnectorKindSecure> {
209 pub fn secure(inner: S) -> Self {
212 Self::new(inner, ConnectorKindSecure)
213 }
214}
215
216impl<S> TlsConnector<S, ConnectorKindTunnel> {
217 pub fn tunnel(inner: S, host: Option<Host>) -> Self {
220 Self::new(inner, ConnectorKindTunnel { host })
221 }
222}
223
224impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindAuto>
227where
228 S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
229 State: Clone + Send + Sync + 'static,
230 Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
231 + Send
232 + 'static,
233{
234 type Response = EstablishedClientConnection<AutoTlsStream<S::Connection>, State, Request>;
235 type Error = BoxError;
236
237 async fn serve(
238 &self,
239 ctx: Context<State>,
240 req: Request,
241 ) -> Result<Self::Response, Self::Error> {
242 let EstablishedClientConnection { mut ctx, req, conn } =
243 self.inner.connect(ctx, req).await.map_err(Into::into)?;
244 let transport_ctx = ctx
245 .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
246 .map_err(|err| {
247 OpaqueError::from_boxed(err.into())
248 .context("TlsConnector(auto): compute transport context")
249 })?
250 .clone();
251
252 if !transport_ctx
253 .app_protocol
254 .as_ref()
255 .map(|p| p.is_secure())
256 .unwrap_or_default()
257 {
258 tracing::trace!(
259 authority = %transport_ctx.authority,
260 "TlsConnector(auto): protocol not secure, return inner connection",
261 );
262 return Ok(EstablishedClientConnection {
263 ctx,
264 req,
265 conn: AutoTlsStream {
266 inner: AutoTlsStreamData::Plain { inner: conn },
267 },
268 });
269 }
270
271 let server_host = transport_ctx.authority.host().clone();
272
273 tracing::trace!(
274 authority = %transport_ctx.authority,
275 app_protocol = ?transport_ctx.app_protocol,
276 "TlsConnector(auto): attempt to secure inner connection",
277 );
278
279 let connector_data = ctx.get::<TlsConnectorData>().cloned();
280 let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
281
282 tracing::trace!(
283 authority = %transport_ctx.authority,
284 app_protocol = ?transport_ctx.app_protocol,
285 "TlsConnector(auto): protocol secure, established tls connection",
286 );
287
288 ctx.insert(negotiated_params);
289
290 Ok(EstablishedClientConnection {
291 ctx,
292 req,
293 conn: AutoTlsStream {
294 inner: AutoTlsStreamData::Secure { inner: stream },
295 },
296 })
297 }
298}
299
300impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindSecure>
301where
302 S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
303 State: Clone + Send + Sync + 'static,
304 Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
305 + Send
306 + 'static,
307{
308 type Response = EstablishedClientConnection<TlsStream<S::Connection>, State, Request>;
309 type Error = BoxError;
310
311 async fn serve(
312 &self,
313 ctx: Context<State>,
314 req: Request,
315 ) -> Result<Self::Response, Self::Error> {
316 let EstablishedClientConnection { mut ctx, req, conn } =
317 self.inner.connect(ctx, req).await.map_err(Into::into)?;
318
319 let transport_ctx = ctx
320 .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
321 .map_err(|err| {
322 OpaqueError::from_boxed(err.into())
323 .context("TlsConnector(auto): compute transport context")
324 })?;
325 tracing::trace!(
326 authority = %transport_ctx.authority,
327 app_protocol = ?transport_ctx.app_protocol,
328 "TlsConnector(secure): attempt to secure inner connection",
329 );
330
331 let server_host = transport_ctx.authority.host().clone();
332
333 let connector_data = ctx.get::<TlsConnectorData>().cloned();
334 let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
335 ctx.insert(negotiated_params);
336
337 Ok(EstablishedClientConnection { ctx, req, conn })
338 }
339}
340
341impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindTunnel>
342where
343 S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
344 State: Clone + Send + Sync + 'static,
345 Request: Send + 'static,
346{
347 type Response = EstablishedClientConnection<AutoTlsStream<S::Connection>, State, Request>;
348 type Error = BoxError;
349
350 async fn serve(
351 &self,
352 ctx: Context<State>,
353 req: Request,
354 ) -> Result<Self::Response, Self::Error> {
355 let EstablishedClientConnection { mut ctx, req, conn } =
356 self.inner.connect(ctx, req).await.map_err(Into::into)?;
357
358 let server_host = match ctx
359 .get::<TlsTunnel>()
360 .as_ref()
361 .map(|t| &t.server_host)
362 .or(self.kind.host.as_ref())
363 {
364 Some(host) => host.clone(),
365 None => {
366 tracing::trace!(
367 "TlsConnector(tunnel): return inner connection: no Tls tunnel is requested"
368 );
369 return Ok(EstablishedClientConnection {
370 ctx,
371 req,
372 conn: AutoTlsStream {
373 inner: AutoTlsStreamData::Plain { inner: conn },
374 },
375 });
376 }
377 };
378
379 let connector_data = ctx.get::<TlsConnectorData>().cloned();
380 let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
381 ctx.insert(negotiated_params);
382
383 tracing::trace!("TlsConnector(tunnel): connection secured");
384 Ok(EstablishedClientConnection {
385 ctx,
386 req,
387 conn: AutoTlsStream {
388 inner: AutoTlsStreamData::Secure { inner: conn },
389 },
390 })
391 }
392}
393
394impl<S, K> TlsConnector<S, K> {
395 async fn handshake<T>(
396 &self,
397 connector_data: Option<TlsConnectorData>,
398 server_host: Host,
399 stream: T,
400 ) -> Result<(TlsStream<T>, NegotiatedTlsParameters), BoxError>
401 where
402 T: Stream + Unpin,
403 {
404 let connector_data = connector_data
405 .or(self.connector_data.clone())
406 .unwrap_or(TlsConnectorData::new_http_auto()?);
407
408 let server_name = rustls_pki_types::ServerName::rama_try_from(
409 connector_data.server_name.unwrap_or(server_host),
410 )?;
411
412 let connector = RustlsConnector::from(connector_data.client_config);
413
414 let stream = connector.connect(server_name, stream).await?;
415
416 let (_, conn_data_ref) = stream.get_ref();
417
418 let server_certificate_chain = if connector_data.store_server_certificate_chain {
419 conn_data_ref.peer_certificates().map(RamaInto::rama_into)
420 } else {
421 None
422 };
423
424 let params = NegotiatedTlsParameters {
425 protocol_version: conn_data_ref
426 .protocol_version()
427 .context("no protocol version available")?
428 .rama_into(),
429 application_layer_protocol: conn_data_ref
430 .alpn_protocol()
431 .map(ApplicationProtocol::from),
432 peer_certificate_chain: server_certificate_chain,
433 };
434
435 Ok((stream, params))
436 }
437}
438
439pin_project! {
440 pub struct AutoTlsStream<S> {
442 #[pin]
443 inner: AutoTlsStreamData<S>,
444 }
445}
446
447impl<S: fmt::Debug> fmt::Debug for AutoTlsStream<S> {
448 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449 f.debug_struct("AutoTlsStream")
450 .field("inner", &self.inner)
451 .finish()
452 }
453}
454
455pin_project! {
456 #[project = AutoTlsStreamDataProj]
457 enum AutoTlsStreamData<S> {
459 Secure{ #[pin] inner: TlsStream<S> },
461 Plain { #[pin] inner: S },
463 }
464}
465
466impl<S: fmt::Debug> fmt::Debug for AutoTlsStreamData<S> {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 match self {
469 AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(),
470 AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(),
471 }
472 }
473}
474
475impl<S> AsyncRead for AutoTlsStream<S>
476where
477 S: Stream + Unpin,
478{
479 fn poll_read(
480 self: std::pin::Pin<&mut Self>,
481 cx: &mut std::task::Context<'_>,
482 buf: &mut tokio::io::ReadBuf<'_>,
483 ) -> std::task::Poll<std::io::Result<()>> {
484 match self.project().inner.project() {
485 AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf),
486 AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf),
487 }
488 }
489}
490
491impl<S> AsyncWrite for AutoTlsStream<S>
492where
493 S: Stream + Unpin,
494{
495 fn poll_write(
496 self: std::pin::Pin<&mut Self>,
497 cx: &mut std::task::Context<'_>,
498 buf: &[u8],
499 ) -> std::task::Poll<Result<usize, std::io::Error>> {
500 match self.project().inner.project() {
501 AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf),
502 AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf),
503 }
504 }
505
506 fn poll_flush(
507 self: std::pin::Pin<&mut Self>,
508 cx: &mut std::task::Context<'_>,
509 ) -> std::task::Poll<Result<(), std::io::Error>> {
510 match self.project().inner.project() {
511 AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx),
512 AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx),
513 }
514 }
515
516 fn poll_shutdown(
517 self: std::pin::Pin<&mut Self>,
518 cx: &mut std::task::Context<'_>,
519 ) -> std::task::Poll<Result<(), std::io::Error>> {
520 match self.project().inner.project() {
521 AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx),
522 AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx),
523 }
524 }
525}
526
527mod private {
528 use rama_net::address::Host;
529
530 #[derive(Debug, Clone)]
531 pub struct ConnectorKindAuto;
538
539 #[derive(Debug, Clone)]
540 pub struct ConnectorKindSecure;
543
544 #[derive(Debug, Clone)]
545 pub struct ConnectorKindTunnel {
555 pub host: Option<Host>,
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn assert_send() {
565 use rama_utils::test_helpers::assert_send;
566
567 assert_send::<TlsConnectorLayer>();
568 }
569
570 #[test]
571 fn assert_sync() {
572 use rama_utils::test_helpers::assert_sync;
573
574 assert_sync::<TlsConnectorLayer>();
575 }
576}