narrowlink_network/
transport.rs

1use std::{
2    io,
3    net::SocketAddr,
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use tokio::{
9    io::{AsyncRead, AsyncWrite, ReadBuf},
10    net::TcpStream,
11};
12use tracing::debug;
13
14use crate::{error::NetworkError, AsyncSocket};
15
16pub struct TlsConfiguration {
17    pub sni: String,
18}
19pub enum StreamType {
20    Tcp,
21    Tls(TlsConfiguration),
22}
23
24pub struct UnifiedSocket {
25    io: Box<dyn AsyncSocket>,
26    local_addr: SocketAddr,
27    peer_addr: SocketAddr,
28}
29
30impl UnifiedSocket {
31    pub async fn new(addr: &str, transport_type: StreamType) -> Result<Self, NetworkError> {
32        match transport_type {
33            StreamType::Tcp | StreamType::Tls(_) => {
34                let tcp_stream = TcpStream::connect(addr).await?;
35                let local_addr = tcp_stream.local_addr()?;
36                let peer_addr = tcp_stream.peer_addr()?;
37                let mut stream: Box<dyn AsyncSocket> = Box::new(tcp_stream);
38                if let StreamType::Tls(conf) = transport_type {
39                    {
40                        debug!("using rustls to connect to {}", peer_addr.to_string());
41                        use std::sync::Arc;
42                        use tokio_rustls::{
43                            rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
44                            TlsConnector,
45                        };
46
47                        let mut root_store = RootCertStore::empty();
48                        root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(
49                            |ta| {
50                                OwnedTrustAnchor::from_subject_spki_name_constraints(
51                                    ta.subject,
52                                    ta.spki,
53                                    ta.name_constraints,
54                                )
55                            },
56                        ));
57
58                        let config = ClientConfig::builder()
59                            .with_safe_default_cipher_suites()
60                            .with_safe_default_kx_groups()
61                            .with_safe_default_protocol_versions()
62                            .or(Err(NetworkError::TlsError))?
63                            .with_root_certificates(root_store)
64                            .with_no_client_auth();
65
66                        let config = TlsConnector::from(Arc::new(config));
67
68                        let dnsname = ServerName::try_from(conf.sni.as_str()).or(Err(
69                            io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"),
70                        ))?;
71                        stream = Box::new(config.connect(dnsname, stream).await?);
72                    }
73                }
74
75                Ok(Self {
76                    io: stream,
77                    local_addr,
78                    peer_addr,
79                })
80            }
81        }
82    }
83    pub fn local_addr(&self) -> SocketAddr {
84        self.local_addr
85    }
86    pub fn peer_addr(&self) -> SocketAddr {
87        self.peer_addr
88    }
89}
90
91impl AsyncRead for UnifiedSocket {
92    fn poll_read(
93        mut self: Pin<&mut Self>,
94        cx: &mut Context<'_>,
95        buf: &mut ReadBuf<'_>,
96    ) -> Poll<io::Result<()>> {
97        Pin::new(&mut self.io).poll_read(cx, buf)
98    }
99}
100impl AsyncWrite for UnifiedSocket {
101    fn poll_write(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &[u8],
105    ) -> Poll<Result<usize, io::Error>> {
106        Pin::new(&mut self.io).poll_write(cx, buf)
107    }
108
109    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
110        Pin::new(&mut self.io).poll_flush(cx)
111    }
112
113    fn poll_shutdown(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116    ) -> Poll<Result<(), io::Error>> {
117        Pin::new(&mut self.io).poll_shutdown(cx)
118    }
119}