1use crate::net::TcpStream;
2
3pub use async_native_tls::TlsAcceptor;
4pub use async_native_tls::TlsConnector;
5pub use async_native_tls::TlsStream;
6
7pub type DefaultServerTlsStream = TlsStream<TcpStream>;
9pub type DefaultClientTlsStream = TlsStream<TcpStream>;
10
11pub use connector::*;
12
13pub use crate::net::certs::CertBuilder;
14
15mod split {
16
17 use async_net::TcpStream;
18 use futures_util::AsyncReadExt;
19
20 use super::*;
21 use crate::net::{BoxReadConnection, BoxWriteConnection, SplitConnection};
22
23 impl SplitConnection for TlsStream<TcpStream> {
24 fn split_connection(self) -> (BoxWriteConnection, BoxReadConnection) {
25 let (read, write) = self.split();
26 (Box::new(write), Box::new(read))
27 }
28 }
29}
30
31mod connector {
32 use std::io::Error as IoError;
33 use std::io::ErrorKind;
34 use std::sync::Arc;
35
36 use async_trait::async_trait;
37 use tracing::debug;
38
39 use crate::net::{
40 AsConnectionFd, BoxReadConnection, BoxWriteConnection, ConnectionFd, DomainConnector,
41 SplitConnection, TcpDomainConnector,
42 tcp_stream::{SocketOpts, stream, stream_with_opts},
43 };
44
45 use super::*;
46
47 #[derive(Clone)]
49 pub struct TlsAnonymousConnector(Arc<TlsConnector>);
50
51 impl From<TlsConnector> for TlsAnonymousConnector {
52 fn from(connector: TlsConnector) -> Self {
53 Self(Arc::new(connector))
54 }
55 }
56
57 #[async_trait]
58 impl TcpDomainConnector for TlsAnonymousConnector {
59 async fn connect(
60 &self,
61 domain: &str,
62 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
63 let tcp_stream = stream(domain).await?;
64 let fd = tcp_stream.as_connection_fd();
65 let (write, read) = self
66 .0
67 .connect(domain, tcp_stream)
68 .await
69 .map_err(|e| {
70 IoError::new(
71 ErrorKind::ConnectionRefused,
72 format!("failed to connect: {e}"),
73 )
74 })?
75 .split_connection();
76 Ok((write, read, fd))
77 }
78
79 fn new_domain(&self, _domain: String) -> DomainConnector {
80 Box::new(self.clone())
81 }
82
83 fn domain(&self) -> &str {
84 "localhost"
85 }
86
87 fn clone_box(&self) -> DomainConnector {
88 Box::new(self.clone())
89 }
90 }
91
92 #[derive(Clone)]
94 pub struct TlsDomainConnector {
95 domain: String,
96 connector: Arc<TlsConnector>,
97 }
98
99 impl TlsDomainConnector {
100 pub fn new(connector: TlsConnector, domain: String) -> Self {
101 Self {
102 domain,
103 connector: Arc::new(connector),
104 }
105 }
106
107 pub fn domain(&self) -> &str {
108 &self.domain
109 }
110
111 pub fn connector(&self) -> &TlsConnector {
112 &self.connector
113 }
114 }
115
116 #[async_trait]
117 impl TcpDomainConnector for TlsDomainConnector {
118 async fn connect(
119 &self,
120 addr: &str,
121 ) -> Result<(BoxWriteConnection, BoxReadConnection, ConnectionFd), IoError> {
122 debug!("connect to tls addr: {}", addr);
123 let socket_opts = SocketOpts {
124 keepalive: Some(Default::default()),
125 nodelay: Some(true),
126 };
127 let tcp_stream = stream_with_opts(addr, Some(socket_opts)).await?;
128 let fd = tcp_stream.as_connection_fd();
129
130 debug!("connect to tls domain: {}", self.domain);
131 let (write, read) = self
132 .connector
133 .connect(&self.domain, tcp_stream)
134 .await
135 .map_err(|e| {
136 IoError::new(
137 ErrorKind::ConnectionRefused,
138 format!("failed to connect: {e}"),
139 )
140 })?
141 .split_connection();
142 Ok((write, read, fd))
143 }
144
145 fn new_domain(&self, domain: String) -> DomainConnector {
146 let mut connector = self.clone();
147 connector.domain = domain;
148 Box::new(connector)
149 }
150
151 fn domain(&self) -> &str {
152 &self.domain
153 }
154
155 fn clone_box(&self) -> DomainConnector {
156 Box::new(self.clone())
157 }
158 }
159}
160
161pub use cert::*;
162
163mod cert {
164 use crate::net::certs::CertBuilder;
165 use anyhow::{Context, Result};
166 use native_tls::Certificate as NativeCertificate;
167 use native_tls::Identity;
168 use openssl::pkcs12::Pkcs12;
169 use openssl::pkey::Private;
170
171 pub type Certificate = openssl::x509::X509;
172 pub type PrivateKey = openssl::pkey::PKey<Private>;
173
174 pub struct X509PemBuilder(Vec<u8>);
175
176 impl CertBuilder for X509PemBuilder {
177 fn new(bytes: Vec<u8>) -> Self {
178 Self(bytes)
179 }
180 }
181
182 impl X509PemBuilder {
183 pub fn build(self) -> Result<Certificate> {
184 let cert = Certificate::from_pem(&self.0).context("invalid cert")?;
185 Ok(cert)
186 }
187
188 pub fn build_native(self) -> Result<NativeCertificate> {
189 NativeCertificate::from_pem(&self.0).context("invalid pem file")
190 }
191 }
192
193 pub struct PrivateKeyBuilder(Vec<u8>);
194
195 impl CertBuilder for PrivateKeyBuilder {
196 fn new(bytes: Vec<u8>) -> Self {
197 Self(bytes)
198 }
199 }
200
201 impl PrivateKeyBuilder {
202 pub fn build(self) -> Result<PrivateKey> {
203 let key = PrivateKey::private_key_from_pem(&self.0).context("invalid key")?;
204 Ok(key)
205 }
206 }
207
208 const PASSWORD: &str = "test";
209
210 pub struct IdentityBuilder(Vec<u8>);
211
212 impl CertBuilder for IdentityBuilder {
213 fn new(bytes: Vec<u8>) -> Self {
214 Self(bytes)
215 }
216 }
217
218 impl IdentityBuilder {
219 pub fn from_x509(x509: X509PemBuilder, key: PrivateKeyBuilder) -> Result<Self> {
221 let server_key = key.build()?;
222 let server_crt = x509.build()?;
223 let p12 = Pkcs12::builder()
224 .name("")
225 .pkey(&server_key)
226 .cert(&server_crt)
227 .build2(PASSWORD)
228 .context("Failed to create Pkcs12")?;
229
230 let der = p12.to_der()?;
231 Ok(Self(der))
232 }
233
234 pub fn build(self) -> Result<Identity> {
235 Identity::from_pkcs12(&self.0, PASSWORD).context("Failed to load der")
236 }
237 }
238}
239
240pub use builder::*;
241
242mod builder {
243 use anyhow::{Context, Result};
244 use native_tls::Identity;
245 use native_tls::TlsAcceptor as NativeTlsAcceptor;
246
247 use super::IdentityBuilder;
248 use super::TlsAcceptor;
249 use super::TlsConnector;
250 use super::X509PemBuilder;
251
252 pub struct ConnectorBuilder(TlsConnector);
253
254 impl ConnectorBuilder {
255 pub fn identity(builder: IdentityBuilder) -> Result<Self> {
256 let identity = builder.build()?;
257 let connector = TlsConnector::new().identity(identity);
258 Ok(Self(connector))
259 }
260
261 pub fn anonymous() -> Self {
262 let connector = TlsConnector::new()
263 .danger_accept_invalid_certs(true)
264 .danger_accept_invalid_hostnames(true);
265 Self(connector)
266 }
267
268 pub fn no_cert_verification(self) -> Self {
269 let connector = self.0.danger_accept_invalid_certs(true);
270 Self(connector)
271 }
272
273 pub fn danger_accept_invalid_hostnames(self) -> Self {
274 let connector = self.0.danger_accept_invalid_hostnames(true);
275 Self(connector)
276 }
277
278 pub fn use_sni(self, use_sni: bool) -> Self {
279 let connector = self.0.use_sni(use_sni);
280 Self(connector)
281 }
282
283 pub fn add_root_certificate(self, builder: X509PemBuilder) -> Result<Self> {
284 let certificate = builder.build_native()?;
285 let connector = self.0.add_root_certificate(certificate);
286 Ok(Self(connector))
287 }
288
289 pub fn build(self) -> TlsConnector {
290 self.0
291 }
292 }
293
294 pub struct AcceptorBuilder(Identity);
295
296 impl AcceptorBuilder {
297 pub fn identity(builder: IdentityBuilder) -> Result<Self> {
298 let identity = builder.build()?;
299 Ok(Self(identity))
300 }
301
302 pub fn build(self) -> Result<TlsAcceptor> {
303 let acceptor = NativeTlsAcceptor::new(self.0).context("invalid cert")?;
304 Ok(acceptor.into())
305 }
306 }
307}
308
309#[cfg(test)]
310mod test {
311
312 use std::net::SocketAddr;
313 use std::time;
314
315 use anyhow::Result;
316 use async_native_tls::TlsAcceptor;
317 use async_native_tls::TlsConnector;
318 use bytes::Buf;
319 use bytes::BufMut;
320 use bytes::Bytes;
321 use bytes::BytesMut;
322 use futures_lite::future::zip;
323 use futures_lite::stream::StreamExt;
324 use futures_util::sink::SinkExt;
325 use tokio_util::codec::BytesCodec;
326 use tokio_util::codec::Framed;
327 use tokio_util::compat::FuturesAsyncReadCompatExt;
328 use tracing::debug;
329
330 use crate::net::TcpListener;
331 use crate::net::certs::CertBuilder;
332 use crate::net::tcp_stream::stream;
333 use crate::test_async;
334 use crate::timer::sleep;
335
336 use super::{
337 AcceptorBuilder, ConnectorBuilder, IdentityBuilder, PrivateKeyBuilder, X509PemBuilder,
338 };
339
340 const CA_PATH: &str = "certs/test-certs/ca.crt";
341 const SERVER_IDENTITY: &str = "certs/test-certs/server.pfx";
342 const CLIENT_IDENTITY: &str = "certs/test-certs/client.pfx";
343 const X509_SERVER: &str = "certs/test-certs/server.crt";
344 const X509_SERVER_KEY: &str = "certs/test-certs/server.key";
345 const X509_CLIENT: &str = "certs/test-certs/client.crt";
346 const X509_CLIENT_KEY: &str = "certs/test-certs/client.key";
347
348 fn to_bytes(bytes: Vec<u8>) -> Bytes {
349 let mut buf = BytesMut::with_capacity(bytes.len());
350 buf.put_slice(&bytes);
351 buf.freeze()
352 }
353
354 #[test_async]
355 async fn test_native_tls_pk12() -> Result<()> {
356 const PK12_PORT: u16 = 9900;
357
358 let acceptor = AcceptorBuilder::identity(
359 IdentityBuilder::from_path(SERVER_IDENTITY).expect("identity"),
360 )
361 .expect("identity:")
362 .build()
363 .expect("acceptor");
364
365 let connector = ConnectorBuilder::identity(IdentityBuilder::from_path(CLIENT_IDENTITY)?)
366 .expect("connector")
367 .danger_accept_invalid_hostnames()
368 .no_cert_verification()
369 .build();
370
371 test_tls(PK12_PORT, acceptor, connector)
372 .await
373 .expect("no client cert test failed");
374
375 let acceptor = AcceptorBuilder::identity(
376 IdentityBuilder::from_path(SERVER_IDENTITY).expect("identity"),
377 )
378 .expect("identity:")
379 .build()
380 .expect("acceptor");
381
382 let connector = ConnectorBuilder::identity(IdentityBuilder::from_path(CLIENT_IDENTITY)?)
383 .expect("connector")
384 .no_cert_verification()
385 .build();
386
387 test_tls(PK12_PORT, acceptor, connector)
388 .await
389 .expect("client cert test fail");
390
391 Ok(())
392 }
393
394 #[test_async]
395 #[cfg(not(windows))]
396 async fn test_native_tls_x509() -> Result<()> {
397 const X500_PORT: u16 = 9910;
398
399 let acceptor = AcceptorBuilder::identity(
400 IdentityBuilder::from_x509(
401 X509PemBuilder::from_path(X509_SERVER).expect("read"),
402 PrivateKeyBuilder::from_path(X509_SERVER_KEY).expect("file"),
403 )
404 .expect("identity"),
405 )
406 .expect("identity:")
407 .build()
408 .expect("acceptor");
409
410 let connector = ConnectorBuilder::identity(
411 IdentityBuilder::from_x509(
412 X509PemBuilder::from_path(X509_CLIENT).expect("read"),
413 PrivateKeyBuilder::from_path(X509_CLIENT_KEY).expect("read"),
414 )
415 .expect("509"),
416 )
417 .expect("connector")
418 .danger_accept_invalid_hostnames()
419 .no_cert_verification()
420 .build();
421
422 test_tls(X500_PORT, acceptor, connector)
423 .await
424 .expect("no client cert test failed");
425
426 let acceptor = AcceptorBuilder::identity(
427 IdentityBuilder::from_x509(
428 X509PemBuilder::from_path(X509_SERVER).expect("read"),
429 PrivateKeyBuilder::from_path(X509_SERVER_KEY).expect("file"),
430 )
431 .expect("identity"),
432 )
433 .expect("identity:")
434 .build()
435 .expect("acceptor");
436
437 let connector = ConnectorBuilder::identity(
438 IdentityBuilder::from_x509(
439 X509PemBuilder::from_path(X509_CLIENT).expect("read"),
440 PrivateKeyBuilder::from_path(X509_CLIENT_KEY).expect("read"),
441 )
442 .expect("509"),
443 )
444 .expect("connector")
445 .add_root_certificate(X509PemBuilder::from_path(CA_PATH).expect("cert"))
446 .expect("root")
447 .no_cert_verification() .build();
449
450 test_tls(X500_PORT, acceptor, connector)
451 .await
452 .expect("no client cert test failed");
453
454 Ok(())
455 }
456
457 async fn test_tls(port: u16, acceptor: TlsAcceptor, connector: TlsConnector) -> Result<()> {
458 const TEST_ITERATION: u16 = 2;
459
460 let addr = format!("127.0.0.1:{port}")
461 .parse::<SocketAddr>()
462 .expect("parse");
463
464 let server_ft = async {
465 debug!("server: binding");
466 let listener = TcpListener::bind(&addr).await.expect("listener failed");
467 debug!("server: successfully binding. waiting for incoming");
468
469 let mut incoming = listener.incoming();
470 let incoming_stream = incoming.next().await.expect("incoming");
471
472 debug!("server: got connection from client");
473 let tcp_stream = incoming_stream.expect("no stream");
474
475 let tls_stream = acceptor.accept(tcp_stream).await.unwrap();
476 let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
477
478 for _ in 0..TEST_ITERATION {
479 debug!("server: sending values to client");
480 let data = vec![0x05, 0x0a, 0x63];
481 framed.send(to_bytes(data)).await.expect("send failed");
482 sleep(time::Duration::from_micros(1)).await;
483 debug!("server: sending 2nd value to client");
484 let data2 = vec![0x20, 0x11];
485 framed.send(to_bytes(data2)).await.expect("2nd send failed");
486 }
487
488 sleep(time::Duration::from_secs(1)).await;
490
491 Ok(()) as Result<()>
492 };
493
494 let client_ft = async {
495 debug!("client: sleep to give server chance to come up");
496 sleep(time::Duration::from_millis(100)).await;
497 debug!("client: trying to connect");
498 let tcp_stream = stream(&addr).await.expect("connection fail");
499 let tls_stream = connector
500 .connect("localhost", tcp_stream)
501 .await
502 .expect("tls failed");
503
504 let mut framed = Framed::new(tls_stream.compat(), BytesCodec::new());
505 debug!("client: got connection. waiting");
506
507 for i in 0..TEST_ITERATION {
508 let value = framed.next().await.expect("frame");
509 debug!("{} client :received first value from server", i);
510 let bytes = value.expect("invalid value");
511 let values = bytes.take(3).into_inner();
512 assert_eq!(values[0], 0x05);
513 assert_eq!(values[1], 0x0a);
514 assert_eq!(values[2], 0x63);
515 assert_eq!(values.len(), 3);
516
517 let value2 = framed.next().await.expect("frame");
518 debug!("client: received 2nd value from server");
519 let bytes = value2.expect("packet decoding works");
520 let values = bytes.take(2).into_inner();
521 assert_eq!(values.len(), 2);
522 }
523
524 Ok(()) as Result<()>
525 };
526
527 let _ = zip(client_ft, server_ft).await;
528
529 Ok(())
530 }
531}