narrowlink_network/
transport.rs1use std::{
2 io,
3 net::SocketAddr,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use tokio::{
9 io::{AsyncRead, AsyncWrite, ReadBuf},
10 net::TcpStream,
11};
12use tracing::debug;
13
14use crate::{error::NetworkError, AsyncSocket};
15
16pub struct TlsConfiguration {
17 pub sni: String,
18}
19pub enum StreamType {
20 Tcp,
21 Tls(TlsConfiguration),
22}
23
24pub struct UnifiedSocket {
25 io: Box<dyn AsyncSocket>,
26 local_addr: SocketAddr,
27 peer_addr: SocketAddr,
28}
29
30impl UnifiedSocket {
31 pub async fn new(addr: &str, transport_type: StreamType) -> Result<Self, NetworkError> {
32 match transport_type {
33 StreamType::Tcp | StreamType::Tls(_) => {
34 let tcp_stream = TcpStream::connect(addr).await?;
35 let local_addr = tcp_stream.local_addr()?;
36 let peer_addr = tcp_stream.peer_addr()?;
37 let mut stream: Box<dyn AsyncSocket> = Box::new(tcp_stream);
38 if let StreamType::Tls(conf) = transport_type {
39 {
40 debug!("using rustls to connect to {}", peer_addr.to_string());
41 use std::sync::Arc;
42 use tokio_rustls::{
43 rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
44 TlsConnector,
45 };
46
47 let mut root_store = RootCertStore::empty();
48 root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
49 |ta| {
50 OwnedTrustAnchor::from_subject_spki_name_constraints(
51 ta.subject,
52 ta.spki,
53 ta.name_constraints,
54 )
55 },
56 ));
57
58 let config = ClientConfig::builder()
59 .with_safe_default_cipher_suites()
60 .with_safe_default_kx_groups()
61 .with_safe_default_protocol_versions()
62 .or(Err(NetworkError::TlsError))?
63 .with_root_certificates(root_store)
64 .with_no_client_auth();
65
66 let config = TlsConnector::from(Arc::new(config));
67
68 let dnsname = ServerName::try_from(conf.sni.as_str()).or(Err(
69 io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"),
70 ))?;
71 stream = Box::new(config.connect(dnsname, stream).await?);
72 }
73 }
74
75 Ok(Self {
76 io: stream,
77 local_addr,
78 peer_addr,
79 })
80 }
81 }
82 }
83 pub fn local_addr(&self) -> SocketAddr {
84 self.local_addr
85 }
86 pub fn peer_addr(&self) -> SocketAddr {
87 self.peer_addr
88 }
89}
90
91impl AsyncRead for UnifiedSocket {
92 fn poll_read(
93 mut self: Pin<&mut Self>,
94 cx: &mut Context<'_>,
95 buf: &mut ReadBuf<'_>,
96 ) -> Poll<io::Result<()>> {
97 Pin::new(&mut self.io).poll_read(cx, buf)
98 }
99}
100impl AsyncWrite for UnifiedSocket {
101 fn poll_write(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &[u8],
105 ) -> Poll<Result<usize, io::Error>> {
106 Pin::new(&mut self.io).poll_write(cx, buf)
107 }
108
109 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
110 Pin::new(&mut self.io).poll_flush(cx)
111 }
112
113 fn poll_shutdown(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 ) -> Poll<Result<(), io::Error>> {
117 Pin::new(&mut self.io).poll_shutdown(cx)
118 }
119}