1use std::fmt;
2use std::io::Error as IoError;
3use std::io::ErrorKind;
4use std::path::Path;
5
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures_lite::io::{AsyncRead, AsyncWrite};
9use openssl::ssl;
10use openssl::x509::verify::X509VerifyFlags;
11use tracing::debug;
12
13use crate::net::{
14 AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
15 SplitConnection, TcpDomainConnector,
16 tcp_stream::{SocketOpts, stream, stream_with_opts},
17};
18
19use super::async_to_sync_wrapper::AsyncToSyncWrapper;
20use super::certificate::Certificate;
21use super::handshake::HandshakeFuture;
22use super::stream::TlsStream;
23
24pub mod certs {
27
28 use anyhow::{Context, Result};
29 use openssl::pkcs12::Pkcs12;
30 use openssl::pkey::Private;
31
32 use super::Certificate;
33 use crate::net::certs::CertBuilder;
34
35 pub type PrivateKey = openssl::pkey::PKey<Private>;
36
37 use identity_impl::Identity;
39
40 mod identity_impl {
42
43 use anyhow::{Result, anyhow};
44 use openssl::pkcs12::Pkcs12;
45 use openssl::pkey::{PKey, Private};
46 use openssl::x509::X509;
47
48 #[derive(Clone)]
49 pub struct Identity {
50 pkey: PKey<Private>,
51 cert: X509,
52 chain: Vec<X509>,
53 }
54
55 impl Identity {
56 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity> {
57 let pkcs12 = Pkcs12::from_der(buf)?;
58
59 let parsed = pkcs12
60 .parse2(pass)
61 .map_err(|err| anyhow!("Couldn't read pkcs12 {err}"))?;
62 let pkey = parsed.pkey.ok_or(anyhow!("Missing private key"))?;
63 let cert = parsed.cert.ok_or(anyhow!("Missing cert"))?;
64 Ok(Identity {
65 pkey,
66 cert,
67 chain: parsed.ca.into_iter().flatten().collect(),
68 })
69 }
70
71 pub fn cert(&self) -> &X509 {
72 &self.cert
73 }
74
75 pub fn pkey(&self) -> &PKey<Private> {
76 &self.pkey
77 }
78
79 pub fn chain(&self) -> &Vec<X509> {
80 &self.chain
81 }
82 }
83 }
84
85 pub struct X509PemBuilder(Vec<u8>);
86
87 impl CertBuilder for X509PemBuilder {
88 fn new(bytes: Vec<u8>) -> Self {
89 Self(bytes)
90 }
91 }
92
93 impl X509PemBuilder {
94 pub fn build(self) -> Result<Certificate> {
95 let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
96 Ok(cert)
97 }
98 }
99
100 const PASSWORD: &str = "test";
101
102 pub struct PrivateKeyBuilder(Vec<u8>);
103
104 impl CertBuilder for PrivateKeyBuilder {
105 fn new(bytes: Vec<u8>) -> Self {
106 Self(bytes)
107 }
108 }
109
110 impl PrivateKeyBuilder {
111 pub fn build(self) -> Result<PrivateKey> {
112 let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
113 Ok(key)
114 }
115 }
116
117 pub struct IdentityBuilder(Vec<u8>);
118
119 impl CertBuilder for IdentityBuilder {
120 fn new(bytes: Vec<u8>) -> Self {
121 Self(bytes)
122 }
123 }
124
125 impl IdentityBuilder {
126 pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
128 let server_key = key.build()?;
129 let server_crt = x509.build()?;
130 let p12 = Pkcs12::builder()
131 .name("")
132 .pkey(&server_key)
133 .cert(server_crt.inner())
134 .build2(PASSWORD)
135 .context("Failed to create Pkcs12")?;
136
137 let der = p12.to_der()?;
138 Ok(Self(der))
139 }
140
141 pub fn build(self) -> Result<Identity> {
142 Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
143 }
144 }
145}
146
147#[derive(Clone, Debug)]
148pub struct TlsConnector {
149 pub inner: ssl::SslConnector,
150 pub verify_hostname: bool,
151 pub allow_partial: bool,
152}
153
154impl TlsConnector {
155 pub fn builder() -> Result<TlsConnectorBuilder> {
156 let inner = ssl::SslConnector::builder(ssl::SslMethod::tls())?;
157 Ok(TlsConnectorBuilder {
158 inner,
159 verify_hostname: true,
160 allow_partial: true,
161 })
162 }
163
164 pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>>
165 where
166 S: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + Sync + 'static,
167 {
168 debug!("tls connecting to: {}", domain);
169 let mut client_configuration = self
170 .inner
171 .configure()?
172 .verify_hostname(self.verify_hostname);
173
174 if self.allow_partial {
175 let params = client_configuration.param_mut();
176 params.set_flags(X509VerifyFlags::PARTIAL_CHAIN)?;
177 }
178
179 HandshakeFuture::Initial(
180 move |stream| client_configuration.connect(domain, stream),
181 AsyncToSyncWrapper::new(stream),
182 )
183 .await
184 }
185}
186
187pub struct TlsConnectorBuilder {
188 inner: ssl::SslConnectorBuilder,
189 verify_hostname: bool,
190 allow_partial: bool,
191}
192
193impl TlsConnectorBuilder {
194 pub fn with_hostname_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
195 self.verify_hostname = false;
196 Ok(self)
197 }
198
199 pub fn with_certificate_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
200 self.inner.set_verify(ssl::SslVerifyMode::NONE);
201 Ok(self)
202 }
203
204 pub fn with_certifiate_and_key_from_pem_files<P: AsRef<Path>>(
205 mut self,
206 cert_file: P,
207 key_file: P,
208 ) -> Result<TlsConnectorBuilder> {
209 self.inner
210 .set_certificate_file(cert_file, ssl::SslFiletype::PEM)?;
211 self.inner
212 .set_private_key_file(key_file, ssl::SslFiletype::PEM)?;
213 Ok(self)
214 }
215
216 pub fn with_ca_from_pem_file<P: AsRef<Path>>(
217 mut self,
218 ca_file: P,
219 ) -> Result<TlsConnectorBuilder> {
220 self.inner.set_ca_file(ca_file)?;
221 Ok(self)
222 }
223
224 pub fn add_root_certificate(mut self, cert: Certificate) -> Result<TlsConnectorBuilder> {
225 self.inner.cert_store_mut().add_cert(cert.0)?;
226 Ok(self)
227 }
228
229 pub fn with_identity(mut self, builder: certs::IdentityBuilder) -> Result<Self> {
231 let identity = builder.build().context("failed to build identity")?;
232 self.inner.set_certificate(identity.cert())?;
233 self.inner.set_private_key(identity.pkey())?;
234 for cert in identity.chain().iter().rev() {
235 self.inner.add_extra_chain_cert(cert.to_owned())?;
236 }
237 Ok(self)
238 }
239
240 pub fn build(self) -> TlsConnector {
241 TlsConnector {
242 inner: self.inner.build(),
243 verify_hostname: self.verify_hostname,
244 allow_partial: self.allow_partial,
245 }
246 }
247}
248
249#[derive(Clone)]
251pub struct TlsAnonymousConnector(TlsConnector);
252
253impl From<TlsConnector> for TlsAnonymousConnector {
254 fn from(connector: TlsConnector) -> Self {
255 Self(connector)
256 }
257}
258
259#[async_trait]
260impl TcpDomainConnector for TlsAnonymousConnector {
261 async fn connect(
262 &self,
263 domain: &str,
264 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
265 debug!("tcp connect: {}", domain);
266 let socket_opts = SocketOpts {
267 keepalive: Some(Default::default()),
268 nodelay: Some(true),
269 };
270 let tcp_stream = stream_with_opts(domain, Some(socket_opts)).await?;
271 let fd = tcp_stream.as_connection_fd();
272
273 let (write, read) = self
274 .0
275 .connect(domain, tcp_stream)
276 .await
277 .map_err(|e| {
278 IoError::new(
279 ErrorKind::ConnectionRefused,
280 format!("failed to connect: {e}"),
281 )
282 })?
283 .split_connection();
284
285 Ok((write, read, fd))
286 }
287
288 fn new_domain(&self, _domain: String) -> DomainConnector {
289 Box::new(self.clone())
290 }
291
292 fn domain(&self) -> &str {
293 "localhost"
294 }
295
296 fn clone_box(&self) -> DomainConnector {
297 Box::new(self.clone())
298 }
299}
300
301#[derive(Clone)]
302pub struct TlsDomainConnector {
303 domain: String,
304 connector: TlsConnector,
305}
306
307impl TlsDomainConnector {
308 pub fn new(connector: TlsConnector, domain: String) -> Self {
309 Self { domain, connector }
310 }
311}
312
313#[async_trait]
314impl TcpDomainConnector for TlsDomainConnector {
315 async fn connect(
316 &self,
317 addr: &str,
318 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
319 debug!("connect to tls addr: {}", addr);
320 let tcp_stream = stream(addr).await?;
321 let fd = tcp_stream.as_connection_fd();
322
323 let (write, read) = self
324 .connector
325 .connect(&self.domain, tcp_stream)
326 .await
327 .map_err(|e| {
328 IoError::new(
329 ErrorKind::ConnectionRefused,
330 format!("failed to connect: {e}"),
331 )
332 })?
333 .split_connection();
334
335 debug!("connect to tls domain: {}", self.domain);
336 Ok((write, read, fd))
337 }
338
339 fn new_domain(&self, domain: String) -> DomainConnector {
340 debug!("setting new domain: {}", domain);
341 let mut connector = self.clone();
342 connector.domain = domain;
343 Box::new(connector)
344 }
345
346 fn domain(&self) -> &str {
347 &self.domain
348 }
349
350 fn clone_box(&self) -> DomainConnector {
351 Box::new(self.clone())
352 }
353}