trillium_rustls/
client.rs

1use crate::crypto_provider;
2use futures_rustls::{
3    client::TlsStream,
4    rustls::{
5        client::danger::ServerCertVerifier, crypto::CryptoProvider, pki_types::ServerName,
6        ClientConfig, ClientConnection,
7    },
8    TlsConnector,
9};
10use std::{
11    fmt::{self, Debug, Formatter},
12    future::Future,
13    io::{Error, ErrorKind, IoSlice, Result},
14    net::SocketAddr,
15    pin::Pin,
16    sync::Arc,
17    task::{Context, Poll},
18};
19use trillium_server_common::{async_trait, AsyncRead, AsyncWrite, Connector, Transport, Url};
20use RustlsClientTransportInner::{Tcp, Tls};
21
22#[derive(Clone, Debug)]
23pub struct RustlsClientConfig(Arc<ClientConfig>);
24
25/**
26Client configuration for RustlsConnector
27*/
28#[derive(Clone, Default)]
29pub struct RustlsConfig<Config> {
30    /// configuration for rustls itself
31    pub rustls_config: RustlsClientConfig,
32
33    /// configuration for the inner transport
34    pub tcp_config: Config,
35}
36
37impl<C: Connector> RustlsConfig<C> {
38    /// build a new default rustls config with this tcp config
39    pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
40        Self {
41            rustls_config: rustls_config.into(),
42            tcp_config,
43        }
44    }
45}
46
47impl Default for RustlsClientConfig {
48    fn default() -> Self {
49        Self(Arc::new(default_client_config()))
50    }
51}
52
53#[cfg(feature = "platform-verifier")]
54fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
55    Arc::new(rustls_platform_verifier::Verifier::new().with_provider(provider))
56}
57
58#[cfg(not(feature = "platform-verifier"))]
59fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
60    let roots = Arc::new(futures_rustls::rustls::RootCertStore::from_iter(
61        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
62    ));
63    futures_rustls::rustls::client::WebPkiServerVerifier::builder_with_provider(roots, provider)
64        .build()
65        .unwrap()
66}
67
68fn default_client_config() -> ClientConfig {
69    let provider = crypto_provider();
70    let verifier = verifier(Arc::clone(&provider));
71
72    ClientConfig::builder_with_provider(provider)
73        .with_safe_default_protocol_versions()
74        .expect("crypto provider did not support safe default protocol versions")
75        .dangerous()
76        .with_custom_certificate_verifier(verifier)
77        .with_no_client_auth()
78}
79
80impl From<ClientConfig> for RustlsClientConfig {
81    fn from(rustls_config: ClientConfig) -> Self {
82        Self(Arc::new(rustls_config))
83    }
84}
85
86impl From<Arc<ClientConfig>> for RustlsClientConfig {
87    fn from(rustls_config: Arc<ClientConfig>) -> Self {
88        Self(rustls_config)
89    }
90}
91
92impl<C: Connector> RustlsConfig<C> {
93    /// replace the tcp config
94    pub fn with_tcp_config(mut self, config: C) -> Self {
95        self.tcp_config = config;
96        self
97    }
98}
99
100impl<Config: Debug> Debug for RustlsConfig<Config> {
101    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102        f.debug_struct("RustlsConfig")
103            .field("rustls_config", &"..")
104            .field("tcp_config", &self.tcp_config)
105            .finish()
106    }
107}
108
109#[async_trait]
110impl<C: Connector> Connector for RustlsConfig<C> {
111    type Transport = RustlsClientTransport<C::Transport>;
112
113    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
114        match url.scheme() {
115            "https" => {
116                let mut http = url.clone();
117                http.set_scheme("http").ok();
118                http.set_port(url.port_or_known_default()).ok();
119
120                let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
121                let domain = url
122                    .domain()
123                    .and_then(|dns_name| ServerName::try_from(dns_name.to_string()).ok())
124                    .ok_or_else(|| Error::new(ErrorKind::Other, "missing domain"))?;
125
126                connector
127                    .connect(domain, self.tcp_config.connect(&http).await?)
128                    .await
129                    .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
130                    .map(Into::into)
131            }
132
133            "http" => self.tcp_config.connect(url).await.map(Into::into),
134
135            unknown => Err(Error::new(
136                ErrorKind::InvalidInput,
137                format!("unknown scheme {unknown}"),
138            )),
139        }
140    }
141
142    fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
143        self.tcp_config.spawn(fut)
144    }
145}
146
147#[derive(Debug)]
148enum RustlsClientTransportInner<T> {
149    Tcp(T),
150    Tls(Box<TlsStream<T>>),
151}
152
153/**
154Transport for the rustls connector
155
156This may represent either an encrypted tls connection or a plaintext
157connection, depending on the request schema
158*/
159#[derive(Debug)]
160pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
161impl<T> From<T> for RustlsClientTransport<T> {
162    fn from(value: T) -> Self {
163        Self(Tcp(value))
164    }
165}
166
167impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
168    fn from(value: TlsStream<T>) -> Self {
169        Self(Tls(Box::new(value)))
170    }
171}
172
173impl<C> AsyncRead for RustlsClientTransport<C>
174where
175    C: AsyncWrite + AsyncRead + Unpin,
176{
177    fn poll_read(
178        mut self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &mut [u8],
181    ) -> Poll<Result<usize>> {
182        match &mut self.0 {
183            Tcp(c) => Pin::new(c).poll_read(cx, buf),
184            Tls(c) => Pin::new(c).poll_read(cx, buf),
185        }
186    }
187}
188
189impl<C> AsyncWrite for RustlsClientTransport<C>
190where
191    C: AsyncRead + AsyncWrite + Unpin,
192{
193    fn poll_write(
194        mut self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        buf: &[u8],
197    ) -> Poll<Result<usize>> {
198        match &mut self.0 {
199            Tcp(c) => Pin::new(c).poll_write(cx, buf),
200            Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
201        }
202    }
203
204    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
205        match &mut self.0 {
206            Tcp(c) => Pin::new(c).poll_flush(cx),
207            Tls(c) => Pin::new(&mut *c).poll_flush(cx),
208        }
209    }
210
211    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
212        match &mut self.0 {
213            Tcp(c) => Pin::new(c).poll_close(cx),
214            Tls(c) => Pin::new(&mut *c).poll_close(cx),
215        }
216    }
217
218    fn poll_write_vectored(
219        mut self: Pin<&mut Self>,
220        cx: &mut Context<'_>,
221        bufs: &[IoSlice<'_>],
222    ) -> Poll<Result<usize>> {
223        match &mut self.0 {
224            Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
225            Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
226        }
227    }
228}
229
230impl<T: Transport> Transport for RustlsClientTransport<T> {
231    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232        self.as_ref().peer_addr()
233    }
234}
235
236impl<T> AsRef<T> for RustlsClientTransport<T> {
237    fn as_ref(&self) -> &T {
238        match &self.0 {
239            Tcp(x) => x,
240            Tls(x) => x.get_ref().0,
241        }
242    }
243}
244
245impl<T> RustlsClientTransport<T> {
246    /// Retrieve the tls [`CommonState`] if this transport is Tls
247    pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
248        match &mut self.0 {
249            Tls(x) => Some(x.get_mut().1),
250            _ => None,
251        }
252    }
253
254    /// Retrieve the tls [`CommonState`] if this transport is Tls
255    pub fn tls_state(&self) -> Option<&ClientConnection> {
256        match &self.0 {
257            Tls(x) => Some(x.get_ref().1),
258            _ => None,
259        }
260    }
261}