1use std::sync::Arc;
4
5use rustls::pki_types::{CertificateDer, PrivateKeyDer};
6
7#[derive(Clone)]
11pub struct ClientAuth {
12 pub certificates: Vec<CertificateDer<'static>>,
14 pub key: Arc<PrivateKeyDer<'static>>,
16}
17
18impl ClientAuth {
19 pub fn new(certificates: Vec<CertificateDer<'static>>, key: PrivateKeyDer<'static>) -> Self {
21 Self {
22 certificates,
23 key: Arc::new(key),
24 }
25 }
26}
27
28impl std::fmt::Debug for ClientAuth {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("ClientAuth")
31 .field("certificates_count", &self.certificates.len())
32 .field("has_key", &true)
33 .finish()
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct TlsConfig {
40 pub trust_server_certificate: bool,
44
45 pub root_certificates: Vec<CertificateDer<'static>>,
50
51 pub client_auth: Option<ClientAuth>,
53
54 pub server_name: Option<String>,
58
59 pub min_protocol_version: TlsVersion,
61
62 pub max_protocol_version: TlsVersion,
64
65 pub strict_mode: bool,
67
68 pub alpn_protocols: Vec<Vec<u8>>,
70}
71
72impl Default for TlsConfig {
73 fn default() -> Self {
74 Self {
75 trust_server_certificate: false,
76 root_certificates: Vec::new(),
77 client_auth: None,
78 server_name: None,
79 min_protocol_version: TlsVersion::Tls12,
80 max_protocol_version: TlsVersion::Tls13,
81 strict_mode: false,
82 alpn_protocols: Vec::new(),
83 }
84 }
85}
86
87impl TlsConfig {
88 #[must_use]
90 pub fn new() -> Self {
91 Self::default()
92 }
93
94 #[must_use]
98 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
99 self.trust_server_certificate = trust;
100 self
101 }
102
103 #[must_use]
105 pub fn add_root_certificate(mut self, cert: CertificateDer<'static>) -> Self {
106 self.root_certificates.push(cert);
107 self
108 }
109
110 #[must_use]
112 pub fn with_root_certificates(mut self, certs: Vec<CertificateDer<'static>>) -> Self {
113 self.root_certificates = certs;
114 self
115 }
116
117 #[must_use]
119 pub fn with_client_auth(
120 mut self,
121 certs: Vec<CertificateDer<'static>>,
122 key: PrivateKeyDer<'static>,
123 ) -> Self {
124 self.client_auth = Some(ClientAuth::new(certs, key));
125 self
126 }
127
128 #[must_use]
130 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
131 self.server_name = Some(name.into());
132 self
133 }
134
135 #[must_use]
137 pub fn min_protocol_version(mut self, version: TlsVersion) -> Self {
138 self.min_protocol_version = version;
139 self
140 }
141
142 #[must_use]
144 pub fn max_protocol_version(mut self, version: TlsVersion) -> Self {
145 self.max_protocol_version = version;
146 self
147 }
148
149 #[must_use]
151 pub fn strict_mode(mut self, enabled: bool) -> Self {
152 self.strict_mode = enabled;
153 self
154 }
155
156 #[must_use]
158 pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
159 self.alpn_protocols = protocols;
160 self
161 }
162
163 #[must_use]
165 pub fn has_client_auth(&self) -> bool {
166 self.client_auth.is_some()
167 }
168
169 #[must_use]
175 pub fn add_root_certificate_der(self, der_bytes: Vec<u8>) -> Self {
176 self.add_root_certificate(CertificateDer::from(der_bytes))
177 }
178
179 #[must_use]
187 pub fn with_client_auth_der(
188 self,
189 cert_chain_der: Vec<Vec<u8>>,
190 private_key_der: Vec<u8>,
191 ) -> Self {
192 let certs = cert_chain_der
193 .into_iter()
194 .map(CertificateDer::from)
195 .collect();
196 let key = PrivateKeyDer::Pkcs8(private_key_der.into());
197 self.with_client_auth(certs, key)
198 }
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
203#[non_exhaustive]
204pub enum TlsVersion {
205 #[default]
207 Tls12,
208 Tls13,
210}
211
212impl TlsVersion {
213 #[must_use]
215 pub fn to_rustls(&self) -> &'static rustls::SupportedProtocolVersion {
216 match self {
217 Self::Tls12 => &rustls::version::TLS12,
218 Self::Tls13 => &rustls::version::TLS13,
219 }
220 }
221}