fluvio_future/
rust_tls.rs

1use crate::net::TcpStream;
2
3pub use futures_rustls::TlsAcceptor;
4pub use futures_rustls::TlsConnector;
5pub use futures_rustls::client::TlsStream as ClientTlsStream;
6pub use futures_rustls::server::TlsStream as ServerTlsStream;
7
8pub type DefaultServerTlsStream = ServerTlsStream<TcpStream>;
9pub type DefaultClientTlsStream = ClientTlsStream<TcpStream>;
10
11pub use builder::*;
12pub use cert::*;
13pub use connector::*;
14
15mod split {
16
17    use futures_util::AsyncReadExt;
18
19    use super::*;
20    use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
21
22    impl SplitConnection for DefaultClientTlsStream {
23        fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
24            let (read, write) = self.split();
25            (Box::new(write), Box::new(read))
26        }
27    }
28
29    impl SplitConnection for DefaultServerTlsStream {
30        fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
31            let (read, write) = self.split();
32            (Box::new(write), Box::new(read))
33        }
34    }
35}
36
37mod cert {
38    use std::fs::File;
39    use std::io::BufRead;
40    use std::io::BufReader;
41    use std::path::Path;
42
43    use anyhow::{Context, Result, anyhow};
44    use futures_rustls::rustls::RootCertStore;
45    use futures_rustls::rustls::pki_types::CertificateDer;
46    use futures_rustls::rustls::pki_types::PrivateKeyDer;
47    use rustls_pemfile::certs;
48    use rustls_pemfile::pkcs8_private_keys;
49
50    pub fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
51        load_certs_from_reader(&mut BufReader::new(File::open(path)?))
52    }
53
54    pub fn load_certs_from_reader(rd: &mut dyn BufRead) -> Result<Vec<CertificateDer<'static>>> {
55        certs(rd).map(|r| r.context("invalid cert")).collect()
56    }
57
58    /// Load the passed keys file
59    pub fn load_keys<P: AsRef<Path>>(path: P) -> Result<Vec<PrivateKeyDer<'static>>> {
60        load_keys_from_reader(&mut BufReader::new(File::open(path)?))
61    }
62
63    pub fn load_keys_from_reader(rd: &mut dyn BufRead) -> Result<Vec<PrivateKeyDer<'static>>> {
64        pkcs8_private_keys(rd)
65            .map(|r| r.map(|p| p.into()).context("invalid key"))
66            .collect()
67    }
68
69    pub(crate) fn load_first_key<P: AsRef<Path>>(path: P) -> Result<PrivateKeyDer<'static>> {
70        load_first_key_from_reader(&mut BufReader::new(File::open(path)?))
71    }
72
73    pub(crate) fn load_first_key_from_reader(
74        rd: &mut dyn BufRead,
75    ) -> Result<PrivateKeyDer<'static>> {
76        let mut keys = load_keys_from_reader(rd)?;
77
78        if keys.is_empty() {
79            Err(anyhow!("no keys found"))
80        } else {
81            Ok(keys.remove(0))
82        }
83    }
84
85    pub fn load_root_ca<P: AsRef<Path>>(path: P) -> Result<RootCertStore> {
86        let certs = load_certs(path).map_err(|err| err.context("invalid ca crt"))?;
87
88        let mut root_store = RootCertStore::empty();
89
90        for cert in certs {
91            root_store.add(cert).context("invalid ca crt")?;
92        }
93
94        Ok(root_store)
95    }
96}
97
98mod connector {
99    use std::io::Error as IoError;
100    use std::io::ErrorKind;
101
102    use async_trait::async_trait;
103    use futures_rustls::rustls::pki_types::ServerName;
104    use tracing::debug;
105
106    use crate::net::{
107        AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
108        SplitConnection, TcpDomainConnector, tcp_stream::stream,
109    };
110
111    use super::TlsConnector;
112
113    pub type TlsError = IoError;
114
115    /// connect as anonymous client
116    #[derive(Clone)]
117    pub struct TlsAnonymousConnector(TlsConnector);
118
119    impl From<TlsConnector> for TlsAnonymousConnector {
120        fn from(connector: TlsConnector) -> Self {
121            Self(connector)
122        }
123    }
124
125    #[async_trait]
126    impl TcpDomainConnector for TlsAnonymousConnector {
127        async fn connect(
128            &self,
129            domain: &str,
130        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
131            let tcp_stream = stream(domain).await?;
132            let fd = tcp_stream.as_connection_fd();
133
134            let server_name = ServerName::try_from(domain).map_err(|err| {
135                IoError::new(
136                    ErrorKind::InvalidInput,
137                    format!("Invalid Dns Name: {}", err),
138                )
139            })?;
140
141            let (write, read) = self
142                .0
143                .connect(server_name.to_owned(), tcp_stream)
144                .await?
145                .split_connection();
146            Ok((write, read, fd))
147        }
148
149        fn new_domain(&self, _domain: String) -> DomainConnector {
150            Box::new(self.clone())
151        }
152
153        fn domain(&self) -> &str {
154            "localhost"
155        }
156
157        fn clone_box(&self) -> DomainConnector {
158            Box::new(self.clone())
159        }
160    }
161
162    #[derive(Clone)]
163    pub struct TlsDomainConnector {
164        domain: String,
165        connector: TlsConnector,
166    }
167
168    impl TlsDomainConnector {
169        pub fn new(connector: TlsConnector, domain: String) -> Self {
170            Self { domain, connector }
171        }
172    }
173
174    #[async_trait]
175    impl TcpDomainConnector for TlsDomainConnector {
176        async fn connect(
177            &self,
178            addr: &str,
179        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
180            debug!("connect to tls addr: {}", addr);
181            let tcp_stream = stream(addr).await?;
182            let fd = tcp_stream.as_connection_fd();
183            debug!("connect to tls domain: {}", self.domain);
184
185            let server_name = ServerName::try_from(self.domain.as_str()).map_err(|err| {
186                IoError::new(
187                    ErrorKind::InvalidInput,
188                    format!("Invalid Dns Name: {}", err),
189                )
190            })?;
191
192            let (write, read) = self
193                .connector
194                .connect(server_name.to_owned(), tcp_stream)
195                .await?
196                .split_connection();
197            Ok((write, read, fd))
198        }
199
200        fn new_domain(&self, domain: String) -> DomainConnector {
201            let mut connector = self.clone();
202            connector.domain = domain;
203            Box::new(connector)
204        }
205
206        fn domain(&self) -> &str {
207            &self.domain
208        }
209
210        fn clone_box(&self) -> DomainConnector {
211            Box::new(self.clone())
212        }
213    }
214}
215
216mod builder {
217
218    use std::io::Cursor;
219    use std::path::Path;
220    use std::sync::Arc;
221
222    use futures_rustls::TlsAcceptor;
223    use futures_rustls::TlsConnector;
224    use futures_rustls::pki_types::UnixTime;
225    use futures_rustls::rustls::ClientConfig;
226    use futures_rustls::rustls::ConfigBuilder;
227    use futures_rustls::rustls::Error as TlsError;
228    use futures_rustls::rustls::RootCertStore;
229    use futures_rustls::rustls::ServerConfig;
230    use futures_rustls::rustls::SignatureScheme;
231    use futures_rustls::rustls::WantsVerifier;
232    use futures_rustls::rustls::client::WantsClientCert;
233    use futures_rustls::rustls::client::danger::HandshakeSignatureValid;
234    use futures_rustls::rustls::client::danger::ServerCertVerified;
235    use futures_rustls::rustls::client::danger::ServerCertVerifier;
236    use futures_rustls::rustls::pki_types::CertificateDer;
237    use futures_rustls::rustls::pki_types::PrivateKeyDer;
238    use futures_rustls::rustls::pki_types::ServerName;
239    use futures_rustls::rustls::server::WantsServerCert;
240    use futures_rustls::rustls::server::WebPkiClientVerifier;
241
242    use anyhow::{Context, Result};
243    use tracing::info;
244
245    use super::load_root_ca;
246    use super::{load_certs, load_first_key_from_reader};
247    use super::{load_certs_from_reader, load_first_key};
248
249    pub type ClientConfigBuilder<Stage> = ConfigBuilder<ClientConfig, Stage>;
250
251    pub struct ConnectorBuilder;
252
253    impl ConnectorBuilder {
254        pub fn with_safe_defaults() -> ConnectorBuilderStage<WantsVerifier> {
255            ConnectorBuilderStage(ClientConfig::builder())
256        }
257    }
258
259    pub struct ConnectorBuilderStage<S>(ConfigBuilder<ClientConfig, S>);
260
261    impl ConnectorBuilderStage<WantsVerifier> {
262        pub fn load_ca_cert<P: AsRef<Path>>(
263            self,
264            path: P,
265        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
266            let certs = load_certs(path)?;
267            self.with_root_certificates(certs)
268        }
269
270        pub fn load_ca_cert_from_bytes(
271            self,
272            buffer: &[u8],
273        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
274            let certs = load_certs_from_reader(&mut Cursor::new(buffer))?;
275            self.with_root_certificates(certs)
276        }
277
278        pub fn no_cert_verification(self) -> ConnectorBuilderWithConfig {
279            let config = self
280                .0
281                .dangerous()
282                .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
283                .with_no_client_auth();
284
285            ConnectorBuilderWithConfig(config)
286        }
287
288        fn with_root_certificates(
289            self,
290            certs: Vec<CertificateDer>,
291        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
292            let mut root_store = RootCertStore::empty();
293
294            for cert in certs {
295                root_store.add(cert).context("invalid ca crt")?;
296            }
297
298            Ok(ConnectorBuilderStage(
299                self.0.with_root_certificates(root_store),
300            ))
301        }
302    }
303
304    impl ConnectorBuilderStage<WantsClientCert> {
305        pub fn load_client_certs<P: AsRef<Path>>(
306            self,
307            cert_path: P,
308            key_path: P,
309        ) -> Result<ConnectorBuilderWithConfig> {
310            let certs = load_certs(cert_path)?;
311            let key = load_first_key(key_path)?;
312            self.with_single_cert(certs, key)
313        }
314
315        pub fn load_client_certs_from_bytes(
316            self,
317            cert_buf: &[u8],
318            key_buf: &[u8],
319        ) -> Result<ConnectorBuilderWithConfig> {
320            let certs = load_certs_from_reader(&mut Cursor::new(cert_buf))?;
321            let key = load_first_key_from_reader(&mut Cursor::new(key_buf))?;
322            self.with_single_cert(certs, key)
323        }
324
325        pub fn no_client_auth(self) -> ConnectorBuilderWithConfig {
326            ConnectorBuilderWithConfig(self.0.with_no_client_auth())
327        }
328
329        fn with_single_cert(
330            self,
331            certs: Vec<CertificateDer<'static>>,
332            key: PrivateKeyDer<'static>,
333        ) -> Result<ConnectorBuilderWithConfig> {
334            let config = self
335                .0
336                .with_client_auth_cert(certs, key)
337                .context("invalid cert")?;
338
339            Ok(ConnectorBuilderWithConfig(config))
340        }
341    }
342
343    pub struct ConnectorBuilderWithConfig(ClientConfig);
344
345    impl ConnectorBuilderWithConfig {
346        pub fn build(self) -> TlsConnector {
347            Arc::new(self.0).into()
348        }
349    }
350
351    pub struct AcceptorBuilder;
352
353    impl AcceptorBuilder {
354        pub fn with_safe_defaults() -> AcceptorBuilderStage<WantsVerifier> {
355            AcceptorBuilderStage(ServerConfig::builder())
356        }
357    }
358
359    pub struct AcceptorBuilderStage<S>(ConfigBuilder<ServerConfig, S>);
360
361    impl AcceptorBuilderStage<WantsVerifier> {
362        /// Require no client authentication.
363        pub fn no_client_authentication(self) -> AcceptorBuilderStage<WantsServerCert> {
364            AcceptorBuilderStage(self.0.with_no_client_auth())
365        }
366
367        /// Require client authentication. Must pass CA root path.
368        pub fn client_authenticate<P: AsRef<Path>>(
369            self,
370            path: P,
371        ) -> Result<AcceptorBuilderStage<WantsServerCert>> {
372            let root_store = load_root_ca(path)?;
373
374            let client_verifier = WebPkiClientVerifier::builder(root_store.into())
375                .build()
376                .context("invalid verifier")?;
377
378            Ok(AcceptorBuilderStage(
379                self.0.with_client_cert_verifier(client_verifier),
380            ))
381        }
382    }
383
384    impl AcceptorBuilderStage<WantsServerCert> {
385        pub fn load_server_certs(
386            self,
387            cert_path: impl AsRef<Path>,
388            key_path: impl AsRef<Path>,
389        ) -> Result<AcceptorBuilderWithConfig> {
390            let certs = load_certs(cert_path)?;
391            let key = load_first_key(key_path)?;
392
393            let config = self
394                .0
395                .with_single_cert(certs, key)
396                .context("invalid cert")?;
397
398            Ok(AcceptorBuilderWithConfig(config))
399        }
400    }
401
402    pub struct AcceptorBuilderWithConfig(ServerConfig);
403
404    impl AcceptorBuilderWithConfig {
405        pub fn build(self) -> TlsAcceptor {
406            TlsAcceptor::from(Arc::new(self.0))
407        }
408    }
409
410    #[derive(Debug)]
411    struct NoCertificateVerification;
412
413    impl ServerCertVerifier for NoCertificateVerification {
414        fn verify_server_cert(
415            &self,
416            _end_entity: &CertificateDer,
417            _intermediates: &[CertificateDer],
418            _server_name: &ServerName,
419            _ocsp_response: &[u8],
420            _now: UnixTime,
421        ) -> Result<ServerCertVerified, TlsError> {
422            info!("ignoring server cert");
423            Ok(ServerCertVerified::assertion())
424        }
425
426        fn verify_tls12_signature(
427            &self,
428            _message: &[u8],
429            _cert: &CertificateDer<'_>,
430            _dss: &futures_rustls::rustls::DigitallySignedStruct,
431        ) -> Result<HandshakeSignatureValid, TlsError> {
432            info!("ignoring server cert");
433            Ok(HandshakeSignatureValid::assertion())
434        }
435
436        fn verify_tls13_signature(
437            &self,
438            _message: &[u8],
439            _cert: &CertificateDer<'_>,
440            _dss: &futures_rustls::rustls::DigitallySignedStruct,
441        ) -> Result<HandshakeSignatureValid, TlsError> {
442            info!("ignoring server cert");
443            Ok(HandshakeSignatureValid::assertion())
444        }
445
446        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
447            let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
448            provider
449                .signature_verification_algorithms
450                .supported_schemes()
451        }
452    }
453}
454
455#[cfg(test)]
456mod test {
457
458    use std::net::SocketAddr;
459    use std::time;
460
461    use bytes::BufMut;
462    use bytes::Bytes;
463    use bytes::BytesMut;
464    use futures_lite::future::zip;
465    use futures_lite::stream::StreamExt;
466    use futures_rustls::TlsAcceptor;
467    use futures_rustls::TlsConnector;
468    use futures_util::sink::SinkExt;
469    use tokio_util::codec::BytesCodec;
470    use tokio_util::codec::Framed;
471    use tokio_util::compat::FuturesAsyncReadCompatExt;
472    use tracing::debug;
473
474    use fluvio_future::net::TcpListener;
475    use fluvio_future::net::tcp_stream::stream;
476    use fluvio_future::test_async;
477    use fluvio_future::timer::sleep;
478
479    use anyhow::Result;
480
481    use super::{AcceptorBuilder, ConnectorBuilder};
482
483    const CA_PATH: &str = "certs/test-certs/ca.crt";
484    const ITER: u16 = 10;
485
486    fn to_bytes(bytes: Vec<u8>) -> Bytes {
487        let mut buf = BytesMut::with_capacity(bytes.len());
488        buf.put_slice(&bytes);
489        buf.freeze()
490    }
491
492    #[test_async(ignore)]
493    async fn test_rust_tls_all() -> Result<()> {
494        test_rustls(
495            AcceptorBuilder::with_safe_defaults()
496                .no_client_authentication()
497                .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
498                .build(),
499            ConnectorBuilder::with_safe_defaults()
500                .no_cert_verification()
501                .build(),
502        )
503        .await
504        .expect("no client cert test failed");
505
506        // test client authentication
507
508        test_rustls(
509            AcceptorBuilder::with_safe_defaults()
510                .client_authenticate(CA_PATH)?
511                .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
512                .build(),
513            ConnectorBuilder::with_safe_defaults()
514                .load_ca_cert(CA_PATH)?
515                .load_client_certs("certs/test-certs/client.crt", "certs/test-certs/client.key")?
516                .build(),
517        )
518        .await
519        .expect("client cert test fail");
520
521        Ok(())
522    }
523
524    async fn test_rustls(acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
525        let addr = "127.0.0.1:19998".parse::<SocketAddr>().expect("parse");
526
527        let server_ft = async {
528            debug!("server: binding");
529            let listener = TcpListener::bind(&addr).await.expect("listener failed");
530            debug!("server: successfully binding. waiting for incoming");
531
532            let mut incoming = listener.incoming();
533            let stream = incoming.next().await.expect("stream");
534            let tcp_stream = stream.expect("no stream");
535            let acceptor = acceptor.clone();
536            debug!("server: got connection from client");
537            debug!("server: try to accept tls connection");
538            let tls_stream = acceptor.accept(tcp_stream).await.expect("accept");
539
540            let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
541
542            for i in 0..ITER {
543                let receives_bytes = framed.next().await.expect("frame");
544
545                let bytes = receives_bytes.expect("invalid value");
546                debug!(
547                    "server: loop {}, received from client: {} bytes",
548                    i,
549                    bytes.len()
550                );
551
552                let slice = bytes.as_ref();
553                let mut str_bytes = vec![];
554                for b in slice {
555                    str_bytes.push(b.to_owned());
556                }
557                let message = String::from_utf8(str_bytes).expect("utf8");
558                assert_eq!(message, format!("message{}", i));
559                let resply = format!("{}reply", message);
560                let reply_bytes = resply.as_bytes();
561                debug!("sever: send back reply: {}", resply);
562                framed
563                    .send(to_bytes(reply_bytes.to_vec()))
564                    .await
565                    .expect("send failed");
566            }
567
568            Ok(()) as Result<()>
569        };
570
571        let client_ft = async {
572            debug!("client: sleep to give server chance to come up");
573            sleep(time::Duration::from_millis(100)).await;
574            debug!("client: trying to connect");
575            let tcp_stream = stream(&addr).await.expect("connection fail");
576            let tls_stream = connector
577                .connect("localhost".try_into().expect("domain"), tcp_stream)
578                .await
579                .expect("tls failed");
580            let all_stream = Box::new(tls_stream);
581            let mut framed = Framed::new(all_stream.compat(), BytesCodec::new());
582            debug!("client: got connection. waiting");
583
584            for i in 0..ITER {
585                let message = format!("message{}", i);
586                let bytes = message.as_bytes();
587                debug!("client: loop {} sending test message", i);
588                framed
589                    .send(to_bytes(bytes.to_vec()))
590                    .await
591                    .expect("send failed");
592                let reply = framed.next().await.expect("messages").expect("frame");
593                debug!("client: loop {}, received reply back", i);
594                let slice = reply.as_ref();
595                let mut str_bytes = vec![];
596                for b in slice {
597                    str_bytes.push(b.to_owned());
598                }
599                let message = String::from_utf8(str_bytes).expect("utf8");
600                assert_eq!(message, format!("message{}reply", i));
601            }
602
603            Ok(()) as Result<()>
604        };
605
606        let _ = zip(client_ft, server_ft).await;
607
608        Ok(())
609    }
610}