1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17 type Stream: Stream + Send;
18 type ClientParams: Unpin + Send;
19 type ServerParams: Unpin + Send;
20
21 #[allow(unused)]
22 fn init_client(
23 params: &TlsParameters,
24 name: Option<ServerName>,
25 ) -> Result<Self::ClientParams, SslError>;
26 #[allow(unused)]
27 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29 fn upgrade_client<S: Stream>(
30 params: Self::ClientParams,
31 stream: S,
32 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33 fn upgrade_server<S: Stream>(
34 params: TlsServerParameterProvider,
35 stream: S,
36 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44 type Stream = BaseStream;
45 type ClientParams = ();
46 type ServerParams = ();
47
48 fn init_client(
49 params: &TlsParameters,
50 name: Option<ServerName>,
51 ) -> Result<Self::ClientParams, SslError> {
52 Err(SslError::SslUnsupportedByClient)
53 }
54
55 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56 Err(SslError::SslUnsupportedByClient)
57 }
58
59 async fn upgrade_client<S: Stream>(
60 params: Self::ClientParams,
61 stream: S,
62 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63 Err(SslError::SslUnsupportedByClient)
64 }
65
66 async fn upgrade_server<S: Stream>(
67 params: TlsServerParameterProvider,
68 stream: S,
69 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70 Err(SslError::SslUnsupportedByClient)
71 }
72}
73
74#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93 Insecure,
96 IgnoreHostname,
98 #[default]
100 VerifyFull,
101}
102
103#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
104pub enum TlsCert {
105 #[default]
107 System,
108 #[debug("SystemPlus([{} cert(s)])", _0.len())]
111 SystemPlus(Vec<CertificateDer<'static>>),
112 Webpki,
114 #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
117 WebpkiPlus(Vec<CertificateDer<'static>>),
118 #[debug("Custom([{} cert(s)])", _0.len())]
120 Custom(Vec<CertificateDer<'static>>),
121}
122
123#[derive(Default, derive_more::Debug, PartialEq, Eq)]
124pub struct TlsParameters {
125 pub server_cert_verify: TlsServerCertVerify,
126 #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
127 pub cert: Option<CertificateDer<'static>>,
128 #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
129 pub key: Option<PrivateKeyDer<'static>>,
130 pub root_cert: TlsCert,
131 #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
132 pub crl: Vec<CertificateRevocationListDer<'static>>,
133 pub min_protocol_version: Option<SslVersion>,
134 pub max_protocol_version: Option<SslVersion>,
135 pub enable_keylog: bool,
136 pub sni_override: Option<Cow<'static, str>>,
137 pub alpn: TlsAlpn,
138}
139
140impl TlsParameters {
141 pub fn insecure() -> Self {
142 Self {
143 server_cert_verify: TlsServerCertVerify::Insecure,
144 ..Default::default()
145 }
146 }
147}
148
149#[derive(Copy, Clone, Debug, PartialEq, Eq)]
150pub enum SslVersion {
151 Tls1,
152 Tls1_1,
153 Tls1_2,
154 Tls1_3,
155}
156
157#[derive(Default, Debug, PartialEq, Eq)]
158pub enum TlsClientCertVerify {
159 #[default]
161 Ignore,
162 Optional(Vec<CertificateDer<'static>>),
164 Validate(Vec<CertificateDer<'static>>),
167}
168
169#[derive(derive_more::Debug, derive_more::Constructor)]
170pub struct TlsKey {
171 #[debug("key(...)")]
172 pub(crate) key: PrivateKeyDer<'static>,
173 #[debug("cert(...)")]
174 pub(crate) cert: CertificateDer<'static>,
175}
176
177#[derive(Debug, Clone)]
178pub struct TlsServerParameterProvider {
179 inner: TlsServerParameterProviderInner,
180}
181
182impl TlsServerParameterProvider {
183 pub fn new(params: TlsServerParameters) -> Self {
184 Self {
185 inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
186 }
187 }
188
189 pub fn with_lookup(
190 lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
191 ) -> Self {
192 Self {
193 inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
194 }
195 }
196
197 pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
198 match &self.inner {
199 TlsServerParameterProviderInner::Static(params) => params.clone(),
200 TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
201 }
202 }
203}
204
205#[derive(derive_more::Debug, Clone)]
206enum TlsServerParameterProviderInner {
207 Static(Arc<TlsServerParameters>),
208 #[debug("Lookup(...)")]
209 #[allow(clippy::type_complexity)]
210 Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
211}
212
213#[derive(Debug)]
214pub struct TlsServerParameters {
215 pub client_cert_verify: TlsClientCertVerify,
216 pub min_protocol_version: Option<SslVersion>,
217 pub max_protocol_version: Option<SslVersion>,
218 pub server_certificate: TlsKey,
219 pub alpn: TlsAlpn,
220}
221
222#[derive(Default, Eq, PartialEq)]
223pub struct TlsAlpn {
224 alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
226}
227
228impl std::fmt::Debug for TlsAlpn {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 if self.alpn_parts.is_empty() {
231 write!(f, "[]")
232 } else {
233 for (i, part) in self.alpn_parts.iter().enumerate() {
234 if i == 0 {
235 write!(f, "[")?;
236 } else {
237 write!(f, ", ")?;
238 }
239 let mut s = String::new();
241 s.push_str("b\"");
242 for &b in part.iter() {
243 for c in b.escape_ascii() {
244 s.push(c as char);
245 }
246 }
247 s.push_str("\"");
248 write!(f, "{}", s)?;
249 }
250 write!(f, "]")?;
251 Ok(())
252 }
253 }
254}
255
256impl TlsAlpn {
257 pub fn new(alpn: &'static [&'static [u8]]) -> Self {
258 let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
259 Self {
260 alpn_parts: Cow::Owned(alpn),
261 }
262 }
263
264 pub fn new_str(alpn: &'static [&'static str]) -> Self {
265 let alpn = alpn
266 .iter()
267 .map(|s| Cow::Borrowed(s.as_bytes()))
268 .collect::<Vec<_>>();
269 Self {
270 alpn_parts: Cow::Owned(alpn),
271 }
272 }
273
274 pub fn is_empty(&self) -> bool {
275 self.alpn_parts.is_empty()
276 }
277
278 pub fn as_bytes(&self) -> Vec<u8> {
279 let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
280 for part in self.alpn_parts.iter() {
281 bytes.push(part.len() as u8);
282 bytes.extend_from_slice(part.as_ref());
283 }
284 bytes
285 }
286
287 pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
288 let mut vec = Vec::with_capacity(self.alpn_parts.len());
289 for part in self.alpn_parts.iter() {
290 vec.push(part.to_vec());
291 }
292 vec
293 }
294}
295
296#[derive(Debug, Clone, Default)]
297pub struct TlsHandshake {
298 pub alpn: Option<Cow<'static, [u8]>>,
299 pub sni: Option<Cow<'static, str>>,
300 pub cert: Option<CertificateDer<'static>>,
301}
302
303#[cfg(test)]
304mod tests {
305 use rustls_pki_types::PrivatePkcs1KeyDer;
306
307 use super::*;
308
309 #[test]
310 fn test_tls_parameters_debug() {
311 let params = TlsParameters::default();
312 assert_eq!(
313 format!("{:?}", params),
314 "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
315 root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
316 enable_keylog: false, sni_override: None, alpn: [] }"
317 );
318 let params = TlsParameters {
319 server_cert_verify: TlsServerCertVerify::Insecure,
320 cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
321 key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
322 1, 2, 3,
323 ]))),
324 root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
325 crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
326 min_protocol_version: None,
327 max_protocol_version: None,
328 enable_keylog: false,
329 sni_override: None,
330 alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
331 };
332 assert_eq!(
333 format!("{:?}", params),
334 "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
335 root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
336 max_protocol_version: None, enable_keylog: false, sni_override: None, \
337 alpn: [b\"h2\", b\"http/1.1\"] }"
338 );
339 }
340
341 #[test]
342 fn test_tls_alpn() {
343 let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
344 assert_eq!(
345 alpn.as_bytes(),
346 vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
347 );
348 assert_eq!(
349 alpn.as_vec_vec(),
350 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
351 );
352 assert!(!alpn.is_empty());
353 assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
354
355 let empty_alpn = TlsAlpn::default();
356 assert!(empty_alpn.is_empty());
357 assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
358 assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
359 assert_eq!(format!("{:?}", empty_alpn), "[]");
360 }
361
362 #[test]
363 fn test_tls_handshake() {
364 let handshake = TlsHandshake {
365 alpn: Some(Cow::Borrowed(b"h2")),
366 sni: Some(Cow::Borrowed("example.com")),
367 cert: None,
368 };
369 assert_eq!(handshake.alpn, Some(Cow::Borrowed(b"h2".as_slice())));
370 assert_eq!(handshake.sni, Some(Cow::Borrowed("example.com")));
371 assert_eq!(handshake.cert, None);
372
373 assert_eq!(
374 format!("{:?}", handshake),
375 "TlsHandshake { alpn: Some([104, 50]), sni: Some(\"example.com\"), cert: None }"
376 );
377
378 let default_handshake = TlsHandshake::default();
379 assert_eq!(default_handshake.alpn, None);
380 assert_eq!(default_handshake.sni, None);
381 assert_eq!(default_handshake.cert, None);
382 }
383}