fluvio_future/openssl/
connector.rs

1use std::fmt;
2use std::io::Error as IoError;
3use std::io::ErrorKind;
4use std::path::Path;
5
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use futures_lite::io::{AsyncRead, AsyncWrite};
9use openssl::ssl;
10use openssl::x509::verify::X509VerifyFlags;
11use tracing::debug;
12
13use crate::net::{
14    AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
15    SplitConnection, TcpDomainConnector,
16    tcp_stream::{SocketOpts, stream, stream_with_opts},
17};
18
19use super::async_to_sync_wrapper::AsyncToSyncWrapper;
20use super::certificate::Certificate;
21use super::handshake::HandshakeFuture;
22use super::stream::TlsStream;
23
24// same code but without native builder
25// TODO: simplification
26pub mod certs {
27
28    use anyhow::{Context, Result};
29    use openssl::pkcs12::Pkcs12;
30    use openssl::pkey::Private;
31
32    use super::Certificate;
33    use crate::net::certs::CertBuilder;
34
35    pub type PrivateKey = openssl::pkey::PKey<Private>;
36
37    // use identity_impl::Certificate;
38    use identity_impl::Identity;
39
40    // copied from https://github.com/sfackler/rust-native-tls/blob/master/src/imp/openssl.rs
41    mod identity_impl {
42
43        use anyhow::{Result, anyhow};
44        use openssl::pkcs12::Pkcs12;
45        use openssl::pkey::{PKey, Private};
46        use openssl::x509::X509;
47
48        #[derive(Clone)]
49        pub struct Identity {
50            pkey: PKey<Private>,
51            cert: X509,
52            chain: Vec<X509>,
53        }
54
55        impl Identity {
56            pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity> {
57                let pkcs12 = Pkcs12::from_der(buf)?;
58
59                let parsed = pkcs12
60                    .parse2(pass)
61                    .map_err(|err| anyhow!("Couldn't read pkcs12 {err}"))?;
62                let pkey = parsed.pkey.ok_or(anyhow!("Missing private key"))?;
63                let cert = parsed.cert.ok_or(anyhow!("Missing cert"))?;
64                Ok(Identity {
65                    pkey,
66                    cert,
67                    chain: parsed.ca.into_iter().flatten().collect(),
68                })
69            }
70
71            pub fn cert(&self) -> &X509 {
72                &self.cert
73            }
74
75            pub fn pkey(&self) -> &PKey<Private> {
76                &self.pkey
77            }
78
79            pub fn chain(&self) -> &Vec<X509> {
80                &self.chain
81            }
82        }
83    }
84
85    pub struct X509PemBuilder(Vec<u8>);
86
87    impl CertBuilder for X509PemBuilder {
88        fn new(bytes: Vec<u8>) -> Self {
89            Self(bytes)
90        }
91    }
92
93    impl X509PemBuilder {
94        pub fn build(self) -> Result<Certificate> {
95            let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
96            Ok(cert)
97        }
98    }
99
100    const PASSWORD: &str = "test";
101
102    pub struct PrivateKeyBuilder(Vec<u8>);
103
104    impl CertBuilder for PrivateKeyBuilder {
105        fn new(bytes: Vec<u8>) -> Self {
106            Self(bytes)
107        }
108    }
109
110    impl PrivateKeyBuilder {
111        pub fn build(self) -> Result<PrivateKey> {
112            let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
113            Ok(key)
114        }
115    }
116
117    pub struct IdentityBuilder(Vec<u8>);
118
119    impl CertBuilder for IdentityBuilder {
120        fn new(bytes: Vec<u8>) -> Self {
121            Self(bytes)
122        }
123    }
124
125    impl IdentityBuilder {
126        /// load pk12 from x509 certs
127        pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
128            let server_key = key.build()?;
129            let server_crt = x509.build()?;
130            let p12 = Pkcs12::builder()
131                .name("")
132                .pkey(&server_key)
133                .cert(server_crt.inner())
134                .build2(PASSWORD)
135                .context("Failed to create Pkcs12")?;
136
137            let der = p12.to_der()?;
138            Ok(Self(der))
139        }
140
141        pub fn build(self) -> Result<Identity> {
142            Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
143        }
144    }
145}
146
147#[derive(Clone, Debug)]
148pub struct TlsConnector {
149    pub inner: ssl::SslConnector,
150    pub verify_hostname: bool,
151    pub allow_partial: bool,
152}
153
154impl TlsConnector {
155    pub fn builder() -> Result<TlsConnectorBuilder> {
156        let inner = ssl::SslConnector::builder(ssl::SslMethod::tls())?;
157        Ok(TlsConnectorBuilder {
158            inner,
159            verify_hostname: true,
160            allow_partial: true,
161        })
162    }
163
164    pub async fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>>
165    where
166        S: AsyncRead + AsyncWrite + fmt::Debug + Unpin + Send + Sync + 'static,
167    {
168        debug!("tls connecting to: {}", domain);
169        let mut client_configuration = self
170            .inner
171            .configure()?
172            .verify_hostname(self.verify_hostname);
173
174        if self.allow_partial {
175            let params = client_configuration.param_mut();
176            params.set_flags(X509VerifyFlags::PARTIAL_CHAIN)?;
177        }
178
179        HandshakeFuture::Initial(
180            move |stream| client_configuration.connect(domain, stream),
181            AsyncToSyncWrapper::new(stream),
182        )
183        .await
184    }
185}
186
187pub struct TlsConnectorBuilder {
188    inner: ssl::SslConnectorBuilder,
189    verify_hostname: bool,
190    allow_partial: bool,
191}
192
193impl TlsConnectorBuilder {
194    pub fn with_hostname_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
195        self.verify_hostname = false;
196        Ok(self)
197    }
198
199    pub fn with_certificate_verification_disabled(mut self) -> Result<TlsConnectorBuilder> {
200        self.inner.set_verify(ssl::SslVerifyMode::NONE);
201        Ok(self)
202    }
203
204    pub fn with_certifiate_and_key_from_pem_files<P: AsRef<Path>>(
205        mut self,
206        cert_file: P,
207        key_file: P,
208    ) -> Result<TlsConnectorBuilder> {
209        self.inner
210            .set_certificate_file(cert_file, ssl::SslFiletype::PEM)?;
211        self.inner
212            .set_private_key_file(key_file, ssl::SslFiletype::PEM)?;
213        Ok(self)
214    }
215
216    pub fn with_ca_from_pem_file<P: AsRef<Path>>(
217        mut self,
218        ca_file: P,
219    ) -> Result<TlsConnectorBuilder> {
220        self.inner.set_ca_file(ca_file)?;
221        Ok(self)
222    }
223
224    pub fn add_root_certificate(mut self, cert: Certificate) -> Result<TlsConnectorBuilder> {
225        self.inner.cert_store_mut().add_cert(cert.0)?;
226        Ok(self)
227    }
228
229    /// set identity
230    pub fn with_identity(mut self, builder: certs::IdentityBuilder) -> Result<Self> {
231        let identity = builder.build().context("failed to build identity")?;
232        self.inner.set_certificate(identity.cert())?;
233        self.inner.set_private_key(identity.pkey())?;
234        for cert in identity.chain().iter().rev() {
235            self.inner.add_extra_chain_cert(cert.to_owned())?;
236        }
237        Ok(self)
238    }
239
240    pub fn build(self) -> TlsConnector {
241        TlsConnector {
242            inner: self.inner.build(),
243            verify_hostname: self.verify_hostname,
244            allow_partial: self.allow_partial,
245        }
246    }
247}
248
249/// connect as anonymous client
250#[derive(Clone)]
251pub struct TlsAnonymousConnector(TlsConnector);
252
253impl From<TlsConnector> for TlsAnonymousConnector {
254    fn from(connector: TlsConnector) -> Self {
255        Self(connector)
256    }
257}
258
259#[async_trait]
260impl TcpDomainConnector for TlsAnonymousConnector {
261    async fn connect(
262        &self,
263        domain: &str,
264    ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
265        debug!("tcp connect: {}", domain);
266        let socket_opts = SocketOpts {
267            keepalive: Some(Default::default()),
268            nodelay: Some(true),
269        };
270        let tcp_stream = stream_with_opts(domain, Some(socket_opts)).await?;
271        let fd = tcp_stream.as_connection_fd();
272
273        let (write, read) = self
274            .0
275            .connect(domain, tcp_stream)
276            .await
277            .map_err(|e| {
278                IoError::new(
279                    ErrorKind::ConnectionRefused,
280                    format!("failed to connect: {e}"),
281                )
282            })?
283            .split_connection();
284
285        Ok((write, read, fd))
286    }
287
288    fn new_domain(&self, _domain: String) -> DomainConnector {
289        Box::new(self.clone())
290    }
291
292    fn domain(&self) -> &str {
293        "localhost"
294    }
295
296    fn clone_box(&self) -> DomainConnector {
297        Box::new(self.clone())
298    }
299}
300
301#[derive(Clone)]
302pub struct TlsDomainConnector {
303    domain: String,
304    connector: TlsConnector,
305}
306
307impl TlsDomainConnector {
308    pub fn new(connector: TlsConnector, domain: String) -> Self {
309        Self { domain, connector }
310    }
311}
312
313#[async_trait]
314impl TcpDomainConnector for TlsDomainConnector {
315    async fn connect(
316        &self,
317        addr: &str,
318    ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
319        debug!("connect to tls addr: {}", addr);
320        let tcp_stream = stream(addr).await?;
321        let fd = tcp_stream.as_connection_fd();
322
323        let (write, read) = self
324            .connector
325            .connect(&self.domain, tcp_stream)
326            .await
327            .map_err(|e| {
328                IoError::new(
329                    ErrorKind::ConnectionRefused,
330                    format!("failed to connect: {e}"),
331                )
332            })?
333            .split_connection();
334
335        debug!("connect to tls domain: {}", self.domain);
336        Ok((write, read, fd))
337    }
338
339    fn new_domain(&self, domain: String) -> DomainConnector {
340        debug!("setting new domain: {}", domain);
341        let mut connector = self.clone();
342        connector.domain = domain;
343        Box::new(connector)
344    }
345
346    fn domain(&self) -> &str {
347        &self.domain
348    }
349
350    fn clone_box(&self) -> DomainConnector {
351        Box::new(self.clone())
352    }
353}