fluvio_future/
rust_tls.rs

1use crate::net::TcpStream;
2
3pub use futures_rustls::TlsAcceptor;
4pub use futures_rustls::TlsConnector;
5pub use futures_rustls::client::TlsStream as ClientTlsStream;
6pub use futures_rustls::server::TlsStream as ServerTlsStream;
7
8pub type DefaultServerTlsStream = ServerTlsStream<TcpStream>;
9pub type DefaultClientTlsStream = ClientTlsStream<TcpStream>;
10
11pub use builder::*;
12pub use cert::*;
13pub use connector::*;
14
15mod split {
16
17    use futures_util::AsyncReadExt;
18
19    use super::*;
20    use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
21
22    impl SplitConnection for DefaultClientTlsStream {
23        fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
24            let (read, write) = self.split();
25            (Box::new(write), Box::new(read))
26        }
27    }
28
29    impl SplitConnection for DefaultServerTlsStream {
30        fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
31            let (read, write) = self.split();
32            (Box::new(write), Box::new(read))
33        }
34    }
35}
36
37mod cert {
38    use std::fs::File;
39    use std::io::BufRead;
40    use std::io::BufReader;
41    use std::iter;
42    use std::path::Path;
43
44    use anyhow::{Context, Result, anyhow};
45    use futures_rustls::rustls::RootCertStore;
46    use futures_rustls::rustls::pki_types::CertificateDer;
47    use futures_rustls::rustls::pki_types::PrivateKeyDer;
48    use rustls_pemfile::{Item, certs, read_one};
49
50    pub fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
51        load_certs_from_reader(&mut BufReader::new(File::open(path)?))
52    }
53
54    pub fn load_certs_from_reader(rd: &mut dyn BufRead) -> Result<Vec<CertificateDer<'static>>> {
55        certs(rd).map(|r| r.context("invalid cert")).collect()
56    }
57
58    /// Load the passed keys file
59    pub fn load_keys<P: AsRef<Path>>(path: P) -> Result<Vec<PrivateKeyDer<'static>>> {
60        load_keys_from_reader(&mut BufReader::new(File::open(path)?))
61    }
62
63    pub fn load_keys_from_reader(rd: &mut dyn BufRead) -> Result<Vec<PrivateKeyDer<'static>>> {
64        let mut keys = vec![];
65        for item in iter::from_fn(|| read_one(rd).transpose()) {
66            match item.unwrap() {
67                Item::Pkcs1Key(key) => keys.push(PrivateKeyDer::from(key)),
68                Item::Pkcs8Key(key) => keys.push(PrivateKeyDer::from(key)),
69                Item::Sec1Key(key) => keys.push(PrivateKeyDer::from(key)),
70                _ => {}
71            }
72        }
73
74        Ok(keys)
75    }
76
77    pub(crate) fn load_first_key<P: AsRef<Path>>(path: P) -> Result<PrivateKeyDer<'static>> {
78        load_first_key_from_reader(&mut BufReader::new(File::open(path)?))
79    }
80
81    pub(crate) fn load_first_key_from_reader(
82        rd: &mut dyn BufRead,
83    ) -> Result<PrivateKeyDer<'static>> {
84        let mut keys = load_keys_from_reader(rd)?;
85
86        if keys.is_empty() {
87            Err(anyhow!("no keys found"))
88        } else {
89            Ok(keys.remove(0))
90        }
91    }
92
93    pub fn load_root_ca<P: AsRef<Path>>(path: P) -> Result<RootCertStore> {
94        let certs = load_certs(path).map_err(|err| err.context("invalid ca crt"))?;
95
96        let mut root_store = RootCertStore::empty();
97
98        for cert in certs {
99            root_store.add(cert).context("invalid ca crt")?;
100        }
101
102        Ok(root_store)
103    }
104}
105
106mod connector {
107    use std::io::Error as IoError;
108    use std::io::ErrorKind;
109
110    use async_trait::async_trait;
111    use futures_rustls::rustls::pki_types::ServerName;
112    use tracing::debug;
113
114    use crate::net::{
115        AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
116        SplitConnection, TcpDomainConnector, tcp_stream::stream,
117    };
118
119    use super::TlsConnector;
120
121    pub type TlsError = IoError;
122
123    /// connect as anonymous client
124    #[derive(Clone)]
125    pub struct TlsAnonymousConnector(TlsConnector);
126
127    impl From<TlsConnector> for TlsAnonymousConnector {
128        fn from(connector: TlsConnector) -> Self {
129            Self(connector)
130        }
131    }
132
133    #[async_trait]
134    impl TcpDomainConnector for TlsAnonymousConnector {
135        async fn connect(
136            &self,
137            domain: &str,
138        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
139            let tcp_stream = stream(domain).await?;
140            let fd = tcp_stream.as_connection_fd();
141
142            let server_name = ServerName::try_from(domain).map_err(|err| {
143                IoError::new(ErrorKind::InvalidInput, format!("Invalid Dns Name: {err}"))
144            })?;
145
146            let (write, read) = self
147                .0
148                .connect(server_name.to_owned(), tcp_stream)
149                .await?
150                .split_connection();
151            Ok((write, read, fd))
152        }
153
154        fn new_domain(&self, _domain: String) -> DomainConnector {
155            Box::new(self.clone())
156        }
157
158        fn domain(&self) -> &str {
159            "localhost"
160        }
161
162        fn clone_box(&self) -> DomainConnector {
163            Box::new(self.clone())
164        }
165    }
166
167    #[derive(Clone)]
168    pub struct TlsDomainConnector {
169        domain: String,
170        connector: TlsConnector,
171    }
172
173    impl TlsDomainConnector {
174        pub fn new(connector: TlsConnector, domain: String) -> Self {
175            Self { domain, connector }
176        }
177    }
178
179    #[async_trait]
180    impl TcpDomainConnector for TlsDomainConnector {
181        async fn connect(
182            &self,
183            addr: &str,
184        ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
185            debug!("connect to tls addr: {}", addr);
186            let tcp_stream = stream(addr).await?;
187            let fd = tcp_stream.as_connection_fd();
188            debug!("connect to tls domain: {}", self.domain);
189
190            let server_name = ServerName::try_from(self.domain.as_str()).map_err(|err| {
191                IoError::new(ErrorKind::InvalidInput, format!("Invalid Dns Name: {err}"))
192            })?;
193
194            let (write, read) = self
195                .connector
196                .connect(server_name.to_owned(), tcp_stream)
197                .await?
198                .split_connection();
199            Ok((write, read, fd))
200        }
201
202        fn new_domain(&self, domain: String) -> DomainConnector {
203            let mut connector = self.clone();
204            connector.domain = domain;
205            Box::new(connector)
206        }
207
208        fn domain(&self) -> &str {
209            &self.domain
210        }
211
212        fn clone_box(&self) -> DomainConnector {
213            Box::new(self.clone())
214        }
215    }
216}
217
218mod builder {
219
220    use std::io::Cursor;
221    use std::path::Path;
222    use std::sync::Arc;
223
224    use futures_rustls::TlsAcceptor;
225    use futures_rustls::TlsConnector;
226    use futures_rustls::pki_types::UnixTime;
227    use futures_rustls::rustls::ClientConfig;
228    use futures_rustls::rustls::ConfigBuilder;
229    use futures_rustls::rustls::Error as TlsError;
230    use futures_rustls::rustls::RootCertStore;
231    use futures_rustls::rustls::ServerConfig;
232    use futures_rustls::rustls::SignatureScheme;
233    use futures_rustls::rustls::WantsVerifier;
234    use futures_rustls::rustls::client::WantsClientCert;
235    use futures_rustls::rustls::client::danger::HandshakeSignatureValid;
236    use futures_rustls::rustls::client::danger::ServerCertVerified;
237    use futures_rustls::rustls::client::danger::ServerCertVerifier;
238    use futures_rustls::rustls::pki_types::CertificateDer;
239    use futures_rustls::rustls::pki_types::PrivateKeyDer;
240    use futures_rustls::rustls::pki_types::ServerName;
241    use futures_rustls::rustls::server::WantsServerCert;
242    use futures_rustls::rustls::server::WebPkiClientVerifier;
243
244    use anyhow::{Context, Result};
245    use tracing::info;
246
247    use super::load_root_ca;
248    use super::{load_certs, load_first_key_from_reader};
249    use super::{load_certs_from_reader, load_first_key};
250
251    pub type ClientConfigBuilder<Stage> = ConfigBuilder<ClientConfig, Stage>;
252
253    pub struct ConnectorBuilder;
254
255    impl ConnectorBuilder {
256        pub fn with_safe_defaults() -> ConnectorBuilderStage<WantsVerifier> {
257            ConnectorBuilderStage(ClientConfig::builder())
258        }
259    }
260
261    pub struct ConnectorBuilderStage<S>(ConfigBuilder<ClientConfig, S>);
262
263    impl ConnectorBuilderStage<WantsVerifier> {
264        pub fn load_ca_cert<P: AsRef<Path>>(
265            self,
266            path: P,
267        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
268            let certs = load_certs(path)?;
269            self.with_root_certificates(certs)
270        }
271
272        pub fn load_ca_cert_from_bytes(
273            self,
274            buffer: &[u8],
275        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
276            let certs = load_certs_from_reader(&mut Cursor::new(buffer))?;
277            self.with_root_certificates(certs)
278        }
279
280        pub fn no_cert_verification(self) -> ConnectorBuilderWithConfig {
281            let config = self
282                .0
283                .dangerous()
284                .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
285                .with_no_client_auth();
286
287            ConnectorBuilderWithConfig(config)
288        }
289
290        fn with_root_certificates(
291            self,
292            certs: Vec<CertificateDer>,
293        ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
294            let mut root_store = RootCertStore::empty();
295
296            for cert in certs {
297                root_store.add(cert).context("invalid ca crt")?;
298            }
299
300            Ok(ConnectorBuilderStage(
301                self.0.with_root_certificates(root_store),
302            ))
303        }
304    }
305
306    impl ConnectorBuilderStage<WantsClientCert> {
307        pub fn load_client_certs<P: AsRef<Path>>(
308            self,
309            cert_path: P,
310            key_path: P,
311        ) -> Result<ConnectorBuilderWithConfig> {
312            let certs = load_certs(cert_path)?;
313            let key = load_first_key(key_path)?;
314            self.with_single_cert(certs, key)
315        }
316
317        pub fn load_client_certs_from_bytes(
318            self,
319            cert_buf: &[u8],
320            key_buf: &[u8],
321        ) -> Result<ConnectorBuilderWithConfig> {
322            let certs = load_certs_from_reader(&mut Cursor::new(cert_buf))?;
323            let key = load_first_key_from_reader(&mut Cursor::new(key_buf))?;
324            self.with_single_cert(certs, key)
325        }
326
327        pub fn no_client_auth(self) -> ConnectorBuilderWithConfig {
328            ConnectorBuilderWithConfig(self.0.with_no_client_auth())
329        }
330
331        fn with_single_cert(
332            self,
333            certs: Vec<CertificateDer<'static>>,
334            key: PrivateKeyDer<'static>,
335        ) -> Result<ConnectorBuilderWithConfig> {
336            let config = self
337                .0
338                .with_client_auth_cert(certs, key)
339                .context("invalid cert")?;
340
341            Ok(ConnectorBuilderWithConfig(config))
342        }
343    }
344
345    pub struct ConnectorBuilderWithConfig(ClientConfig);
346
347    impl ConnectorBuilderWithConfig {
348        pub fn build(self) -> TlsConnector {
349            Arc::new(self.0).into()
350        }
351    }
352
353    pub struct AcceptorBuilder;
354
355    impl AcceptorBuilder {
356        pub fn with_safe_defaults() -> AcceptorBuilderStage<WantsVerifier> {
357            AcceptorBuilderStage(ServerConfig::builder())
358        }
359    }
360
361    pub struct AcceptorBuilderStage<S>(ConfigBuilder<ServerConfig, S>);
362
363    impl AcceptorBuilderStage<WantsVerifier> {
364        /// Require no client authentication.
365        pub fn no_client_authentication(self) -> AcceptorBuilderStage<WantsServerCert> {
366            AcceptorBuilderStage(self.0.with_no_client_auth())
367        }
368
369        /// Require client authentication. Must pass CA root path.
370        pub fn client_authenticate<P: AsRef<Path>>(
371            self,
372            path: P,
373        ) -> Result<AcceptorBuilderStage<WantsServerCert>> {
374            let root_store = load_root_ca(path)?;
375
376            let client_verifier = WebPkiClientVerifier::builder(root_store.into())
377                .build()
378                .context("invalid verifier")?;
379
380            Ok(AcceptorBuilderStage(
381                self.0.with_client_cert_verifier(client_verifier),
382            ))
383        }
384    }
385
386    impl AcceptorBuilderStage<WantsServerCert> {
387        pub fn load_server_certs(
388            self,
389            cert_path: impl AsRef<Path>,
390            key_path: impl AsRef<Path>,
391        ) -> Result<AcceptorBuilderWithConfig> {
392            let certs = load_certs(cert_path)?;
393            let key = load_first_key(key_path)?;
394
395            let config = self
396                .0
397                .with_single_cert(certs, key)
398                .context("invalid cert")?;
399
400            Ok(AcceptorBuilderWithConfig(config))
401        }
402    }
403
404    pub struct AcceptorBuilderWithConfig(ServerConfig);
405
406    impl AcceptorBuilderWithConfig {
407        pub fn build(self) -> TlsAcceptor {
408            TlsAcceptor::from(Arc::new(self.0))
409        }
410    }
411
412    #[derive(Debug)]
413    struct NoCertificateVerification;
414
415    impl ServerCertVerifier for NoCertificateVerification {
416        fn verify_server_cert(
417            &self,
418            _end_entity: &CertificateDer,
419            _intermediates: &[CertificateDer],
420            _server_name: &ServerName,
421            _ocsp_response: &[u8],
422            _now: UnixTime,
423        ) -> Result<ServerCertVerified, TlsError> {
424            info!("ignoring server cert");
425            Ok(ServerCertVerified::assertion())
426        }
427
428        fn verify_tls12_signature(
429            &self,
430            _message: &[u8],
431            _cert: &CertificateDer<'_>,
432            _dss: &futures_rustls::rustls::DigitallySignedStruct,
433        ) -> Result<HandshakeSignatureValid, TlsError> {
434            info!("ignoring server cert");
435            Ok(HandshakeSignatureValid::assertion())
436        }
437
438        fn verify_tls13_signature(
439            &self,
440            _message: &[u8],
441            _cert: &CertificateDer<'_>,
442            _dss: &futures_rustls::rustls::DigitallySignedStruct,
443        ) -> Result<HandshakeSignatureValid, TlsError> {
444            info!("ignoring server cert");
445            Ok(HandshakeSignatureValid::assertion())
446        }
447
448        fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
449            let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
450            provider
451                .signature_verification_algorithms
452                .supported_schemes()
453        }
454    }
455}
456
457#[cfg(test)]
458mod test {
459
460    use std::net::SocketAddr;
461    use std::time;
462
463    use bytes::BufMut;
464    use bytes::Bytes;
465    use bytes::BytesMut;
466    use futures_lite::future::zip;
467    use futures_lite::stream::StreamExt;
468    use futures_rustls::TlsAcceptor;
469    use futures_rustls::TlsConnector;
470    use futures_util::sink::SinkExt;
471    use tokio_util::codec::BytesCodec;
472    use tokio_util::codec::Framed;
473    use tokio_util::compat::FuturesAsyncReadCompatExt;
474    use tracing::debug;
475
476    use fluvio_future::net::TcpListener;
477    use fluvio_future::net::tcp_stream::stream;
478    use fluvio_future::test_async;
479    use fluvio_future::timer::sleep;
480
481    use anyhow::Result;
482
483    use super::{AcceptorBuilder, ConnectorBuilder};
484
485    const CA_PATH: &str = "certs/test-certs/ca.crt";
486    const ITER: u16 = 10;
487
488    fn to_bytes(bytes: Vec<u8>) -> Bytes {
489        let mut buf = BytesMut::with_capacity(bytes.len());
490        buf.put_slice(&bytes);
491        buf.freeze()
492    }
493
494    #[test_async(ignore)]
495    async fn test_rust_tls_all() -> Result<()> {
496        test_rustls(
497            AcceptorBuilder::with_safe_defaults()
498                .no_client_authentication()
499                .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
500                .build(),
501            ConnectorBuilder::with_safe_defaults()
502                .no_cert_verification()
503                .build(),
504        )
505        .await
506        .expect("no client cert test failed");
507
508        // test client authentication
509
510        test_rustls(
511            AcceptorBuilder::with_safe_defaults()
512                .client_authenticate(CA_PATH)?
513                .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
514                .build(),
515            ConnectorBuilder::with_safe_defaults()
516                .load_ca_cert(CA_PATH)?
517                .load_client_certs("certs/test-certs/client.crt", "certs/test-certs/client.key")?
518                .build(),
519        )
520        .await
521        .expect("client cert test fail");
522
523        Ok(())
524    }
525
526    async fn test_rustls(acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
527        let addr = "127.0.0.1:19998".parse::<SocketAddr>().expect("parse");
528
529        let server_ft = async {
530            debug!("server: binding");
531            let listener = TcpListener::bind(&addr).await.expect("listener failed");
532            debug!("server: successfully binding. waiting for incoming");
533
534            let mut incoming = listener.incoming();
535            let stream = incoming.next().await.expect("stream");
536            let tcp_stream = stream.expect("no stream");
537            let acceptor = acceptor.clone();
538            debug!("server: got connection from client");
539            debug!("server: try to accept tls connection");
540            let tls_stream = acceptor.accept(tcp_stream).await.expect("accept");
541
542            let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
543
544            for i in 0..ITER {
545                let receives_bytes = framed.next().await.expect("frame");
546
547                let bytes = receives_bytes.expect("invalid value");
548                debug!(
549                    "server: loop {}, received from client: {} bytes",
550                    i,
551                    bytes.len()
552                );
553
554                let slice = bytes.as_ref();
555                let mut str_bytes = vec![];
556                for b in slice {
557                    str_bytes.push(b.to_owned());
558                }
559                let message = String::from_utf8(str_bytes).expect("utf8");
560                assert_eq!(message, format!("message{i}"));
561                let resply = format!("{message}reply");
562                let reply_bytes = resply.as_bytes();
563                debug!("sever: send back reply: {}", resply);
564                framed
565                    .send(to_bytes(reply_bytes.to_vec()))
566                    .await
567                    .expect("send failed");
568            }
569
570            Ok(()) as Result<()>
571        };
572
573        let client_ft = async {
574            debug!("client: sleep to give server chance to come up");
575            sleep(time::Duration::from_millis(100)).await;
576            debug!("client: trying to connect");
577            let tcp_stream = stream(&addr).await.expect("connection fail");
578            let tls_stream = connector
579                .connect("localhost".try_into().expect("domain"), tcp_stream)
580                .await
581                .expect("tls failed");
582            let all_stream = Box::new(tls_stream);
583            let mut framed = Framed::new(all_stream.compat(), BytesCodec::new());
584            debug!("client: got connection. waiting");
585
586            for i in 0..ITER {
587                let message = format!("message{i}");
588                let bytes = message.as_bytes();
589                debug!("client: loop {} sending test message", i);
590                framed
591                    .send(to_bytes(bytes.to_vec()))
592                    .await
593                    .expect("send failed");
594                let reply = framed.next().await.expect("messages").expect("frame");
595                debug!("client: loop {}, received reply back", i);
596                let slice = reply.as_ref();
597                let mut str_bytes = vec![];
598                for b in slice {
599                    str_bytes.push(b.to_owned());
600                }
601                let message = String::from_utf8(str_bytes).expect("utf8");
602                assert_eq!(message, format!("message{i}reply"));
603            }
604
605            Ok(()) as Result<()>
606        };
607
608        let _ = zip(client_ft, server_ft).await;
609
610        Ok(())
611    }
612}