1use crate::memdx::error::Error;
20use crate::memdx::error::Result;
21use crate::tls_config::TlsConfig;
22use socket2::TcpKeepalive;
23use std::fmt::Debug;
24use std::io;
25use std::net::{IpAddr, Ipv4Addr, SocketAddr};
26use std::time::Duration;
27use tokio::io::{AsyncRead, AsyncWrite};
28use tokio::net::TcpStream;
29use tokio::time::{timeout_at, Instant};
30
31use crate::address::Address;
32#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
33use {
34 tokio_rustls::rustls::pki_types::DnsName, tokio_rustls::rustls::pki_types::ServerName,
35 tokio_rustls::TlsConnector,
36};
37
38#[derive(Debug)]
39pub struct ConnectOptions {
40 pub deadline: Instant,
41 pub tcp_keep_alive_time: Duration,
42}
43
44pub trait Stream: Debug + AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static {}
45
46impl Stream for TcpStream {}
47
48#[derive(Debug)]
49pub enum ConnectionType {
50 Tcp(TcpConnection),
51 Tls(TlsConnection),
52}
53
54impl ConnectionType {
55 pub fn into_inner(self) -> Box<dyn Stream> {
56 match self {
57 ConnectionType::Tcp(connection) => Box::new(connection.stream),
58 ConnectionType::Tls(connection) => Box::new(connection.stream),
59 }
60 }
61
62 pub fn local_addr(&self) -> &SocketAddr {
63 match self {
64 ConnectionType::Tcp(connection) => &connection.local_addr,
65 ConnectionType::Tls(connection) => &connection.local_addr,
66 }
67 }
68
69 pub fn peer_addr(&self) -> &SocketAddr {
70 match self {
71 ConnectionType::Tcp(connection) => &connection.peer_addr,
72 ConnectionType::Tls(connection) => &connection.peer_addr,
73 }
74 }
75}
76
77#[derive(Debug)]
78pub struct TcpConnection {
79 stream: TcpStream,
80
81 local_addr: SocketAddr,
82 peer_addr: SocketAddr,
83}
84
85impl TcpConnection {
86 async fn tcp_stream(
87 addr: &str,
88 opts: &ConnectOptions,
89 ) -> Result<(TcpStream, SocketAddr, SocketAddr)> {
90 let tcp_socket = timeout_at(opts.deadline, TcpStream::connect(addr))
91 .await
92 .map_err(|e| {
93 Error::new_connection_failed_error(
94 "failed to connect to server within timeout",
95 Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
96 )
97 })?
98 .map_err(|e| {
99 Error::new_connection_failed_error("failed to create tcp stream", Box::new(e))
100 })?;
101
102 let local_addr = tcp_socket
103 .local_addr()
104 .unwrap_or_else(|_e| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
105
106 let peer_addr = tcp_socket
107 .peer_addr()
108 .unwrap_or_else(|_e| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0));
109
110 socket2::SockRef::from(&tcp_socket)
112 .set_tcp_keepalive(&TcpKeepalive::new().with_time(opts.tcp_keep_alive_time))?;
113
114 tcp_socket.set_nodelay(false).map_err(|e| {
115 Error::new_connection_failed_error("failed to set tcp nodelay", Box::new(e))
116 })?;
117
118 Ok((tcp_socket, local_addr, peer_addr))
119 }
120
121 pub async fn connect(addr: Address, opts: ConnectOptions) -> Result<TcpConnection> {
122 let (stream, local_addr, peer_addr) =
123 TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
124
125 Ok(TcpConnection {
126 stream,
127 local_addr,
128 peer_addr,
129 })
130 }
131
132 fn local_addr(&self) -> &SocketAddr {
133 &self.local_addr
134 }
135
136 fn peer_addr(&self) -> &SocketAddr {
137 &self.peer_addr
138 }
139}
140
141#[derive(Debug)]
142pub struct TlsConnection {
143 #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
144 stream: tokio_rustls::client::TlsStream<TcpStream>,
145 #[cfg(feature = "native-tls")]
146 stream: tokio_native_tls::TlsStream<TcpStream>,
147
148 local_addr: SocketAddr,
149 peer_addr: SocketAddr,
150}
151
152#[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
153impl Stream for tokio_rustls::client::TlsStream<TcpStream> {}
154
155#[cfg(feature = "native-tls")]
156impl Stream for tokio_native_tls::TlsStream<TcpStream> {}
157
158impl TlsConnection {
159 #[cfg(all(feature = "rustls-tls", not(feature = "native-tls")))]
160 pub async fn connect(
161 addr: Address,
162 tls_config: TlsConfig,
163 opts: ConnectOptions,
164 ) -> Result<TlsConnection> {
165 let (tcp_socket, local_addr, peer_addr) =
166 TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
167
168 let connector = TlsConnector::from(tls_config);
169
170 let server_name = match DnsName::try_from(addr.host) {
171 Ok(name) => ServerName::DnsName(name),
172 Err(_e) => ServerName::IpAddress(tokio_rustls::rustls::pki_types::IpAddr::from(
173 peer_addr.ip(),
174 )),
175 };
176
177 let stream = timeout_at(opts.deadline, connector.connect(server_name, tcp_socket))
178 .await
179 .map_err(|e| {
180 Error::new_connection_failed_error(
181 "failed to upgrade tcp stream to tls within timeout",
182 Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
183 )
184 })?
185 .map_err(|e| {
186 Error::new_connection_failed_error(
187 "failed to upgrade tcp stream to tls",
188 Box::new(e),
189 )
190 })?;
191
192 Ok(TlsConnection {
193 stream,
194 local_addr,
195 peer_addr,
196 })
197 }
198
199 #[cfg(feature = "native-tls")]
200 pub async fn connect(
201 addr: Address,
202 tls_config: TlsConfig,
203 opts: ConnectOptions,
204 ) -> Result<TlsConnection> {
205 let (tcp_socket, local_addr, peer_addr) =
206 TcpConnection::tcp_stream(addr.to_string().as_str(), &opts).await?;
207
208 let tls_connector = tokio_native_tls::TlsConnector::from(tls_config);
209
210 let remote_addr = addr.to_string();
211 let stream = timeout_at(
212 opts.deadline,
213 tls_connector.connect(&remote_addr, tcp_socket),
214 )
215 .await
216 .map_err(|e| {
217 Error::new_connection_failed_error(
218 "failed to upgrade tcp stream to tls within timeout",
219 Box::new(io::Error::new(io::ErrorKind::TimedOut, e)),
220 )
221 })?
222 .map_err(|e| {
223 Error::new_connection_failed_error(
224 "failed to upgrade tcp stream to tls",
225 Box::new(io::Error::other(e)),
226 )
227 })?;
228
229 Ok(TlsConnection {
230 stream,
231 local_addr,
232 peer_addr,
233 })
234 }
235
236 fn local_addr(&self) -> &SocketAddr {
237 &self.local_addr
238 }
239
240 fn peer_addr(&self) -> &SocketAddr {
241 &self.peer_addr
242 }
243}