gel_stream/common/
tls.rs

1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7// Note that we choose rustls when both openssl and rustls are enabled.
8
9/// The default TLS driver.
10#[cfg(all(feature = "openssl", not(feature = "rustls")))]
11pub type Ssl = crate::common::openssl::OpensslDriver;
12#[cfg(feature = "rustls")]
13pub type Ssl = crate::common::rustls::RustlsDriver;
14#[cfg(not(any(feature = "openssl", feature = "rustls")))]
15pub type Ssl = NullTlsDriver;
16
17/// A trait for TLS drivers.
18#[doc(hidden)]
19pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
20    type Stream: Stream + Send;
21    type ClientParams: Unpin + Send;
22    type ServerParams: Unpin + Send;
23    const DRIVER_NAME: &'static str;
24
25    #[allow(unused)]
26    fn init_client(
27        params: &TlsParameters,
28        name: Option<ServerName>,
29    ) -> Result<Self::ClientParams, SslError>;
30    #[allow(unused)]
31    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
32
33    fn upgrade_client<S: Stream>(
34        params: Self::ClientParams,
35        stream: S,
36    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37    fn upgrade_server<S: Stream>(
38        params: TlsServerParameterProvider,
39        stream: S,
40    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
41    fn unclean_shutdown(this: Self::Stream) -> Result<(), Self::Stream>;
42
43    fn is<D: TlsDriver>() -> bool {
44        D::DRIVER_NAME == Self::DRIVER_NAME
45    }
46}
47
48/// A TLS driver that fails when TLS is requested.
49#[derive(Default)]
50pub struct NullTlsDriver;
51
52#[allow(unused)]
53impl TlsDriver for NullTlsDriver {
54    type Stream = BaseStream;
55    type ClientParams = ();
56    type ServerParams = ();
57    const DRIVER_NAME: &'static str = "null";
58
59    fn init_client(
60        params: &TlsParameters,
61        name: Option<ServerName>,
62    ) -> Result<Self::ClientParams, SslError> {
63        Err(SslError::SslUnsupportedByClient)
64    }
65
66    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
67        Err(SslError::SslUnsupportedByClient)
68    }
69
70    async fn upgrade_client<S: Stream>(
71        params: Self::ClientParams,
72        stream: S,
73    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
74        Err(SslError::SslUnsupportedByClient)
75    }
76
77    async fn upgrade_server<S: Stream>(
78        params: TlsServerParameterProvider,
79        stream: S,
80    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
81        Err(SslError::SslUnsupportedByClient)
82    }
83
84    fn unclean_shutdown(_this: Self::Stream) -> Result<(), Self::Stream> {
85        // Do nothing
86        Ok(())
87    }
88}
89
90/// Verification modes for TLS that are a superset of both PostgreSQL and EdgeDB/Gel.
91///
92/// Postgres offers six levels: `disable`, `allow`, `prefer`, `require`, `verify-ca` and `verify-full`.
93///
94/// EdgeDB/Gel offers three levels: `insecure`, `no_host_verification' and 'strict'.
95///
96/// This table maps the various levels:
97///
98/// | Postgres | EdgeDB/Gel | `TlsServerCertVerify` enum |
99/// | -------- | ----------- | ----------------- |
100/// | require  | insecure    | `Insecure`        |
101/// | verify-ca | no_host_verification | `IgnoreHostname`        |
102/// | verify-full | strict | `VerifyFull`      |
103///
104/// Note that both EdgeDB/Gel and Postgres may alter certificate validation levels
105/// when custom root certificates are provided. This must be done in the
106/// `TlsParameters` struct by the caller.
107#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
108pub enum TlsServerCertVerify {
109    /// Do not verify the server's certificate. Only confirm that the server is
110    /// using TLS.
111    Insecure,
112    /// Verify the server's certificate using the CA (ignore hostname).
113    IgnoreHostname,
114    /// Verify the server's certificate using the CA and hostname.
115    #[default]
116    VerifyFull,
117}
118
119#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
120pub enum TlsCert {
121    /// Use the system's default certificate.
122    #[default]
123    System,
124    /// Use the system's default certificate and a set of custom root
125    /// certificates.
126    #[debug("SystemPlus([{} cert(s)])", _0.len())]
127    SystemPlus(Vec<CertificateDer<'static>>),
128    /// Use the webpki-roots default certificate.
129    Webpki,
130    /// Use the webpki-roots default certificate and a set of custom root
131    /// certificates.
132    #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
133    WebpkiPlus(Vec<CertificateDer<'static>>),
134    /// Use a custom root certificate only.
135    #[debug("Custom([{} cert(s)])", _0.len())]
136    Custom(Vec<CertificateDer<'static>>),
137}
138
139#[derive(Default, derive_more::Debug, PartialEq, Eq)]
140pub struct TlsParameters {
141    pub server_cert_verify: TlsServerCertVerify,
142    #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
143    pub cert: Option<CertificateDer<'static>>,
144    #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
145    pub key: Option<PrivateKeyDer<'static>>,
146    pub root_cert: TlsCert,
147    #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
148    pub crl: Vec<CertificateRevocationListDer<'static>>,
149    pub min_protocol_version: Option<SslVersion>,
150    pub max_protocol_version: Option<SslVersion>,
151    pub enable_keylog: bool,
152    pub sni_override: Option<Cow<'static, str>>,
153    pub alpn: TlsAlpn,
154}
155
156impl TlsParameters {
157    pub fn insecure() -> Self {
158        Self {
159            server_cert_verify: TlsServerCertVerify::Insecure,
160            ..Default::default()
161        }
162    }
163}
164
165#[derive(Copy, Clone, Debug, PartialEq, Eq)]
166pub enum SslVersion {
167    Tls1,
168    Tls1_1,
169    Tls1_2,
170    Tls1_3,
171}
172
173impl std::fmt::Display for SslVersion {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        let s = match self {
176            SslVersion::Tls1 => "TLSv1",
177            SslVersion::Tls1_1 => "TLSv1.1",
178            SslVersion::Tls1_2 => "TLSv1.2",
179            SslVersion::Tls1_3 => "TLSv1.3",
180        };
181        f.write_str(s)
182    }
183}
184
185#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
186pub struct SslVersionParseError(#[error(not(source))] pub String);
187
188#[cfg(feature = "serde")]
189impl serde::Serialize for SslVersion {
190    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
191    where
192        S: serde::Serializer,
193    {
194        serializer.serialize_str(match self {
195            SslVersion::Tls1 => "TLSv1",
196            SslVersion::Tls1_1 => "TLSv1.1",
197            SslVersion::Tls1_2 => "TLSv1.2",
198            SslVersion::Tls1_3 => "TLSv1.3",
199        })
200    }
201}
202
203impl TryFrom<Cow<'_, str>> for SslVersion {
204    type Error = SslVersionParseError;
205    fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
206        Ok(match value.to_lowercase().as_ref() {
207            "tls_1" | "tlsv1" => SslVersion::Tls1,
208            "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
209            "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
210            "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
211            _ => return Err(SslVersionParseError(value.to_string())),
212        })
213    }
214}
215
216#[derive(Default, Debug, PartialEq, Eq)]
217pub enum TlsClientCertVerify {
218    /// Do not verify the client's certificate, just ignore it.
219    #[default]
220    Ignore,
221    /// If a client certificate is provided, validate it.
222    Optional(Vec<CertificateDer<'static>>),
223    /// Validate that a client certificate exists and is valid. This configuration
224    /// may not be ideal, because it does not fail the client-side handshake.
225    Validate(Vec<CertificateDer<'static>>),
226}
227
228#[derive(derive_more::Debug, derive_more::Constructor)]
229pub struct TlsKey {
230    #[debug("key(...)")]
231    pub(crate) key: PrivateKeyDer<'static>,
232    #[debug("cert(...)")]
233    pub(crate) cert: CertificateDer<'static>,
234}
235
236impl TlsKey {
237    /// Create a new TlsKey from a PEM-encoded certificate and key.
238    #[cfg(feature = "pem")]
239    pub fn new_pem(mut key: &[u8], mut cert: &[u8]) -> Result<Self, std::io::Error> {
240        let cert = rustls_pemfile::certs(&mut cert)
241            .next()
242            .ok_or(std::io::Error::new(
243                std::io::ErrorKind::InvalidData,
244                "No certificate found",
245            ))??;
246        let key = rustls_pemfile::private_key(&mut key)?.ok_or(std::io::Error::new(
247            std::io::ErrorKind::InvalidData,
248            "No key found",
249        ))?;
250        Ok(Self { cert, key })
251    }
252}
253
254#[derive(Debug, Clone)]
255pub struct TlsServerParameterProvider {
256    inner: TlsServerParameterProviderInner,
257}
258
259impl TlsServerParameterProvider {
260    pub fn new(params: TlsServerParameters) -> Self {
261        Self {
262            inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
263        }
264    }
265
266    pub fn with_lookup(
267        lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
268    ) -> Self {
269        Self {
270            inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
271        }
272    }
273
274    pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
275        match &self.inner {
276            TlsServerParameterProviderInner::Static(params) => params.clone(),
277            TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
278        }
279    }
280}
281
282#[derive(derive_more::Debug, Clone)]
283enum TlsServerParameterProviderInner {
284    Static(Arc<TlsServerParameters>),
285    #[debug("Lookup(...)")]
286    #[allow(clippy::type_complexity)]
287    Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
288}
289
290#[derive(Debug)]
291pub struct TlsServerParameters {
292    pub client_cert_verify: TlsClientCertVerify,
293    pub min_protocol_version: Option<SslVersion>,
294    pub max_protocol_version: Option<SslVersion>,
295    pub server_certificate: TlsKey,
296    pub alpn: TlsAlpn,
297}
298
299impl TlsServerParameters {
300    pub fn new_with_certificate(server_certificate: TlsKey) -> Self {
301        Self {
302            client_cert_verify: TlsClientCertVerify::default(),
303            min_protocol_version: None,
304            max_protocol_version: None,
305            server_certificate,
306            alpn: TlsAlpn::default(),
307        }
308    }
309}
310
311#[derive(Default, Eq, PartialEq)]
312pub struct TlsAlpn {
313    /// The split form (ie: ["AB", "ABCD"])
314    alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
315}
316
317impl std::fmt::Debug for TlsAlpn {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        if self.alpn_parts.is_empty() {
320            write!(f, "[]")
321        } else {
322            for (i, part) in self.alpn_parts.iter().enumerate() {
323                if i == 0 {
324                    write!(f, "[")?;
325                } else {
326                    write!(f, ", ")?;
327                }
328                // Print as binary literal with appropriate escaping
329                let mut s = String::new();
330                s.push_str("b\"");
331                for &b in part.iter() {
332                    for c in b.escape_ascii() {
333                        s.push(c as char);
334                    }
335                }
336                s.push('"');
337                write!(f, "{}", s)?;
338            }
339            write!(f, "]")?;
340            Ok(())
341        }
342    }
343}
344
345impl TlsAlpn {
346    pub fn new(alpn: &'static [&'static [u8]]) -> Self {
347        let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
348        Self {
349            alpn_parts: Cow::Owned(alpn),
350        }
351    }
352
353    pub fn new_str(alpn: &'static [&'static str]) -> Self {
354        let alpn = alpn
355            .iter()
356            .map(|s| Cow::Borrowed(s.as_bytes()))
357            .collect::<Vec<_>>();
358        Self {
359            alpn_parts: Cow::Owned(alpn),
360        }
361    }
362
363    pub fn is_empty(&self) -> bool {
364        self.alpn_parts.is_empty()
365    }
366
367    pub fn as_bytes(&self) -> Vec<u8> {
368        let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
369        for part in self.alpn_parts.iter() {
370            bytes.push(part.len() as u8);
371            bytes.extend_from_slice(part.as_ref());
372        }
373        bytes
374    }
375
376    pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
377        let mut vec = Vec::with_capacity(self.alpn_parts.len());
378        for part in self.alpn_parts.iter() {
379            vec.push(part.to_vec());
380        }
381        vec
382    }
383}
384
385/// Negotiated TLS handshake information.
386#[derive(Debug, Clone, Default)]
387pub struct TlsHandshake {
388    /// The negotiated ALPN protocol.
389    pub alpn: Option<Cow<'static, [u8]>>,
390    /// The SNI hostname if provided.
391    pub sni: Option<Cow<'static, str>>,
392    /// The client certificate, if any.
393    pub cert: Option<CertificateDer<'static>>,
394    /// The negotiated TLS version.
395    pub version: Option<SslVersion>,
396}
397
398#[cfg(test)]
399mod tests {
400    use rustls_pki_types::PrivatePkcs1KeyDer;
401
402    use super::*;
403
404    #[test]
405    fn test_tls_parameters_debug() {
406        let params = TlsParameters::default();
407        assert_eq!(
408            format!("{:?}", params),
409            "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
410            root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
411            enable_keylog: false, sni_override: None, alpn: [] }"
412        );
413        let params = TlsParameters {
414            server_cert_verify: TlsServerCertVerify::Insecure,
415            cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
416            key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
417                1, 2, 3,
418            ]))),
419            root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
420            crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
421            min_protocol_version: None,
422            max_protocol_version: None,
423            enable_keylog: false,
424            sni_override: None,
425            alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
426        };
427        assert_eq!(
428            format!("{:?}", params),
429            "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
430            root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
431            max_protocol_version: None, enable_keylog: false, sni_override: None, \
432            alpn: [b\"h2\", b\"http/1.1\"] }"
433        );
434    }
435
436    #[test]
437    fn test_tls_alpn() {
438        let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
439        assert_eq!(
440            alpn.as_bytes(),
441            vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
442        );
443        assert_eq!(
444            alpn.as_vec_vec(),
445            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
446        );
447        assert!(!alpn.is_empty());
448        assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
449
450        let empty_alpn = TlsAlpn::default();
451        assert!(empty_alpn.is_empty());
452        assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
453        assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
454        assert_eq!(format!("{:?}", empty_alpn), "[]");
455    }
456}