rama_tls_rustls/client/
connector.rs

1use super::TlsConnectorData;
2use crate::dep::tokio_rustls::{TlsConnector as RustlsConnector, client::TlsStream};
3use crate::types::TlsTunnel;
4use crate::{RamaInto, RamaTryFrom};
5use pin_project_lite::pin_project;
6use private::{ConnectorKindAuto, ConnectorKindSecure, ConnectorKindTunnel};
7use rama_core::error::ErrorContext;
8use rama_core::error::{BoxError, ErrorExt, OpaqueError};
9use rama_core::{Context, Layer, Service};
10use rama_net::address::Host;
11use rama_net::client::{ConnectorService, EstablishedClientConnection};
12use rama_net::stream::Stream;
13use rama_net::tls::ApplicationProtocol;
14use rama_net::tls::client::NegotiatedTlsParameters;
15use rama_net::transport::TryRefIntoTransportContext;
16use std::fmt;
17use tokio::io::{AsyncRead, AsyncWrite};
18
19/// A [`Layer`] which wraps the given service with a [`TlsConnector`].
20///
21/// See [`TlsConnector`] for more information.
22pub struct TlsConnectorLayer<K = ConnectorKindAuto> {
23    connector_data: Option<TlsConnectorData>,
24    kind: K,
25}
26
27impl<K: fmt::Debug> std::fmt::Debug for TlsConnectorLayer<K> {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("TlsConnectorLayer")
30            .field("connector_data", &self.connector_data)
31            .field("kind", &self.kind)
32            .finish()
33    }
34}
35
36impl<K: Clone> Clone for TlsConnectorLayer<K> {
37    fn clone(&self) -> Self {
38        Self {
39            connector_data: self.connector_data.clone(),
40            kind: self.kind.clone(),
41        }
42    }
43}
44
45impl<K> TlsConnectorLayer<K> {
46    /// Attach [`TlsConnectorData`] to this [`TlsConnectorLayer`],
47    /// to be used instead of a globally shared [`TlsConnectorData::default`].
48    pub fn with_connector_data(mut self, connector_data: TlsConnectorData) -> Self {
49        self.connector_data = Some(connector_data);
50        self
51    }
52
53    /// Maybe attach [`TlsConnectorData`] to this [`TlsConnectorLayer`],
54    /// to be used if `Some` instead of a globally shared [`TlsConnectorData::default`].
55    pub fn maybe_with_connector_data(mut self, connector_data: Option<TlsConnectorData>) -> Self {
56        self.connector_data = connector_data;
57        self
58    }
59
60    /// Attach [`TlsConnectorData`] to this [`TlsConnectorLayer`],
61    /// to be used instead of a globally shared default client config.
62    pub fn set_connector_data(&mut self, connector_data: TlsConnectorData) -> &mut Self {
63        self.connector_data = Some(connector_data);
64        self
65    }
66}
67
68impl TlsConnectorLayer<ConnectorKindAuto> {
69    /// Creates a new [`TlsConnectorLayer`] which will establish
70    /// a secure connection if the request demands it,
71    /// otherwise it will forward the pre-established inner connection.
72    pub fn auto() -> Self {
73        Self {
74            connector_data: None,
75            kind: ConnectorKindAuto,
76        }
77    }
78}
79
80impl TlsConnectorLayer<ConnectorKindSecure> {
81    /// Creates a new [`TlsConnectorLayer`] which will always
82    /// establish a secure connection regardless of the request it is for.
83    pub fn secure() -> Self {
84        Self {
85            connector_data: None,
86            kind: ConnectorKindSecure,
87        }
88    }
89}
90
91impl TlsConnectorLayer<ConnectorKindTunnel> {
92    /// Creates a new [`TlsConnectorLayer`] which will establish
93    /// a secure connection if the request is to be tunneled.
94    pub fn tunnel(host: Option<Host>) -> Self {
95        Self {
96            connector_data: None,
97            kind: ConnectorKindTunnel { host },
98        }
99    }
100}
101
102impl<K: Clone, S> Layer<S> for TlsConnectorLayer<K> {
103    type Service = TlsConnector<S, K>;
104
105    fn layer(&self, inner: S) -> Self::Service {
106        TlsConnector {
107            inner,
108            connector_data: self.connector_data.clone(),
109            kind: self.kind.clone(),
110        }
111    }
112
113    fn into_layer(self, inner: S) -> Self::Service {
114        TlsConnector {
115            inner,
116            connector_data: self.connector_data,
117            kind: self.kind,
118        }
119    }
120}
121
122impl Default for TlsConnectorLayer<ConnectorKindAuto> {
123    fn default() -> Self {
124        Self::auto()
125    }
126}
127
128/// A connector which can be used to establish a connection to a server.
129///
130/// By default it will created in auto mode ([`TlsConnector::auto`]),
131/// which will perform the Tls handshake on the underlying stream,
132/// only if the request requires a secure connection. You can instead use
133/// [`TlsConnector::secure_only`] to force the connector to always
134/// establish a secure connection.
135pub struct TlsConnector<S, K = ConnectorKindAuto> {
136    inner: S,
137    connector_data: Option<TlsConnectorData>,
138    kind: K,
139}
140
141impl<S: fmt::Debug, K: fmt::Debug> fmt::Debug for TlsConnector<S, K> {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        f.debug_struct("TlsConnector")
144            .field("inner", &self.inner)
145            .field("connector_data", &self.connector_data)
146            .field("kind", &self.kind)
147            .finish()
148    }
149}
150
151impl<S: Clone, K: Clone> Clone for TlsConnector<S, K> {
152    fn clone(&self) -> Self {
153        Self {
154            inner: self.inner.clone(),
155            connector_data: self.connector_data.clone(),
156            kind: self.kind.clone(),
157        }
158    }
159}
160
161impl<S, K> TlsConnector<S, K> {
162    /// Creates a new [`TlsConnector`].
163    pub const fn new(inner: S, kind: K) -> Self {
164        Self {
165            inner,
166            connector_data: None,
167            kind,
168        }
169    }
170
171    /// Attach [`TlsConnectorData`] to this [`TlsConnector`],
172    /// to be used instead of a globally shared [`TlsConnectorData::default`].
173    ///
174    /// NOTE: for a smooth interaction with HTTP you most likely do want to
175    /// create tls connector data to at the very least define the ALPN's correctly.
176    ///
177    /// E.g. if you create an auto client, you want to make sure your ALPN can handle all.
178    /// It will be then also be the [`TlsConnector`] that sets the request http version correctly.
179    pub fn with_connector_data(mut self, connector_data: TlsConnectorData) -> Self {
180        self.connector_data = Some(connector_data);
181        self
182    }
183
184    /// Maybe attach [`TlsConnectorData`] to this [`TlsConnector`],
185    /// to be used if `Some` instead of a globally shared [`TlsConnectorData::default`].
186    pub fn maybe_with_connector_data(mut self, connector_data: Option<TlsConnectorData>) -> Self {
187        self.connector_data = connector_data;
188        self
189    }
190
191    /// Attach [`TlsConnectorData`] to this [`TlsConnector`],
192    /// to be used instead of a globally shared default client config.
193    pub fn set_connector_data(&mut self, connector_data: TlsConnectorData) -> &mut Self {
194        self.connector_data = Some(connector_data);
195        self
196    }
197}
198
199impl<S> TlsConnector<S, ConnectorKindAuto> {
200    /// Creates a new [`TlsConnector`] which will establish
201    /// a secure connection if the request demands it,
202    /// otherwise it will forward the pre-established inner connection.
203    pub fn auto(inner: S) -> Self {
204        Self::new(inner, ConnectorKindAuto)
205    }
206}
207
208impl<S> TlsConnector<S, ConnectorKindSecure> {
209    /// Creates a new [`TlsConnector`] which will always
210    /// establish a secure connection regardless of the request it is for.
211    pub fn secure(inner: S) -> Self {
212        Self::new(inner, ConnectorKindSecure)
213    }
214}
215
216impl<S> TlsConnector<S, ConnectorKindTunnel> {
217    /// Creates a new [`TlsConnector`] which will establish
218    /// a secure connection if the request is to be tunneled.
219    pub fn tunnel(inner: S, host: Option<Host>) -> Self {
220        Self::new(inner, ConnectorKindTunnel { host })
221    }
222}
223
224// this way we do not need a hacky macro... however is there a way to do this without needing to hacK?!?!
225
226impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindAuto>
227where
228    S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
229    State: Clone + Send + Sync + 'static,
230    Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
231        + Send
232        + 'static,
233{
234    type Response = EstablishedClientConnection<AutoTlsStream<S::Connection>, State, Request>;
235    type Error = BoxError;
236
237    async fn serve(
238        &self,
239        ctx: Context<State>,
240        req: Request,
241    ) -> Result<Self::Response, Self::Error> {
242        let EstablishedClientConnection { mut ctx, req, conn } =
243            self.inner.connect(ctx, req).await.map_err(Into::into)?;
244        let transport_ctx = ctx
245            .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
246            .map_err(|err| {
247                OpaqueError::from_boxed(err.into())
248                    .context("TlsConnector(auto): compute transport context")
249            })?
250            .clone();
251
252        if !transport_ctx
253            .app_protocol
254            .as_ref()
255            .map(|p| p.is_secure())
256            .unwrap_or_default()
257        {
258            tracing::trace!(
259                authority = %transport_ctx.authority,
260                "TlsConnector(auto): protocol not secure, return inner connection",
261            );
262            return Ok(EstablishedClientConnection {
263                ctx,
264                req,
265                conn: AutoTlsStream {
266                    inner: AutoTlsStreamData::Plain { inner: conn },
267                },
268            });
269        }
270
271        let server_host = transport_ctx.authority.host().clone();
272
273        tracing::trace!(
274            authority = %transport_ctx.authority,
275            app_protocol = ?transport_ctx.app_protocol,
276            "TlsConnector(auto): attempt to secure inner connection",
277        );
278
279        let connector_data = ctx.get::<TlsConnectorData>().cloned();
280        let (stream, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
281
282        tracing::trace!(
283            authority = %transport_ctx.authority,
284            app_protocol = ?transport_ctx.app_protocol,
285            "TlsConnector(auto): protocol secure, established tls connection",
286        );
287
288        ctx.insert(negotiated_params);
289
290        Ok(EstablishedClientConnection {
291            ctx,
292            req,
293            conn: AutoTlsStream {
294                inner: AutoTlsStreamData::Secure { inner: stream },
295            },
296        })
297    }
298}
299
300impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindSecure>
301where
302    S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
303    State: Clone + Send + Sync + 'static,
304    Request: TryRefIntoTransportContext<State, Error: Into<BoxError> + Send + Sync + 'static>
305        + Send
306        + 'static,
307{
308    type Response = EstablishedClientConnection<TlsStream<S::Connection>, State, Request>;
309    type Error = BoxError;
310
311    async fn serve(
312        &self,
313        ctx: Context<State>,
314        req: Request,
315    ) -> Result<Self::Response, Self::Error> {
316        let EstablishedClientConnection { mut ctx, req, conn } =
317            self.inner.connect(ctx, req).await.map_err(Into::into)?;
318
319        let transport_ctx = ctx
320            .get_or_try_insert_with_ctx(|ctx| req.try_ref_into_transport_ctx(ctx))
321            .map_err(|err| {
322                OpaqueError::from_boxed(err.into())
323                    .context("TlsConnector(auto): compute transport context")
324            })?;
325        tracing::trace!(
326            authority = %transport_ctx.authority,
327            app_protocol = ?transport_ctx.app_protocol,
328            "TlsConnector(secure): attempt to secure inner connection",
329        );
330
331        let server_host = transport_ctx.authority.host().clone();
332
333        let connector_data = ctx.get::<TlsConnectorData>().cloned();
334        let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
335        ctx.insert(negotiated_params);
336
337        Ok(EstablishedClientConnection { ctx, req, conn })
338    }
339}
340
341impl<S, State, Request> Service<State, Request> for TlsConnector<S, ConnectorKindTunnel>
342where
343    S: ConnectorService<State, Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
344    State: Clone + Send + Sync + 'static,
345    Request: Send + 'static,
346{
347    type Response = EstablishedClientConnection<AutoTlsStream<S::Connection>, State, Request>;
348    type Error = BoxError;
349
350    async fn serve(
351        &self,
352        ctx: Context<State>,
353        req: Request,
354    ) -> Result<Self::Response, Self::Error> {
355        let EstablishedClientConnection { mut ctx, req, conn } =
356            self.inner.connect(ctx, req).await.map_err(Into::into)?;
357
358        let server_host = match ctx
359            .get::<TlsTunnel>()
360            .as_ref()
361            .map(|t| &t.server_host)
362            .or(self.kind.host.as_ref())
363        {
364            Some(host) => host.clone(),
365            None => {
366                tracing::trace!(
367                    "TlsConnector(tunnel): return inner connection: no Tls tunnel is requested"
368                );
369                return Ok(EstablishedClientConnection {
370                    ctx,
371                    req,
372                    conn: AutoTlsStream {
373                        inner: AutoTlsStreamData::Plain { inner: conn },
374                    },
375                });
376            }
377        };
378
379        let connector_data = ctx.get::<TlsConnectorData>().cloned();
380        let (conn, negotiated_params) = self.handshake(connector_data, server_host, conn).await?;
381        ctx.insert(negotiated_params);
382
383        tracing::trace!("TlsConnector(tunnel): connection secured");
384        Ok(EstablishedClientConnection {
385            ctx,
386            req,
387            conn: AutoTlsStream {
388                inner: AutoTlsStreamData::Secure { inner: conn },
389            },
390        })
391    }
392}
393
394impl<S, K> TlsConnector<S, K> {
395    async fn handshake<T>(
396        &self,
397        connector_data: Option<TlsConnectorData>,
398        server_host: Host,
399        stream: T,
400    ) -> Result<(TlsStream<T>, NegotiatedTlsParameters), BoxError>
401    where
402        T: Stream + Unpin,
403    {
404        let connector_data = connector_data
405            .or(self.connector_data.clone())
406            .unwrap_or(TlsConnectorData::new_http_auto()?);
407
408        let server_name = rustls_pki_types::ServerName::rama_try_from(
409            connector_data.server_name.unwrap_or(server_host),
410        )?;
411
412        let connector = RustlsConnector::from(connector_data.client_config);
413
414        let stream = connector.connect(server_name, stream).await?;
415
416        let (_, conn_data_ref) = stream.get_ref();
417
418        let server_certificate_chain = if connector_data.store_server_certificate_chain {
419            conn_data_ref.peer_certificates().map(RamaInto::rama_into)
420        } else {
421            None
422        };
423
424        let params = NegotiatedTlsParameters {
425            protocol_version: conn_data_ref
426                .protocol_version()
427                .context("no protocol version available")?
428                .rama_into(),
429            application_layer_protocol: conn_data_ref
430                .alpn_protocol()
431                .map(ApplicationProtocol::from),
432            peer_certificate_chain: server_certificate_chain,
433        };
434
435        Ok((stream, params))
436    }
437}
438
439pin_project! {
440    /// A stream which can be either a secure or a plain stream.
441    pub struct AutoTlsStream<S> {
442        #[pin]
443        inner: AutoTlsStreamData<S>,
444    }
445}
446
447impl<S: fmt::Debug> fmt::Debug for AutoTlsStream<S> {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        f.debug_struct("AutoTlsStream")
450            .field("inner", &self.inner)
451            .finish()
452    }
453}
454
455pin_project! {
456    #[project = AutoTlsStreamDataProj]
457    /// A stream which can be either a secure or a plain stream.
458    enum AutoTlsStreamData<S> {
459        /// A secure stream.
460        Secure{ #[pin] inner: TlsStream<S> },
461        /// A plain stream.
462        Plain { #[pin] inner: S },
463    }
464}
465
466impl<S: fmt::Debug> fmt::Debug for AutoTlsStreamData<S> {
467    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468        match self {
469            AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(),
470            AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(),
471        }
472    }
473}
474
475impl<S> AsyncRead for AutoTlsStream<S>
476where
477    S: Stream + Unpin,
478{
479    fn poll_read(
480        self: std::pin::Pin<&mut Self>,
481        cx: &mut std::task::Context<'_>,
482        buf: &mut tokio::io::ReadBuf<'_>,
483    ) -> std::task::Poll<std::io::Result<()>> {
484        match self.project().inner.project() {
485            AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf),
486            AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf),
487        }
488    }
489}
490
491impl<S> AsyncWrite for AutoTlsStream<S>
492where
493    S: Stream + Unpin,
494{
495    fn poll_write(
496        self: std::pin::Pin<&mut Self>,
497        cx: &mut std::task::Context<'_>,
498        buf: &[u8],
499    ) -> std::task::Poll<Result<usize, std::io::Error>> {
500        match self.project().inner.project() {
501            AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf),
502            AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf),
503        }
504    }
505
506    fn poll_flush(
507        self: std::pin::Pin<&mut Self>,
508        cx: &mut std::task::Context<'_>,
509    ) -> std::task::Poll<Result<(), std::io::Error>> {
510        match self.project().inner.project() {
511            AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx),
512            AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx),
513        }
514    }
515
516    fn poll_shutdown(
517        self: std::pin::Pin<&mut Self>,
518        cx: &mut std::task::Context<'_>,
519    ) -> std::task::Poll<Result<(), std::io::Error>> {
520        match self.project().inner.project() {
521            AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx),
522            AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx),
523        }
524    }
525}
526
527mod private {
528    use rama_net::address::Host;
529
530    #[derive(Debug, Clone)]
531    /// A connector which can be used to establish a connection to a server
532    /// in function of the Request, meaning either it will be a seucre
533    /// connector or it will be a plain connector.
534    ///
535    /// This connector can be handy as it allows to have a single layer
536    /// which will work both for plain and secure connections.
537    pub struct ConnectorKindAuto;
538
539    #[derive(Debug, Clone)]
540    /// A connector which can _only_ be used to establish a secure connection,
541    /// regardless of the scheme of the request URI.
542    pub struct ConnectorKindSecure;
543
544    #[derive(Debug, Clone)]
545    /// A connector which can be used to use this connector to support
546    /// secure tls tunnel connections.
547    ///
548    /// The connections will only be done if the [`TlsTunnel`]
549    /// is present in the context for optional versions,
550    /// and using the hardcoded host otherwise.
551    /// Context always overwrites though.
552    ///
553    /// [`TlsTunnel`]: crate::TlsTunnel
554    pub struct ConnectorKindTunnel {
555        pub host: Option<Host>,
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn assert_send() {
565        use rama_utils::test_helpers::assert_send;
566
567        assert_send::<TlsConnectorLayer>();
568    }
569
570    #[test]
571    fn assert_sync() {
572        use rama_utils::test_helpers::assert_sync;
573
574        assert_sync::<TlsConnectorLayer>();
575    }
576}