1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::net::TcpStream;
6use tokio_rustls::TlsConnector;
7use tokio_rustls::client::TlsStream;
8use tokio_rustls::rustls::pki_types::ServerName;
9use tokio_rustls::rustls::{self, ClientConfig, RootCertStore};
10
11use crate::constants::*;
12use crate::prelude::*;
13use crate::{Error, Result};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct Destination {
18 inner: DestinationInner,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
22enum DestinationInner {
23 SocketAddrs(Vec<SocketAddr>), SocketAddr(SocketAddr), HostPort(String, u16), Endpoint(String), }
28
29impl Destination {
30 pub(crate) async fn resolve(&self, ipv4_only: bool) -> Result<Vec<SocketAddr>> {
32 let addrs: Vec<SocketAddr> = match &self.inner {
33 DestinationInner::SocketAddrs(addrs) => return Ok(addrs.clone()),
34 DestinationInner::SocketAddr(addr) => return Ok(vec![*addr]),
35 DestinationInner::HostPort(host, port) => {
36 tokio::net::lookup_host((host.as_str(), *port)).await.map(Iterator::collect)
37 }
38 DestinationInner::Endpoint(endpoint) => {
39 tokio::net::lookup_host(endpoint).await.map(Iterator::collect)
40 }
41 }
42 .map_err(|_| {
43 Error::MalformedConnectionInformation("Could not resolve destination".into())
44 })?;
45
46 Ok(addrs
47 .into_iter()
48 .filter(|addr| !ipv4_only || matches!(addr, SocketAddr::V4(_)))
49 .collect())
50 }
51
52 pub(crate) fn domain(&self) -> String {
54 match &self.inner {
55 DestinationInner::SocketAddrs(addrs) => {
56 addrs.iter().next().map(|addr| addr.ip().to_string()).unwrap_or_default()
57 }
58 DestinationInner::SocketAddr(addr) => addr.ip().to_string(),
59 DestinationInner::HostPort(host, _) => host.clone(),
60 DestinationInner::Endpoint(endpoint) => {
61 endpoint.split(':').next().map(ToString::to_string).unwrap_or(endpoint.clone())
62 }
63 }
64 }
65}
66
67pub(super) async fn connect_tls(
69 addrs: &[SocketAddr],
70 domain: Option<&str>,
71) -> Result<TlsStream<TcpStream>> {
72 let domain: String =
73 domain.as_ref().map_or_else(|| addrs[0].ip().to_string(), ToString::to_string);
74 debug!(%domain, "Initiating TLS connection");
75 let stream = connect_socket(addrs).await?;
76 tls_stream(domain, stream).await
77}
78
79#[instrument(level = "trace", name = "clickhouse._connect_socket", skip_all)]
81pub(crate) async fn connect_socket(addrs: &[SocketAddr]) -> Result<TcpStream> {
82 debug!(?addrs, "Initiating TCP connection");
83 let addr = addrs.first().ok_or(Error::MissingConnectionInformation)?;
84 let domain = if addr.is_ipv4() { socket2::Domain::IPV4 } else { socket2::Domain::IPV6 };
85 let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
86 socket.set_nonblocking(true)?;
87 socket.set_recv_buffer_size(TCP_READ_BUFFER_SIZE as usize)?;
89 socket.set_send_buffer_size(TCP_WRITE_BUFFER_SIZE as usize)?;
90 let keepalive = socket2::TcpKeepalive::new()
92 .with_time(Duration::from_secs(TCP_KEEP_ALIVE_SECS))
93 .with_interval(Duration::from_secs(TCP_KEEP_ALIVE_INTERVAL))
94 .with_retries(TCP_KEEP_ALIVE_RETRIES);
95 socket.set_tcp_keepalive(&keepalive)?;
96
97 let sock_addr = socket2::SockAddr::from(*addr);
99 socket.connect_timeout(&sock_addr, Duration::from_secs(TCP_CONNECT_TIMEOUT))?;
100 trace!("Connected socket for {addr}");
101
102 let stream = std::net::TcpStream::from(socket);
104 stream.set_nodelay(true)?;
105 stream.set_nonblocking(true)?;
106
107 Ok(TcpStream::from_std(stream)?)
108}
109
110async fn tls_stream(domain: String, stream: TcpStream) -> Result<TlsStream<TcpStream>> {
112 let root_store = RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.into() };
113
114 let mut tls_config =
115 ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
116
117 tls_config.resumption = rustls::client::Resumption::in_memory_sessions(256);
119
120 let connector = TlsConnector::from(Arc::new(tls_config));
121 let dnsname =
122 ServerName::try_from(domain.clone()).map_err(|e| Error::InvalidDnsName(e.to_string()))?;
123 Ok(connector.connect(dnsname, stream).await?)
124}
125
126impl std::fmt::Display for Destination {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 match &self.inner {
129 DestinationInner::SocketAddrs(addrs) => {
130 write!(f, "{}", addrs.first().map(ToString::to_string).unwrap_or_default())
131 }
132 DestinationInner::SocketAddr(addr) => write!(f, "{addr}"),
133 DestinationInner::HostPort(host, port) => write!(f, "{host}:{port}"),
134 DestinationInner::Endpoint(endpoint) => write!(f, "{endpoint}"),
135 }
136 }
137}
138
139impl From<Vec<SocketAddr>> for Destination {
141 fn from(addrs: Vec<SocketAddr>) -> Self {
142 Destination { inner: DestinationInner::SocketAddrs(addrs) }
143 }
144}
145
146impl From<SocketAddr> for Destination {
148 fn from(addr: SocketAddr) -> Self { Destination { inner: DestinationInner::SocketAddr(addr) } }
149}
150
151impl From<(String, u16)> for Destination {
152 fn from((host, port): (String, u16)) -> Self {
153 Destination { inner: DestinationInner::HostPort(host, port) }
154 }
155}
156
157impl From<&(String, u16)> for Destination {
158 fn from((host, port): &(String, u16)) -> Self {
159 Destination { inner: DestinationInner::HostPort(host.clone(), *port) }
160 }
161}
162
163impl From<(&String, u16)> for Destination {
164 fn from((host, port): (&String, u16)) -> Self {
165 Destination { inner: DestinationInner::HostPort(host.clone(), port) }
166 }
167}
168
169impl From<(&str, u16)> for Destination {
170 fn from((host, port): (&str, u16)) -> Self {
171 Destination { inner: DestinationInner::HostPort(host.to_string(), port) }
172 }
173}
174
175impl From<String> for Destination {
176 fn from(endpoint: String) -> Self {
177 Destination { inner: DestinationInner::Endpoint(endpoint) }
178 }
179}
180
181impl From<&String> for Destination {
182 fn from(endpoint: &String) -> Self {
183 Destination { inner: DestinationInner::Endpoint(endpoint.clone()) }
184 }
185}
186
187impl From<&str> for Destination {
188 fn from(endpoint: &str) -> Self {
189 Destination { inner: DestinationInner::Endpoint(endpoint.to_string()) }
190 }
191}
192
193impl From<std::borrow::Cow<'_, str>> for Destination {
194 fn from(endpoint: std::borrow::Cow<'_, str>) -> Self {
195 Destination { inner: DestinationInner::Endpoint(endpoint.into_owned()) }
196 }
197}
198
199impl From<(Ipv4Addr, u16)> for Destination {
200 fn from((host, port): (Ipv4Addr, u16)) -> Self {
201 Destination { inner: DestinationInner::SocketAddr((host, port).into()) }
202 }
203}
204
205impl From<(Ipv6Addr, u16)> for Destination {
206 fn from((host, port): (Ipv6Addr, u16)) -> Self {
207 Destination { inner: DestinationInner::SocketAddr((host, port).into()) }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
214
215 use super::*;
216
217 fn socket_addr() -> SocketAddr { SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 9000) }
219
220 #[tokio::test]
221 async fn test_resolve_socket_addrs() {
222 let addrs = vec![socket_addr()];
223 let dest = Destination { inner: DestinationInner::SocketAddrs(addrs.clone()) };
224 let result = dest.resolve(false).await.unwrap();
225 assert_eq!(result, addrs);
226 }
227
228 #[tokio::test]
229 async fn test_resolve_socket_addr() {
230 let addr = socket_addr();
231 let dest = Destination { inner: DestinationInner::SocketAddr(addr) };
232 let result = dest.resolve(false).await.unwrap();
233 assert_eq!(result, vec![addr]);
234 }
235
236 #[tokio::test]
237 async fn test_resolve_host_port_valid() {
238 let dest = Destination { inner: DestinationInner::HostPort("localhost".to_string(), 9000) };
239 let result = dest.resolve(false).await.unwrap();
240 assert!(!result.is_empty());
241 assert!(result.iter().all(|addr| addr.port() == 9000));
242 }
243
244 #[tokio::test]
245 async fn test_resolve_host_port_invalid() {
246 let dest =
247 Destination { inner: DestinationInner::HostPort("invalid-host-xyz".to_string(), 9000) };
248 let result = dest.resolve(false).await;
249 assert!(matches!(
250 result,
251 Err(Error::MalformedConnectionInformation(msg))
252 if msg == "Could not resolve destination"
253 ));
254 }
255
256 #[tokio::test]
257 async fn test_resolve_endpoint_valid() {
258 let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
259 let result = dest.resolve(false).await.unwrap();
260 assert!(!result.is_empty());
261 assert!(result.iter().all(|addr| addr.port() == 9000));
262 }
263
264 #[tokio::test]
265 async fn test_resolve_endpoint_invalid() {
266 let dest =
267 Destination { inner: DestinationInner::Endpoint("invalid-host-xyz:9000".to_string()) };
268 let result = dest.resolve(false).await;
269 assert!(matches!(
270 result,
271 Err(Error::MalformedConnectionInformation(msg))
272 if msg == "Could not resolve destination"
273 ));
274 }
275
276 #[tokio::test]
277 async fn test_resolve_ipv4_only() {
278 let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
279 let result = dest.resolve(true).await.unwrap();
280 assert!(!result.is_empty());
281 assert!(result.iter().all(|addr| matches!(addr, SocketAddr::V4(_))));
282 }
283
284 #[test]
285 fn test_domain_socket_addrs() {
286 let addrs = vec![socket_addr()];
287 let dest = Destination { inner: DestinationInner::SocketAddrs(addrs) };
288 assert_eq!(dest.domain(), "127.0.0.1");
289 }
290
291 #[test]
292 fn test_domain_socket_addrs_empty() {
293 let dest = Destination { inner: DestinationInner::SocketAddrs(vec![]) };
294 assert_eq!(dest.domain(), "");
295 }
296
297 #[test]
298 fn test_domain_socket_addr() {
299 let addr = socket_addr();
300 let dest = Destination { inner: DestinationInner::SocketAddr(addr) };
301 assert_eq!(dest.domain(), "127.0.0.1");
302 }
303
304 #[test]
305 fn test_domain_host_port() {
306 let dest = Destination { inner: DestinationInner::HostPort("localhost".to_string(), 9000) };
307 assert_eq!(dest.domain(), "localhost");
308 }
309
310 #[test]
311 fn test_domain_endpoint() {
312 let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
313 assert_eq!(dest.domain(), "localhost");
314 }
315
316 #[test]
317 fn test_domain_endpoint_no_port() {
318 let dest = Destination { inner: DestinationInner::Endpoint("localhost".to_string()) };
319 assert_eq!(dest.domain(), "localhost");
320 }
321}