1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17 type Stream: Stream + Send;
18 type ClientParams: Unpin + Send;
19 type ServerParams: Unpin + Send;
20
21 #[allow(unused)]
22 fn init_client(
23 params: &TlsParameters,
24 name: Option<ServerName>,
25 ) -> Result<Self::ClientParams, SslError>;
26 #[allow(unused)]
27 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29 fn upgrade_client<S: Stream>(
30 params: Self::ClientParams,
31 stream: S,
32 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33 fn upgrade_server<S: Stream>(
34 params: TlsServerParameterProvider,
35 stream: S,
36 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44 type Stream = BaseStream;
45 type ClientParams = ();
46 type ServerParams = ();
47
48 fn init_client(
49 params: &TlsParameters,
50 name: Option<ServerName>,
51 ) -> Result<Self::ClientParams, SslError> {
52 Err(SslError::SslUnsupportedByClient)
53 }
54
55 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56 Err(SslError::SslUnsupportedByClient)
57 }
58
59 async fn upgrade_client<S: Stream>(
60 params: Self::ClientParams,
61 stream: S,
62 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63 Err(SslError::SslUnsupportedByClient)
64 }
65
66 async fn upgrade_server<S: Stream>(
67 params: TlsServerParameterProvider,
68 stream: S,
69 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70 Err(SslError::SslUnsupportedByClient)
71 }
72}
73
74#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93 Insecure,
96 IgnoreHostname,
98 #[default]
100 VerifyFull,
101}
102
103#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
104pub enum TlsCert {
105 #[default]
107 System,
108 #[debug("SystemPlus([{} cert(s)])", _0.len())]
111 SystemPlus(Vec<CertificateDer<'static>>),
112 Webpki,
114 #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
117 WebpkiPlus(Vec<CertificateDer<'static>>),
118 #[debug("Custom([{} cert(s)])", _0.len())]
120 Custom(Vec<CertificateDer<'static>>),
121}
122
123#[derive(Default, derive_more::Debug, PartialEq, Eq)]
124pub struct TlsParameters {
125 pub server_cert_verify: TlsServerCertVerify,
126 #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
127 pub cert: Option<CertificateDer<'static>>,
128 #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
129 pub key: Option<PrivateKeyDer<'static>>,
130 pub root_cert: TlsCert,
131 #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
132 pub crl: Vec<CertificateRevocationListDer<'static>>,
133 pub min_protocol_version: Option<SslVersion>,
134 pub max_protocol_version: Option<SslVersion>,
135 pub enable_keylog: bool,
136 pub sni_override: Option<Cow<'static, str>>,
137 pub alpn: TlsAlpn,
138}
139
140impl TlsParameters {
141 pub fn insecure() -> Self {
142 Self {
143 server_cert_verify: TlsServerCertVerify::Insecure,
144 ..Default::default()
145 }
146 }
147}
148
149#[derive(Copy, Clone, Debug, PartialEq, Eq)]
150pub enum SslVersion {
151 Tls1,
152 Tls1_1,
153 Tls1_2,
154 Tls1_3,
155}
156
157impl std::fmt::Display for SslVersion {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 let s = match self {
160 SslVersion::Tls1 => "TLSv1",
161 SslVersion::Tls1_1 => "TLSv1.1",
162 SslVersion::Tls1_2 => "TLSv1.2",
163 SslVersion::Tls1_3 => "TLSv1.3",
164 };
165 f.write_str(s)
166 }
167}
168
169#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
170pub struct SslVersionParseError(#[error(not(source))] pub String);
171
172#[cfg(feature = "serde")]
173impl serde::Serialize for SslVersion {
174 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175 where
176 S: serde::Serializer,
177 {
178 serializer.serialize_str(match self {
179 SslVersion::Tls1 => "TLSv1",
180 SslVersion::Tls1_1 => "TLSv1.1",
181 SslVersion::Tls1_2 => "TLSv1.2",
182 SslVersion::Tls1_3 => "TLSv1.3",
183 })
184 }
185}
186
187impl TryFrom<Cow<'_, str>> for SslVersion {
188 type Error = SslVersionParseError;
189 fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
190 Ok(match value.to_lowercase().as_ref() {
191 "tls_1" | "tlsv1" => SslVersion::Tls1,
192 "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
193 "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
194 "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
195 _ => return Err(SslVersionParseError(value.to_string())),
196 })
197 }
198}
199
200#[derive(Default, Debug, PartialEq, Eq)]
201pub enum TlsClientCertVerify {
202 #[default]
204 Ignore,
205 Optional(Vec<CertificateDer<'static>>),
207 Validate(Vec<CertificateDer<'static>>),
210}
211
212#[derive(derive_more::Debug, derive_more::Constructor)]
213pub struct TlsKey {
214 #[debug("key(...)")]
215 pub(crate) key: PrivateKeyDer<'static>,
216 #[debug("cert(...)")]
217 pub(crate) cert: CertificateDer<'static>,
218}
219
220#[derive(Debug, Clone)]
221pub struct TlsServerParameterProvider {
222 inner: TlsServerParameterProviderInner,
223}
224
225impl TlsServerParameterProvider {
226 pub fn new(params: TlsServerParameters) -> Self {
227 Self {
228 inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
229 }
230 }
231
232 pub fn with_lookup(
233 lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
234 ) -> Self {
235 Self {
236 inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
237 }
238 }
239
240 pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
241 match &self.inner {
242 TlsServerParameterProviderInner::Static(params) => params.clone(),
243 TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
244 }
245 }
246}
247
248#[derive(derive_more::Debug, Clone)]
249enum TlsServerParameterProviderInner {
250 Static(Arc<TlsServerParameters>),
251 #[debug("Lookup(...)")]
252 #[allow(clippy::type_complexity)]
253 Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
254}
255
256#[derive(Debug)]
257pub struct TlsServerParameters {
258 pub client_cert_verify: TlsClientCertVerify,
259 pub min_protocol_version: Option<SslVersion>,
260 pub max_protocol_version: Option<SslVersion>,
261 pub server_certificate: TlsKey,
262 pub alpn: TlsAlpn,
263}
264
265#[derive(Default, Eq, PartialEq)]
266pub struct TlsAlpn {
267 alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
269}
270
271impl std::fmt::Debug for TlsAlpn {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 if self.alpn_parts.is_empty() {
274 write!(f, "[]")
275 } else {
276 for (i, part) in self.alpn_parts.iter().enumerate() {
277 if i == 0 {
278 write!(f, "[")?;
279 } else {
280 write!(f, ", ")?;
281 }
282 let mut s = String::new();
284 s.push_str("b\"");
285 for &b in part.iter() {
286 for c in b.escape_ascii() {
287 s.push(c as char);
288 }
289 }
290 s.push('"');
291 write!(f, "{}", s)?;
292 }
293 write!(f, "]")?;
294 Ok(())
295 }
296 }
297}
298
299impl TlsAlpn {
300 pub fn new(alpn: &'static [&'static [u8]]) -> Self {
301 let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
302 Self {
303 alpn_parts: Cow::Owned(alpn),
304 }
305 }
306
307 pub fn new_str(alpn: &'static [&'static str]) -> Self {
308 let alpn = alpn
309 .iter()
310 .map(|s| Cow::Borrowed(s.as_bytes()))
311 .collect::<Vec<_>>();
312 Self {
313 alpn_parts: Cow::Owned(alpn),
314 }
315 }
316
317 pub fn is_empty(&self) -> bool {
318 self.alpn_parts.is_empty()
319 }
320
321 pub fn as_bytes(&self) -> Vec<u8> {
322 let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
323 for part in self.alpn_parts.iter() {
324 bytes.push(part.len() as u8);
325 bytes.extend_from_slice(part.as_ref());
326 }
327 bytes
328 }
329
330 pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
331 let mut vec = Vec::with_capacity(self.alpn_parts.len());
332 for part in self.alpn_parts.iter() {
333 vec.push(part.to_vec());
334 }
335 vec
336 }
337}
338
339#[derive(Debug, Clone, Default)]
340pub struct TlsHandshake {
341 pub alpn: Option<Cow<'static, [u8]>>,
342 pub sni: Option<Cow<'static, str>>,
343 pub cert: Option<CertificateDer<'static>>,
344}
345
346#[cfg(test)]
347mod tests {
348 use rustls_pki_types::PrivatePkcs1KeyDer;
349
350 use super::*;
351
352 #[test]
353 fn test_tls_parameters_debug() {
354 let params = TlsParameters::default();
355 assert_eq!(
356 format!("{:?}", params),
357 "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
358 root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
359 enable_keylog: false, sni_override: None, alpn: [] }"
360 );
361 let params = TlsParameters {
362 server_cert_verify: TlsServerCertVerify::Insecure,
363 cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
364 key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
365 1, 2, 3,
366 ]))),
367 root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
368 crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
369 min_protocol_version: None,
370 max_protocol_version: None,
371 enable_keylog: false,
372 sni_override: None,
373 alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
374 };
375 assert_eq!(
376 format!("{:?}", params),
377 "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
378 root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
379 max_protocol_version: None, enable_keylog: false, sni_override: None, \
380 alpn: [b\"h2\", b\"http/1.1\"] }"
381 );
382 }
383
384 #[test]
385 fn test_tls_alpn() {
386 let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
387 assert_eq!(
388 alpn.as_bytes(),
389 vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
390 );
391 assert_eq!(
392 alpn.as_vec_vec(),
393 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
394 );
395 assert!(!alpn.is_empty());
396 assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
397
398 let empty_alpn = TlsAlpn::default();
399 assert!(empty_alpn.is_empty());
400 assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
401 assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
402 assert_eq!(format!("{:?}", empty_alpn), "[]");
403 }
404
405 #[test]
406 fn test_tls_handshake() {
407 let handshake = TlsHandshake {
408 alpn: Some(Cow::Borrowed(b"h2")),
409 sni: Some(Cow::Borrowed("example.com")),
410 cert: None,
411 };
412 assert_eq!(handshake.alpn, Some(Cow::Borrowed(b"h2".as_slice())));
413 assert_eq!(handshake.sni, Some(Cow::Borrowed("example.com")));
414 assert_eq!(handshake.cert, None);
415
416 assert_eq!(
417 format!("{:?}", handshake),
418 "TlsHandshake { alpn: Some([104, 50]), sni: Some(\"example.com\"), cert: None }"
419 );
420
421 let default_handshake = TlsHandshake::default();
422 assert_eq!(default_handshake.alpn, None);
423 assert_eq!(default_handshake.sni, None);
424 assert_eq!(default_handshake.cert, None);
425 }
426}