gel_stream/common/
tls.rs

1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7// Note that we choose rustls when both openssl and rustls are enabled.
8
9#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17    type Stream: Stream + Send;
18    type ClientParams: Unpin + Send;
19    type ServerParams: Unpin + Send;
20
21    #[allow(unused)]
22    fn init_client(
23        params: &TlsParameters,
24        name: Option<ServerName>,
25    ) -> Result<Self::ClientParams, SslError>;
26    #[allow(unused)]
27    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29    fn upgrade_client<S: Stream>(
30        params: Self::ClientParams,
31        stream: S,
32    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33    fn upgrade_server<S: Stream>(
34        params: TlsServerParameterProvider,
35        stream: S,
36    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44    type Stream = BaseStream;
45    type ClientParams = ();
46    type ServerParams = ();
47
48    fn init_client(
49        params: &TlsParameters,
50        name: Option<ServerName>,
51    ) -> Result<Self::ClientParams, SslError> {
52        Err(SslError::SslUnsupportedByClient)
53    }
54
55    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56        Err(SslError::SslUnsupportedByClient)
57    }
58
59    async fn upgrade_client<S: Stream>(
60        params: Self::ClientParams,
61        stream: S,
62    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63        Err(SslError::SslUnsupportedByClient)
64    }
65
66    async fn upgrade_server<S: Stream>(
67        params: TlsServerParameterProvider,
68        stream: S,
69    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70        Err(SslError::SslUnsupportedByClient)
71    }
72}
73
74/// Verification modes for TLS that are a superset of both PostgreSQL and EdgeDB/Gel.
75///
76/// Postgres offers six levels: `disable`, `allow`, `prefer`, `require`, `verify-ca` and `verify-full`.
77///
78/// EdgeDB/Gel offers three levels: `insecure`, `no_host_verification' and 'strict'.
79///
80/// This table maps the various levels:
81///
82/// | Postgres | EdgeDB/Gel | `TlsServerCertVerify` enum |
83/// | -------- | ----------- | ----------------- |
84/// | require  | insecure    | `Insecure`        |
85/// | verify-ca | no_host_verification | `IgnoreHostname`        |
86/// | verify-full | strict | `VerifyFull`      |
87///
88/// Note that both EdgeDB/Gel and Postgres may alter certificate validation levels
89/// when custom root certificates are provided. This must be done in the
90/// `TlsParameters` struct by the caller.
91#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93    /// Do not verify the server's certificate. Only confirm that the server is
94    /// using TLS.
95    Insecure,
96    /// Verify the server's certificate using the CA (ignore hostname).
97    IgnoreHostname,
98    /// Verify the server's certificate using the CA and hostname.
99    #[default]
100    VerifyFull,
101}
102
103#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
104pub enum TlsCert {
105    /// Use the system's default certificate.
106    #[default]
107    System,
108    /// Use the system's default certificate and a set of custom root
109    /// certificates.
110    #[debug("SystemPlus([{} cert(s)])", _0.len())]
111    SystemPlus(Vec<CertificateDer<'static>>),
112    /// Use the webpki-roots default certificate.
113    Webpki,
114    /// Use the webpki-roots default certificate and a set of custom root
115    /// certificates.
116    #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
117    WebpkiPlus(Vec<CertificateDer<'static>>),
118    /// Use a custom root certificate only.
119    #[debug("Custom([{} cert(s)])", _0.len())]
120    Custom(Vec<CertificateDer<'static>>),
121}
122
123#[derive(Default, derive_more::Debug, PartialEq, Eq)]
124pub struct TlsParameters {
125    pub server_cert_verify: TlsServerCertVerify,
126    #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
127    pub cert: Option<CertificateDer<'static>>,
128    #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
129    pub key: Option<PrivateKeyDer<'static>>,
130    pub root_cert: TlsCert,
131    #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
132    pub crl: Vec<CertificateRevocationListDer<'static>>,
133    pub min_protocol_version: Option<SslVersion>,
134    pub max_protocol_version: Option<SslVersion>,
135    pub enable_keylog: bool,
136    pub sni_override: Option<Cow<'static, str>>,
137    pub alpn: TlsAlpn,
138}
139
140impl TlsParameters {
141    pub fn insecure() -> Self {
142        Self {
143            server_cert_verify: TlsServerCertVerify::Insecure,
144            ..Default::default()
145        }
146    }
147}
148
149#[derive(Copy, Clone, Debug, PartialEq, Eq)]
150pub enum SslVersion {
151    Tls1,
152    Tls1_1,
153    Tls1_2,
154    Tls1_3,
155}
156
157impl std::fmt::Display for SslVersion {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        let s = match self {
160            SslVersion::Tls1 => "TLSv1",
161            SslVersion::Tls1_1 => "TLSv1.1",
162            SslVersion::Tls1_2 => "TLSv1.2",
163            SslVersion::Tls1_3 => "TLSv1.3",
164        };
165        f.write_str(s)
166    }
167}
168
169#[derive(Debug, Clone, derive_more::Error, derive_more::Display, Eq, PartialEq)]
170pub struct SslVersionParseError(#[error(not(source))] pub String);
171
172#[cfg(feature = "serde")]
173impl serde::Serialize for SslVersion {
174    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175    where
176        S: serde::Serializer,
177    {
178        serializer.serialize_str(match self {
179            SslVersion::Tls1 => "TLSv1",
180            SslVersion::Tls1_1 => "TLSv1.1",
181            SslVersion::Tls1_2 => "TLSv1.2",
182            SslVersion::Tls1_3 => "TLSv1.3",
183        })
184    }
185}
186
187impl TryFrom<Cow<'_, str>> for SslVersion {
188    type Error = SslVersionParseError;
189    fn try_from(value: Cow<str>) -> Result<SslVersion, Self::Error> {
190        Ok(match value.to_lowercase().as_ref() {
191            "tls_1" | "tlsv1" => SslVersion::Tls1,
192            "tls_1.1" | "tlsv1.1" => SslVersion::Tls1_1,
193            "tls_1.2" | "tlsv1.2" => SslVersion::Tls1_2,
194            "tls_1.3" | "tlsv1.3" => SslVersion::Tls1_3,
195            _ => return Err(SslVersionParseError(value.to_string())),
196        })
197    }
198}
199
200#[derive(Default, Debug, PartialEq, Eq)]
201pub enum TlsClientCertVerify {
202    /// Do not verify the client's certificate, just ignore it.
203    #[default]
204    Ignore,
205    /// If a client certificate is provided, validate it.
206    Optional(Vec<CertificateDer<'static>>),
207    /// Validate that a client certificate exists and is valid. This configuration
208    /// may not be ideal, because it does not fail the client-side handshake.
209    Validate(Vec<CertificateDer<'static>>),
210}
211
212#[derive(derive_more::Debug, derive_more::Constructor)]
213pub struct TlsKey {
214    #[debug("key(...)")]
215    pub(crate) key: PrivateKeyDer<'static>,
216    #[debug("cert(...)")]
217    pub(crate) cert: CertificateDer<'static>,
218}
219
220#[derive(Debug, Clone)]
221pub struct TlsServerParameterProvider {
222    inner: TlsServerParameterProviderInner,
223}
224
225impl TlsServerParameterProvider {
226    pub fn new(params: TlsServerParameters) -> Self {
227        Self {
228            inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
229        }
230    }
231
232    pub fn with_lookup(
233        lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
234    ) -> Self {
235        Self {
236            inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
237        }
238    }
239
240    pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
241        match &self.inner {
242            TlsServerParameterProviderInner::Static(params) => params.clone(),
243            TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
244        }
245    }
246}
247
248#[derive(derive_more::Debug, Clone)]
249enum TlsServerParameterProviderInner {
250    Static(Arc<TlsServerParameters>),
251    #[debug("Lookup(...)")]
252    #[allow(clippy::type_complexity)]
253    Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
254}
255
256#[derive(Debug)]
257pub struct TlsServerParameters {
258    pub client_cert_verify: TlsClientCertVerify,
259    pub min_protocol_version: Option<SslVersion>,
260    pub max_protocol_version: Option<SslVersion>,
261    pub server_certificate: TlsKey,
262    pub alpn: TlsAlpn,
263}
264
265#[derive(Default, Eq, PartialEq)]
266pub struct TlsAlpn {
267    /// The split form (ie: ["AB", "ABCD"])
268    alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
269}
270
271impl std::fmt::Debug for TlsAlpn {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        if self.alpn_parts.is_empty() {
274            write!(f, "[]")
275        } else {
276            for (i, part) in self.alpn_parts.iter().enumerate() {
277                if i == 0 {
278                    write!(f, "[")?;
279                } else {
280                    write!(f, ", ")?;
281                }
282                // Print as binary literal with appropriate escaping
283                let mut s = String::new();
284                s.push_str("b\"");
285                for &b in part.iter() {
286                    for c in b.escape_ascii() {
287                        s.push(c as char);
288                    }
289                }
290                s.push('"');
291                write!(f, "{}", s)?;
292            }
293            write!(f, "]")?;
294            Ok(())
295        }
296    }
297}
298
299impl TlsAlpn {
300    pub fn new(alpn: &'static [&'static [u8]]) -> Self {
301        let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
302        Self {
303            alpn_parts: Cow::Owned(alpn),
304        }
305    }
306
307    pub fn new_str(alpn: &'static [&'static str]) -> Self {
308        let alpn = alpn
309            .iter()
310            .map(|s| Cow::Borrowed(s.as_bytes()))
311            .collect::<Vec<_>>();
312        Self {
313            alpn_parts: Cow::Owned(alpn),
314        }
315    }
316
317    pub fn is_empty(&self) -> bool {
318        self.alpn_parts.is_empty()
319    }
320
321    pub fn as_bytes(&self) -> Vec<u8> {
322        let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
323        for part in self.alpn_parts.iter() {
324            bytes.push(part.len() as u8);
325            bytes.extend_from_slice(part.as_ref());
326        }
327        bytes
328    }
329
330    pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
331        let mut vec = Vec::with_capacity(self.alpn_parts.len());
332        for part in self.alpn_parts.iter() {
333            vec.push(part.to_vec());
334        }
335        vec
336    }
337}
338
339#[derive(Debug, Clone, Default)]
340pub struct TlsHandshake {
341    pub alpn: Option<Cow<'static, [u8]>>,
342    pub sni: Option<Cow<'static, str>>,
343    pub cert: Option<CertificateDer<'static>>,
344}
345
346#[cfg(test)]
347mod tests {
348    use rustls_pki_types::PrivatePkcs1KeyDer;
349
350    use super::*;
351
352    #[test]
353    fn test_tls_parameters_debug() {
354        let params = TlsParameters::default();
355        assert_eq!(
356            format!("{:?}", params),
357            "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
358            root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
359            enable_keylog: false, sni_override: None, alpn: [] }"
360        );
361        let params = TlsParameters {
362            server_cert_verify: TlsServerCertVerify::Insecure,
363            cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
364            key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
365                1, 2, 3,
366            ]))),
367            root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
368            crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
369            min_protocol_version: None,
370            max_protocol_version: None,
371            enable_keylog: false,
372            sni_override: None,
373            alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
374        };
375        assert_eq!(
376            format!("{:?}", params),
377            "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
378            root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
379            max_protocol_version: None, enable_keylog: false, sni_override: None, \
380            alpn: [b\"h2\", b\"http/1.1\"] }"
381        );
382    }
383
384    #[test]
385    fn test_tls_alpn() {
386        let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
387        assert_eq!(
388            alpn.as_bytes(),
389            vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
390        );
391        assert_eq!(
392            alpn.as_vec_vec(),
393            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
394        );
395        assert!(!alpn.is_empty());
396        assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
397
398        let empty_alpn = TlsAlpn::default();
399        assert!(empty_alpn.is_empty());
400        assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
401        assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
402        assert_eq!(format!("{:?}", empty_alpn), "[]");
403    }
404
405    #[test]
406    fn test_tls_handshake() {
407        let handshake = TlsHandshake {
408            alpn: Some(Cow::Borrowed(b"h2")),
409            sni: Some(Cow::Borrowed("example.com")),
410            cert: None,
411        };
412        assert_eq!(handshake.alpn, Some(Cow::Borrowed(b"h2".as_slice())));
413        assert_eq!(handshake.sni, Some(Cow::Borrowed("example.com")));
414        assert_eq!(handshake.cert, None);
415
416        assert_eq!(
417            format!("{:?}", handshake),
418            "TlsHandshake { alpn: Some([104, 50]), sni: Some(\"example.com\"), cert: None }"
419        );
420
421        let default_handshake = TlsHandshake::default();
422        assert_eq!(default_handshake.alpn, None);
423        assert_eq!(default_handshake.sni, None);
424        assert_eq!(default_handshake.cert, None);
425    }
426}