1use crate::net::TcpStream;
2
3pub use futures_rustls::client::TlsStream as ClientTlsStream;
4pub use futures_rustls::server::TlsStream as ServerTlsStream;
5pub use futures_rustls::TlsAcceptor;
6pub use futures_rustls::TlsConnector;
7
8pub type DefaultServerTlsStream = ServerTlsStream<TcpStream>;
9pub type DefaultClientTlsStream = ClientTlsStream<TcpStream>;
10
11pub use builder::*;
12pub use cert::*;
13pub use connector::*;
14
15mod split {
16
17 use futures_util::AsyncReadExt;
18
19 use super::*;
20 use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
21
22 impl SplitConnection for DefaultClientTlsStream {
23 fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
24 let (read, write) = self.split();
25 (Box::new(write), Box::new(read))
26 }
27 }
28
29 impl SplitConnection for DefaultServerTlsStream {
30 fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
31 let (read, write) = self.split();
32 (Box::new(write), Box::new(read))
33 }
34 }
35}
36
37mod cert {
38 use std::fs::File;
39 use std::io::BufRead;
40 use std::io::BufReader;
41 use std::path::Path;
42
43 use anyhow::{anyhow, Context, Result};
44 use futures_rustls::rustls::pki_types::CertificateDer;
45 use futures_rustls::rustls::pki_types::PrivateKeyDer;
46 use futures_rustls::rustls::RootCertStore;
47 use rustls_pemfile::certs;
48 use rustls_pemfile::pkcs8_private_keys;
49
50 pub fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
51 load_certs_from_reader(&mut BufReader::new(File::open(path)?))
52 }
53
54 pub fn load_certs_from_reader(rd: &mut dyn BufRead) -> Result<Vec<CertificateDer<'static>>> {
55 certs(rd).map(|r| r.context("invalid cert")).collect()
56 }
57
58 pub fn load_keys<P: AsRef<Path>>(path: P) -> Result<Vec<PrivateKeyDer<'static>>> {
60 load_keys_from_reader(&mut BufReader::new(File::open(path)?))
61 }
62
63 pub fn load_keys_from_reader(rd: &mut dyn BufRead) -> Result<Vec<PrivateKeyDer<'static>>> {
64 pkcs8_private_keys(rd)
65 .map(|r| r.map(|p| p.into()).context("invalid key"))
66 .collect()
67 }
68
69 pub(crate) fn load_first_key<P: AsRef<Path>>(path: P) -> Result<PrivateKeyDer<'static>> {
70 load_first_key_from_reader(&mut BufReader::new(File::open(path)?))
71 }
72
73 pub(crate) fn load_first_key_from_reader(
74 rd: &mut dyn BufRead,
75 ) -> Result<PrivateKeyDer<'static>> {
76 let mut keys = load_keys_from_reader(rd)?;
77
78 if keys.is_empty() {
79 Err(anyhow!("no keys found"))
80 } else {
81 Ok(keys.remove(0))
82 }
83 }
84
85 pub fn load_root_ca<P: AsRef<Path>>(path: P) -> Result<RootCertStore> {
86 let certs = load_certs(path).map_err(|err| err.context("invalid ca crt"))?;
87
88 let mut root_store = RootCertStore::empty();
89
90 for cert in certs {
91 root_store.add(cert).context("invalid ca crt")?;
92 }
93
94 Ok(root_store)
95 }
96}
97
98mod connector {
99 use std::io::Error as IoError;
100 use std::io::ErrorKind;
101
102 use async_trait::async_trait;
103 use futures_rustls::rustls::pki_types::ServerName;
104 use tracing::debug;
105
106 use crate::net::{
107 tcp_stream::stream, AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd,
108 DomainConnector, SplitConnection, TcpDomainConnector,
109 };
110
111 use super::TlsConnector;
112
113 pub type TlsError = IoError;
114
115 #[derive(Clone)]
117 pub struct TlsAnonymousConnector(TlsConnector);
118
119 impl From<TlsConnector> for TlsAnonymousConnector {
120 fn from(connector: TlsConnector) -> Self {
121 Self(connector)
122 }
123 }
124
125 #[async_trait]
126 impl TcpDomainConnector for TlsAnonymousConnector {
127 async fn connect(
128 &self,
129 domain: &str,
130 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
131 let tcp_stream = stream(domain).await?;
132 let fd = tcp_stream.as_connection_fd();
133
134 let server_name = ServerName::try_from(domain).map_err(|err| {
135 IoError::new(
136 ErrorKind::InvalidInput,
137 format!("Invalid Dns Name: {}", err),
138 )
139 })?;
140
141 let (write, read) = self
142 .0
143 .connect(server_name.to_owned(), tcp_stream)
144 .await?
145 .split_connection();
146 Ok((write, read, fd))
147 }
148
149 fn new_domain(&self, _domain: String) -> DomainConnector {
150 Box::new(self.clone())
151 }
152
153 fn domain(&self) -> &str {
154 "localhost"
155 }
156
157 fn clone_box(&self) -> DomainConnector {
158 Box::new(self.clone())
159 }
160 }
161
162 #[derive(Clone)]
163 pub struct TlsDomainConnector {
164 domain: String,
165 connector: TlsConnector,
166 }
167
168 impl TlsDomainConnector {
169 pub fn new(connector: TlsConnector, domain: String) -> Self {
170 Self { domain, connector }
171 }
172 }
173
174 #[async_trait]
175 impl TcpDomainConnector for TlsDomainConnector {
176 async fn connect(
177 &self,
178 addr: &str,
179 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
180 debug!("connect to tls addr: {}", addr);
181 let tcp_stream = stream(addr).await?;
182 let fd = tcp_stream.as_connection_fd();
183 debug!("connect to tls domain: {}", self.domain);
184
185 let server_name = ServerName::try_from(self.domain.as_str()).map_err(|err| {
186 IoError::new(
187 ErrorKind::InvalidInput,
188 format!("Invalid Dns Name: {}", err),
189 )
190 })?;
191
192 let (write, read) = self
193 .connector
194 .connect(server_name.to_owned(), tcp_stream)
195 .await?
196 .split_connection();
197 Ok((write, read, fd))
198 }
199
200 fn new_domain(&self, domain: String) -> DomainConnector {
201 let mut connector = self.clone();
202 connector.domain = domain;
203 Box::new(connector)
204 }
205
206 fn domain(&self) -> &str {
207 &self.domain
208 }
209
210 fn clone_box(&self) -> DomainConnector {
211 Box::new(self.clone())
212 }
213 }
214}
215
216mod builder {
217
218 use std::io::Cursor;
219 use std::path::Path;
220 use std::sync::Arc;
221
222 use futures_rustls::pki_types::UnixTime;
223 use futures_rustls::rustls::client::danger::HandshakeSignatureValid;
224 use futures_rustls::rustls::client::danger::ServerCertVerified;
225 use futures_rustls::rustls::client::danger::ServerCertVerifier;
226 use futures_rustls::rustls::client::WantsClientCert;
227 use futures_rustls::rustls::pki_types::CertificateDer;
228 use futures_rustls::rustls::pki_types::PrivateKeyDer;
229 use futures_rustls::rustls::pki_types::ServerName;
230 use futures_rustls::rustls::server::WantsServerCert;
231 use futures_rustls::rustls::server::WebPkiClientVerifier;
232 use futures_rustls::rustls::ClientConfig;
233 use futures_rustls::rustls::ConfigBuilder;
234 use futures_rustls::rustls::Error as TlsError;
235 use futures_rustls::rustls::RootCertStore;
236 use futures_rustls::rustls::ServerConfig;
237 use futures_rustls::rustls::SignatureScheme;
238 use futures_rustls::rustls::WantsVerifier;
239 use futures_rustls::TlsAcceptor;
240 use futures_rustls::TlsConnector;
241
242 use anyhow::{Context, Result};
243 use tracing::info;
244
245 use super::load_root_ca;
246 use super::{load_certs, load_first_key_from_reader};
247 use super::{load_certs_from_reader, load_first_key};
248
249 pub type ClientConfigBuilder<Stage> = ConfigBuilder<ClientConfig, Stage>;
250
251 pub struct ConnectorBuilder;
252
253 impl ConnectorBuilder {
254 pub fn with_safe_defaults() -> ConnectorBuilderStage<WantsVerifier> {
255 ConnectorBuilderStage(ClientConfig::builder())
256 }
257 }
258
259 pub struct ConnectorBuilderStage<S>(ConfigBuilder<ClientConfig, S>);
260
261 impl ConnectorBuilderStage<WantsVerifier> {
262 pub fn load_ca_cert<P: AsRef<Path>>(
263 self,
264 path: P,
265 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
266 let certs = load_certs(path)?;
267 self.with_root_certificates(certs)
268 }
269
270 pub fn load_ca_cert_from_bytes(
271 self,
272 buffer: &[u8],
273 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
274 let certs = load_certs_from_reader(&mut Cursor::new(buffer))?;
275 self.with_root_certificates(certs)
276 }
277
278 pub fn no_cert_verification(self) -> ConnectorBuilderWithConfig {
279 let config = self
280 .0
281 .dangerous()
282 .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
283 .with_no_client_auth();
284
285 ConnectorBuilderWithConfig(config)
286 }
287
288 fn with_root_certificates(
289 self,
290 certs: Vec<CertificateDer>,
291 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
292 let mut root_store = RootCertStore::empty();
293
294 for cert in certs {
295 root_store.add(cert).context("invalid ca crt")?;
296 }
297
298 Ok(ConnectorBuilderStage(
299 self.0.with_root_certificates(root_store),
300 ))
301 }
302 }
303
304 impl ConnectorBuilderStage<WantsClientCert> {
305 pub fn load_client_certs<P: AsRef<Path>>(
306 self,
307 cert_path: P,
308 key_path: P,
309 ) -> Result<ConnectorBuilderWithConfig> {
310 let certs = load_certs(cert_path)?;
311 let key = load_first_key(key_path)?;
312 self.with_single_cert(certs, key)
313 }
314
315 pub fn load_client_certs_from_bytes(
316 self,
317 cert_buf: &[u8],
318 key_buf: &[u8],
319 ) -> Result<ConnectorBuilderWithConfig> {
320 let certs = load_certs_from_reader(&mut Cursor::new(cert_buf))?;
321 let key = load_first_key_from_reader(&mut Cursor::new(key_buf))?;
322 self.with_single_cert(certs, key)
323 }
324
325 pub fn no_client_auth(self) -> ConnectorBuilderWithConfig {
326 ConnectorBuilderWithConfig(self.0.with_no_client_auth())
327 }
328
329 fn with_single_cert(
330 self,
331 certs: Vec<CertificateDer<'static>>,
332 key: PrivateKeyDer<'static>,
333 ) -> Result<ConnectorBuilderWithConfig> {
334 let config = self
335 .0
336 .with_client_auth_cert(certs, key)
337 .context("invalid cert")?;
338
339 Ok(ConnectorBuilderWithConfig(config))
340 }
341 }
342
343 pub struct ConnectorBuilderWithConfig(ClientConfig);
344
345 impl ConnectorBuilderWithConfig {
346 pub fn build(self) -> TlsConnector {
347 Arc::new(self.0).into()
348 }
349 }
350
351 pub struct AcceptorBuilder;
352
353 impl AcceptorBuilder {
354 pub fn with_safe_defaults() -> AcceptorBuilderStage<WantsVerifier> {
355 AcceptorBuilderStage(ServerConfig::builder())
356 }
357 }
358
359 pub struct AcceptorBuilderStage<S>(ConfigBuilder<ServerConfig, S>);
360
361 impl AcceptorBuilderStage<WantsVerifier> {
362 pub fn no_client_authentication(self) -> AcceptorBuilderStage<WantsServerCert> {
364 AcceptorBuilderStage(self.0.with_no_client_auth())
365 }
366
367 pub fn client_authenticate<P: AsRef<Path>>(
369 self,
370 path: P,
371 ) -> Result<AcceptorBuilderStage<WantsServerCert>> {
372 let root_store = load_root_ca(path)?;
373
374 let client_verifier = WebPkiClientVerifier::builder(root_store.into())
375 .build()
376 .context("invalid verifier")?;
377
378 Ok(AcceptorBuilderStage(
379 self.0.with_client_cert_verifier(client_verifier),
380 ))
381 }
382 }
383
384 impl AcceptorBuilderStage<WantsServerCert> {
385 pub fn load_server_certs(
386 self,
387 cert_path: impl AsRef<Path>,
388 key_path: impl AsRef<Path>,
389 ) -> Result<AcceptorBuilderWithConfig> {
390 let certs = load_certs(cert_path)?;
391 let key = load_first_key(key_path)?;
392
393 let config = self
394 .0
395 .with_single_cert(certs, key)
396 .context("invalid cert")?;
397
398 Ok(AcceptorBuilderWithConfig(config))
399 }
400 }
401
402 pub struct AcceptorBuilderWithConfig(ServerConfig);
403
404 impl AcceptorBuilderWithConfig {
405 pub fn build(self) -> TlsAcceptor {
406 TlsAcceptor::from(Arc::new(self.0))
407 }
408 }
409
410 #[derive(Debug)]
411 struct NoCertificateVerification;
412
413 impl ServerCertVerifier for NoCertificateVerification {
414 fn verify_server_cert(
415 &self,
416 _end_entity: &CertificateDer,
417 _intermediates: &[CertificateDer],
418 _server_name: &ServerName,
419 _ocsp_response: &[u8],
420 _now: UnixTime,
421 ) -> Result<ServerCertVerified, TlsError> {
422 info!("ignoring server cert");
423 Ok(ServerCertVerified::assertion())
424 }
425
426 fn verify_tls12_signature(
427 &self,
428 _message: &[u8],
429 _cert: &CertificateDer<'_>,
430 _dss: &futures_rustls::rustls::DigitallySignedStruct,
431 ) -> Result<HandshakeSignatureValid, TlsError> {
432 info!("ignoring server cert");
433 Ok(HandshakeSignatureValid::assertion())
434 }
435
436 fn verify_tls13_signature(
437 &self,
438 _message: &[u8],
439 _cert: &CertificateDer<'_>,
440 _dss: &futures_rustls::rustls::DigitallySignedStruct,
441 ) -> Result<HandshakeSignatureValid, TlsError> {
442 info!("ignoring server cert");
443 Ok(HandshakeSignatureValid::assertion())
444 }
445
446 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
447 let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
448 provider
449 .signature_verification_algorithms
450 .supported_schemes()
451 }
452 }
453}
454
455#[cfg(test)]
456mod test {
457
458 use std::net::SocketAddr;
459 use std::time;
460
461 use bytes::BufMut;
462 use bytes::Bytes;
463 use bytes::BytesMut;
464 use futures_lite::future::zip;
465 use futures_lite::stream::StreamExt;
466 use futures_rustls::TlsAcceptor;
467 use futures_rustls::TlsConnector;
468 use futures_util::sink::SinkExt;
469 use tokio_util::codec::BytesCodec;
470 use tokio_util::codec::Framed;
471 use tokio_util::compat::FuturesAsyncReadCompatExt;
472 use tracing::debug;
473
474 use fluvio_future::net::tcp_stream::stream;
475 use fluvio_future::net::TcpListener;
476 use fluvio_future::test_async;
477 use fluvio_future::timer::sleep;
478
479 use anyhow::Result;
480
481 use super::{AcceptorBuilder, ConnectorBuilder};
482
483 const CA_PATH: &str = "certs/test-certs/ca.crt";
484 const ITER: u16 = 10;
485
486 fn to_bytes(bytes: Vec<u8>) -> Bytes {
487 let mut buf = BytesMut::with_capacity(bytes.len());
488 buf.put_slice(&bytes);
489 buf.freeze()
490 }
491
492 #[test_async(ignore)]
493 async fn test_rust_tls_all() -> Result<()> {
494 test_rustls(
495 AcceptorBuilder::with_safe_defaults()
496 .no_client_authentication()
497 .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
498 .build(),
499 ConnectorBuilder::with_safe_defaults()
500 .no_cert_verification()
501 .build(),
502 )
503 .await
504 .expect("no client cert test failed");
505
506 test_rustls(
509 AcceptorBuilder::with_safe_defaults()
510 .client_authenticate(CA_PATH)?
511 .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
512 .build(),
513 ConnectorBuilder::with_safe_defaults()
514 .load_ca_cert(CA_PATH)?
515 .load_client_certs("certs/test-certs/client.crt", "certs/test-certs/client.key")?
516 .build(),
517 )
518 .await
519 .expect("client cert test fail");
520
521 Ok(())
522 }
523
524 async fn test_rustls(acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
525 let addr = "127.0.0.1:19998".parse::<SocketAddr>().expect("parse");
526
527 let server_ft = async {
528 debug!("server: binding");
529 let listener = TcpListener::bind(&addr).await.expect("listener failed");
530 debug!("server: successfully binding. waiting for incoming");
531
532 let mut incoming = listener.incoming();
533 let stream = incoming.next().await.expect("stream");
534 let tcp_stream = stream.expect("no stream");
535 let acceptor = acceptor.clone();
536 debug!("server: got connection from client");
537 debug!("server: try to accept tls connection");
538 let tls_stream = acceptor.accept(tcp_stream).await.expect("accept");
539
540 let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
541
542 for i in 0..ITER {
543 let receives_bytes = framed.next().await.expect("frame");
544
545 let bytes = receives_bytes.expect("invalid value");
546 debug!(
547 "server: loop {}, received from client: {} bytes",
548 i,
549 bytes.len()
550 );
551
552 let slice = bytes.as_ref();
553 let mut str_bytes = vec![];
554 for b in slice {
555 str_bytes.push(b.to_owned());
556 }
557 let message = String::from_utf8(str_bytes).expect("utf8");
558 assert_eq!(message, format!("message{}", i));
559 let resply = format!("{}reply", message);
560 let reply_bytes = resply.as_bytes();
561 debug!("sever: send back reply: {}", resply);
562 framed
563 .send(to_bytes(reply_bytes.to_vec()))
564 .await
565 .expect("send failed");
566 }
567
568 Ok(()) as Result<()>
569 };
570
571 let client_ft = async {
572 debug!("client: sleep to give server chance to come up");
573 sleep(time::Duration::from_millis(100)).await;
574 debug!("client: trying to connect");
575 let tcp_stream = stream(&addr).await.expect("connection fail");
576 let tls_stream = connector
577 .connect("localhost".try_into().expect("domain"), tcp_stream)
578 .await
579 .expect("tls failed");
580 let all_stream = Box::new(tls_stream);
581 let mut framed = Framed::new(all_stream.compat(), BytesCodec::new());
582 debug!("client: got connection. waiting");
583
584 for i in 0..ITER {
585 let message = format!("message{}", i);
586 let bytes = message.as_bytes();
587 debug!("client: loop {} sending test message", i);
588 framed
589 .send(to_bytes(bytes.to_vec()))
590 .await
591 .expect("send failed");
592 let reply = framed.next().await.expect("messages").expect("frame");
593 debug!("client: loop {}, received reply back", i);
594 let slice = reply.as_ref();
595 let mut str_bytes = vec![];
596 for b in slice {
597 str_bytes.push(b.to_owned());
598 }
599 let message = String::from_utf8(str_bytes).expect("utf8");
600 assert_eq!(message, format!("message{}reply", i));
601 }
602
603 Ok(()) as Result<()>
604 };
605
606 let _ = zip(client_ft, server_ft).await;
607
608 Ok(())
609 }
610}