1use std::sync::Arc;
4
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6
7#[derive(Clone)]
11pub struct ClientAuth {
12 pub certificates: Vec<CertificateDer<'static>>,
14 pub key: Arc<PrivateKeyDer<'static>>,
16}
17
18impl ClientAuth {
19 pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
21 Self {
22 certificates,
23 key: Arc::new(key),
24 }
25 }
26}
27
28impl std::fmt::Debug for ClientAuth {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ClientAuth")
31 .field("certificates_count", &self.certificates.len())
32 .field("has_key", &true)
33 .finish()
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct TlsConfig {
40 pub trust_server_certificate: bool,
44
45 pub root_certificates: Vec<CertificateDer<'static>>,
53
54 pub client_auth: Option<ClientAuth>,
56
57 pub server_name: Option<String>,
61
62 pub min_protocol_version: TlsVersion,
64
65 pub max_protocol_version: TlsVersion,
67
68 pub strict_mode: bool,
70
71 pub alpn_protocols: Vec<Vec<u8>>,
73}
74
75impl Default for TlsConfig {
76 fn default() -> Self {
77 Self {
78 trust_server_certificate: false,
79 root_certificates: Vec::new(),
80 client_auth: None,
81 server_name: None,
82 min_protocol_version: TlsVersion::Tls12,
83 max_protocol_version: TlsVersion::Tls13,
84 strict_mode: false,
85 alpn_protocols: Vec::new(),
86 }
87 }
88}
89
90impl TlsConfig {
91 #[must_use]
93 pub fn new() -> Self {
94 Self::default()
95 }
96
97 #[must_use]
101 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
102 self.trust_server_certificate = trust;
103 self
104 }
105
106 #[must_use]
108 pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
109 self.root_certificates.push(cert);
110 self
111 }
112
113 #[must_use]
115 pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
116 self.root_certificates = certs;
117 self
118 }
119
120 #[must_use]
122 pub fn with_client_auth(
123 mut self,
124 certs: Vec<CertificateDer<'static>>,
125 key: PrivateKeyDer<'static>,
126 ) -> Self {
127 self.client_auth = Some(ClientAuth::new(certs, key));
128 self
129 }
130
131 #[must_use]
133 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
134 self.server_name = Some(name.into());
135 self
136 }
137
138 #[must_use]
140 pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
141 self.min_protocol_version = version;
142 self
143 }
144
145 #[must_use]
147 pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
148 self.max_protocol_version = version;
149 self
150 }
151
152 #[must_use]
154 pub fn strict_mode(mut self, enabled: bool) -> Self {
155 self.strict_mode = enabled;
156 self
157 }
158
159 #[must_use]
161 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
162 self.alpn_protocols = protocols;
163 self
164 }
165
166 #[must_use]
168 pub fn has_client_auth(&self) -> bool {
169 self.client_auth.is_some()
170 }
171
172 #[must_use]
178 pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
179 self.add_root_certificate(CertificateDer::from(der_bytes))
180 }
181
182 #[must_use]
190 pub fn with_client_auth_der(
191 self,
192 cert_chain_der: Vec<Vec<u8>>,
193 private_key_der: Vec<u8>,
194 ) -> Self {
195 let certs = cert_chain_der
196 .into_iter()
197 .map(CertificateDer::from)
198 .collect();
199 let key = PrivateKeyDer::Pkcs8(private_key_der.into());
200 self.with_client_auth(certs, key)
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
206#[non_exhaustive]
207pub enum TlsVersion {
208 #[default]
210 Tls12,
211 Tls13,
213}
214
215impl TlsVersion {
216 #[must_use]
218 pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
219 match self {
220 Self::Tls12 => &rustls::version::TLS12,
221 Self::Tls13 => &rustls::version::TLS13,
222 }
223 }
224}