gel_stream/common/
tls.rs

1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7// Note that we choose rustls when both openssl and rustls are enabled.
8
9#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17    type Stream: Stream + Send;
18    type ClientParams: Unpin + Send;
19    type ServerParams: Unpin + Send;
20
21    #[allow(unused)]
22    fn init_client(
23        params: &TlsParameters,
24        name: Option<ServerName>,
25    ) -> Result<Self::ClientParams, SslError>;
26    #[allow(unused)]
27    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29    fn upgrade_client<S: Stream>(
30        params: Self::ClientParams,
31        stream: S,
32    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33    fn upgrade_server<S: Stream>(
34        params: TlsServerParameterProvider,
35        stream: S,
36    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44    type Stream = BaseStream;
45    type ClientParams = ();
46    type ServerParams = ();
47
48    fn init_client(
49        params: &TlsParameters,
50        name: Option<ServerName>,
51    ) -> Result<Self::ClientParams, SslError> {
52        Err(SslError::SslUnsupportedByClient)
53    }
54
55    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56        Err(SslError::SslUnsupportedByClient)
57    }
58
59    async fn upgrade_client<S: Stream>(
60        params: Self::ClientParams,
61        stream: S,
62    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63        Err(SslError::SslUnsupportedByClient)
64    }
65
66    async fn upgrade_server<S: Stream>(
67        params: TlsServerParameterProvider,
68        stream: S,
69    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70        Err(SslError::SslUnsupportedByClient)
71    }
72}
73
74/// Verification modes for TLS that are a superset of both PostgreSQL and EdgeDB/Gel.
75///
76/// Postgres offers six levels: `disable`, `allow`, `prefer`, `require`, `verify-ca` and `verify-full`.
77///
78/// EdgeDB/Gel offers three levels: `insecure`, `no_host_verification' and 'strict'.
79///
80/// This table maps the various levels:
81///
82/// | Postgres | EdgeDB/Gel | `TlsServerCertVerify` enum |
83/// | -------- | ----------- | ----------------- |
84/// | require  | insecure    | `Insecure`        |
85/// | verify-ca | no_host_verification | `IgnoreHostname`        |
86/// | verify-full | strict | `VerifyFull`      |
87///
88/// Note that both EdgeDB/Gel and Postgres may alter certificate validation levels
89/// when custom root certificates are provided. This must be done in the
90/// `TlsParameters` struct by the caller.
91#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93    /// Do not verify the server's certificate. Only confirm that the server is
94    /// using TLS.
95    Insecure,
96    /// Verify the server's certificate using the CA (ignore hostname).
97    IgnoreHostname,
98    /// Verify the server's certificate using the CA and hostname.
99    #[default]
100    VerifyFull,
101}
102
103#[derive(Clone, derive_more::Debug, Default, PartialEq, Eq)]
104pub enum TlsCert {
105    /// Use the system's default certificate.
106    #[default]
107    System,
108    /// Use the system's default certificate and a set of custom root
109    /// certificates.
110    #[debug("SystemPlus([{} cert(s)])", _0.len())]
111    SystemPlus(Vec<CertificateDer<'static>>),
112    /// Use the webpki-roots default certificate.
113    Webpki,
114    /// Use the webpki-roots default certificate and a set of custom root
115    /// certificates.
116    #[debug("WebpkiPlus([{} cert(s)])", _0.len())]
117    WebpkiPlus(Vec<CertificateDer<'static>>),
118    /// Use a custom root certificate only.
119    #[debug("Custom([{} cert(s)])", _0.len())]
120    Custom(Vec<CertificateDer<'static>>),
121}
122
123#[derive(Default, derive_more::Debug, PartialEq, Eq)]
124pub struct TlsParameters {
125    pub server_cert_verify: TlsServerCertVerify,
126    #[debug("{}", cert.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
127    pub cert: Option<CertificateDer<'static>>,
128    #[debug("{}", key.as_ref().map(|_| "Some(...)").unwrap_or("None"))]
129    pub key: Option<PrivateKeyDer<'static>>,
130    pub root_cert: TlsCert,
131    #[debug("{}", if crl.is_empty() { "[]".to_string() } else { format!("[{} item(s)]", crl.len()) })]
132    pub crl: Vec<CertificateRevocationListDer<'static>>,
133    pub min_protocol_version: Option<SslVersion>,
134    pub max_protocol_version: Option<SslVersion>,
135    pub enable_keylog: bool,
136    pub sni_override: Option<Cow<'static, str>>,
137    pub alpn: TlsAlpn,
138}
139
140impl TlsParameters {
141    pub fn insecure() -> Self {
142        Self {
143            server_cert_verify: TlsServerCertVerify::Insecure,
144            ..Default::default()
145        }
146    }
147}
148
149#[derive(Copy, Clone, Debug, PartialEq, Eq)]
150pub enum SslVersion {
151    Tls1,
152    Tls1_1,
153    Tls1_2,
154    Tls1_3,
155}
156
157#[derive(Default, Debug, PartialEq, Eq)]
158pub enum TlsClientCertVerify {
159    /// Do not verify the client's certificate, just ignore it.
160    #[default]
161    Ignore,
162    /// If a client certificate is provided, validate it.
163    Optional(Vec<CertificateDer<'static>>),
164    /// Validate that a client certificate exists and is valid. This configuration
165    /// may not be ideal, because it does not fail the client-side handshake.
166    Validate(Vec<CertificateDer<'static>>),
167}
168
169#[derive(derive_more::Debug, derive_more::Constructor)]
170pub struct TlsKey {
171    #[debug("key(...)")]
172    pub(crate) key: PrivateKeyDer<'static>,
173    #[debug("cert(...)")]
174    pub(crate) cert: CertificateDer<'static>,
175}
176
177#[derive(Debug, Clone)]
178pub struct TlsServerParameterProvider {
179    inner: TlsServerParameterProviderInner,
180}
181
182impl TlsServerParameterProvider {
183    pub fn new(params: TlsServerParameters) -> Self {
184        Self {
185            inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
186        }
187    }
188
189    pub fn with_lookup(
190        lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
191    ) -> Self {
192        Self {
193            inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
194        }
195    }
196
197    pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
198        match &self.inner {
199            TlsServerParameterProviderInner::Static(params) => params.clone(),
200            TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
201        }
202    }
203}
204
205#[derive(derive_more::Debug, Clone)]
206enum TlsServerParameterProviderInner {
207    Static(Arc<TlsServerParameters>),
208    #[debug("Lookup(...)")]
209    #[allow(clippy::type_complexity)]
210    Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
211}
212
213#[derive(Debug)]
214pub struct TlsServerParameters {
215    pub client_cert_verify: TlsClientCertVerify,
216    pub min_protocol_version: Option<SslVersion>,
217    pub max_protocol_version: Option<SslVersion>,
218    pub server_certificate: TlsKey,
219    pub alpn: TlsAlpn,
220}
221
222#[derive(Default, Eq, PartialEq)]
223pub struct TlsAlpn {
224    /// The split form (ie: ["AB", "ABCD"])
225    alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
226}
227
228impl std::fmt::Debug for TlsAlpn {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        if self.alpn_parts.is_empty() {
231            write!(f, "[]")
232        } else {
233            for (i, part) in self.alpn_parts.iter().enumerate() {
234                if i == 0 {
235                    write!(f, "[")?;
236                } else {
237                    write!(f, ", ")?;
238                }
239                // Print as binary literal with appropriate escaping
240                let mut s = String::new();
241                s.push_str("b\"");
242                for &b in part.iter() {
243                    for c in b.escape_ascii() {
244                        s.push(c as char);
245                    }
246                }
247                s.push_str("\"");
248                write!(f, "{}", s)?;
249            }
250            write!(f, "]")?;
251            Ok(())
252        }
253    }
254}
255
256impl TlsAlpn {
257    pub fn new(alpn: &'static [&'static [u8]]) -> Self {
258        let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
259        Self {
260            alpn_parts: Cow::Owned(alpn),
261        }
262    }
263
264    pub fn new_str(alpn: &'static [&'static str]) -> Self {
265        let alpn = alpn
266            .iter()
267            .map(|s| Cow::Borrowed(s.as_bytes()))
268            .collect::<Vec<_>>();
269        Self {
270            alpn_parts: Cow::Owned(alpn),
271        }
272    }
273
274    pub fn is_empty(&self) -> bool {
275        self.alpn_parts.is_empty()
276    }
277
278    pub fn as_bytes(&self) -> Vec<u8> {
279        let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
280        for part in self.alpn_parts.iter() {
281            bytes.push(part.len() as u8);
282            bytes.extend_from_slice(part.as_ref());
283        }
284        bytes
285    }
286
287    pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
288        let mut vec = Vec::with_capacity(self.alpn_parts.len());
289        for part in self.alpn_parts.iter() {
290            vec.push(part.to_vec());
291        }
292        vec
293    }
294}
295
296#[derive(Debug, Clone, Default)]
297pub struct TlsHandshake {
298    pub alpn: Option<Cow<'static, [u8]>>,
299    pub sni: Option<Cow<'static, str>>,
300    pub cert: Option<CertificateDer<'static>>,
301}
302
303#[cfg(test)]
304mod tests {
305    use rustls_pki_types::PrivatePkcs1KeyDer;
306
307    use super::*;
308
309    #[test]
310    fn test_tls_parameters_debug() {
311        let params = TlsParameters::default();
312        assert_eq!(
313            format!("{:?}", params),
314            "TlsParameters { server_cert_verify: VerifyFull, cert: None, key: None, \
315            root_cert: System, crl: [], min_protocol_version: None, max_protocol_version: None, \
316            enable_keylog: false, sni_override: None, alpn: [] }"
317        );
318        let params = TlsParameters {
319            server_cert_verify: TlsServerCertVerify::Insecure,
320            cert: Some(CertificateDer::from_slice(&[1, 2, 3])),
321            key: Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(vec![
322                1, 2, 3,
323            ]))),
324            root_cert: TlsCert::SystemPlus(vec![CertificateDer::from_slice(&[1, 2, 3])]),
325            crl: vec![CertificateRevocationListDer::from(vec![1, 2, 3])],
326            min_protocol_version: None,
327            max_protocol_version: None,
328            enable_keylog: false,
329            sni_override: None,
330            alpn: TlsAlpn::new_str(&["h2", "http/1.1"]),
331        };
332        assert_eq!(
333            format!("{:?}", params),
334            "TlsParameters { server_cert_verify: Insecure, cert: Some(...), key: Some(...), \
335            root_cert: SystemPlus([1 cert(s)]), crl: [1 item(s)], min_protocol_version: None, \
336            max_protocol_version: None, enable_keylog: false, sni_override: None, \
337            alpn: [b\"h2\", b\"http/1.1\"] }"
338        );
339    }
340
341    #[test]
342    fn test_tls_alpn() {
343        let alpn = TlsAlpn::new_str(&["h2", "http/1.1"]);
344        assert_eq!(
345            alpn.as_bytes(),
346            vec![2, b'h', b'2', 8, b'h', b't', b't', b'p', b'/', b'1', b'.', b'1']
347        );
348        assert_eq!(
349            alpn.as_vec_vec(),
350            vec![b"h2".to_vec(), b"http/1.1".to_vec()]
351        );
352        assert!(!alpn.is_empty());
353        assert_eq!(format!("{:?}", alpn), "[b\"h2\", b\"http/1.1\"]");
354
355        let empty_alpn = TlsAlpn::default();
356        assert!(empty_alpn.is_empty());
357        assert_eq!(empty_alpn.as_bytes(), Vec::<u8>::new());
358        assert_eq!(empty_alpn.as_vec_vec(), Vec::<Vec<u8>>::new());
359        assert_eq!(format!("{:?}", empty_alpn), "[]");
360    }
361
362    #[test]
363    fn test_tls_handshake() {
364        let handshake = TlsHandshake {
365            alpn: Some(Cow::Borrowed(b"h2")),
366            sni: Some(Cow::Borrowed("example.com")),
367            cert: None,
368        };
369        assert_eq!(handshake.alpn, Some(Cow::Borrowed(b"h2".as_slice())));
370        assert_eq!(handshake.sni, Some(Cow::Borrowed("example.com")));
371        assert_eq!(handshake.cert, None);
372
373        assert_eq!(
374            format!("{:?}", handshake),
375            "TlsHandshake { alpn: Some([104, 50]), sni: Some(\"example.com\"), cert: None }"
376        );
377
378        let default_handshake = TlsHandshake::default();
379        assert_eq!(default_handshake.alpn, None);
380        assert_eq!(default_handshake.sni, None);
381        assert_eq!(default_handshake.cert, None);
382    }
383}