1use std::sync::Arc;
4
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6
7#[derive(Clone)]
11pub struct ClientAuth {
12 pub certificates: Vec<CertificateDer<'static>>,
14 pub key: Arc<PrivateKeyDer<'static>>,
16}
17
18impl ClientAuth {
19 pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
21 Self {
22 certificates,
23 key: Arc::new(key),
24 }
25 }
26}
27
28impl std::fmt::Debug for ClientAuth {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ClientAuth")
31 .field("certificates_count", &self.certificates.len())
32 .field("has_key", &true)
33 .finish()
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct TlsConfig {
40 pub trust_server_certificate: bool,
44
45 pub root_certificates: Vec<CertificateDer<'static>>,
49
50 pub client_auth: Option<ClientAuth>,
52
53 pub server_name: Option<String>,
57
58 pub min_protocol_version: TlsVersion,
60
61 pub max_protocol_version: TlsVersion,
63
64 pub strict_mode: bool,
66
67 pub alpn_protocols: Vec<Vec<u8>>,
69}
70
71impl Default for TlsConfig {
72 fn default() -> Self {
73 Self {
74 trust_server_certificate: false,
75 root_certificates: Vec::new(),
76 client_auth: None,
77 server_name: None,
78 min_protocol_version: TlsVersion::Tls12,
79 max_protocol_version: TlsVersion::Tls13,
80 strict_mode: false,
81 alpn_protocols: Vec::new(),
82 }
83 }
84}
85
86impl TlsConfig {
87 #[must_use]
89 pub fn new() -> Self {
90 Self::default()
91 }
92
93 #[must_use]
97 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
98 self.trust_server_certificate = trust;
99 self
100 }
101
102 #[must_use]
104 pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
105 self.root_certificates.push(cert);
106 self
107 }
108
109 #[must_use]
111 pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
112 self.root_certificates = certs;
113 self
114 }
115
116 #[must_use]
118 pub fn with_client_auth(
119 mut self,
120 certs: Vec<CertificateDer<'static>>,
121 key: PrivateKeyDer<'static>,
122 ) -> Self {
123 self.client_auth = Some(ClientAuth::new(certs, key));
124 self
125 }
126
127 #[must_use]
129 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
130 self.server_name = Some(name.into());
131 self
132 }
133
134 #[must_use]
136 pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
137 self.min_protocol_version = version;
138 self
139 }
140
141 #[must_use]
143 pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
144 self.max_protocol_version = version;
145 self
146 }
147
148 #[must_use]
150 pub fn strict_mode(mut self, enabled: bool) -> Self {
151 self.strict_mode = enabled;
152 self
153 }
154
155 #[must_use]
157 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
158 self.alpn_protocols = protocols;
159 self
160 }
161
162 #[must_use]
164 pub fn has_client_auth(&self) -> bool {
165 self.client_auth.is_some()
166 }
167
168 #[must_use]
174 pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
175 self.add_root_certificate(CertificateDer::from(der_bytes))
176 }
177
178 #[must_use]
186 pub fn with_client_auth_der(
187 self,
188 cert_chain_der: Vec<Vec<u8>>,
189 private_key_der: Vec<u8>,
190 ) -> Self {
191 let certs = cert_chain_der
192 .into_iter()
193 .map(CertificateDer::from)
194 .collect();
195 let key = PrivateKeyDer::Pkcs8(private_key_der.into());
196 self.with_client_auth(certs, key)
197 }
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
202#[non_exhaustive]
203pub enum TlsVersion {
204 #[default]
206 Tls12,
207 Tls13,
209}
210
211impl TlsVersion {
212 #[must_use]
214 pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
215 match self {
216 Self::Tls12 => &rustls::version::TLS12,
217 Self::Tls13 => &rustls::version::TLS13,
218 }
219 }
220}