gel_stream/common/
tls.rs

1use crate::{SslError, Stream, StreamMetadata};
2use rustls_pki_types::{
3    CertificateDer, CertificateRevocationListDer, DnsName, PrivateKeyDer, ServerName,
4};
5use std::{borrow::Cow, future::Future, sync::Arc};
6
7use super::BaseStream;
8
9// Note that we choose rustls when both openssl and rustls are enabled.
10
11/// The default TLS driver.
12#[cfg(all(feature = "openssl", not(feature = "rustls")))]
13pub type Ssl = crate::common::openssl::OpensslDriver;
14#[cfg(feature = "rustls")]
15pub type Ssl = crate::common::rustls::RustlsDriver;
16#[cfg(not(any(feature = "openssl", feature = "rustls")))]
17pub type Ssl = NullTlsDriver;
18
19/// A trait for TLS drivers.
20#[doc(hidden)]
21pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
22    type Stream: Stream + Send;
23    type ClientParams: Unpin + Send;
24    type ServerParams: Unpin + Send;
25    const DRIVER_NAME: &'static str;
26
27    #[allow(unused)]
28    fn init_client(
29        params: &TlsParameters,
30        name: Option<ServerName>,
31    ) -> Result<Self::ClientParams, SslError>;
32    #[allow(unused)]
33    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
34
35    fn upgrade_client<S: Stream>(
36        params: Self::ClientParams,
37        stream: S,
38    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
39    fn upgrade_server<S: Stream>(
40        params: TlsServerParameterProvider,
41        stream: S,
42    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
43    fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream>;
44
45    fn is<D: TlsDriver>() -> bool {
46        D::DRIVER_NAME == Self::DRIVER_NAME
47    }
48}
49
50/// A TLS driver that fails when TLS is requested.
51#[derive(Default)]
52pub struct NullTlsDriver;
53
54#[allow(unused)]
55impl TlsDriver for NullTlsDriver {
56    type Stream = BaseStream;
57    type ClientParams = ();
58    type ServerParams = ();
59    const DRIVER_NAME: &'static str = "null";
60
61    fn init_client(
62        params: &TlsParameters,
63        name: Option<ServerName>,
64    ) -> Result<Self::ClientParams, SslError> {
65        Err(SslError::SslUnsupported)
66    }
67
68    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
69        Err(SslError::SslUnsupported)
70    }
71
72    async fn upgrade_client<S: Stream>(
73        params: Self::ClientParams,
74        stream: S,
75    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
76        Err(SslError::SslUnsupported)
77    }
78
79    async fn upgrade_server<S: Stream>(
80        params: TlsServerParameterProvider,
81        stream: S,
82    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
83        Err(SslError::SslUnsupported)
84    }
85
86    fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
87        // Do nothing
88        Ok(())
89    }
90}
91
92/// Verification modes for TLS that are a superset of both PostgreSQL and EdgeDB/Gel.
93///
94/// Postgres offers six levels: `disable`, `allow`, `prefer`, `require`, `verify-ca` and `verify-full`.
95///
96/// EdgeDB/Gel offers three levels: `insecure`, `no_host_verification' and 'strict'.
97///
98/// This table maps the various levels:
99///
100/// | Postgres | EdgeDB/Gel | `TlsServerCertVerify` enum |
101/// | -------- | ----------- | ----------------- |
102/// | require  | insecure    | `Insecure`        |
103/// | verify-ca | no_host_verification | `IgnoreHostname`        |
104/// | verify-full | strict | `VerifyFull`      |
105///
106/// Note that both EdgeDB/Gel and Postgres may alter certificate validation levels
107/// when custom root certificates are provided. This must be done in the
108/// `TlsParameters` struct by the caller.
109#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
110pub enum TlsServerCertVerify {
111    /// Do not verify the server's certificate. Only confirm that the server is
112    /// using TLS.
113    Insecure,
114    /// Verify the server's certificate using the CA (ignore hostname).
115    IgnoreHostname,
116    /// Verify the server's certificate using the CA and hostname.
117    #[default]
118    VerifyFull,
119}
120
121#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
122pub enum TlsCert {
123    /// Use the system's default certificate.
124    #[default]
125    System,
126    /// Use the system's default certificate and a set of custom root
127    /// certificates.
128    #[debug("SystemPlus([{} cert(s)])", _0.len())]
129    SystemPlus(Vec<CertificateDer<'static>>),
130    /// Use the webpki-roots default certificate.
131    Webpki,
132    /// Use the webpki-roots default certificate and a set of custom root
133    /// certificates.
134    #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
135    WebpkiPlus(Vec<CertificateDer<'static>>),
136    /// Use a custom root certificate only.
137    #[debug("Custom([{} cert(s)])", _0.len())]
138    Custom(Vec<CertificateDer<'static>>),
139}
140
141#[derive(Default, derive_more::Debug, PartialEq, Eq)]
142pub struct TlsParameters {
143    pub server_cert_verify: TlsServerCertVerify,
144    #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
145    pub cert: Option<CertificateDer<'static>>,
146    #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
147    pub key: Option<PrivateKeyDer<'static>>,
148    pub root_cert: TlsCert,
149    #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
150    pub crl: Vec<CertificateRevocationListDer<'static>>,
151    pub min_protocol_version: Option<SslVersion>,
152    pub max_protocol_version: Option<SslVersion>,
153    pub enable_keylog: bool,
154    pub sni_override: Option<Cow<'static, str>>,
155    pub alpn: TlsAlpn,
156}
157
158impl TlsParameters {
159    pub fn insecure() -> Self {
160        Self {
161            server_cert_verify: TlsServerCertVerify::Insecure,
162            ..Default::default()
163        }
164    }
165}
166
167#[derive(Copy, Clone, Debug, PartialEq, Eq)]
168pub enum SslVersion {
169    Tls1,
170    Tls1_1,
171    Tls1_2,
172    Tls1_3,
173}
174
175impl std::fmt::Display for SslVersion {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        let s = match self {
178            SslVersion::Tls1 => "TLSv1",
179            SslVersion::Tls1_1 => "TLSv1.1",
180            SslVersion::Tls1_2 => "TLSv1.2",
181            SslVersion::Tls1_3 => "TLSv1.3",
182        };
183        f.write_str(s)
184    }
185}
186
187#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
188pub struct SslVersionParseError(#[error(not(source))] pub String);
189
190#[cfg(feature = "serde")]
191impl serde::Serialize for SslVersion {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: serde::Serializer,
195    {
196        serializer.serialize_str(match self {
197            SslVersion::Tls1 => "TLSv1",
198            SslVersion::Tls1_1 => "TLSv1.1",
199            SslVersion::Tls1_2 => "TLSv1.2",
200            SslVersion::Tls1_3 => "TLSv1.3",
201        })
202    }
203}
204
205impl TryFrom<Cow<'_, str>> for SslVersion {
206    type Error = SslVersionParseError;
207    fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
208        Ok(match value.to_lowercase().as_ref() {
209            "tls_1" | "tlsv1" => SslVersion::Tls1,
210            "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
211            "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
212            "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
213            _ => return Err(SslVersionParseError(value.to_string())),
214        })
215    }
216}
217
218#[derive(Default, Debug, PartialEq, Eq)]
219pub enum TlsClientCertVerify {
220    /// Do not verify the client's certificate, just ignore it.
221    #[default]
222    Ignore,
223    /// If a client certificate is provided, validate it.
224    Optional(Vec<CertificateDer<'static>>),
225    /// Validate that a client certificate exists and is valid. This configuration
226    /// may not be ideal, because it does not fail the client-side handshake.
227    Validate(Vec<CertificateDer<'static>>),
228}
229
230#[derive(derive_more::Debug, derive_more::Constructor)]
231pub struct TlsKey {
232    #[debug("key(...)")]
233    pub(crate) key: PrivateKeyDer<'static>,
234    #[debug("cert(...)")]
235    pub(crate) cert: CertificateDer<'static>,
236}
237
238impl TlsKey {
239    /// Create a new `TlsKey` from a PEM-encoded certificate and key.
240    #[cfg(feature = "pem")]
241    pub fn new_pem(mut key: &[u8], mut cert: &[u8]) -> Result<Self, std::io::Error> {
242        let cert = rustls_pemfile::certs(&mut cert)
243            .next()
244            .ok_or(std::io::Error::new(
245                std::io::ErrorKind::InvalidData,
246                "No certificate found",
247            ))??;
248        let key = rustls_pemfile::private_key(&mut key)?.ok_or(std::io::Error::new(
249            std::io::ErrorKind::InvalidData,
250            "No key found",
251        ))?;
252        Ok(Self { cert, key })
253    }
254
255    /// Create a clone of this private key and certificate.
256    pub fn clone_key(&self) -> Self {
257        Self {
258            key: self.key.clone_key(),
259            cert: self.cert.clone(),
260        }
261    }
262}
263
264#[derive(Debug, Clone)]
265pub struct TlsServerParameterProvider {
266    inner: TlsServerParameterProviderInner,
267}
268
269impl TlsServerParameterProvider {
270    pub fn new(params: TlsServerParameters) -> Self {
271        Self {
272            inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
273        }
274    }
275
276    pub fn with_lookup(
277        lookup: impl Fn(Option<DnsName>, &dyn StreamMetadata) -> Arc<TlsServerParameters>
278            + Send
279            + Sync
280            + 'static,
281    ) -> Self {
282        Self {
283            inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
284        }
285    }
286
287    pub fn lookup(
288        &self,
289        name: Option<DnsName>,
290        stream: &dyn StreamMetadata,
291    ) -> Arc<TlsServerParameters> {
292        match &self.inner {
293            TlsServerParameterProviderInner::Static(params) => params.clone(),
294            TlsServerParameterProviderInner::Lookup(lookup) => lookup(name, stream),
295        }
296    }
297}
298
299/// A function that looks up TLS server parameters based on the server name and/or
300/// stream metadata.
301pub type TlsServerParameterLookupFn = dyn Fn(Option<DnsName>, &dyn StreamMetadata) -> Arc<TlsServerParameters>
302    + Send
303    + Sync
304    + 'static;
305
306#[derive(derive_more::Debug, Clone)]
307enum TlsServerParameterProviderInner {
308    Static(Arc<TlsServerParameters>),
309    #[debug("Lookup(...)")]
310    #[allow(clippy::type_complexity)]
311    Lookup(Arc<TlsServerParameterLookupFn>),
312}
313
314#[derive(Debug)]
315pub struct TlsServerParameters {
316    pub client_cert_verify: TlsClientCertVerify,
317    pub min_protocol_version: Option<SslVersion>,
318    pub max_protocol_version: Option<SslVersion>,
319    pub server_certificate: TlsKey,
320    pub alpn: TlsAlpn,
321}
322
323impl TlsServerParameters {
324    pub fn new_with_certificate(server_certificate: TlsKey) -> Self {
325        Self {
326            client_cert_verify: TlsClientCertVerify::default(),
327            min_protocol_version: None,
328            max_protocol_version: None,
329            server_certificate,
330            alpn: TlsAlpn::default(),
331        }
332    }
333}
334
335#[derive(Default, Eq, PartialEq)]
336pub struct TlsAlpn {
337    /// The split form (ie: ["AB", "ABCD"])
338    alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
339}
340
341impl std::fmt::Debug for TlsAlpn {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        if self.alpn_parts.is_empty() {
344            write!(f, "[]")
345        } else {
346            for (i, part) in self.alpn_parts.iter().enumerate() {
347                if i == 0 {
348                    write!(f, "[")?;
349                } else {
350                    write!(f, ", ")?;
351                }
352                // Print as binary literal with appropriate escaping
353                let mut s = String::new();
354                s.push_str("b\"");
355                for &b in part.iter() {
356                    for c in b.escape_ascii() {
357                        s.push(c as char);
358                    }
359                }
360                s.push('"');
361                write!(f, "{s}")?;
362            }
363            write!(f, "]")?;
364            Ok(())
365        }
366    }
367}
368
369impl TlsAlpn {
370    pub fn new(alpn: &'static [&'static [u8]]) -> Self {
371        let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
372        Self {
373            alpn_parts: Cow::Owned(alpn),
374        }
375    }
376
377    pub fn new_str(alpn: &[&'static str]) -> Self {
378        let alpn = alpn
379            .iter()
380            .map(|s| Cow::Borrowed(s.as_bytes()))
381            .collect::<Vec<_>>();
382        Self {
383            alpn_parts: Cow::Owned(alpn),
384        }
385    }
386
387    pub fn is_empty(&self) -> bool {
388        self.alpn_parts.is_empty()
389    }
390
391    pub fn as_bytes(&self) -> Vec<u8> {
392        let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
393        for part in self.alpn_parts.iter() {
394            bytes.push(part.len() as u8);
395            bytes.extend_from_slice(part.as_ref());
396        }
397        bytes
398    }
399
400    pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
401        let mut vec = Vec::with_capacity(self.alpn_parts.len());
402        for part in self.alpn_parts.iter() {
403            vec.push(part.to_vec());
404        }
405        vec
406    }
407}
408
409impl<T, U> From<T> for TlsAlpn
410where
411    U: AsRef<[u8]>,
412    T: IntoIterator<Item = U>,
413{
414    fn from(alpn: T) -> Self {
415        Self {
416            alpn_parts: Cow::Owned(
417                alpn.into_iter()
418                    .map(|s| Cow::Owned(s.as_ref().to_vec()))
419                    .collect(),
420            ),
421        }
422    }
423}
424
425/// Negotiated TLS handshake information.
426#[derive(Debug, Clone, Default)]
427pub struct TlsHandshake {
428    /// The negotiated ALPN protocol.
429    pub alpn: Option<Cow<'static, [u8]>>,
430    /// The SNI hostname if provided.
431    pub sni: Option<DnsName<'static>>,
432    /// The client certificate, if any.
433    pub cert: Option<CertificateDer<'static>>,
434    /// The negotiated TLS version.
435    pub version: Option<SslVersion>,
436}
437
438#[cfg(test)]
439mod tests {
440    use rustls_pki_types::PrivatePkcs1KeyDer;
441
442    use super::*;
443
444    #[test]
445    fn test_tls_parameters_debug() {
446        let params = TlsParameters::default();
447        assert_eq!(
448            format!("{params:?}"),
449            "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
450            root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
451            enable_keylog: false, sni_override: None, alpn: [] }"
452        );
453        let params = TlsParameters {
454            server_cert_verify: TlsServerCertVerify::Insecure,
455            cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
456            key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
457                1, 2, 3,
458            ]))),
459            root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
460            crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
461            min_protocol_version: None,
462            max_protocol_version: None,
463            enable_keylog: false,
464            sni_override: None,
465            alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
466        };
467        assert_eq!(
468            format!("{params:?}"),
469            "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
470            root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
471            max_protocol_version: None, enable_keylog: false, sni_override: None, \
472            alpn: [b\"h2\", b\"http/1.1\"] }"
473        );
474    }
475
476    #[test]
477    fn test_tls_alpn() {
478        let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
479        assert_eq!(
480            alpn.as_bytes(),
481            vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
482        );
483        assert_eq!(
484            alpn.as_vec_vec(),
485            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
486        );
487        assert!(!alpn.is_empty());
488        assert_eq!(format!("{alpn:?}"), "[b\"h2\", b\"http/1.1\"]");
489
490        let empty_alpn = TlsAlpn::default();
491        assert!(empty_alpn.is_empty());
492        assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
493        assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
494        assert_eq!(format!("{empty_alpn:?}"), "[]");
495    }
496}