1use std::fmt;
2use std::io::Error as IoError;
3use std::io::ErrorKind;
4use std::path::Path;
5
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures_lite::io::{AsyncRead, AsyncWrite};
9use openssl::ssl;
10use openssl::x509::verify::X509VerifyFlags;
11use tracing::debug;
12
13use crate::net::{
14 tcp_stream::{stream, stream_with_opts, SocketOpts},
15 AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
16 SplitConnection, TcpDomainConnector,
17};
18
19use super::async_to_sync_wrapper::AsyncToSyncWrapper;
20use super::certificate::Certificate;
21use super::handshake::HandshakeFuture;
22use super::stream::TlsStream;
23
24pub mod certs {
27
28 use anyhow::{Context, Result};
29 use openssl::pkcs12::Pkcs12;
30 use openssl::pkey::Private;
31
32 use super::Certificate;
33 use crate::net::certs::CertBuilder;
34
35 pub type PrivateKey = openssl::pkey::PKey<Private>;
36
37 use identity_impl::Identity;
39
40 mod identity_impl {
42
43 use anyhow::{anyhow, Result};
44 use openssl::pkcs12::Pkcs12;
45 use openssl::pkey::{PKey, Private};
46 use openssl::x509::X509;
47
48 #[derive(Clone)]
49 pub struct Identity {
50 pkey: PKey<Private>,
51 cert: X509,
52 chain: Vec<X509>,
53 }
54
55 impl Identity {
56 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity> {
57 let pkcs12 = Pkcs12::from_der(buf)?;
58 let parsed = pkcs12
59 .parse2(pass)
60 .map_err(|err| anyhow!("Couldn't read pkcs12 {err}"))?;
61 let pkey = parsed.pkey.ok_or(anyhow!("Missing private key"))?;
62 let cert = parsed.cert.ok_or(anyhow!("Missing cert"))?;
63 Ok(Identity {
64 pkey,
65 cert,
66 chain: parsed.ca.into_iter().flatten().collect(),
67 })
68 }
69
70 pub fn cert(&self) -> &X509 {
71 &self.cert
72 }
73
74 pub fn pkey(&self) -> &PKey<Private> {
75 &self.pkey
76 }
77
78 pub fn chain(&self) -> &Vec<X509> {
79 &self.chain
80 }
81 }
82 }
83
84 pub struct X509PemBuilder(Vec<u8>);
85
86 impl CertBuilder for X509PemBuilder {
87 fn new(bytes: Vec<u8>) -> Self {
88 Self(bytes)
89 }
90 }
91
92 impl X509PemBuilder {
93 pub fn build(self) -> Result<Certificate> {
94 let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
95 Ok(cert)
96 }
97 }
98
99 const PASSWORD: &str = "test";
100
101 pub struct PrivateKeyBuilder(Vec<u8>);
102
103 impl CertBuilder for PrivateKeyBuilder {
104 fn new(bytes: Vec<u8>) -> Self {
105 Self(bytes)
106 }
107 }
108
109 impl PrivateKeyBuilder {
110 pub fn build(self) -> Result<PrivateKey> {
111 let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
112 Ok(key)
113 }
114 }
115
116 pub struct IdentityBuilder(Vec<u8>);
117
118 impl CertBuilder for IdentityBuilder {
119 fn new(bytes: Vec<u8>) -> Self {
120 Self(bytes)
121 }
122 }
123
124 impl IdentityBuilder {
125 pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
127 let server_key = key.build()?;
128 let server_crt = x509.build()?;
129 let p12 = Pkcs12::builder()
130 .name("")
131 .pkey(&server_key)
132 .cert(server_crt.inner())
133 .build2(PASSWORD)
134 .context("Failed to create Pkcs12")?;
135
136 let der = p12.to_der()?;
137 Ok(Self(der))
138 }
139
140 pub fn build(self) -> Result<Identity> {
141 Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
142 }
143 }
144}
145
146#[derive(Clone, Debug)]
147pub struct TlsConnector {
148 pub inner: ssl::SslConnector,
149 pub verify_hostname: bool,
150 pub allow_partial: bool,
151}
152
153impl TlsConnector {
154 pub fn builder() -> Result<TlsConnectorBuilder> {
155 let inner = ssl::SslConnector::builder(ssl::SslMethod::tls())?;
156 Ok(TlsConnectorBuilder {
157 inner,
158 verify_hostname: true,
159 allow_partial: true,
160 })
161 }
162
163 pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>>
164 where
165 S: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + Sync + 'static,
166 {
167 debug!("tls connecting to: {}", domain);
168 let mut client_configuration = self
169 .inner
170 .configure()?
171 .verify_hostname(self.verify_hostname);
172
173 if self.allow_partial {
174 let params = client_configuration.param_mut();
175 params.set_flags(X509VerifyFlags::PARTIAL_CHAIN)?;
176 }
177
178 HandshakeFuture::Initial(
179 move |stream| client_configuration.connect(domain, stream),
180 AsyncToSyncWrapper::new(stream),
181 )
182 .await
183 }
184}
185
186pub struct TlsConnectorBuilder {
187 inner: ssl::SslConnectorBuilder,
188 verify_hostname: bool,
189 allow_partial: bool,
190}
191
192impl TlsConnectorBuilder {
193 pub fn with_hostname_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
194 self.verify_hostname = false;
195 Ok(self)
196 }
197
198 pub fn with_certificate_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
199 self.inner.set_verify(ssl::SslVerifyMode::NONE);
200 Ok(self)
201 }
202
203 pub fn with_certifiate_and_key_from_pem_files<P: AsRef<Path>>(
204 mut self,
205 cert_file: P,
206 key_file: P,
207 ) -> Result<TlsConnectorBuilder> {
208 self.inner
209 .set_certificate_file(cert_file, ssl::SslFiletype::PEM)?;
210 self.inner
211 .set_private_key_file(key_file, ssl::SslFiletype::PEM)?;
212 Ok(self)
213 }
214
215 pub fn with_ca_from_pem_file<P: AsRef<Path>>(
216 mut self,
217 ca_file: P,
218 ) -> Result<TlsConnectorBuilder> {
219 self.inner.set_ca_file(ca_file)?;
220 Ok(self)
221 }
222
223 pub fn add_root_certificate(mut self, cert: Certificate) -> Result<TlsConnectorBuilder> {
224 self.inner.cert_store_mut().add_cert(cert.0)?;
225 Ok(self)
226 }
227
228 pub fn with_identity(mut self, builder: certs::IdentityBuilder) -> Result<Self> {
230 let identity = builder.build().context("failed to build identity")?;
231 self.inner.set_certificate(identity.cert())?;
232 self.inner.set_private_key(identity.pkey())?;
233 for cert in identity.chain().iter().rev() {
234 self.inner.add_extra_chain_cert(cert.to_owned())?;
235 }
236 Ok(self)
237 }
238
239 pub fn build(self) -> TlsConnector {
240 TlsConnector {
241 inner: self.inner.build(),
242 verify_hostname: self.verify_hostname,
243 allow_partial: self.allow_partial,
244 }
245 }
246}
247
248#[derive(Clone)]
250pub struct TlsAnonymousConnector(TlsConnector);
251
252impl From<TlsConnector> for TlsAnonymousConnector {
253 fn from(connector: TlsConnector) -> Self {
254 Self(connector)
255 }
256}
257
258#[async_trait]
259impl TcpDomainConnector for TlsAnonymousConnector {
260 async fn connect(
261 &self,
262 domain: &str,
263 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
264 debug!("tcp connect: {}", domain);
265 let socket_opts = SocketOpts {
266 keepalive: Some(Default::default()),
267 nodelay: Some(true),
268 };
269 let tcp_stream = stream_with_opts(domain, Some(socket_opts)).await?;
270 let fd = tcp_stream.as_connection_fd();
271
272 let (write, read) = self
273 .0
274 .connect(domain, tcp_stream)
275 .await
276 .map_err(|e| {
277 IoError::new(
278 ErrorKind::ConnectionRefused,
279 format!("failed to connect: {}", e),
280 )
281 })?
282 .split_connection();
283
284 Ok((write, read, fd))
285 }
286
287 fn new_domain(&self, _domain: String) -> DomainConnector {
288 Box::new(self.clone())
289 }
290
291 fn domain(&self) -> &str {
292 "localhost"
293 }
294
295 fn clone_box(&self) -> DomainConnector {
296 Box::new(self.clone())
297 }
298}
299
300#[derive(Clone)]
301pub struct TlsDomainConnector {
302 domain: String,
303 connector: TlsConnector,
304}
305
306impl TlsDomainConnector {
307 pub fn new(connector: TlsConnector, domain: String) -> Self {
308 Self { domain, connector }
309 }
310}
311
312#[async_trait]
313impl TcpDomainConnector for TlsDomainConnector {
314 async fn connect(
315 &self,
316 addr: &str,
317 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
318 debug!("connect to tls addr: {}", addr);
319 let tcp_stream = stream(addr).await?;
320 let fd = tcp_stream.as_connection_fd();
321
322 let (write, read) = self
323 .connector
324 .connect(&self.domain, tcp_stream)
325 .await
326 .map_err(|e| {
327 IoError::new(
328 ErrorKind::ConnectionRefused,
329 format!("failed to connect: {}", e),
330 )
331 })?
332 .split_connection();
333
334 debug!("connect to tls domain: {}", self.domain);
335 Ok((write, read, fd))
336 }
337
338 fn new_domain(&self, domain: String) -> DomainConnector {
339 debug!("setting new domain: {}", domain);
340 let mut connector = self.clone();
341 connector.domain = domain;
342 Box::new(connector)
343 }
344
345 fn domain(&self) -> &str {
346 &self.domain
347 }
348
349 fn clone_box(&self) -> DomainConnector {
350 Box::new(self.clone())
351 }
352}