1use std::{io, sync::Arc};
2
3use compio_net::ToSocketAddrsAsync;
4use quinn_proto::{
5 ClientConfig, ServerConfig,
6 crypto::rustls::{QuicClientConfig, QuicServerConfig},
7};
8
9use crate::Endpoint;
10
11#[derive(Debug)]
17pub struct ClientBuilder<T>(T);
18
19impl ClientBuilder<rustls::RootCertStore> {
20 pub fn new_with_empty_roots() -> Self {
22 ClientBuilder(rustls::RootCertStore::empty())
23 }
24
25 #[cfg(feature = "native-certs")]
27 pub fn new_with_native_certs() -> io::Result<Self> {
28 let mut roots = rustls::RootCertStore::empty();
29 let mut certs = rustls_native_certs::load_native_certs();
30 if certs.certs.is_empty() {
31 return Err(io::Error::other(
32 certs
33 .errors
34 .pop()
35 .expect("certs and errors should not be both empty"),
36 ));
37 }
38 roots.add_parsable_certificates(certs.certs);
39 Ok(ClientBuilder(roots))
40 }
41
42 #[cfg(feature = "webpki-roots")]
44 pub fn new_with_webpki_roots() -> Self {
45 let roots =
46 rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
47 ClientBuilder(roots)
48 }
49
50 pub fn with_custom_certificate(
52 mut self,
53 der: rustls::pki_types::CertificateDer,
54 ) -> Result<Self, rustls::Error> {
55 self.0.add(der)?;
56 Ok(self)
57 }
58
59 pub fn with_no_crls(self) -> ClientBuilder<rustls::ClientConfig> {
61 ClientBuilder::new_with_root_certificates(self.0)
62 }
63
64 pub fn with_crls(
67 self,
68 crls: impl IntoIterator<Item = rustls::pki_types::CertificateRevocationListDer<'static>>,
69 ) -> Result<ClientBuilder<rustls::ClientConfig>, rustls::client::VerifierBuilderError> {
70 let verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(self.0))
71 .with_crls(crls)
72 .build()?;
73 Ok(ClientBuilder::new_with_webpki_verifier(verifier))
74 }
75}
76
77impl ClientBuilder<rustls::ClientConfig> {
78 pub fn new_with_rustls_client_config(
80 client_config: rustls::ClientConfig,
81 ) -> ClientBuilder<rustls::ClientConfig> {
82 ClientBuilder(client_config)
83 }
84
85 pub fn new_with_no_server_verification() -> ClientBuilder<rustls::ClientConfig> {
88 ClientBuilder(
89 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
90 .dangerous()
91 .with_custom_certificate_verifier(Arc::new(verifier::SkipServerVerification::new()))
92 .with_no_client_auth(),
93 )
94 }
95
96 #[cfg(feature = "platform-verifier")]
98 pub fn new_with_platform_verifier() -> Result<ClientBuilder<rustls::ClientConfig>, rustls::Error>
99 {
100 use rustls_platform_verifier::BuilderVerifierExt;
101
102 Ok(ClientBuilder(
103 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
104 .with_platform_verifier()?
105 .with_no_client_auth(),
106 ))
107 }
108
109 pub fn new_with_root_certificates(
111 roots: rustls::RootCertStore,
112 ) -> ClientBuilder<rustls::ClientConfig> {
113 ClientBuilder(
114 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
115 .with_root_certificates(roots)
116 .with_no_client_auth(),
117 )
118 }
119
120 pub fn new_with_webpki_verifier(
122 verifier: Arc<rustls::client::WebPkiServerVerifier>,
123 ) -> ClientBuilder<rustls::ClientConfig> {
124 ClientBuilder(
125 rustls::ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
126 .with_webpki_verifier(verifier)
127 .with_no_client_auth(),
128 )
129 }
130
131 pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self {
133 self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect();
134 self
135 }
136
137 pub fn with_key_log(mut self) -> Self {
143 self.0.key_log = Arc::new(rustls::KeyLogFile::new());
144 self
145 }
146
147 pub fn build(mut self) -> ClientConfig {
149 self.0.enable_early_data = true;
150 ClientConfig::new(Arc::new(
151 QuicClientConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"),
152 ))
153 }
154
155 pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
159 let mut endpoint = Endpoint::client(addr).await?;
160 endpoint.default_client_config = Some(self.build());
161 Ok(endpoint)
162 }
163}
164
165#[derive(Debug)]
171pub struct ServerBuilder<T>(T);
172
173impl ServerBuilder<rustls::ServerConfig> {
174 pub fn new_with_rustls_server_config(server_config: rustls::ServerConfig) -> Self {
176 Self(server_config)
177 }
178
179 pub fn new_with_single_cert(
183 cert_chain: Vec<rustls::pki_types::CertificateDer<'static>>,
184 key_der: rustls::pki_types::PrivateKeyDer<'static>,
185 ) -> Result<Self, rustls::Error> {
186 let server_config =
187 rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
188 .with_no_client_auth()
189 .with_single_cert(cert_chain, key_der)?;
190 Ok(Self::new_with_rustls_server_config(server_config))
191 }
192
193 pub fn with_alpn_protocols(mut self, protocols: &[&str]) -> Self {
195 self.0.alpn_protocols = protocols.iter().map(|p| p.as_bytes().to_vec()).collect();
196 self
197 }
198
199 pub fn with_key_log(mut self) -> Self {
205 self.0.key_log = Arc::new(rustls::KeyLogFile::new());
206 self
207 }
208
209 pub fn build(mut self) -> ServerConfig {
211 self.0.max_early_data_size = u32::MAX;
212 ServerConfig::with_crypto(Arc::new(
213 QuicServerConfig::try_from(self.0).expect("should support TLS13_AES_128_GCM_SHA256"),
214 ))
215 }
216
217 pub async fn bind(self, addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
221 Endpoint::server(addr, self.build()).await
222 }
223}
224
225mod verifier {
226 use rustls::{
227 DigitallySignedStruct, Error, SignatureScheme,
228 client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
229 crypto::WebPkiSupportedAlgorithms,
230 pki_types::{CertificateDer, ServerName, UnixTime},
231 };
232
233 #[derive(Debug)]
234 pub struct SkipServerVerification(WebPkiSupportedAlgorithms);
235
236 impl SkipServerVerification {
237 pub fn new() -> Self {
238 Self(
239 rustls::crypto::CryptoProvider::get_default()
240 .map(|provider| provider.signature_verification_algorithms)
241 .unwrap_or_else(|| {
242 #[cfg(feature = "aws-lc-rs")]
243 use rustls::crypto::aws_lc_rs::default_provider;
244 #[cfg(all(not(feature = "aws-lc-rs"), feature = "ring"))]
245 use rustls::crypto::ring::default_provider;
246 default_provider().signature_verification_algorithms
247 }),
248 )
249 }
250 }
251
252 impl ServerCertVerifier for SkipServerVerification {
253 fn verify_server_cert(
254 &self,
255 _end_entity: &CertificateDer<'_>,
256 _intermediates: &[CertificateDer<'_>],
257 _server_name: &ServerName<'_>,
258 _ocsp: &[u8],
259 _now: UnixTime,
260 ) -> Result<ServerCertVerified, Error> {
261 Ok(ServerCertVerified::assertion())
262 }
263
264 fn verify_tls12_signature(
265 &self,
266 message: &[u8],
267 cert: &CertificateDer<'_>,
268 dss: &DigitallySignedStruct,
269 ) -> Result<HandshakeSignatureValid, Error> {
270 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0)
271 }
272
273 fn verify_tls13_signature(
274 &self,
275 message: &[u8],
276 cert: &CertificateDer<'_>,
277 dss: &DigitallySignedStruct,
278 ) -> Result<HandshakeSignatureValid, Error> {
279 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0)
280 }
281
282 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
283 self.0.supported_schemes()
284 }
285 }
286}