compio_quic/
builder.rs

1use std::{io, sync::Arc};
2
3use compio_net::ToSocketAddrsAsync;
4use quinn_proto::{
5    ClientConfig, ServerConfig,
6    crypto::rustls::{QuicClientConfig, QuicServerConfig},
7};
8
9use crate::Endpoint;
10
11/// Helper to construct an [`Endpoint`] for use with outgoing connections only.
12///
13/// To get one, call `new_with_xxx` methods.
14///
15/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html
16#[derive(Debug)]
17pub struct ClientBuilder<T>(T);
18
19impl ClientBuilder<rustls::RootCertStore> {
20    /// Create a builder with an empty [`rustls::RootCertStore`].
21    pub fn new_with_empty_roots() -> Self {
22        ClientBuilder(rustls::RootCertStore::empty())
23    }
24
25    /// Create a builder with [`rustls_native_certs`].
26    #[cfg(feature = "native-certs")]
27    pub fn new_with_native_certs() -> io::Result<Self> {
28        let mut roots = rustls::RootCertStore::empty();
29        let mut certs = rustls_native_certs::load_native_certs();
30        if certs.certs.is_empty() {
31            return Err(io::Error::other(
32                certs
33                    .errors
34                    .pop()
35                    .expect("certs and errors should not be both empty"),
36            ));
37        }
38        roots.add_parsable_certificates(certs.certs);
39        Ok(ClientBuilder(roots))
40    }
41
42    /// Create a builder with [`webpki_roots`].
43    #[cfg(feature = "webpki-roots")]
44    pub fn new_with_webpki_roots() -> Self {
45        let roots =
46            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
47        ClientBuilder(roots)
48    }
49
50    /// Add a custom certificate.
51    pub fn with_custom_certificate(
52        mut self,
53        der: rustls::pki_types::CertificateDer,
54    ) -> Result<Self, rustls::Error> {
55        self.0.add(der)?;
56        Ok(self)
57    }
58
59    /// Don't configure revocation.
60    pub fn with_no_crls(self) -> ClientBuilder<rustls::ClientConfig> {
61        ClientBuilder::new_with_root_certificates(self.0)
62    }
63
64    /// Verify the revocation state of presented client certificates against the
65    /// provided certificate revocation lists (CRLs).
66    pub fn with_crls(
67        self,
68        crls: impl IntoIterator<Item = rustls::pki_types::CertificateRevocationListDer<'static>>,
69    ) -> Result<ClientBuilder<rustls::ClientConfig>, rustls::client::VerifierBuilderError> {
70        let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(self.0))
71            .with_crls(crls)
72            .build()?;
73        Ok(ClientBuilder::new_with_webpki_verifier(verifier))
74    }
75}
76
77impl ClientBuilder<rustls::ClientConfig> {
78    /// Create a builder with the provided [`rustls::ClientConfig`].
79    pub fn new_with_rustls_client_config(
80        client_config: rustls::ClientConfig,
81    ) -> ClientBuilder<rustls::ClientConfig> {
82        ClientBuilder(client_config)
83    }
84
85    /// Do not verify the server's certificate. It is vulnerable to MITM
86    /// attacks, but convenient for testing.
87    pub fn new_with_no_server_verification() -> ClientBuilder<rustls::ClientConfig> {
88        ClientBuilder(
89            rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
90                .dangerous()
91                .with_custom_certificate_verifier(Arc::new(verifier::SkipServerVerification::new()))
92                .with_no_client_auth(),
93        )
94    }
95
96    /// Create a builder with [`rustls_platform_verifier`].
97    #[cfg(feature = "platform-verifier")]
98    pub fn new_with_platform_verifier() -> Result<ClientBuilder<rustls::ClientConfig>, rustls::Error>
99    {
100        use rustls_platform_verifier::BuilderVerifierExt;
101
102        Ok(ClientBuilder(
103            rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
104                .with_platform_verifier()?
105                .with_no_client_auth(),
106        ))
107    }
108
109    /// Create a builder with the provided [`rustls::RootCertStore`].
110    pub fn new_with_root_certificates(
111        roots: rustls::RootCertStore,
112    ) -> ClientBuilder<rustls::ClientConfig> {
113        ClientBuilder(
114            rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
115                .with_root_certificates(roots)
116                .with_no_client_auth(),
117        )
118    }
119
120    /// Create a builder with a custom [`rustls::client::WebPkiServerVerifier`].
121    pub fn new_with_webpki_verifier(
122        verifier: Arc<rustls::client::WebPkiServerVerifier>,
123    ) -> ClientBuilder<rustls::ClientConfig> {
124        ClientBuilder(
125            rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
126                .with_webpki_verifier(verifier)
127                .with_no_client_auth(),
128        )
129    }
130
131    /// Set the ALPN protocols to use.
132    pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self {
133        self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect();
134        self
135    }
136
137    /// Logging key material to a file for debugging. The file's name is given
138    /// by the `SSLKEYLOGFILE` environment variable.
139    ///
140    /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot
141    /// be written, this does nothing.
142    pub fn with_key_log(mut self) -> Self {
143        self.0.key_log = Arc::new(rustls::KeyLogFile::new());
144        self
145    }
146
147    /// Build a [`ClientConfig`].
148    pub fn build(mut self) -> ClientConfig {
149        self.0.enable_early_data = true;
150        ClientConfig::new(Arc::new(
151            QuicClientConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"),
152        ))
153    }
154
155    /// Create a new [`Endpoint`].
156    ///
157    /// See [`Endpoint::client`] for more information.
158    pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
159        let mut endpoint = Endpoint::client(addr).await?;
160        endpoint.default_client_config = Some(self.build());
161        Ok(endpoint)
162    }
163}
164
165/// Helper to construct an [`Endpoint`] for use with incoming connections.
166///
167/// To get one, call `new_with_xxx` methods.
168///
169/// [builder]: https://rust-unofficial.github.io/patterns/patterns/creational/builder.html
170#[derive(Debug)]
171pub struct ServerBuilder<T>(T);
172
173impl ServerBuilder<rustls::ServerConfig> {
174    /// Create a builder with the provided [`rustls::ServerConfig`].
175    pub fn new_with_rustls_server_config(server_config: rustls::ServerConfig) -> Self {
176        Self(server_config)
177    }
178
179    /// Create a builder with a single certificate chain and matching private
180    /// key. Using this method gets the same result as calling
181    /// [`ServerConfig::with_single_cert`].
182    pub fn new_with_single_cert(
183        cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
184        key_der: rustls::pki_types::PrivateKeyDer<'static>,
185    ) -> Result<Self, rustls::Error> {
186        let server_config =
187            rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
188                .with_no_client_auth()
189                .with_single_cert(cert_chain, key_der)?;
190        Ok(Self::new_with_rustls_server_config(server_config))
191    }
192
193    /// Set the ALPN protocols to use.
194    pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self {
195        self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect();
196        self
197    }
198
199    /// Logging key material to a file for debugging. The file's name is given
200    /// by the `SSLKEYLOGFILE` environment variable.
201    ///
202    /// If `SSLKEYLOGFILE` is not set, or such a file cannot be opened or cannot
203    /// be written, this does nothing.
204    pub fn with_key_log(mut self) -> Self {
205        self.0.key_log = Arc::new(rustls::KeyLogFile::new());
206        self
207    }
208
209    /// Build a [`ServerConfig`].
210    pub fn build(mut self) -> ServerConfig {
211        self.0.max_early_data_size = u32::MAX;
212        ServerConfig::with_crypto(Arc::new(
213            QuicServerConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"),
214        ))
215    }
216
217    /// Create a new [`Endpoint`].
218    ///
219    /// See [`Endpoint::server`] for more information.
220    pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
221        Endpoint::server(addr, self.build()).await
222    }
223}
224
225mod verifier {
226    use rustls::{
227        DigitallySignedStruct, Error, SignatureScheme,
228        client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
229        crypto::WebPkiSupportedAlgorithms,
230        pki_types::{CertificateDer, ServerName, UnixTime},
231    };
232
233    #[derive(Debug)]
234    pub struct SkipServerVerification(WebPkiSupportedAlgorithms);
235
236    impl SkipServerVerification {
237        pub fn new() -> Self {
238            Self(
239                rustls::crypto::CryptoProvider::get_default()
240                    .map(|provider| provider.signature_verification_algorithms)
241                    .unwrap_or_else(|| {
242                        #[cfg(feature = "aws-lc-rs")]
243                        use rustls::crypto::aws_lc_rs::default_provider;
244                        #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))]
245                        use rustls::crypto::ring::default_provider;
246                        default_provider().signature_verification_algorithms
247                    }),
248            )
249        }
250    }
251
252    impl ServerCertVerifier for SkipServerVerification {
253        fn verify_server_cert(
254            &self,
255            _end_entity: &CertificateDer<'_>,
256            _intermediates: &[CertificateDer<'_>],
257            _server_name: &ServerName<'_>,
258            _ocsp: &[u8],
259            _now: UnixTime,
260        ) -> Result<ServerCertVerified, Error> {
261            Ok(ServerCertVerified::assertion())
262        }
263
264        fn verify_tls12_signature(
265            &self,
266            message: &[u8],
267            cert: &CertificateDer<'_>,
268            dss: &DigitallySignedStruct,
269        ) -> Result<HandshakeSignatureValid, Error> {
270            rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0)
271        }
272
273        fn verify_tls13_signature(
274            &self,
275            message: &[u8],
276            cert: &CertificateDer<'_>,
277            dss: &DigitallySignedStruct,
278        ) -> Result<HandshakeSignatureValid, Error> {
279            rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0)
280        }
281
282        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
283            self.0.supported_schemes()
284        }
285    }
286}