fluvio_future/
native_tls.rs

1use crate::net::TcpStream;
2
3pub use async_native_tls::TlsAcceptor;
4pub use async_native_tls::TlsConnector;
5pub use async_native_tls::TlsStream;
6
7// server both cliennt and server and same but use same pattern as rustls
8pub type DefaultServerTlsStream = TlsStream<TcpStream>;
9pub type DefaultClientTlsStream = TlsStream<TcpStream>;
10
11pub use connector::*;
12
13pub use crate::net::certs::CertBuilder;
14
15mod split {
16
17    use async_net::TcpStream;
18    use futures_util::AsyncReadExt;
19
20    use super::*;
21    use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
22
23    impl SplitConnection for TlsStream<TcpStream> {
24        fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
25            let (read, write) = self.split();
26            (Box::new(write), Box::new(read))
27        }
28    }
29}
30
31mod connector {
32    use std::io::Error as IoError;
33    use std::io::ErrorKind;
34    use std::sync::Arc;
35
36    use async_trait::async_trait;
37    use tracing::debug;
38
39    use crate::net::{
40        AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
41        SplitConnection, TcpDomainConnector,
42        tcp_stream::{SocketOpts, stream, stream_with_opts},
43    };
44
45    use super::*;
46
47    /// connect as anonymous client
48    #[derive(Clone)]
49    pub struct TlsAnonymousConnector(Arc<TlsConnector>);
50
51    impl From<TlsConnector> for TlsAnonymousConnector {
52        fn from(connector: TlsConnector) -> Self {
53            Self(Arc::new(connector))
54        }
55    }
56
57    #[async_trait]
58    impl TcpDomainConnector for TlsAnonymousConnector {
59        async fn connect(
60            &self,
61            domain: &str,
62        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
63            let tcp_stream = stream(domain).await?;
64            let fd = tcp_stream.as_connection_fd();
65            let (write, read) = self
66                .0
67                .connect(domain, tcp_stream)
68                .await
69                .map_err(|e| {
70                    IoError::new(
71                        ErrorKind::ConnectionRefused,
72                        format!("failed to connect: {e}"),
73                    )
74                })?
75                .split_connection();
76            Ok((write, read, fd))
77        }
78
79        fn new_domain(&self, _domain: String) -> DomainConnector {
80            Box::new(self.clone())
81        }
82
83        fn domain(&self) -> &str {
84            "localhost"
85        }
86
87        fn clone_box(&self) -> DomainConnector {
88            Box::new(self.clone())
89        }
90    }
91
92    /// Connect to TLS
93    #[derive(Clone)]
94    pub struct TlsDomainConnector {
95        domain: String,
96        connector: Arc<TlsConnector>,
97    }
98
99    impl TlsDomainConnector {
100        pub fn new(connector: TlsConnector, domain: String) -> Self {
101            Self {
102                domain,
103                connector: Arc::new(connector),
104            }
105        }
106
107        pub fn domain(&self) -> &str {
108            &self.domain
109        }
110
111        pub fn connector(&self) -> &TlsConnector {
112            &self.connector
113        }
114    }
115
116    #[async_trait]
117    impl TcpDomainConnector for TlsDomainConnector {
118        async fn connect(
119            &self,
120            addr: &str,
121        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
122            debug!("connect to tls addr: {}", addr);
123            let socket_opts = SocketOpts {
124                keepalive: Some(Default::default()),
125                nodelay: Some(true),
126            };
127            let tcp_stream = stream_with_opts(addr, Some(socket_opts)).await?;
128            let fd = tcp_stream.as_connection_fd();
129
130            debug!("connect to tls domain: {}", self.domain);
131            let (write, read) = self
132                .connector
133                .connect(&self.domain, tcp_stream)
134                .await
135                .map_err(|e| {
136                    IoError::new(
137                        ErrorKind::ConnectionRefused,
138                        format!("failed to connect: {e}"),
139                    )
140                })?
141                .split_connection();
142            Ok((write, read, fd))
143        }
144
145        fn new_domain(&self, domain: String) -> DomainConnector {
146            let mut connector = self.clone();
147            connector.domain = domain;
148            Box::new(connector)
149        }
150
151        fn domain(&self) -> &str {
152            &self.domain
153        }
154
155        fn clone_box(&self) -> DomainConnector {
156            Box::new(self.clone())
157        }
158    }
159}
160
161pub use cert::*;
162
163mod cert {
164    use crate::net::certs::CertBuilder;
165    use anyhow::{Context, Result};
166    use native_tls::Certificate as NativeCertificate;
167    use native_tls::Identity;
168    use openssl::pkcs12::Pkcs12;
169    use openssl::pkey::Private;
170
171    pub type Certificate = openssl::x509::X509;
172    pub type PrivateKey = openssl::pkey::PKey<Private>;
173
174    pub struct X509PemBuilder(Vec<u8>);
175
176    impl CertBuilder for X509PemBuilder {
177        fn new(bytes: Vec<u8>) -> Self {
178            Self(bytes)
179        }
180    }
181
182    impl X509PemBuilder {
183        pub fn build(self) -> Result<Certificate> {
184            let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
185            Ok(cert)
186        }
187
188        pub fn build_native(self) -> Result<NativeCertificate> {
189            NativeCertificate::from_pem(&self.0).context("invalid pem file")
190        }
191    }
192
193    pub struct PrivateKeyBuilder(Vec<u8>);
194
195    impl CertBuilder for PrivateKeyBuilder {
196        fn new(bytes: Vec<u8>) -> Self {
197            Self(bytes)
198        }
199    }
200
201    impl PrivateKeyBuilder {
202        pub fn build(self) -> Result<PrivateKey> {
203            let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
204            Ok(key)
205        }
206    }
207
208    const PASSWORD: &str = "test";
209
210    pub struct IdentityBuilder(Vec<u8>);
211
212    impl CertBuilder for IdentityBuilder {
213        fn new(bytes: Vec<u8>) -> Self {
214            Self(bytes)
215        }
216    }
217
218    impl IdentityBuilder {
219        /// load pk12 from x509 certs
220        pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
221            let server_key = key.build()?;
222            let server_crt = x509.build()?;
223            let p12 = Pkcs12::builder()
224                .name("")
225                .pkey(&server_key)
226                .cert(&server_crt)
227                .build2(PASSWORD)
228                .context("Failed to create Pkcs12")?;
229
230            let der = p12.to_der()?;
231            Ok(Self(der))
232        }
233
234        pub fn build(self) -> Result<Identity> {
235            Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
236        }
237    }
238}
239
240pub use builder::*;
241
242mod builder {
243    use anyhow::{Context, Result};
244    use native_tls::Identity;
245    use native_tls::TlsAcceptor as NativeTlsAcceptor;
246
247    use super::IdentityBuilder;
248    use super::TlsAcceptor;
249    use super::TlsConnector;
250    use super::X509PemBuilder;
251
252    pub struct ConnectorBuilder(TlsConnector);
253
254    impl ConnectorBuilder {
255        pub fn identity(builder: IdentityBuilder) -> Result<Self> {
256            let identity = builder.build()?;
257            let connector = TlsConnector::new().identity(identity);
258            Ok(Self(connector))
259        }
260
261        pub fn anonymous() -> Self {
262            let connector = TlsConnector::new()
263                .danger_accept_invalid_certs(true)
264                .danger_accept_invalid_hostnames(true);
265            Self(connector)
266        }
267
268        pub fn no_cert_verification(self) -> Self {
269            let connector = self.0.danger_accept_invalid_certs(true);
270            Self(connector)
271        }
272
273        pub fn danger_accept_invalid_hostnames(self) -> Self {
274            let connector = self.0.danger_accept_invalid_hostnames(true);
275            Self(connector)
276        }
277
278        pub fn use_sni(self, use_sni: bool) -> Self {
279            let connector = self.0.use_sni(use_sni);
280            Self(connector)
281        }
282
283        pub fn add_root_certificate(self, builder: X509PemBuilder) -> Result<Self> {
284            let certificate = builder.build_native()?;
285            let connector = self.0.add_root_certificate(certificate);
286            Ok(Self(connector))
287        }
288
289        pub fn build(self) -> TlsConnector {
290            self.0
291        }
292    }
293
294    pub struct AcceptorBuilder(Identity);
295
296    impl AcceptorBuilder {
297        pub fn identity(builder: IdentityBuilder) -> Result<Self> {
298            let identity = builder.build()?;
299            Ok(Self(identity))
300        }
301
302        pub fn build(self) -> Result<TlsAcceptor> {
303            let acceptor = NativeTlsAcceptor::new(self.0).context("invalid cert")?;
304            Ok(acceptor.into())
305        }
306    }
307}
308
309#[cfg(test)]
310mod test {
311
312    use std::net::SocketAddr;
313    use std::time;
314
315    use anyhow::Result;
316    use async_native_tls::TlsAcceptor;
317    use async_native_tls::TlsConnector;
318    use bytes::Buf;
319    use bytes::BufMut;
320    use bytes::Bytes;
321    use bytes::BytesMut;
322    use futures_lite::future::zip;
323    use futures_lite::stream::StreamExt;
324    use futures_util::sink::SinkExt;
325    use tokio_util::codec::BytesCodec;
326    use tokio_util::codec::Framed;
327    use tokio_util::compat::FuturesAsyncReadCompatExt;
328    use tracing::debug;
329
330    use crate::net::TcpListener;
331    use crate::net::certs::CertBuilder;
332    use crate::net::tcp_stream::stream;
333    use crate::test_async;
334    use crate::timer::sleep;
335
336    use super::{
337        AcceptorBuilder, ConnectorBuilder, IdentityBuilder, PrivateKeyBuilder, X509PemBuilder,
338    };
339
340    const CA_PATH: &str = "certs/test-certs/ca.crt";
341    const SERVER_IDENTITY: &str = "certs/test-certs/server.pfx";
342    const CLIENT_IDENTITY: &str = "certs/test-certs/client.pfx";
343    const X509_SERVER: &str = "certs/test-certs/server.crt";
344    const X509_SERVER_KEY: &str = "certs/test-certs/server.key";
345    const X509_CLIENT: &str = "certs/test-certs/client.crt";
346    const X509_CLIENT_KEY: &str = "certs/test-certs/client.key";
347
348    fn to_bytes(bytes: Vec<u8>) -> Bytes {
349        let mut buf = BytesMut::with_capacity(bytes.len());
350        buf.put_slice(&bytes);
351        buf.freeze()
352    }
353
354    #[test_async]
355    async fn test_native_tls_pk12() -> Result<()> {
356        const PK12_PORT: u16 = 9900;
357
358        let acceptor = AcceptorBuilder::identity(
359            IdentityBuilder::from_path(SERVER_IDENTITY).expect("identity"),
360        )
361        .expect("identity:")
362        .build()
363        .expect("acceptor");
364
365        let connector = ConnectorBuilder::identity(IdentityBuilder::from_path(CLIENT_IDENTITY)?)
366            .expect("connector")
367            .danger_accept_invalid_hostnames()
368            .no_cert_verification()
369            .build();
370
371        test_tls(PK12_PORT, acceptor, connector)
372            .await
373            .expect("no client cert test failed");
374
375        let acceptor = AcceptorBuilder::identity(
376            IdentityBuilder::from_path(SERVER_IDENTITY).expect("identity"),
377        )
378        .expect("identity:")
379        .build()
380        .expect("acceptor");
381
382        let connector = ConnectorBuilder::identity(IdentityBuilder::from_path(CLIENT_IDENTITY)?)
383            .expect("connector")
384            .no_cert_verification()
385            .build();
386
387        test_tls(PK12_PORT, acceptor, connector)
388            .await
389            .expect("client cert test fail");
390
391        Ok(())
392    }
393
394    #[test_async]
395    #[cfg(not(windows))]
396    async fn test_native_tls_x509() -> Result<()> {
397        const X500_PORT: u16 = 9910;
398
399        let acceptor = AcceptorBuilder::identity(
400            IdentityBuilder::from_x509(
401                X509PemBuilder::from_path(X509_SERVER).expect("read"),
402                PrivateKeyBuilder::from_path(X509_SERVER_KEY).expect("file"),
403            )
404            .expect("identity"),
405        )
406        .expect("identity:")
407        .build()
408        .expect("acceptor");
409
410        let connector = ConnectorBuilder::identity(
411            IdentityBuilder::from_x509(
412                X509PemBuilder::from_path(X509_CLIENT).expect("read"),
413                PrivateKeyBuilder::from_path(X509_CLIENT_KEY).expect("read"),
414            )
415            .expect("509"),
416        )
417        .expect("connector")
418        .danger_accept_invalid_hostnames()
419        .no_cert_verification()
420        .build();
421
422        test_tls(X500_PORT, acceptor, connector)
423            .await
424            .expect("no client cert test failed");
425
426        let acceptor = AcceptorBuilder::identity(
427            IdentityBuilder::from_x509(
428                X509PemBuilder::from_path(X509_SERVER).expect("read"),
429                PrivateKeyBuilder::from_path(X509_SERVER_KEY).expect("file"),
430            )
431            .expect("identity"),
432        )
433        .expect("identity:")
434        .build()
435        .expect("acceptor");
436
437        let connector = ConnectorBuilder::identity(
438            IdentityBuilder::from_x509(
439                X509PemBuilder::from_path(X509_CLIENT).expect("read"),
440                PrivateKeyBuilder::from_path(X509_CLIENT_KEY).expect("read"),
441            )
442            .expect("509"),
443        )
444        .expect("connector")
445        .add_root_certificate(X509PemBuilder::from_path(CA_PATH).expect("cert"))
446        .expect("root")
447        .no_cert_verification() // for mac
448        .build();
449
450        test_tls(X500_PORT, acceptor, connector)
451            .await
452            .expect("no client cert test failed");
453
454        Ok(())
455    }
456
457    async fn test_tls(port: u16, acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
458        const TEST_ITERATION: u16 = 2;
459
460        let addr = format!("127.0.0.1:{port}")
461            .parse::<SocketAddr>()
462            .expect("parse");
463
464        let server_ft = async {
465            debug!("server: binding");
466            let listener = TcpListener::bind(&addr).await.expect("listener failed");
467            debug!("server: successfully binding. waiting for incoming");
468
469            let mut incoming = listener.incoming();
470            let incoming_stream = incoming.next().await.expect("incoming");
471
472            debug!("server: got connection from client");
473            let tcp_stream = incoming_stream.expect("no stream");
474
475            let tls_stream = acceptor.accept(tcp_stream).await.unwrap();
476            let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
477
478            for _ in 0..TEST_ITERATION {
479                debug!("server: sending values to client");
480                let data = vec![0x05, 0x0a, 0x63];
481                framed.send(to_bytes(data)).await.expect("send failed");
482                sleep(time::Duration::from_micros(1)).await;
483                debug!("server: sending 2nd value to client");
484                let data2 = vec![0x20, 0x11];
485                framed.send(to_bytes(data2)).await.expect("2nd send failed");
486            }
487
488            // sleep 1 seconds so we don't lost connection
489            sleep(time::Duration::from_secs(1)).await;
490
491            Ok(()) as Result<()>
492        };
493
494        let client_ft = async {
495            debug!("client: sleep to give server chance to come up");
496            sleep(time::Duration::from_millis(100)).await;
497            debug!("client: trying to connect");
498            let tcp_stream = stream(&addr).await.expect("connection fail");
499            let tls_stream = connector
500                .connect("localhost", tcp_stream)
501                .await
502                .expect("tls failed");
503
504            let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
505            debug!("client: got connection. waiting");
506
507            for i in 0..TEST_ITERATION {
508                let value = framed.next().await.expect("frame");
509                debug!("{} client :received first value from server", i);
510                let bytes = value.expect("invalid value");
511                let values = bytes.take(3).into_inner();
512                assert_eq!(values[0], 0x05);
513                assert_eq!(values[1], 0x0a);
514                assert_eq!(values[2], 0x63);
515                assert_eq!(values.len(), 3);
516
517                let value2 = framed.next().await.expect("frame");
518                debug!("client: received 2nd value from server");
519                let bytes = value2.expect("packet decoding works");
520                let values = bytes.take(2).into_inner();
521                assert_eq!(values.len(), 2);
522            }
523
524            Ok(()) as Result<()>
525        };
526
527        let _ = zip(client_ft, server_ft).await;
528
529        Ok(())
530    }
531}