1use crate::{SslError, Stream, StreamMetadata};
2use rustls_pki_types::{
3 CertificateDer, CertificateRevocationListDer, DnsName, PrivateKeyDer, ServerName,
4};
5use std::{borrow::Cow, future::Future, sync::Arc};
6
7use super::BaseStream;
8
9#[cfg(all(feature = "openssl", not(feature = "rustls")))]
13pub type Ssl = crate::common::openssl::OpensslDriver;
14#[cfg(feature = "rustls")]
15pub type Ssl = crate::common::rustls::RustlsDriver;
16#[cfg(not(any(feature = "openssl", feature = "rustls")))]
17pub type Ssl = NullTlsDriver;
18
19#[doc(hidden)]
21pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
22 type Stream: Stream + Send;
23 type ClientParams: Unpin + Send;
24 type ServerParams: Unpin + Send;
25 const DRIVER_NAME: &'static str;
26
27 #[allow(unused)]
28 fn init_client(
29 params: &TlsParameters,
30 name: Option<ServerName>,
31 ) -> Result<Self::ClientParams, SslError>;
32 #[allow(unused)]
33 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
34
35 fn upgrade_client<S: Stream>(
36 params: Self::ClientParams,
37 stream: S,
38 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
39 fn upgrade_server<S: Stream>(
40 params: TlsServerParameterProvider,
41 stream: S,
42 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
43 fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream>;
44
45 fn is<D: TlsDriver>() -> bool {
46 D::DRIVER_NAME == Self::DRIVER_NAME
47 }
48}
49
50#[derive(Default)]
52pub struct NullTlsDriver;
53
54#[allow(unused)]
55impl TlsDriver for NullTlsDriver {
56 type Stream = BaseStream;
57 type ClientParams = ();
58 type ServerParams = ();
59 const DRIVER_NAME: &'static str = "null";
60
61 fn init_client(
62 params: &TlsParameters,
63 name: Option<ServerName>,
64 ) -> Result<Self::ClientParams, SslError> {
65 Err(SslError::SslUnsupported)
66 }
67
68 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
69 Err(SslError::SslUnsupported)
70 }
71
72 async fn upgrade_client<S: Stream>(
73 params: Self::ClientParams,
74 stream: S,
75 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
76 Err(SslError::SslUnsupported)
77 }
78
79 async fn upgrade_server<S: Stream>(
80 params: TlsServerParameterProvider,
81 stream: S,
82 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
83 Err(SslError::SslUnsupported)
84 }
85
86 fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
87 Ok(())
89 }
90}
91
92#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
110pub enum TlsServerCertVerify {
111 Insecure,
114 IgnoreHostname,
116 #[default]
118 VerifyFull,
119}
120
121#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
122pub enum TlsCert {
123 #[default]
125 System,
126 #[debug("SystemPlus([{} cert(s)])", _0.len())]
129 SystemPlus(Vec<CertificateDer<'static>>),
130 Webpki,
132 #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
135 WebpkiPlus(Vec<CertificateDer<'static>>),
136 #[debug("Custom([{} cert(s)])", _0.len())]
138 Custom(Vec<CertificateDer<'static>>),
139}
140
141#[derive(Default, derive_more::Debug, PartialEq, Eq)]
142pub struct TlsParameters {
143 pub server_cert_verify: TlsServerCertVerify,
144 #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
145 pub cert: Option<CertificateDer<'static>>,
146 #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
147 pub key: Option<PrivateKeyDer<'static>>,
148 pub root_cert: TlsCert,
149 #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
150 pub crl: Vec<CertificateRevocationListDer<'static>>,
151 pub min_protocol_version: Option<SslVersion>,
152 pub max_protocol_version: Option<SslVersion>,
153 pub enable_keylog: bool,
154 pub sni_override: Option<Cow<'static, str>>,
155 pub alpn: TlsAlpn,
156}
157
158impl TlsParameters {
159 pub fn insecure() -> Self {
160 Self {
161 server_cert_verify: TlsServerCertVerify::Insecure,
162 ..Default::default()
163 }
164 }
165}
166
167#[derive(Copy, Clone, Debug, PartialEq, Eq)]
168pub enum SslVersion {
169 Tls1,
170 Tls1_1,
171 Tls1_2,
172 Tls1_3,
173}
174
175impl std::fmt::Display for SslVersion {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 let s = match self {
178 SslVersion::Tls1 => "TLSv1",
179 SslVersion::Tls1_1 => "TLSv1.1",
180 SslVersion::Tls1_2 => "TLSv1.2",
181 SslVersion::Tls1_3 => "TLSv1.3",
182 };
183 f.write_str(s)
184 }
185}
186
187#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
188pub struct SslVersionParseError(#[error(not(source))] pub String);
189
190#[cfg(feature = "serde")]
191impl serde::Serialize for SslVersion {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: serde::Serializer,
195 {
196 serializer.serialize_str(match self {
197 SslVersion::Tls1 => "TLSv1",
198 SslVersion::Tls1_1 => "TLSv1.1",
199 SslVersion::Tls1_2 => "TLSv1.2",
200 SslVersion::Tls1_3 => "TLSv1.3",
201 })
202 }
203}
204
205impl TryFrom<Cow<'_, str>> for SslVersion {
206 type Error = SslVersionParseError;
207 fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
208 Ok(match value.to_lowercase().as_ref() {
209 "tls_1" | "tlsv1" => SslVersion::Tls1,
210 "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
211 "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
212 "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
213 _ => return Err(SslVersionParseError(value.to_string())),
214 })
215 }
216}
217
218#[derive(Default, Debug, PartialEq, Eq)]
219pub enum TlsClientCertVerify {
220 #[default]
222 Ignore,
223 Optional(Vec<CertificateDer<'static>>),
225 Validate(Vec<CertificateDer<'static>>),
228}
229
230#[derive(derive_more::Debug, derive_more::Constructor)]
231pub struct TlsKey {
232 #[debug("key(...)")]
233 pub(crate) key: PrivateKeyDer<'static>,
234 #[debug("cert(...)")]
235 pub(crate) cert: CertificateDer<'static>,
236}
237
238impl TlsKey {
239 #[cfg(feature = "pem")]
241 pub fn new_pem(mut key: &[u8], mut cert: &[u8]) -> Result<Self, std::io::Error> {
242 let cert = rustls_pemfile::certs(&mut cert)
243 .next()
244 .ok_or(std::io::Error::new(
245 std::io::ErrorKind::InvalidData,
246 "No certificate found",
247 ))??;
248 let key = rustls_pemfile::private_key(&mut key)?.ok_or(std::io::Error::new(
249 std::io::ErrorKind::InvalidData,
250 "No key found",
251 ))?;
252 Ok(Self { cert, key })
253 }
254
255 pub fn clone_key(&self) -> Self {
257 Self {
258 key: self.key.clone_key(),
259 cert: self.cert.clone(),
260 }
261 }
262}
263
264#[derive(Debug, Clone)]
265pub struct TlsServerParameterProvider {
266 inner: TlsServerParameterProviderInner,
267}
268
269impl TlsServerParameterProvider {
270 pub fn new(params: TlsServerParameters) -> Self {
271 Self {
272 inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
273 }
274 }
275
276 pub fn with_lookup(
277 lookup: impl Fn(Option<DnsName>, &dyn StreamMetadata) -> Arc<TlsServerParameters>
278 + Send
279 + Sync
280 + 'static,
281 ) -> Self {
282 Self {
283 inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
284 }
285 }
286
287 pub fn lookup(
288 &self,
289 name: Option<DnsName>,
290 stream: &dyn StreamMetadata,
291 ) -> Arc<TlsServerParameters> {
292 match &self.inner {
293 TlsServerParameterProviderInner::Static(params) => params.clone(),
294 TlsServerParameterProviderInner::Lookup(lookup) => lookup(name, stream),
295 }
296 }
297}
298
299pub type TlsServerParameterLookupFn = dyn Fn(Option<DnsName>, &dyn StreamMetadata) -> Arc<TlsServerParameters>
302 + Send
303 + Sync
304 + 'static;
305
306#[derive(derive_more::Debug, Clone)]
307enum TlsServerParameterProviderInner {
308 Static(Arc<TlsServerParameters>),
309 #[debug("Lookup(...)")]
310 #[allow(clippy::type_complexity)]
311 Lookup(Arc<TlsServerParameterLookupFn>),
312}
313
314#[derive(Debug)]
315pub struct TlsServerParameters {
316 pub client_cert_verify: TlsClientCertVerify,
317 pub min_protocol_version: Option<SslVersion>,
318 pub max_protocol_version: Option<SslVersion>,
319 pub server_certificate: TlsKey,
320 pub alpn: TlsAlpn,
321}
322
323impl TlsServerParameters {
324 pub fn new_with_certificate(server_certificate: TlsKey) -> Self {
325 Self {
326 client_cert_verify: TlsClientCertVerify::default(),
327 min_protocol_version: None,
328 max_protocol_version: None,
329 server_certificate,
330 alpn: TlsAlpn::default(),
331 }
332 }
333}
334
335#[derive(Default, Eq, PartialEq)]
336pub struct TlsAlpn {
337 alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
339}
340
341impl std::fmt::Debug for TlsAlpn {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 if self.alpn_parts.is_empty() {
344 write!(f, "[]")
345 } else {
346 for (i, part) in self.alpn_parts.iter().enumerate() {
347 if i == 0 {
348 write!(f, "[")?;
349 } else {
350 write!(f, ", ")?;
351 }
352 let mut s = String::new();
354 s.push_str("b\"");
355 for &b in part.iter() {
356 for c in b.escape_ascii() {
357 s.push(c as char);
358 }
359 }
360 s.push('"');
361 write!(f, "{s}")?;
362 }
363 write!(f, "]")?;
364 Ok(())
365 }
366 }
367}
368
369impl TlsAlpn {
370 pub fn new(alpn: &'static [&'static [u8]]) -> Self {
371 let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
372 Self {
373 alpn_parts: Cow::Owned(alpn),
374 }
375 }
376
377 pub fn new_str(alpn: &[&'static str]) -> Self {
378 let alpn = alpn
379 .iter()
380 .map(|s| Cow::Borrowed(s.as_bytes()))
381 .collect::<Vec<_>>();
382 Self {
383 alpn_parts: Cow::Owned(alpn),
384 }
385 }
386
387 pub fn is_empty(&self) -> bool {
388 self.alpn_parts.is_empty()
389 }
390
391 pub fn as_bytes(&self) -> Vec<u8> {
392 let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
393 for part in self.alpn_parts.iter() {
394 bytes.push(part.len() as u8);
395 bytes.extend_from_slice(part.as_ref());
396 }
397 bytes
398 }
399
400 pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
401 let mut vec = Vec::with_capacity(self.alpn_parts.len());
402 for part in self.alpn_parts.iter() {
403 vec.push(part.to_vec());
404 }
405 vec
406 }
407}
408
409impl<T, U> From<T> for TlsAlpn
410where
411 U: AsRef<[u8]>,
412 T: IntoIterator<Item = U>,
413{
414 fn from(alpn: T) -> Self {
415 Self {
416 alpn_parts: Cow::Owned(
417 alpn.into_iter()
418 .map(|s| Cow::Owned(s.as_ref().to_vec()))
419 .collect(),
420 ),
421 }
422 }
423}
424
425#[derive(Debug, Clone, Default)]
427pub struct TlsHandshake {
428 pub alpn: Option<Cow<'static, [u8]>>,
430 pub sni: Option<DnsName<'static>>,
432 pub cert: Option<CertificateDer<'static>>,
434 pub version: Option<SslVersion>,
436}
437
438#[cfg(test)]
439mod tests {
440 use rustls_pki_types::PrivatePkcs1KeyDer;
441
442 use super::*;
443
444 #[test]
445 fn test_tls_parameters_debug() {
446 let params = TlsParameters::default();
447 assert_eq!(
448 format!("{params:?}"),
449 "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
450 root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
451 enable_keylog: false, sni_override: None, alpn: [] }"
452 );
453 let params = TlsParameters {
454 server_cert_verify: TlsServerCertVerify::Insecure,
455 cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
456 key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
457 1, 2, 3,
458 ]))),
459 root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
460 crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
461 min_protocol_version: None,
462 max_protocol_version: None,
463 enable_keylog: false,
464 sni_override: None,
465 alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
466 };
467 assert_eq!(
468 format!("{params:?}"),
469 "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
470 root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
471 max_protocol_version: None, enable_keylog: false, sni_override: None, \
472 alpn: [b\"h2\", b\"http/1.1\"] }"
473 );
474 }
475
476 #[test]
477 fn test_tls_alpn() {
478 let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
479 assert_eq!(
480 alpn.as_bytes(),
481 vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
482 );
483 assert_eq!(
484 alpn.as_vec_vec(),
485 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
486 );
487 assert!(!alpn.is_empty());
488 assert_eq!(format!("{alpn:?}"), "[b\"h2\", b\"http/1.1\"]");
489
490 let empty_alpn = TlsAlpn::default();
491 assert!(empty_alpn.is_empty());
492 assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
493 assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
494 assert_eq!(format!("{empty_alpn:?}"), "[]");
495 }
496}