gel_stream/common/
tls.rs

1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7// Note that we choose rustls when both openssl and rustls are enabled.
8
9#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17    type Stream: Stream + Send;
18    type ClientParams: Unpin + Send;
19    type ServerParams: Unpin + Send;
20
21    #[allow(unused)]
22    fn init_client(
23        params: &TlsParameters,
24        name: Option<ServerName>,
25    ) -> Result<Self::ClientParams, SslError>;
26    #[allow(unused)]
27    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29    fn upgrade_client<S: Stream>(
30        params: Self::ClientParams,
31        stream: S,
32    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33    fn upgrade_server<S: Stream>(
34        params: TlsServerParameterProvider,
35        stream: S,
36    ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44    type Stream = BaseStream;
45    type ClientParams = ();
46    type ServerParams = ();
47
48    fn init_client(
49        params: &TlsParameters,
50        name: Option<ServerName>,
51    ) -> Result<Self::ClientParams, SslError> {
52        Err(SslError::SslUnsupportedByClient)
53    }
54
55    fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56        Err(SslError::SslUnsupportedByClient)
57    }
58
59    async fn upgrade_client<S: Stream>(
60        params: Self::ClientParams,
61        stream: S,
62    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63        Err(SslError::SslUnsupportedByClient)
64    }
65
66    async fn upgrade_server<S: Stream>(
67        params: TlsServerParameterProvider,
68        stream: S,
69    ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70        Err(SslError::SslUnsupportedByClient)
71    }
72}
73
74/// Verification modes for TLS that are a superset of both PostgreSQL and EdgeDB/Gel.
75///
76/// Postgres offers six levels: `disable`, `allow`, `prefer`, `require`, `verify-ca` and `verify-full`.
77///
78/// EdgeDB/Gel offers three levels: `insecure`, `no_host_verification' and 'strict'.
79///
80/// This table maps the various levels:
81///
82/// | Postgres | EdgeDB/Gel | `TlsServerCertVerify` enum |
83/// | -------- | ----------- | ----------------- |
84/// | require  | insecure    | `Insecure`        |
85/// | verify-ca | no_host_verification | `IgnoreHostname`        |
86/// | verify-full | strict | `VerifyFull`      |
87///
88/// Note that both EdgeDB/Gel and Postgres may alter certificate validation levels
89/// when custom root certificates are provided. This must be done in the
90/// `TlsParameters` struct by the caller.
91#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93    /// Do not verify the server's certificate. Only confirm that the server is
94    /// using TLS.
95    Insecure,
96    /// Verify the server's certificate using the CA (ignore hostname).
97    IgnoreHostname,
98    /// Verify the server's certificate using the CA and hostname.
99    #[default]
100    VerifyFull,
101}
102
103#[derive(Debug, Clone, Default, PartialEq, Eq)]
104pub enum TlsCert {
105    /// Use the system's default certificate.
106    #[default]
107    System,
108    /// Use the system's default certificate and a set of custom root
109    /// certificates.
110    SystemPlus(Vec<CertificateDer<'static>>),
111    /// Use the webpki-roots default certificate.
112    Webpki,
113    /// Use the webpki-roots default certificate and a set of custom root
114    /// certificates.
115    WebpkiPlus(Vec<CertificateDer<'static>>),
116    /// Use a custom root certificate only.
117    Custom(Vec<CertificateDer<'static>>),
118}
119
120#[derive(Default, Debug, PartialEq, Eq)]
121pub struct TlsParameters {
122    pub server_cert_verify: TlsServerCertVerify,
123    pub cert: Option<CertificateDer<'static>>,
124    pub key: Option<PrivateKeyDer<'static>>,
125    pub root_cert: TlsCert,
126    pub crl: Vec<CertificateRevocationListDer<'static>>,
127    pub min_protocol_version: Option<SslVersion>,
128    pub max_protocol_version: Option<SslVersion>,
129    pub enable_keylog: bool,
130    pub sni_override: Option<Cow<'static, str>>,
131    pub alpn: TlsAlpn,
132}
133
134impl TlsParameters {
135    pub fn insecure() -> Self {
136        Self {
137            server_cert_verify: TlsServerCertVerify::Insecure,
138            ..Default::default()
139        }
140    }
141}
142
143#[derive(Copy, Clone, Debug, PartialEq, Eq)]
144pub enum SslVersion {
145    Tls1,
146    Tls1_1,
147    Tls1_2,
148    Tls1_3,
149}
150
151#[derive(Default, Debug, PartialEq, Eq)]
152pub enum TlsClientCertVerify {
153    /// Do not verify the client's certificate, just ignore it.
154    #[default]
155    Ignore,
156    /// If a client certificate is provided, validate it.
157    Optional(Vec<CertificateDer<'static>>),
158    /// Validate that a client certificate exists and is valid. This configuration
159    /// may not be ideal, because it does not fail the client-side handshake.
160    Validate(Vec<CertificateDer<'static>>),
161}
162
163#[derive(derive_more::Debug, derive_more::Constructor)]
164pub struct TlsKey {
165    #[debug("key(...)")]
166    pub(crate) key: PrivateKeyDer<'static>,
167    #[debug("cert(...)")]
168    pub(crate) cert: CertificateDer<'static>,
169}
170
171#[derive(Debug, Clone)]
172pub struct TlsServerParameterProvider {
173    inner: TlsServerParameterProviderInner,
174}
175
176impl TlsServerParameterProvider {
177    pub fn new(params: TlsServerParameters) -> Self {
178        Self {
179            inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
180        }
181    }
182
183    pub fn with_lookup(
184        lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
185    ) -> Self {
186        Self {
187            inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
188        }
189    }
190
191    pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
192        match &self.inner {
193            TlsServerParameterProviderInner::Static(params) => params.clone(),
194            TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
195        }
196    }
197}
198
199#[derive(derive_more::Debug, Clone)]
200enum TlsServerParameterProviderInner {
201    Static(Arc<TlsServerParameters>),
202    #[debug("Lookup(...)")]
203    #[allow(clippy::type_complexity)]
204    Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
205}
206
207#[derive(Debug)]
208pub struct TlsServerParameters {
209    pub client_cert_verify: TlsClientCertVerify,
210    pub min_protocol_version: Option<SslVersion>,
211    pub max_protocol_version: Option<SslVersion>,
212    pub server_certificate: TlsKey,
213    pub alpn: TlsAlpn,
214}
215
216#[derive(Debug, Default, Eq, PartialEq)]
217pub struct TlsAlpn {
218    /// The split form (ie: ["AB", "ABCD"])
219    alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
220}
221
222impl TlsAlpn {
223    pub fn new(alpn: &'static [&'static [u8]]) -> Self {
224        let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
225        Self {
226            alpn_parts: Cow::Owned(alpn),
227        }
228    }
229
230    pub fn new_str(alpn: &'static [&'static str]) -> Self {
231        let alpn = alpn
232            .iter()
233            .map(|s| Cow::Borrowed(s.as_bytes()))
234            .collect::<Vec<_>>();
235        Self {
236            alpn_parts: Cow::Owned(alpn),
237        }
238    }
239
240    pub fn is_empty(&self) -> bool {
241        self.alpn_parts.is_empty()
242    }
243
244    pub fn as_bytes(&self) -> Vec<u8> {
245        let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
246        for part in self.alpn_parts.iter() {
247            bytes.push(part.len() as u8);
248            bytes.extend_from_slice(part.as_ref());
249        }
250        bytes
251    }
252
253    pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
254        let mut vec = Vec::with_capacity(self.alpn_parts.len());
255        for part in self.alpn_parts.iter() {
256            vec.push(part.to_vec());
257        }
258        vec
259    }
260}
261
262#[derive(Debug, Clone, Default)]
263pub struct TlsHandshake {
264    pub alpn: Option<Cow<'static, [u8]>>,
265    pub sni: Option<Cow<'static, str>>,
266    pub cert: Option<CertificateDer<'static>>,
267}