Skip to main content

trillium_rustls/
client.rs

1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3#[cfg(feature = "dangerous")]
4use futures_rustls::rustls::{
5    DigitallySignedStruct, SignatureScheme,
6    client::danger::{HandshakeSignatureValid, ServerCertVerified},
7    crypto::{verify_tls12_signature, verify_tls13_signature},
8    pki_types::{CertificateDer, UnixTime},
9};
10use futures_rustls::{
11    TlsConnector,
12    client::TlsStream,
13    rustls::{
14        ClientConfig, ClientConnection, RootCertStore,
15        client::{WebPkiServerVerifier, danger::ServerCertVerifier},
16        crypto::CryptoProvider,
17        pki_types::ServerName,
18    },
19};
20use std::{
21    fmt::{self, Debug, Formatter},
22    io::{Error, ErrorKind, IoSlice, Result},
23    net::SocketAddr,
24    pin::Pin,
25    sync::Arc,
26    task::{Context, Poll},
27};
28use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
29
30/// Rustls [`ClientConfig`] wrapper used by [`RustlsConfig`].
31///
32/// [`RustlsClientConfig::default`] trusts the platform or webpki roots (depending on the
33/// `platform-verifier` feature). Use [`RustlsClientConfig::from_root_cert_pem`] to trust a specific
34/// private or self-signed certificate instead, or convert an existing [`ClientConfig`] via
35/// [`From`].
36#[derive(Clone, Debug)]
37pub struct RustlsClientConfig(Arc<ClientConfig>);
38
39/// Client configuration for RustlsConnector
40#[derive(Clone, Default)]
41pub struct RustlsConfig<Config> {
42    /// configuration for rustls itself
43    pub rustls_config: RustlsClientConfig,
44
45    /// configuration for the inner transport
46    pub tcp_config: Config,
47}
48
49impl<C: Connector> RustlsConfig<C> {
50    /// build a new default rustls config with this tcp config
51    pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
52        Self {
53            rustls_config: rustls_config.into(),
54            tcp_config,
55        }
56    }
57}
58
59impl Default for RustlsClientConfig {
60    fn default() -> Self {
61        Self(Arc::new(default_client_config()))
62    }
63}
64
65#[cfg(feature = "platform-verifier")]
66fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
67    Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
68}
69
70#[cfg(not(feature = "platform-verifier"))]
71fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
72    let roots = Arc::new(RootCertStore::from_iter(
73        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
74    ));
75    WebPkiServerVerifier::builder_with_provider(roots, provider)
76        .build()
77        .unwrap()
78}
79
80fn client_config_with_verifier(verifier: Arc<dyn ServerCertVerifier>) -> ClientConfig {
81    let mut config = ClientConfig::builder_with_provider(crypto_provider())
82        .with_safe_default_protocol_versions()
83        .expect("crypto provider did not support safe default protocol versions")
84        .dangerous()
85        .with_custom_certificate_verifier(verifier)
86        .with_no_client_auth();
87
88    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
89
90    config
91}
92
93fn default_client_config() -> ClientConfig {
94    client_config_with_verifier(verifier(crypto_provider()))
95}
96
97impl RustlsClientConfig {
98    /// Build a client configuration that trusts exactly the certificate(s) in `pem`.
99    ///
100    /// Unlike [`RustlsClientConfig::default`], this consults neither the platform trust store nor
101    /// the webpki root bundle — the provided roots are the only trust anchors. Server
102    /// authentication is otherwise unchanged: certificate chains, signatures, expiry, and server
103    /// name are all still verified against these roots. This is the right tool for talking to a
104    /// service that presents a private or self-signed certificate.
105    ///
106    /// The crate's configured crypto provider and default ALPN protocol list (`h2`, `http/1.1`)
107    /// are reused.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if `pem` contains no certificates or cannot be parsed, or if the resulting
112    /// trust anchors are rejected by the verifier builder.
113    pub fn from_root_cert_pem(pem: &[u8]) -> Result<Self> {
114        let mut roots = RootCertStore::empty();
115        let mut reader = pem;
116        for cert in rustls_pemfile::certs(&mut reader) {
117            roots.add(cert?).map_err(Error::other)?;
118        }
119
120        if roots.is_empty() {
121            return Err(Error::new(
122                ErrorKind::InvalidInput,
123                "no certificates found in pem",
124            ));
125        }
126
127        let verifier =
128            WebPkiServerVerifier::builder_with_provider(Arc::new(roots), crypto_provider())
129                .build()
130                .map_err(Error::other)?;
131
132        Ok(Self(Arc::new(client_config_with_verifier(verifier))))
133    }
134}
135
136impl From<ClientConfig> for RustlsClientConfig {
137    fn from(rustls_config: ClientConfig) -> Self {
138        Self(Arc::new(rustls_config))
139    }
140}
141
142impl From<Arc<ClientConfig>> for RustlsClientConfig {
143    fn from(rustls_config: Arc<ClientConfig>) -> Self {
144        Self(rustls_config)
145    }
146}
147
148#[cfg(feature = "dangerous")]
149#[derive(Debug)]
150struct AcceptAnyServerCert(Arc<CryptoProvider>);
151
152#[cfg(feature = "dangerous")]
153impl ServerCertVerifier for AcceptAnyServerCert {
154    fn verify_server_cert(
155        &self,
156        _end_entity: &CertificateDer<'_>,
157        _intermediates: &[CertificateDer<'_>],
158        _server_name: &ServerName<'_>,
159        _ocsp_response: &[u8],
160        _now: UnixTime,
161    ) -> std::result::Result<ServerCertVerified, futures_rustls::rustls::Error> {
162        Ok(ServerCertVerified::assertion())
163    }
164
165    fn verify_tls12_signature(
166        &self,
167        message: &[u8],
168        cert: &CertificateDer<'_>,
169        dss: &DigitallySignedStruct,
170    ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
171        verify_tls12_signature(
172            message,
173            cert,
174            dss,
175            &self.0.signature_verification_algorithms,
176        )
177    }
178
179    fn verify_tls13_signature(
180        &self,
181        message: &[u8],
182        cert: &CertificateDer<'_>,
183        dss: &DigitallySignedStruct,
184    ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
185        verify_tls13_signature(
186            message,
187            cert,
188            dss,
189            &self.0.signature_verification_algorithms,
190        )
191    }
192
193    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
194        self.0.signature_verification_algorithms.supported_schemes()
195    }
196}
197
198#[cfg(feature = "dangerous")]
199#[cfg_attr(docsrs, doc(cfg(feature = "dangerous")))]
200impl RustlsClientConfig {
201    /// Build a client configuration that accepts **any** server certificate without verification.
202    ///
203    /// ⚠️ This disables server authentication entirely: handshake signatures are still checked,
204    /// but the certificate is never validated against any trust anchor, so the connection is
205    /// vulnerable to man-in-the-middle attacks. It exists for development against throwaway
206    /// self-signed certificates and for `--insecure`-style CLI flags. For talking to a service
207    /// with a known private certificate, prefer [`RustlsClientConfig::from_root_cert_pem`], which
208    /// keeps authentication intact.
209    ///
210    /// This constructor is only available with the `dangerous` crate feature enabled, and logs a
211    /// warning when called.
212    pub fn dangerously_accept_any_cert() -> Self {
213        log::warn!(
214            "constructing a rustls client config that accepts any server certificate; server \
215             authentication is disabled and connections are vulnerable to interception"
216        );
217        let verifier = Arc::new(AcceptAnyServerCert(crypto_provider()));
218        Self(Arc::new(client_config_with_verifier(verifier)))
219    }
220}
221
222impl<C: Connector> RustlsConfig<C> {
223    /// replace the tcp config
224    pub fn with_tcp_config(mut self, config: C) -> Self {
225        self.tcp_config = config;
226        self
227    }
228
229    /// Drop `h2` from the ALPN protocol list, forcing HTTP/1.1 over TLS.
230    ///
231    /// `RustlsConfig::default()` advertises `[h2, http/1.1]` so HTTP/2 is the preferred
232    /// protocol when the server supports it. Call this to opt out and pin the connection to
233    /// HTTP/1.1.
234    #[must_use]
235    pub fn without_http2(mut self) -> Self {
236        let config = Arc::make_mut(&mut self.rustls_config.0);
237        config.alpn_protocols.retain(|p| p != b"h2");
238        self
239    }
240}
241
242impl<Config: Debug> Debug for RustlsConfig<Config> {
243    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244        f.debug_struct("RustlsConfig")
245            .field("rustls_config", &format_args!(".."))
246            .field("tcp_config", &self.tcp_config)
247            .finish()
248    }
249}
250
251impl<C: Connector> Connector for RustlsConfig<C> {
252    type Runtime = C::Runtime;
253    type Transport = RustlsClientTransport<C::Transport>;
254    type Udp = C::Udp;
255
256    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
257        self.connect_to(Destination::from_url(url)?).await
258    }
259
260    async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
261        if !destination.secure() {
262            return self
263                .tcp_config
264                .connect_to(destination)
265                .await
266                .map(Into::into);
267        }
268
269        // A per-connection ALPN override replaces the config's default; absent one, the shared
270        // config is used as-is. Only clone the (otherwise shared) config when an override is
271        // present.
272        let rustls_config = if let Some(alpn) = destination.alpn() {
273            let mut config = (*self.rustls_config.0).clone();
274            config.alpn_protocols = alpn.iter().map(|p| p.to_vec()).collect();
275            Arc::new(config)
276        } else {
277            Arc::clone(&self.rustls_config.0)
278        };
279        let connector: TlsConnector = rustls_config.into();
280
281        // A domain destination's certificate identity (a DNS `ServerName`, sent via SNI) is fixed
282        // before the dial, so pre-resolved addresses can't influence validation. A host-less
283        // (bare-IP) destination has no SNI and validates against the address actually connected to,
284        // so its `IpAddress` server name is derived from the dialed stream below.
285        let domain_server_name = destination
286            .host()
287            .map(|domain| {
288                ServerName::try_from(domain.to_owned())
289                    .map_err(|e| Error::other(format!("invalid server name {domain:?}: {e}")))
290            })
291            .transpose()?;
292
293        let stream = self
294            .tcp_config
295            .connect_to(destination.with_secure(false))
296            .await?;
297
298        let server_name = match domain_server_name {
299            Some(server_name) => server_name,
300            None => {
301                let ip = stream
302                    .peer_addr()?
303                    .ok_or_else(|| Error::other("no peer address for bare-ip destination"))?
304                    .ip();
305                ServerName::IpAddress(ip.into())
306            }
307        };
308
309        connector
310            .connect(server_name, stream)
311            .await
312            .map_err(|e| Error::other(e.to_string()))
313            .map(Into::into)
314    }
315
316    fn runtime(&self) -> Self::Runtime {
317        self.tcp_config.runtime()
318    }
319
320    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
321        self.tcp_config.resolve(host, port).await
322    }
323}
324
325#[derive(Debug)]
326enum RustlsClientTransportInner<T> {
327    Tcp(T),
328    Tls(Box<TlsStream<T>>),
329}
330
331/// Transport for the rustls connector
332///
333/// This may represent either an encrypted tls connection or a plaintext
334/// connection, depending on the request schema
335#[derive(Debug)]
336pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
337impl<T> From<T> for RustlsClientTransport<T> {
338    fn from(value: T) -> Self {
339        Self(Tcp(value))
340    }
341}
342
343impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
344    fn from(value: TlsStream<T>) -> Self {
345        Self(Tls(Box::new(value)))
346    }
347}
348
349impl<C> AsyncRead for RustlsClientTransport<C>
350where
351    C: AsyncWrite + AsyncRead + Unpin,
352{
353    fn poll_read(
354        mut self: Pin<&mut Self>,
355        cx: &mut Context<'_>,
356        buf: &mut [u8],
357    ) -> Poll<Result<usize>> {
358        match &mut self.0 {
359            Tcp(c) => Pin::new(c).poll_read(cx, buf),
360            Tls(c) => Pin::new(c).poll_read(cx, buf),
361        }
362    }
363
364    fn poll_read_vectored(
365        mut self: Pin<&mut Self>,
366        cx: &mut Context<'_>,
367        bufs: &mut [std::io::IoSliceMut<'_>],
368    ) -> Poll<Result<usize>> {
369        match &mut self.0 {
370            Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
371            Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
372        }
373    }
374}
375
376impl<C> AsyncWrite for RustlsClientTransport<C>
377where
378    C: AsyncRead + AsyncWrite + Unpin,
379{
380    fn poll_write(
381        mut self: Pin<&mut Self>,
382        cx: &mut Context<'_>,
383        buf: &[u8],
384    ) -> Poll<Result<usize>> {
385        match &mut self.0 {
386            Tcp(c) => Pin::new(c).poll_write(cx, buf),
387            Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
388        }
389    }
390
391    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
392        match &mut self.0 {
393            Tcp(c) => Pin::new(c).poll_flush(cx),
394            Tls(c) => Pin::new(&mut *c).poll_flush(cx),
395        }
396    }
397
398    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
399        match &mut self.0 {
400            Tcp(c) => Pin::new(c).poll_close(cx),
401            Tls(c) => Pin::new(&mut *c).poll_close(cx),
402        }
403    }
404
405    fn poll_write_vectored(
406        mut self: Pin<&mut Self>,
407        cx: &mut Context<'_>,
408        bufs: &[IoSlice<'_>],
409    ) -> Poll<Result<usize>> {
410        match &mut self.0 {
411            Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
412            Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
413        }
414    }
415}
416
417impl<T: Transport> Transport for RustlsClientTransport<T> {
418    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
419        self.as_ref().peer_addr()
420    }
421
422    fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
423        self.tls_state()
424            .and_then(|conn| conn.alpn_protocol())
425            .map(std::borrow::Cow::Borrowed)
426    }
427}
428
429impl<T> AsRef<T> for RustlsClientTransport<T> {
430    fn as_ref(&self) -> &T {
431        match &self.0 {
432            Tcp(x) => x,
433            Tls(x) => x.get_ref().0,
434        }
435    }
436}
437
438impl<T> RustlsClientTransport<T> {
439    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
440    pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
441        match &mut self.0 {
442            Tls(x) => Some(x.get_mut().1),
443            _ => None,
444        }
445    }
446
447    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
448    pub fn tls_state(&self) -> Option<&ClientConnection> {
449        match &self.0 {
450            Tls(x) => Some(x.get_ref().1),
451            _ => None,
452        }
453    }
454}