1use crate::net::TcpStream;
2
3pub use futures_rustls::TlsAcceptor;
4pub use futures_rustls::TlsConnector;
5pub use futures_rustls::client::TlsStream as ClientTlsStream;
6pub use futures_rustls::server::TlsStream as ServerTlsStream;
7
8pub type DefaultServerTlsStream = ServerTlsStream<TcpStream>;
9pub type DefaultClientTlsStream = ClientTlsStream<TcpStream>;
10
11pub use builder::*;
12pub use cert::*;
13pub use connector::*;
14
15mod split {
16
17 use futures_util::AsyncReadExt;
18
19 use super::*;
20 use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
21
22 impl SplitConnection for DefaultClientTlsStream {
23 fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
24 let (read, write) = self.split();
25 (Box::new(write), Box::new(read))
26 }
27 }
28
29 impl SplitConnection for DefaultServerTlsStream {
30 fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
31 let (read, write) = self.split();
32 (Box::new(write), Box::new(read))
33 }
34 }
35}
36
37mod cert {
38 use std::fs::File;
39 use std::io::BufRead;
40 use std::io::BufReader;
41 use std::iter;
42 use std::path::Path;
43
44 use anyhow::{Context, Result, anyhow};
45 use futures_rustls::rustls::RootCertStore;
46 use futures_rustls::rustls::pki_types::CertificateDer;
47 use futures_rustls::rustls::pki_types::PrivateKeyDer;
48 use rustls_pemfile::{Item, certs, read_one};
49
50 pub fn load_certs<P: AsRef<Path>>(path: P) -> Result<Vec<CertificateDer<'static>>> {
51 load_certs_from_reader(&mut BufReader::new(File::open(path)?))
52 }
53
54 pub fn load_certs_from_reader(rd: &mut dyn BufRead) -> Result<Vec<CertificateDer<'static>>> {
55 certs(rd).map(|r| r.context("invalid cert")).collect()
56 }
57
58 pub fn load_keys<P: AsRef<Path>>(path: P) -> Result<Vec<PrivateKeyDer<'static>>> {
60 load_keys_from_reader(&mut BufReader::new(File::open(path)?))
61 }
62
63 pub fn load_keys_from_reader(rd: &mut dyn BufRead) -> Result<Vec<PrivateKeyDer<'static>>> {
64 let mut keys = vec![];
65 for item in iter::from_fn(|| read_one(rd).transpose()) {
66 match item.unwrap() {
67 Item::Pkcs1Key(key) => keys.push(PrivateKeyDer::from(key)),
68 Item::Pkcs8Key(key) => keys.push(PrivateKeyDer::from(key)),
69 Item::Sec1Key(key) => keys.push(PrivateKeyDer::from(key)),
70 _ => {}
71 }
72 }
73
74 Ok(keys)
75 }
76
77 pub(crate) fn load_first_key<P: AsRef<Path>>(path: P) -> Result<PrivateKeyDer<'static>> {
78 load_first_key_from_reader(&mut BufReader::new(File::open(path)?))
79 }
80
81 pub(crate) fn load_first_key_from_reader(
82 rd: &mut dyn BufRead,
83 ) -> Result<PrivateKeyDer<'static>> {
84 let mut keys = load_keys_from_reader(rd)?;
85
86 if keys.is_empty() {
87 Err(anyhow!("no keys found"))
88 } else {
89 Ok(keys.remove(0))
90 }
91 }
92
93 pub fn load_root_ca<P: AsRef<Path>>(path: P) -> Result<RootCertStore> {
94 let certs = load_certs(path).map_err(|err| err.context("invalid ca crt"))?;
95
96 let mut root_store = RootCertStore::empty();
97
98 for cert in certs {
99 root_store.add(cert).context("invalid ca crt")?;
100 }
101
102 Ok(root_store)
103 }
104}
105
106mod connector {
107 use std::io::Error as IoError;
108 use std::io::ErrorKind;
109
110 use async_trait::async_trait;
111 use futures_rustls::rustls::pki_types::ServerName;
112 use tracing::debug;
113
114 use crate::net::{
115 AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
116 SplitConnection, TcpDomainConnector, tcp_stream::stream,
117 };
118
119 use super::TlsConnector;
120
121 pub type TlsError = IoError;
122
123 #[derive(Clone)]
125 pub struct TlsAnonymousConnector(TlsConnector);
126
127 impl From<TlsConnector> for TlsAnonymousConnector {
128 fn from(connector: TlsConnector) -> Self {
129 Self(connector)
130 }
131 }
132
133 #[async_trait]
134 impl TcpDomainConnector for TlsAnonymousConnector {
135 async fn connect(
136 &self,
137 domain: &str,
138 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
139 let tcp_stream = stream(domain).await?;
140 let fd = tcp_stream.as_connection_fd();
141
142 let server_name = ServerName::try_from(domain).map_err(|err| {
143 IoError::new(ErrorKind::InvalidInput, format!("Invalid Dns Name: {err}"))
144 })?;
145
146 let (write, read) = self
147 .0
148 .connect(server_name.to_owned(), tcp_stream)
149 .await?
150 .split_connection();
151 Ok((write, read, fd))
152 }
153
154 fn new_domain(&self, _domain: String) -> DomainConnector {
155 Box::new(self.clone())
156 }
157
158 fn domain(&self) -> &str {
159 "localhost"
160 }
161
162 fn clone_box(&self) -> DomainConnector {
163 Box::new(self.clone())
164 }
165 }
166
167 #[derive(Clone)]
168 pub struct TlsDomainConnector {
169 domain: String,
170 connector: TlsConnector,
171 }
172
173 impl TlsDomainConnector {
174 pub fn new(connector: TlsConnector, domain: String) -> Self {
175 Self { domain, connector }
176 }
177 }
178
179 #[async_trait]
180 impl TcpDomainConnector for TlsDomainConnector {
181 async fn connect(
182 &self,
183 addr: &str,
184 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
185 debug!("connect to tls addr: {}", addr);
186 let tcp_stream = stream(addr).await?;
187 let fd = tcp_stream.as_connection_fd();
188 debug!("connect to tls domain: {}", self.domain);
189
190 let server_name = ServerName::try_from(self.domain.as_str()).map_err(|err| {
191 IoError::new(ErrorKind::InvalidInput, format!("Invalid Dns Name: {err}"))
192 })?;
193
194 let (write, read) = self
195 .connector
196 .connect(server_name.to_owned(), tcp_stream)
197 .await?
198 .split_connection();
199 Ok((write, read, fd))
200 }
201
202 fn new_domain(&self, domain: String) -> DomainConnector {
203 let mut connector = self.clone();
204 connector.domain = domain;
205 Box::new(connector)
206 }
207
208 fn domain(&self) -> &str {
209 &self.domain
210 }
211
212 fn clone_box(&self) -> DomainConnector {
213 Box::new(self.clone())
214 }
215 }
216}
217
218mod builder {
219
220 use std::io::Cursor;
221 use std::path::Path;
222 use std::sync::Arc;
223
224 use futures_rustls::TlsAcceptor;
225 use futures_rustls::TlsConnector;
226 use futures_rustls::pki_types::UnixTime;
227 use futures_rustls::rustls::ClientConfig;
228 use futures_rustls::rustls::ConfigBuilder;
229 use futures_rustls::rustls::Error as TlsError;
230 use futures_rustls::rustls::RootCertStore;
231 use futures_rustls::rustls::ServerConfig;
232 use futures_rustls::rustls::SignatureScheme;
233 use futures_rustls::rustls::WantsVerifier;
234 use futures_rustls::rustls::client::WantsClientCert;
235 use futures_rustls::rustls::client::danger::HandshakeSignatureValid;
236 use futures_rustls::rustls::client::danger::ServerCertVerified;
237 use futures_rustls::rustls::client::danger::ServerCertVerifier;
238 use futures_rustls::rustls::pki_types::CertificateDer;
239 use futures_rustls::rustls::pki_types::PrivateKeyDer;
240 use futures_rustls::rustls::pki_types::ServerName;
241 use futures_rustls::rustls::server::WantsServerCert;
242 use futures_rustls::rustls::server::WebPkiClientVerifier;
243
244 use anyhow::{Context, Result};
245 use tracing::info;
246
247 use super::load_root_ca;
248 use super::{load_certs, load_first_key_from_reader};
249 use super::{load_certs_from_reader, load_first_key};
250
251 pub type ClientConfigBuilder<Stage> = ConfigBuilder<ClientConfig, Stage>;
252
253 pub struct ConnectorBuilder;
254
255 impl ConnectorBuilder {
256 pub fn with_safe_defaults() -> ConnectorBuilderStage<WantsVerifier> {
257 ConnectorBuilderStage(ClientConfig::builder())
258 }
259 }
260
261 pub struct ConnectorBuilderStage<S>(ConfigBuilder<ClientConfig, S>);
262
263 impl ConnectorBuilderStage<WantsVerifier> {
264 pub fn load_ca_cert<P: AsRef<Path>>(
265 self,
266 path: P,
267 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
268 let certs = load_certs(path)?;
269 self.with_root_certificates(certs)
270 }
271
272 pub fn load_ca_cert_from_bytes(
273 self,
274 buffer: &[u8],
275 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
276 let certs = load_certs_from_reader(&mut Cursor::new(buffer))?;
277 self.with_root_certificates(certs)
278 }
279
280 pub fn no_cert_verification(self) -> ConnectorBuilderWithConfig {
281 let config = self
282 .0
283 .dangerous()
284 .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
285 .with_no_client_auth();
286
287 ConnectorBuilderWithConfig(config)
288 }
289
290 fn with_root_certificates(
291 self,
292 certs: Vec<CertificateDer>,
293 ) -> Result<ConnectorBuilderStage<WantsClientCert>> {
294 let mut root_store = RootCertStore::empty();
295
296 for cert in certs {
297 root_store.add(cert).context("invalid ca crt")?;
298 }
299
300 Ok(ConnectorBuilderStage(
301 self.0.with_root_certificates(root_store),
302 ))
303 }
304 }
305
306 impl ConnectorBuilderStage<WantsClientCert> {
307 pub fn load_client_certs<P: AsRef<Path>>(
308 self,
309 cert_path: P,
310 key_path: P,
311 ) -> Result<ConnectorBuilderWithConfig> {
312 let certs = load_certs(cert_path)?;
313 let key = load_first_key(key_path)?;
314 self.with_single_cert(certs, key)
315 }
316
317 pub fn load_client_certs_from_bytes(
318 self,
319 cert_buf: &[u8],
320 key_buf: &[u8],
321 ) -> Result<ConnectorBuilderWithConfig> {
322 let certs = load_certs_from_reader(&mut Cursor::new(cert_buf))?;
323 let key = load_first_key_from_reader(&mut Cursor::new(key_buf))?;
324 self.with_single_cert(certs, key)
325 }
326
327 pub fn no_client_auth(self) -> ConnectorBuilderWithConfig {
328 ConnectorBuilderWithConfig(self.0.with_no_client_auth())
329 }
330
331 fn with_single_cert(
332 self,
333 certs: Vec<CertificateDer<'static>>,
334 key: PrivateKeyDer<'static>,
335 ) -> Result<ConnectorBuilderWithConfig> {
336 let config = self
337 .0
338 .with_client_auth_cert(certs, key)
339 .context("invalid cert")?;
340
341 Ok(ConnectorBuilderWithConfig(config))
342 }
343 }
344
345 pub struct ConnectorBuilderWithConfig(ClientConfig);
346
347 impl ConnectorBuilderWithConfig {
348 pub fn build(self) -> TlsConnector {
349 Arc::new(self.0).into()
350 }
351 }
352
353 pub struct AcceptorBuilder;
354
355 impl AcceptorBuilder {
356 pub fn with_safe_defaults() -> AcceptorBuilderStage<WantsVerifier> {
357 AcceptorBuilderStage(ServerConfig::builder())
358 }
359 }
360
361 pub struct AcceptorBuilderStage<S>(ConfigBuilder<ServerConfig, S>);
362
363 impl AcceptorBuilderStage<WantsVerifier> {
364 pub fn no_client_authentication(self) -> AcceptorBuilderStage<WantsServerCert> {
366 AcceptorBuilderStage(self.0.with_no_client_auth())
367 }
368
369 pub fn client_authenticate<P: AsRef<Path>>(
371 self,
372 path: P,
373 ) -> Result<AcceptorBuilderStage<WantsServerCert>> {
374 let root_store = load_root_ca(path)?;
375
376 let client_verifier = WebPkiClientVerifier::builder(root_store.into())
377 .build()
378 .context("invalid verifier")?;
379
380 Ok(AcceptorBuilderStage(
381 self.0.with_client_cert_verifier(client_verifier),
382 ))
383 }
384 }
385
386 impl AcceptorBuilderStage<WantsServerCert> {
387 pub fn load_server_certs(
388 self,
389 cert_path: impl AsRef<Path>,
390 key_path: impl AsRef<Path>,
391 ) -> Result<AcceptorBuilderWithConfig> {
392 let certs = load_certs(cert_path)?;
393 let key = load_first_key(key_path)?;
394
395 let config = self
396 .0
397 .with_single_cert(certs, key)
398 .context("invalid cert")?;
399
400 Ok(AcceptorBuilderWithConfig(config))
401 }
402 }
403
404 pub struct AcceptorBuilderWithConfig(ServerConfig);
405
406 impl AcceptorBuilderWithConfig {
407 pub fn build(self) -> TlsAcceptor {
408 TlsAcceptor::from(Arc::new(self.0))
409 }
410 }
411
412 #[derive(Debug)]
413 struct NoCertificateVerification;
414
415 impl ServerCertVerifier for NoCertificateVerification {
416 fn verify_server_cert(
417 &self,
418 _end_entity: &CertificateDer,
419 _intermediates: &[CertificateDer],
420 _server_name: &ServerName,
421 _ocsp_response: &[u8],
422 _now: UnixTime,
423 ) -> Result<ServerCertVerified, TlsError> {
424 info!("ignoring server cert");
425 Ok(ServerCertVerified::assertion())
426 }
427
428 fn verify_tls12_signature(
429 &self,
430 _message: &[u8],
431 _cert: &CertificateDer<'_>,
432 _dss: &futures_rustls::rustls::DigitallySignedStruct,
433 ) -> Result<HandshakeSignatureValid, TlsError> {
434 info!("ignoring server cert");
435 Ok(HandshakeSignatureValid::assertion())
436 }
437
438 fn verify_tls13_signature(
439 &self,
440 _message: &[u8],
441 _cert: &CertificateDer<'_>,
442 _dss: &futures_rustls::rustls::DigitallySignedStruct,
443 ) -> Result<HandshakeSignatureValid, TlsError> {
444 info!("ignoring server cert");
445 Ok(HandshakeSignatureValid::assertion())
446 }
447
448 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
449 let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
450 provider
451 .signature_verification_algorithms
452 .supported_schemes()
453 }
454 }
455}
456
457#[cfg(test)]
458mod test {
459
460 use std::net::SocketAddr;
461 use std::time;
462
463 use bytes::BufMut;
464 use bytes::Bytes;
465 use bytes::BytesMut;
466 use futures_lite::future::zip;
467 use futures_lite::stream::StreamExt;
468 use futures_rustls::TlsAcceptor;
469 use futures_rustls::TlsConnector;
470 use futures_util::sink::SinkExt;
471 use tokio_util::codec::BytesCodec;
472 use tokio_util::codec::Framed;
473 use tokio_util::compat::FuturesAsyncReadCompatExt;
474 use tracing::debug;
475
476 use fluvio_future::net::TcpListener;
477 use fluvio_future::net::tcp_stream::stream;
478 use fluvio_future::test_async;
479 use fluvio_future::timer::sleep;
480
481 use anyhow::Result;
482
483 use super::{AcceptorBuilder, ConnectorBuilder};
484
485 const CA_PATH: &str = "certs/test-certs/ca.crt";
486 const ITER: u16 = 10;
487
488 fn to_bytes(bytes: Vec<u8>) -> Bytes {
489 let mut buf = BytesMut::with_capacity(bytes.len());
490 buf.put_slice(&bytes);
491 buf.freeze()
492 }
493
494 #[test_async(ignore)]
495 async fn test_rust_tls_all() -> Result<()> {
496 test_rustls(
497 AcceptorBuilder::with_safe_defaults()
498 .no_client_authentication()
499 .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
500 .build(),
501 ConnectorBuilder::with_safe_defaults()
502 .no_cert_verification()
503 .build(),
504 )
505 .await
506 .expect("no client cert test failed");
507
508 test_rustls(
511 AcceptorBuilder::with_safe_defaults()
512 .client_authenticate(CA_PATH)?
513 .load_server_certs("certs/test-certs/server.crt", "certs/test-certs/server.key")?
514 .build(),
515 ConnectorBuilder::with_safe_defaults()
516 .load_ca_cert(CA_PATH)?
517 .load_client_certs("certs/test-certs/client.crt", "certs/test-certs/client.key")?
518 .build(),
519 )
520 .await
521 .expect("client cert test fail");
522
523 Ok(())
524 }
525
526 async fn test_rustls(acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
527 let addr = "127.0.0.1:19998".parse::<SocketAddr>().expect("parse");
528
529 let server_ft = async {
530 debug!("server: binding");
531 let listener = TcpListener::bind(&addr).await.expect("listener failed");
532 debug!("server: successfully binding. waiting for incoming");
533
534 let mut incoming = listener.incoming();
535 let stream = incoming.next().await.expect("stream");
536 let tcp_stream = stream.expect("no stream");
537 let acceptor = acceptor.clone();
538 debug!("server: got connection from client");
539 debug!("server: try to accept tls connection");
540 let tls_stream = acceptor.accept(tcp_stream).await.expect("accept");
541
542 let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
543
544 for i in 0..ITER {
545 let receives_bytes = framed.next().await.expect("frame");
546
547 let bytes = receives_bytes.expect("invalid value");
548 debug!(
549 "server: loop {}, received from client: {} bytes",
550 i,
551 bytes.len()
552 );
553
554 let slice = bytes.as_ref();
555 let mut str_bytes = vec![];
556 for b in slice {
557 str_bytes.push(b.to_owned());
558 }
559 let message = String::from_utf8(str_bytes).expect("utf8");
560 assert_eq!(message, format!("message{i}"));
561 let resply = format!("{message}reply");
562 let reply_bytes = resply.as_bytes();
563 debug!("sever: send back reply: {}", resply);
564 framed
565 .send(to_bytes(reply_bytes.to_vec()))
566 .await
567 .expect("send failed");
568 }
569
570 Ok(()) as Result<()>
571 };
572
573 let client_ft = async {
574 debug!("client: sleep to give server chance to come up");
575 sleep(time::Duration::from_millis(100)).await;
576 debug!("client: trying to connect");
577 let tcp_stream = stream(&addr).await.expect("connection fail");
578 let tls_stream = connector
579 .connect("localhost".try_into().expect("domain"), tcp_stream)
580 .await
581 .expect("tls failed");
582 let all_stream = Box::new(tls_stream);
583 let mut framed = Framed::new(all_stream.compat(), BytesCodec::new());
584 debug!("client: got connection. waiting");
585
586 for i in 0..ITER {
587 let message = format!("message{i}");
588 let bytes = message.as_bytes();
589 debug!("client: loop {} sending test message", i);
590 framed
591 .send(to_bytes(bytes.to_vec()))
592 .await
593 .expect("send failed");
594 let reply = framed.next().await.expect("messages").expect("frame");
595 debug!("client: loop {}, received reply back", i);
596 let slice = reply.as_ref();
597 let mut str_bytes = vec![];
598 for b in slice {
599 str_bytes.push(b.to_owned());
600 }
601 let message = String::from_utf8(str_bytes).expect("utf8");
602 assert_eq!(message, format!("message{i}reply"));
603 }
604
605 Ok(()) as Result<()>
606 };
607
608 let _ = zip(client_ft, server_ft).await;
609
610 Ok(())
611 }
612}