clickhouse_arrow/client/
tcp.rs

1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::net::TcpStream;
6use tokio_rustls::TlsConnector;
7use tokio_rustls::client::TlsStream;
8use tokio_rustls::rustls::pki_types::ServerName;
9use tokio_rustls::rustls::{self, ClientConfig, RootCertStore};
10
11use crate::constants::*;
12use crate::prelude::*;
13use crate::{Error, Result};
14
15// Custom Destination type
16#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct Destination {
18    inner: DestinationInner,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
22enum DestinationInner {
23    SocketAddrs(Vec<SocketAddr>), // It is important this is guaranteed to be non-empty
24    SocketAddr(SocketAddr),       // Direct SocketAddr (e.g., 127.0.0.1:9000)
25    HostPort(String, u16),        // Hostname and port (e.g., "localhost", 9000)
26    Endpoint(String),             // String to parse (e.g., "localhost:9000")
27}
28
29impl Destination {
30    /// Resolve to Vec<SocketAddr> using [`tokio::net::lookup_host`]
31    pub(crate) async fn resolve(&self, ipv4_only: bool) -> Result<Vec<SocketAddr>> {
32        let addrs: Vec<SocketAddr> = match &self.inner {
33            DestinationInner::SocketAddrs(addrs) => return Ok(addrs.clone()),
34            DestinationInner::SocketAddr(addr) => return Ok(vec![*addr]),
35            DestinationInner::HostPort(host, port) => {
36                tokio::net::lookup_host((host.as_str(), *port)).await.map(Iterator::collect)
37            }
38            DestinationInner::Endpoint(endpoint) => {
39                tokio::net::lookup_host(endpoint).await.map(Iterator::collect)
40            }
41        }
42        .map_err(|_| {
43            Error::MalformedConnectionInformation("Could not resolve destination".into())
44        })?;
45
46        Ok(addrs
47            .into_iter()
48            .filter(|addr| !ipv4_only || matches!(addr, SocketAddr::V4(_)))
49            .collect())
50    }
51
52    // Create a domain from this Destination
53    pub(crate) fn domain(&self) -> String {
54        match &self.inner {
55            DestinationInner::SocketAddrs(addrs) => {
56                addrs.iter().next().map(|addr| addr.ip().to_string()).unwrap_or_default()
57            }
58            DestinationInner::SocketAddr(addr) => addr.ip().to_string(),
59            DestinationInner::HostPort(host, _) => host.to_string(),
60            DestinationInner::Endpoint(endpoint) => {
61                endpoint.split(':').next().map(ToString::to_string).unwrap_or(endpoint.to_string())
62            }
63        }
64    }
65}
66
67/// Connects to `ClickHouse`'s native server port over TLS.
68pub(super) async fn connect_tls(
69    addrs: &[SocketAddr],
70    domain: Option<&str>,
71) -> Result<TlsStream<TcpStream>> {
72    let domain: String =
73        domain.as_ref().map_or_else(|| addrs[0].ip().to_string(), ToString::to_string);
74    debug!(%domain, "Initiating TLS connection");
75    let stream = connect_socket(addrs).await?;
76    tls_stream(domain, stream).await
77}
78
79/// Connects to `ClickHouse`'s native server port and configures common socket options.
80#[instrument(level = "trace", name = "clickhouse._connect_socket", skip_all)]
81pub(crate) async fn connect_socket(addrs: &[SocketAddr]) -> Result<TcpStream> {
82    debug!(?addrs, "Initiating TCP connection");
83    let addr = addrs.first().ok_or(Error::MissingConnectionInformation)?;
84    let domain = if addr.is_ipv4() { socket2::Domain::IPV4 } else { socket2::Domain::IPV6 };
85    let socket = socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
86    socket.set_nonblocking(true)?;
87    // Increase buffer sizes for high-throughput data transfer
88    socket.set_recv_buffer_size(TCP_READ_BUFFER_SIZE as usize)?;
89    socket.set_send_buffer_size(TCP_WRITE_BUFFER_SIZE as usize)?;
90    // Configure TCP keepalive
91    let keepalive = socket2::TcpKeepalive::new()
92        .with_time(Duration::from_secs(TCP_KEEP_ALIVE_SECS))
93        .with_interval(Duration::from_secs(TCP_KEEP_ALIVE_INTERVAL))
94        .with_retries(TCP_KEEP_ALIVE_RETRIES);
95    socket.set_tcp_keepalive(&keepalive)?;
96
97    // Connect with a timeout
98    let sock_addr = socket2::SockAddr::from(*addr);
99    socket.connect_timeout(&sock_addr, Duration::from_secs(TCP_CONNECT_TIMEOUT))?;
100    trace!("Connected socket for {addr}");
101
102    // Convert to TcpStream
103    let stream = std::net::TcpStream::from(socket);
104    stream.set_nodelay(true)?;
105    stream.set_nonblocking(true)?;
106
107    Ok(TcpStream::from_std(stream)?)
108}
109
110// Helper function to facilitate TLS connection setup
111async fn tls_stream(domain: String, stream: TcpStream) -> Result<TlsStream<TcpStream>> {
112    let root_store = RootCertStore { roots: webpki_roots::TLS_SERVER_ROOTS.into() };
113
114    let mut tls_config =
115        ClientConfig::builder().with_root_certificates(root_store).with_no_client_auth();
116
117    // Enable session resumption by default
118    tls_config.resumption = rustls::client::Resumption::in_memory_sessions(256);
119
120    let connector = TlsConnector::from(Arc::new(tls_config));
121    let dnsname =
122        ServerName::try_from(domain.clone()).map_err(|e| Error::InvalidDnsName(e.to_string()))?;
123    Ok(connector.connect(dnsname, stream).await?)
124}
125
126impl std::fmt::Display for Destination {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match &self.inner {
129            DestinationInner::SocketAddrs(addrs) => {
130                write!(f, "{}", addrs.first().map(ToString::to_string).unwrap_or_default())
131            }
132            DestinationInner::SocketAddr(addr) => write!(f, "{addr}"),
133            DestinationInner::HostPort(host, port) => write!(f, "{host}:{port}"),
134            DestinationInner::Endpoint(endpoint) => write!(f, "{endpoint}"),
135        }
136    }
137}
138
139// From implementations for common destination types
140impl From<Vec<SocketAddr>> for Destination {
141    fn from(addrs: Vec<SocketAddr>) -> Self {
142        Destination { inner: DestinationInner::SocketAddrs(addrs) }
143    }
144}
145
146// From implementations for common destination types
147impl From<SocketAddr> for Destination {
148    fn from(addr: SocketAddr) -> Self { Destination { inner: DestinationInner::SocketAddr(addr) } }
149}
150
151impl From<(String, u16)> for Destination {
152    fn from((host, port): (String, u16)) -> Self {
153        Destination { inner: DestinationInner::HostPort(host, port) }
154    }
155}
156
157impl From<&(String, u16)> for Destination {
158    fn from((host, port): &(String, u16)) -> Self {
159        Destination { inner: DestinationInner::HostPort(host.to_string(), *port) }
160    }
161}
162
163impl From<(&String, u16)> for Destination {
164    fn from((host, port): (&String, u16)) -> Self {
165        Destination { inner: DestinationInner::HostPort(host.to_string(), port) }
166    }
167}
168
169impl From<(&str, u16)> for Destination {
170    fn from((host, port): (&str, u16)) -> Self {
171        Destination { inner: DestinationInner::HostPort(host.to_string(), port) }
172    }
173}
174
175impl From<String> for Destination {
176    fn from(endpoint: String) -> Self {
177        Destination { inner: DestinationInner::Endpoint(endpoint) }
178    }
179}
180
181impl From<&String> for Destination {
182    fn from(endpoint: &String) -> Self {
183        Destination { inner: DestinationInner::Endpoint(endpoint.to_string()) }
184    }
185}
186
187impl From<&str> for Destination {
188    fn from(endpoint: &str) -> Self {
189        Destination { inner: DestinationInner::Endpoint(endpoint.to_string()) }
190    }
191}
192
193impl From<std::borrow::Cow<'_, str>> for Destination {
194    fn from(endpoint: std::borrow::Cow<'_, str>) -> Self {
195        Destination { inner: DestinationInner::Endpoint(endpoint.into_owned()) }
196    }
197}
198
199impl From<(Ipv4Addr, u16)> for Destination {
200    fn from((host, port): (Ipv4Addr, u16)) -> Self {
201        Destination { inner: DestinationInner::SocketAddr((host, port).into()) }
202    }
203}
204
205impl From<(Ipv6Addr, u16)> for Destination {
206    fn from((host, port): (Ipv6Addr, u16)) -> Self {
207        Destination { inner: DestinationInner::SocketAddr((host, port).into()) }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
214
215    use super::*;
216
217    // Helper to create Destination variants
218    fn socket_addr() -> SocketAddr {
219        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9000)
220    }
221
222    #[tokio::test]
223    async fn test_resolve_socket_addrs() {
224        let addrs = vec![socket_addr()];
225        let dest = Destination { inner: DestinationInner::SocketAddrs(addrs.clone()) };
226        let result = dest.resolve(false).await.unwrap();
227        assert_eq!(result, addrs);
228    }
229
230    #[tokio::test]
231    async fn test_resolve_socket_addr() {
232        let addr = socket_addr();
233        let dest = Destination { inner: DestinationInner::SocketAddr(addr) };
234        let result = dest.resolve(false).await.unwrap();
235        assert_eq!(result, vec![addr]);
236    }
237
238    #[tokio::test]
239    async fn test_resolve_host_port_valid() {
240        let dest = Destination { inner: DestinationInner::HostPort("localhost".to_string(), 9000) };
241        let result = dest.resolve(false).await.unwrap();
242        assert!(!result.is_empty());
243        assert!(result.iter().all(|addr| addr.port() == 9000));
244    }
245
246    #[tokio::test]
247    async fn test_resolve_host_port_invalid() {
248        let dest =
249            Destination { inner: DestinationInner::HostPort("invalid-host-xyz".to_string(), 9000) };
250        let result = dest.resolve(false).await;
251        assert!(matches!(
252            result,
253            Err(Error::MalformedConnectionInformation(msg))
254            if msg == "Could not resolve destination"
255        ));
256    }
257
258    #[tokio::test]
259    async fn test_resolve_endpoint_valid() {
260        let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
261        let result = dest.resolve(false).await.unwrap();
262        assert!(!result.is_empty());
263        assert!(result.iter().all(|addr| addr.port() == 9000));
264    }
265
266    #[tokio::test]
267    async fn test_resolve_endpoint_invalid() {
268        let dest =
269            Destination { inner: DestinationInner::Endpoint("invalid-host-xyz:9000".to_string()) };
270        let result = dest.resolve(false).await;
271        assert!(matches!(
272            result,
273            Err(Error::MalformedConnectionInformation(msg))
274            if msg == "Could not resolve destination"
275        ));
276    }
277
278    #[tokio::test]
279    async fn test_resolve_ipv4_only() {
280        let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
281        let result = dest.resolve(true).await.unwrap();
282        assert!(!result.is_empty());
283        assert!(result.iter().all(|addr| matches!(addr, SocketAddr::V4(_))));
284    }
285
286    #[test]
287    fn test_domain_socket_addrs() {
288        let addrs = vec![socket_addr()];
289        let dest = Destination { inner: DestinationInner::SocketAddrs(addrs) };
290        assert_eq!(dest.domain(), "127.0.0.1");
291    }
292
293    #[test]
294    fn test_domain_socket_addrs_empty() {
295        let dest = Destination { inner: DestinationInner::SocketAddrs(vec![]) };
296        assert_eq!(dest.domain(), "");
297    }
298
299    #[test]
300    fn test_domain_socket_addr() {
301        let addr = socket_addr();
302        let dest = Destination { inner: DestinationInner::SocketAddr(addr) };
303        assert_eq!(dest.domain(), "127.0.0.1");
304    }
305
306    #[test]
307    fn test_domain_host_port() {
308        let dest = Destination { inner: DestinationInner::HostPort("localhost".to_string(), 9000) };
309        assert_eq!(dest.domain(), "localhost");
310    }
311
312    #[test]
313    fn test_domain_endpoint() {
314        let dest = Destination { inner: DestinationInner::Endpoint("localhost:9000".to_string()) };
315        assert_eq!(dest.domain(), "localhost");
316    }
317
318    #[test]
319    fn test_domain_endpoint_no_port() {
320        let dest = Destination { inner: DestinationInner::Endpoint("localhost".to_string()) };
321        assert_eq!(dest.domain(), "localhost");
322    }
323}