1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7#[cfg(all(feature = "openssl", not(feature = "rustls")))]
11pub type Ssl = crate::common::openssl::OpensslDriver;
12#[cfg(feature = "rustls")]
13pub type Ssl = crate::common::rustls::RustlsDriver;
14#[cfg(not(any(feature = "openssl", feature = "rustls")))]
15pub type Ssl = NullTlsDriver;
16
17#[doc(hidden)]
19pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
20 type Stream: Stream + Send;
21 type ClientParams: Unpin + Send;
22 type ServerParams: Unpin + Send;
23 const DRIVER_NAME: &'static str;
24
25 #[allow(unused)]
26 fn init_client(
27 params: &TlsParameters,
28 name: Option<ServerName>,
29 ) -> Result<Self::ClientParams, SslError>;
30 #[allow(unused)]
31 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
32
33 fn upgrade_client<S: Stream>(
34 params: Self::ClientParams,
35 stream: S,
36 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37 fn upgrade_server<S: Stream>(
38 params: TlsServerParameterProvider,
39 stream: S,
40 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
41 fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream>;
42
43 fn is<D: TlsDriver>() -> bool {
44 D::DRIVER_NAME == Self::DRIVER_NAME
45 }
46}
47
48#[derive(Default)]
50pub struct NullTlsDriver;
51
52#[allow(unused)]
53impl TlsDriver for NullTlsDriver {
54 type Stream = BaseStream;
55 type ClientParams = ();
56 type ServerParams = ();
57 const DRIVER_NAME: &'static str = "null";
58
59 fn init_client(
60 params: &TlsParameters,
61 name: Option<ServerName>,
62 ) -> Result<Self::ClientParams, SslError> {
63 Err(SslError::SslUnsupportedByClient)
64 }
65
66 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
67 Err(SslError::SslUnsupportedByClient)
68 }
69
70 async fn upgrade_client<S: Stream>(
71 params: Self::ClientParams,
72 stream: S,
73 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
74 Err(SslError::SslUnsupportedByClient)
75 }
76
77 async fn upgrade_server<S: Stream>(
78 params: TlsServerParameterProvider,
79 stream: S,
80 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
81 Err(SslError::SslUnsupportedByClient)
82 }
83
84 fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
85 Ok(())
87 }
88}
89
90#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
108pub enum TlsServerCertVerify {
109 Insecure,
112 IgnoreHostname,
114 #[default]
116 VerifyFull,
117}
118
119#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
120pub enum TlsCert {
121 #[default]
123 System,
124 #[debug("SystemPlus([{} cert(s)])", _0.len())]
127 SystemPlus(Vec<CertificateDer<'static>>),
128 Webpki,
130 #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
133 WebpkiPlus(Vec<CertificateDer<'static>>),
134 #[debug("Custom([{} cert(s)])", _0.len())]
136 Custom(Vec<CertificateDer<'static>>),
137}
138
139#[derive(Default, derive_more::Debug, PartialEq, Eq)]
140pub struct TlsParameters {
141 pub server_cert_verify: TlsServerCertVerify,
142 #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
143 pub cert: Option<CertificateDer<'static>>,
144 #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
145 pub key: Option<PrivateKeyDer<'static>>,
146 pub root_cert: TlsCert,
147 #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
148 pub crl: Vec<CertificateRevocationListDer<'static>>,
149 pub min_protocol_version: Option<SslVersion>,
150 pub max_protocol_version: Option<SslVersion>,
151 pub enable_keylog: bool,
152 pub sni_override: Option<Cow<'static, str>>,
153 pub alpn: TlsAlpn,
154}
155
156impl TlsParameters {
157 pub fn insecure() -> Self {
158 Self {
159 server_cert_verify: TlsServerCertVerify::Insecure,
160 ..Default::default()
161 }
162 }
163}
164
165#[derive(Copy, Clone, Debug, PartialEq, Eq)]
166pub enum SslVersion {
167 Tls1,
168 Tls1_1,
169 Tls1_2,
170 Tls1_3,
171}
172
173impl std::fmt::Display for SslVersion {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 let s = match self {
176 SslVersion::Tls1 => "TLSv1",
177 SslVersion::Tls1_1 => "TLSv1.1",
178 SslVersion::Tls1_2 => "TLSv1.2",
179 SslVersion::Tls1_3 => "TLSv1.3",
180 };
181 f.write_str(s)
182 }
183}
184
185#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
186pub struct SslVersionParseError(#[error(not(source))] pub String);
187
188#[cfg(feature = "serde")]
189impl serde::Serialize for SslVersion {
190 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191 where
192 S: serde::Serializer,
193 {
194 serializer.serialize_str(match self {
195 SslVersion::Tls1 => "TLSv1",
196 SslVersion::Tls1_1 => "TLSv1.1",
197 SslVersion::Tls1_2 => "TLSv1.2",
198 SslVersion::Tls1_3 => "TLSv1.3",
199 })
200 }
201}
202
203impl TryFrom<Cow<'_, str>> for SslVersion {
204 type Error = SslVersionParseError;
205 fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
206 Ok(match value.to_lowercase().as_ref() {
207 "tls_1" | "tlsv1" => SslVersion::Tls1,
208 "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
209 "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
210 "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
211 _ => return Err(SslVersionParseError(value.to_string())),
212 })
213 }
214}
215
216#[derive(Default, Debug, PartialEq, Eq)]
217pub enum TlsClientCertVerify {
218 #[default]
220 Ignore,
221 Optional(Vec<CertificateDer<'static>>),
223 Validate(Vec<CertificateDer<'static>>),
226}
227
228#[derive(derive_more::Debug, derive_more::Constructor)]
229pub struct TlsKey {
230 #[debug("key(...)")]
231 pub(crate) key: PrivateKeyDer<'static>,
232 #[debug("cert(...)")]
233 pub(crate) cert: CertificateDer<'static>,
234}
235
236impl TlsKey {
237 #[cfg(feature = "pem")]
239 pub fn new_pem(mut key: &[u8], mut cert: &[u8]) -> Result<Self, std::io::Error> {
240 let cert = rustls_pemfile::certs(&mut cert)
241 .next()
242 .ok_or(std::io::Error::new(
243 std::io::ErrorKind::InvalidData,
244 "No certificate found",
245 ))??;
246 let key = rustls_pemfile::private_key(&mut key)?.ok_or(std::io::Error::new(
247 std::io::ErrorKind::InvalidData,
248 "No key found",
249 ))?;
250 Ok(Self { cert, key })
251 }
252}
253
254#[derive(Debug, Clone)]
255pub struct TlsServerParameterProvider {
256 inner: TlsServerParameterProviderInner,
257}
258
259impl TlsServerParameterProvider {
260 pub fn new(params: TlsServerParameters) -> Self {
261 Self {
262 inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
263 }
264 }
265
266 pub fn with_lookup(
267 lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
268 ) -> Self {
269 Self {
270 inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
271 }
272 }
273
274 pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
275 match &self.inner {
276 TlsServerParameterProviderInner::Static(params) => params.clone(),
277 TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
278 }
279 }
280}
281
282#[derive(derive_more::Debug, Clone)]
283enum TlsServerParameterProviderInner {
284 Static(Arc<TlsServerParameters>),
285 #[debug("Lookup(...)")]
286 #[allow(clippy::type_complexity)]
287 Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
288}
289
290#[derive(Debug)]
291pub struct TlsServerParameters {
292 pub client_cert_verify: TlsClientCertVerify,
293 pub min_protocol_version: Option<SslVersion>,
294 pub max_protocol_version: Option<SslVersion>,
295 pub server_certificate: TlsKey,
296 pub alpn: TlsAlpn,
297}
298
299impl TlsServerParameters {
300 pub fn new_with_certificate(server_certificate: TlsKey) -> Self {
301 Self {
302 client_cert_verify: TlsClientCertVerify::default(),
303 min_protocol_version: None,
304 max_protocol_version: None,
305 server_certificate,
306 alpn: TlsAlpn::default(),
307 }
308 }
309}
310
311#[derive(Default, Eq, PartialEq)]
312pub struct TlsAlpn {
313 alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
315}
316
317impl std::fmt::Debug for TlsAlpn {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 if self.alpn_parts.is_empty() {
320 write!(f, "[]")
321 } else {
322 for (i, part) in self.alpn_parts.iter().enumerate() {
323 if i == 0 {
324 write!(f, "[")?;
325 } else {
326 write!(f, ", ")?;
327 }
328 let mut s = String::new();
330 s.push_str("b\"");
331 for &b in part.iter() {
332 for c in b.escape_ascii() {
333 s.push(c as char);
334 }
335 }
336 s.push('"');
337 write!(f, "{}", s)?;
338 }
339 write!(f, "]")?;
340 Ok(())
341 }
342 }
343}
344
345impl TlsAlpn {
346 pub fn new(alpn: &'static [&'static [u8]]) -> Self {
347 let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
348 Self {
349 alpn_parts: Cow::Owned(alpn),
350 }
351 }
352
353 pub fn new_str(alpn: &'static [&'static str]) -> Self {
354 let alpn = alpn
355 .iter()
356 .map(|s| Cow::Borrowed(s.as_bytes()))
357 .collect::<Vec<_>>();
358 Self {
359 alpn_parts: Cow::Owned(alpn),
360 }
361 }
362
363 pub fn is_empty(&self) -> bool {
364 self.alpn_parts.is_empty()
365 }
366
367 pub fn as_bytes(&self) -> Vec<u8> {
368 let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
369 for part in self.alpn_parts.iter() {
370 bytes.push(part.len() as u8);
371 bytes.extend_from_slice(part.as_ref());
372 }
373 bytes
374 }
375
376 pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
377 let mut vec = Vec::with_capacity(self.alpn_parts.len());
378 for part in self.alpn_parts.iter() {
379 vec.push(part.to_vec());
380 }
381 vec
382 }
383}
384
385#[derive(Debug, Clone, Default)]
387pub struct TlsHandshake {
388 pub alpn: Option<Cow<'static, [u8]>>,
390 pub sni: Option<Cow<'static, str>>,
392 pub cert: Option<CertificateDer<'static>>,
394 pub version: Option<SslVersion>,
396}
397
398#[cfg(test)]
399mod tests {
400 use rustls_pki_types::PrivatePkcs1KeyDer;
401
402 use super::*;
403
404 #[test]
405 fn test_tls_parameters_debug() {
406 let params = TlsParameters::default();
407 assert_eq!(
408 format!("{:?}", params),
409 "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
410 root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
411 enable_keylog: false, sni_override: None, alpn: [] }"
412 );
413 let params = TlsParameters {
414 server_cert_verify: TlsServerCertVerify::Insecure,
415 cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
416 key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
417 1, 2, 3,
418 ]))),
419 root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
420 crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
421 min_protocol_version: None,
422 max_protocol_version: None,
423 enable_keylog: false,
424 sni_override: None,
425 alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
426 };
427 assert_eq!(
428 format!("{:?}", params),
429 "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
430 root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
431 max_protocol_version: None, enable_keylog: false, sni_override: None, \
432 alpn: [b\"h2\", b\"http/1.1\"] }"
433 );
434 }
435
436 #[test]
437 fn test_tls_alpn() {
438 let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
439 assert_eq!(
440 alpn.as_bytes(),
441 vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
442 );
443 assert_eq!(
444 alpn.as_vec_vec(),
445 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
446 );
447 assert!(!alpn.is_empty());
448 assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
449
450 let empty_alpn = TlsAlpn::default();
451 assert!(empty_alpn.is_empty());
452 assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
453 assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
454 assert_eq!(format!("{:?}", empty_alpn), "[]");
455 }
456}