fluvio_future/openssl/
connector.rs

1use std::fmt;
2use std::io::Error as IoError;
3use std::io::ErrorKind;
4use std::path::Path;
5
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures_lite::io::{AsyncRead, AsyncWrite};
9use openssl::ssl;
10use openssl::x509::verify::X509VerifyFlags;
11use tracing::debug;
12
13use crate::net::{
14    tcp_stream::{stream, stream_with_opts, SocketOpts},
15    AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
16    SplitConnection, TcpDomainConnector,
17};
18
19use super::async_to_sync_wrapper::AsyncToSyncWrapper;
20use super::certificate::Certificate;
21use super::handshake::HandshakeFuture;
22use super::stream::TlsStream;
23
24// same code but without native builder
25// TODO: simplification
26pub mod certs {
27
28    use anyhow::{Context, Result};
29    use openssl::pkcs12::Pkcs12;
30    use openssl::pkey::Private;
31
32    use super::Certificate;
33    use crate::net::certs::CertBuilder;
34
35    pub type PrivateKey = openssl::pkey::PKey<Private>;
36
37    // use identity_impl::Certificate;
38    use identity_impl::Identity;
39
40    // copied from https://github.com/sfackler/rust-native-tls/blob/master/src/imp/openssl.rs
41    mod identity_impl {
42
43        use anyhow::{anyhow, Result};
44        use openssl::pkcs12::Pkcs12;
45        use openssl::pkey::{PKey, Private};
46        use openssl::x509::X509;
47
48        #[derive(Clone)]
49        pub struct Identity {
50            pkey: PKey<Private>,
51            cert: X509,
52            chain: Vec<X509>,
53        }
54
55        impl Identity {
56            pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity> {
57                let pkcs12 = Pkcs12::from_der(buf)?;
58                let parsed = pkcs12
59                    .parse2(pass)
60                    .map_err(|err| anyhow!("Couldn't read pkcs12 {err}"))?;
61                let pkey = parsed.pkey.ok_or(anyhow!("Missing private key"))?;
62                let cert = parsed.cert.ok_or(anyhow!("Missing cert"))?;
63                Ok(Identity {
64                    pkey,
65                    cert,
66                    chain: parsed.ca.into_iter().flatten().collect(),
67                })
68            }
69
70            pub fn cert(&self) -> &X509 {
71                &self.cert
72            }
73
74            pub fn pkey(&self) -> &PKey<Private> {
75                &self.pkey
76            }
77
78            pub fn chain(&self) -> &Vec<X509> {
79                &self.chain
80            }
81        }
82    }
83
84    pub struct X509PemBuilder(Vec<u8>);
85
86    impl CertBuilder for X509PemBuilder {
87        fn new(bytes: Vec<u8>) -> Self {
88            Self(bytes)
89        }
90    }
91
92    impl X509PemBuilder {
93        pub fn build(self) -> Result<Certificate> {
94            let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
95            Ok(cert)
96        }
97    }
98
99    const PASSWORD: &str = "test";
100
101    pub struct PrivateKeyBuilder(Vec<u8>);
102
103    impl CertBuilder for PrivateKeyBuilder {
104        fn new(bytes: Vec<u8>) -> Self {
105            Self(bytes)
106        }
107    }
108
109    impl PrivateKeyBuilder {
110        pub fn build(self) -> Result<PrivateKey> {
111            let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
112            Ok(key)
113        }
114    }
115
116    pub struct IdentityBuilder(Vec<u8>);
117
118    impl CertBuilder for IdentityBuilder {
119        fn new(bytes: Vec<u8>) -> Self {
120            Self(bytes)
121        }
122    }
123
124    impl IdentityBuilder {
125        /// load pk12 from x509 certs
126        pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
127            let server_key = key.build()?;
128            let server_crt = x509.build()?;
129            let p12 = Pkcs12::builder()
130                .name("")
131                .pkey(&server_key)
132                .cert(server_crt.inner())
133                .build2(PASSWORD)
134                .context("Failed to create Pkcs12")?;
135
136            let der = p12.to_der()?;
137            Ok(Self(der))
138        }
139
140        pub fn build(self) -> Result<Identity> {
141            Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
142        }
143    }
144}
145
146#[derive(Clone, Debug)]
147pub struct TlsConnector {
148    pub inner: ssl::SslConnector,
149    pub verify_hostname: bool,
150    pub allow_partial: bool,
151}
152
153impl TlsConnector {
154    pub fn builder() -> Result<TlsConnectorBuilder> {
155        let inner = ssl::SslConnector::builder(ssl::SslMethod::tls())?;
156        Ok(TlsConnectorBuilder {
157            inner,
158            verify_hostname: true,
159            allow_partial: true,
160        })
161    }
162
163    pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>>
164    where
165        S: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + Sync + 'static,
166    {
167        debug!("tls connecting to: {}", domain);
168        let mut client_configuration = self
169            .inner
170            .configure()?
171            .verify_hostname(self.verify_hostname);
172
173        if self.allow_partial {
174            let params = client_configuration.param_mut();
175            params.set_flags(X509VerifyFlags::PARTIAL_CHAIN)?;
176        }
177
178        HandshakeFuture::Initial(
179            move |stream| client_configuration.connect(domain, stream),
180            AsyncToSyncWrapper::new(stream),
181        )
182        .await
183    }
184}
185
186pub struct TlsConnectorBuilder {
187    inner: ssl::SslConnectorBuilder,
188    verify_hostname: bool,
189    allow_partial: bool,
190}
191
192impl TlsConnectorBuilder {
193    pub fn with_hostname_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
194        self.verify_hostname = false;
195        Ok(self)
196    }
197
198    pub fn with_certificate_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
199        self.inner.set_verify(ssl::SslVerifyMode::NONE);
200        Ok(self)
201    }
202
203    pub fn with_certifiate_and_key_from_pem_files<P: AsRef<Path>>(
204        mut self,
205        cert_file: P,
206        key_file: P,
207    ) -> Result<TlsConnectorBuilder> {
208        self.inner
209            .set_certificate_file(cert_file, ssl::SslFiletype::PEM)?;
210        self.inner
211            .set_private_key_file(key_file, ssl::SslFiletype::PEM)?;
212        Ok(self)
213    }
214
215    pub fn with_ca_from_pem_file<P: AsRef<Path>>(
216        mut self,
217        ca_file: P,
218    ) -> Result<TlsConnectorBuilder> {
219        self.inner.set_ca_file(ca_file)?;
220        Ok(self)
221    }
222
223    pub fn add_root_certificate(mut self, cert: Certificate) -> Result<TlsConnectorBuilder> {
224        self.inner.cert_store_mut().add_cert(cert.0)?;
225        Ok(self)
226    }
227
228    /// set identity
229    pub fn with_identity(mut self, builder: certs::IdentityBuilder) -> Result<Self> {
230        let identity = builder.build().context("failed to build identity")?;
231        self.inner.set_certificate(identity.cert())?;
232        self.inner.set_private_key(identity.pkey())?;
233        for cert in identity.chain().iter().rev() {
234            self.inner.add_extra_chain_cert(cert.to_owned())?;
235        }
236        Ok(self)
237    }
238
239    pub fn build(self) -> TlsConnector {
240        TlsConnector {
241            inner: self.inner.build(),
242            verify_hostname: self.verify_hostname,
243            allow_partial: self.allow_partial,
244        }
245    }
246}
247
248/// connect as anonymous client
249#[derive(Clone)]
250pub struct TlsAnonymousConnector(TlsConnector);
251
252impl From<TlsConnector> for TlsAnonymousConnector {
253    fn from(connector: TlsConnector) -> Self {
254        Self(connector)
255    }
256}
257
258#[async_trait]
259impl TcpDomainConnector for TlsAnonymousConnector {
260    async fn connect(
261        &self,
262        domain: &str,
263    ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
264        debug!("tcp connect: {}", domain);
265        let socket_opts = SocketOpts {
266            keepalive: Some(Default::default()),
267            nodelay: Some(true),
268        };
269        let tcp_stream = stream_with_opts(domain, Some(socket_opts)).await?;
270        let fd = tcp_stream.as_connection_fd();
271
272        let (write, read) = self
273            .0
274            .connect(domain, tcp_stream)
275            .await
276            .map_err(|e| {
277                IoError::new(
278                    ErrorKind::ConnectionRefused,
279                    format!("failed to connect: {}", e),
280                )
281            })?
282            .split_connection();
283
284        Ok((write, read, fd))
285    }
286
287    fn new_domain(&self, _domain: String) -> DomainConnector {
288        Box::new(self.clone())
289    }
290
291    fn domain(&self) -> &str {
292        "localhost"
293    }
294
295    fn clone_box(&self) -> DomainConnector {
296        Box::new(self.clone())
297    }
298}
299
300#[derive(Clone)]
301pub struct TlsDomainConnector {
302    domain: String,
303    connector: TlsConnector,
304}
305
306impl TlsDomainConnector {
307    pub fn new(connector: TlsConnector, domain: String) -> Self {
308        Self { domain, connector }
309    }
310}
311
312#[async_trait]
313impl TcpDomainConnector for TlsDomainConnector {
314    async fn connect(
315        &self,
316        addr: &str,
317    ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
318        debug!("connect to tls addr: {}", addr);
319        let tcp_stream = stream(addr).await?;
320        let fd = tcp_stream.as_connection_fd();
321
322        let (write, read) = self
323            .connector
324            .connect(&self.domain, tcp_stream)
325            .await
326            .map_err(|e| {
327                IoError::new(
328                    ErrorKind::ConnectionRefused,
329                    format!("failed to connect: {}", e),
330                )
331            })?
332            .split_connection();
333
334        debug!("connect to tls domain: {}", self.domain);
335        Ok((write, read, fd))
336    }
337
338    fn new_domain(&self, domain: String) -> DomainConnector {
339        debug!("setting new domain: {}", domain);
340        let mut connector = self.clone();
341        connector.domain = domain;
342        Box::new(connector)
343    }
344
345    fn domain(&self) -> &str {
346        &self.domain
347    }
348
349    fn clone_box(&self) -> DomainConnector {
350        Box::new(self.clone())
351    }
352}