1use crate::{SslError, Stream};
2use rustls_pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer, ServerName};
3use std::{borrow::Cow, future::Future, sync::Arc};
4
5use super::BaseStream;
6
7#[cfg(all(feature = "openssl", not(feature = "rustls")))]
10pub type Ssl = crate::common::openssl::OpensslDriver;
11#[cfg(feature = "rustls")]
12pub type Ssl = crate::common::rustls::RustlsDriver;
13#[cfg(not(any(feature = "openssl", feature = "rustls")))]
14pub type Ssl = NullTlsDriver;
15
16pub trait TlsDriver: Default + Send + Sync + Unpin + 'static {
17 type Stream: Stream + Send;
18 type ClientParams: Unpin + Send;
19 type ServerParams: Unpin + Send;
20
21 #[allow(unused)]
22 fn init_client(
23 params: &TlsParameters,
24 name: Option<ServerName>,
25 ) -> Result<Self::ClientParams, SslError>;
26 #[allow(unused)]
27 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError>;
28
29 fn upgrade_client<S: Stream>(
30 params: Self::ClientParams,
31 stream: S,
32 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
33 fn upgrade_server<S: Stream>(
34 params: TlsServerParameterProvider,
35 stream: S,
36 ) -> impl Future<Output = Result<(Self::Stream, TlsHandshake), SslError>> + Send;
37}
38
39#[derive(Default)]
40pub struct NullTlsDriver;
41
42#[allow(unused)]
43impl TlsDriver for NullTlsDriver {
44 type Stream = BaseStream;
45 type ClientParams = ();
46 type ServerParams = ();
47
48 fn init_client(
49 params: &TlsParameters,
50 name: Option<ServerName>,
51 ) -> Result<Self::ClientParams, SslError> {
52 Err(SslError::SslUnsupportedByClient)
53 }
54
55 fn init_server(params: &TlsServerParameters) -> Result<Self::ServerParams, SslError> {
56 Err(SslError::SslUnsupportedByClient)
57 }
58
59 async fn upgrade_client<S: Stream>(
60 params: Self::ClientParams,
61 stream: S,
62 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
63 Err(SslError::SslUnsupportedByClient)
64 }
65
66 async fn upgrade_server<S: Stream>(
67 params: TlsServerParameterProvider,
68 stream: S,
69 ) -> Result<(Self::Stream, TlsHandshake), SslError> {
70 Err(SslError::SslUnsupportedByClient)
71 }
72}
73
74#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
92pub enum TlsServerCertVerify {
93 Insecure,
96 IgnoreHostname,
98 #[default]
100 VerifyFull,
101}
102
103#[derive(Debug, Clone, Default, PartialEq, Eq)]
104pub enum TlsCert {
105 #[default]
107 System,
108 SystemPlus(Vec<CertificateDer<'static>>),
111 Webpki,
113 WebpkiPlus(Vec<CertificateDer<'static>>),
116 Custom(Vec<CertificateDer<'static>>),
118}
119
120#[derive(Default, Debug, PartialEq, Eq)]
121pub struct TlsParameters {
122 pub server_cert_verify: TlsServerCertVerify,
123 pub cert: Option<CertificateDer<'static>>,
124 pub key: Option<PrivateKeyDer<'static>>,
125 pub root_cert: TlsCert,
126 pub crl: Vec<CertificateRevocationListDer<'static>>,
127 pub min_protocol_version: Option<SslVersion>,
128 pub max_protocol_version: Option<SslVersion>,
129 pub enable_keylog: bool,
130 pub sni_override: Option<Cow<'static, str>>,
131 pub alpn: TlsAlpn,
132}
133
134impl TlsParameters {
135 pub fn insecure() -> Self {
136 Self {
137 server_cert_verify: TlsServerCertVerify::Insecure,
138 ..Default::default()
139 }
140 }
141}
142
143#[derive(Copy, Clone, Debug, PartialEq, Eq)]
144pub enum SslVersion {
145 Tls1,
146 Tls1_1,
147 Tls1_2,
148 Tls1_3,
149}
150
151#[derive(Default, Debug, PartialEq, Eq)]
152pub enum TlsClientCertVerify {
153 #[default]
155 Ignore,
156 Optional(Vec<CertificateDer<'static>>),
158 Validate(Vec<CertificateDer<'static>>),
161}
162
163#[derive(derive_more::Debug, derive_more::Constructor)]
164pub struct TlsKey {
165 #[debug("key(...)")]
166 pub(crate) key: PrivateKeyDer<'static>,
167 #[debug("cert(...)")]
168 pub(crate) cert: CertificateDer<'static>,
169}
170
171#[derive(Debug, Clone)]
172pub struct TlsServerParameterProvider {
173 inner: TlsServerParameterProviderInner,
174}
175
176impl TlsServerParameterProvider {
177 pub fn new(params: TlsServerParameters) -> Self {
178 Self {
179 inner: TlsServerParameterProviderInner::Static(Arc::new(params)),
180 }
181 }
182
183 pub fn with_lookup(
184 lookup: impl Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static,
185 ) -> Self {
186 Self {
187 inner: TlsServerParameterProviderInner::Lookup(Arc::new(lookup)),
188 }
189 }
190
191 pub fn lookup(&self, name: Option<ServerName>) -> Arc<TlsServerParameters> {
192 match &self.inner {
193 TlsServerParameterProviderInner::Static(params) => params.clone(),
194 TlsServerParameterProviderInner::Lookup(lookup) => lookup(name),
195 }
196 }
197}
198
199#[derive(derive_more::Debug, Clone)]
200enum TlsServerParameterProviderInner {
201 Static(Arc<TlsServerParameters>),
202 #[debug("Lookup(...)")]
203 #[allow(clippy::type_complexity)]
204 Lookup(Arc<dyn Fn(Option<ServerName>) -> Arc<TlsServerParameters> + Send + Sync + 'static>),
205}
206
207#[derive(Debug)]
208pub struct TlsServerParameters {
209 pub client_cert_verify: TlsClientCertVerify,
210 pub min_protocol_version: Option<SslVersion>,
211 pub max_protocol_version: Option<SslVersion>,
212 pub server_certificate: TlsKey,
213 pub alpn: TlsAlpn,
214}
215
216#[derive(Debug, Default, Eq, PartialEq)]
217pub struct TlsAlpn {
218 alpn_parts: Cow<'static, [Cow<'static, [u8]>]>,
220}
221
222impl TlsAlpn {
223 pub fn new(alpn: &'static [&'static [u8]]) -> Self {
224 let alpn = alpn.iter().map(|s| Cow::Borrowed(*s)).collect::<Vec<_>>();
225 Self {
226 alpn_parts: Cow::Owned(alpn),
227 }
228 }
229
230 pub fn new_str(alpn: &'static [&'static str]) -> Self {
231 let alpn = alpn
232 .iter()
233 .map(|s| Cow::Borrowed(s.as_bytes()))
234 .collect::<Vec<_>>();
235 Self {
236 alpn_parts: Cow::Owned(alpn),
237 }
238 }
239
240 pub fn is_empty(&self) -> bool {
241 self.alpn_parts.is_empty()
242 }
243
244 pub fn as_bytes(&self) -> Vec<u8> {
245 let mut bytes = Vec::with_capacity(self.alpn_parts.len() * 2);
246 for part in self.alpn_parts.iter() {
247 bytes.push(part.len() as u8);
248 bytes.extend_from_slice(part.as_ref());
249 }
250 bytes
251 }
252
253 pub fn as_vec_vec(&self) -> Vec<Vec<u8>> {
254 let mut vec = Vec::with_capacity(self.alpn_parts.len());
255 for part in self.alpn_parts.iter() {
256 vec.push(part.to_vec());
257 }
258 vec
259 }
260}
261
262#[derive(Debug, Clone, Default)]
263pub struct TlsHandshake {
264 pub alpn: Option<Cow<'static, [u8]>>,
265 pub sni: Option<Cow<'static, str>>,
266 pub cert: Option<CertificateDer<'static>>,
267}