Skip to main content

clickhouse_native_client/
connection.rs

1use crate::{
2    wire_format::WireFormat,
3    Error,
4    Result,
5};
6use bytes::Bytes;
7use std::time::Duration;
8use tokio::{
9    io::{
10        AsyncRead,
11        AsyncReadExt,
12        AsyncWrite,
13        AsyncWriteExt,
14        BufReader,
15        BufWriter,
16    },
17    net::TcpStream,
18};
19
20#[cfg(feature = "tls")]
21use rustls::ServerName;
22#[cfg(feature = "tls")]
23use std::sync::Arc;
24#[cfg(feature = "tls")]
25use tokio_rustls::TlsConnector;
26
27/// Default buffer sizes for reading and writing
28const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
29const DEFAULT_WRITE_BUFFER_SIZE: usize = 8192;
30
31/// Connection timeout and TCP options
32#[derive(Clone, Debug)]
33pub struct ConnectionOptions {
34    /// Connection timeout (default: 5 seconds)
35    pub connect_timeout: Duration,
36    /// Receive timeout (0 = no timeout)
37    pub recv_timeout: Duration,
38    /// Send timeout (0 = no timeout)
39    pub send_timeout: Duration,
40    /// Enable TCP keepalive
41    pub tcp_keepalive: bool,
42    /// TCP keepalive idle time (default: 60 seconds)
43    pub tcp_keepalive_idle: Duration,
44    /// TCP keepalive interval (default: 5 seconds)
45    pub tcp_keepalive_interval: Duration,
46    /// TCP keepalive probe count (default: 3)
47    pub tcp_keepalive_count: u32,
48    /// Enable TCP_NODELAY (disable Nagle's algorithm)
49    pub tcp_nodelay: bool,
50}
51
52impl Default for ConnectionOptions {
53    fn default() -> Self {
54        Self {
55            connect_timeout: Duration::from_secs(5),
56            recv_timeout: Duration::ZERO,
57            send_timeout: Duration::ZERO,
58            tcp_keepalive: false,
59            tcp_keepalive_idle: Duration::from_secs(60),
60            tcp_keepalive_interval: Duration::from_secs(5),
61            tcp_keepalive_count: 3,
62            tcp_nodelay: true,
63        }
64    }
65}
66
67impl ConnectionOptions {
68    /// Create new connection options
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    /// Set connection timeout
74    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
75        self.connect_timeout = timeout;
76        self
77    }
78
79    /// Set receive timeout
80    pub fn recv_timeout(mut self, timeout: Duration) -> Self {
81        self.recv_timeout = timeout;
82        self
83    }
84
85    /// Set send timeout
86    pub fn send_timeout(mut self, timeout: Duration) -> Self {
87        self.send_timeout = timeout;
88        self
89    }
90
91    /// Enable TCP keepalive
92    pub fn tcp_keepalive(mut self, enabled: bool) -> Self {
93        self.tcp_keepalive = enabled;
94        self
95    }
96
97    /// Set TCP keepalive idle time
98    pub fn tcp_keepalive_idle(mut self, duration: Duration) -> Self {
99        self.tcp_keepalive_idle = duration;
100        self
101    }
102
103    /// Set TCP keepalive interval
104    pub fn tcp_keepalive_interval(mut self, duration: Duration) -> Self {
105        self.tcp_keepalive_interval = duration;
106        self
107    }
108
109    /// Set TCP keepalive probe count
110    pub fn tcp_keepalive_count(mut self, count: u32) -> Self {
111        self.tcp_keepalive_count = count;
112        self
113    }
114
115    /// Enable/disable TCP_NODELAY
116    pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
117        self.tcp_nodelay = enabled;
118        self
119    }
120}
121
122/// Async connection wrapper for TCP/TLS socket
123/// This is the async I/O boundary - all socket operations are async
124pub struct Connection {
125    reader: BufReader<Box<dyn AsyncRead + Unpin + Send>>,
126    writer: BufWriter<Box<dyn AsyncWrite + Unpin + Send>>,
127}
128
129impl Connection {
130    /// Create a new connection from a TCP stream
131    pub fn new(stream: TcpStream) -> Self {
132        let (read_half, write_half) = tokio::io::split(stream);
133
134        Self {
135            reader: BufReader::with_capacity(
136                DEFAULT_READ_BUFFER_SIZE,
137                Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
138            ),
139            writer: BufWriter::with_capacity(
140                DEFAULT_WRITE_BUFFER_SIZE,
141                Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
142            ),
143        }
144    }
145
146    /// Create a new connection from a TLS stream
147    #[cfg(feature = "tls")]
148    pub fn new_tls(
149        stream: tokio_rustls::client::TlsStream<TcpStream>,
150    ) -> Self {
151        let (read_half, write_half) = tokio::io::split(stream);
152
153        Self {
154            reader: BufReader::with_capacity(
155                DEFAULT_READ_BUFFER_SIZE,
156                Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
157            ),
158            writer: BufWriter::with_capacity(
159                DEFAULT_WRITE_BUFFER_SIZE,
160                Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
161            ),
162        }
163    }
164
165    /// Connect to a ClickHouse server with default options
166    pub async fn connect(host: &str, port: u16) -> Result<Self> {
167        Self::connect_with_options(host, port, &ConnectionOptions::default())
168            .await
169    }
170
171    /// Connect to a ClickHouse server with custom options
172    pub async fn connect_with_options(
173        host: &str,
174        port: u16,
175        options: &ConnectionOptions,
176    ) -> Result<Self> {
177        let addr = format!("{}:{}", host, port);
178
179        // Apply connection timeout
180        let stream = if options.connect_timeout > Duration::ZERO {
181            tokio::time::timeout(
182                options.connect_timeout,
183                TcpStream::connect(&addr),
184            )
185            .await
186            .map_err(|_| {
187                Error::Connection(format!(
188                    "Connection timeout after {:?} to {}",
189                    options.connect_timeout, addr
190                ))
191            })?
192            .map_err(|e| {
193                Error::Connection(format!(
194                    "Failed to connect to {}: {}",
195                    addr, e
196                ))
197            })?
198        } else {
199            TcpStream::connect(&addr).await.map_err(|e| {
200                Error::Connection(format!(
201                    "Failed to connect to {}: {}",
202                    addr, e
203                ))
204            })?
205        };
206
207        // Apply TCP_NODELAY
208        if options.tcp_nodelay {
209            stream.set_nodelay(true).map_err(|e| {
210                Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
211            })?;
212        }
213
214        // Apply TCP keepalive
215        #[cfg(unix)]
216        if options.tcp_keepalive {
217            use socket2::{
218                Socket,
219                TcpKeepalive,
220            };
221            use std::os::unix::io::{
222                AsRawFd,
223                FromRawFd,
224            };
225
226            let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
227
228            let mut keepalive =
229                TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
230
231            #[cfg(any(target_os = "linux", target_os = "macos"))]
232            {
233                keepalive =
234                    keepalive.with_interval(options.tcp_keepalive_interval);
235            }
236
237            // Note: with_retries is not available in socket2 0.5.x
238            // TCP_KEEPCNT can be set via raw socket options if needed
239            // For now, we rely on system defaults for keepalive retry count
240
241            socket.set_tcp_keepalive(&keepalive).map_err(|e| {
242                Error::Connection(format!(
243                    "Failed to set TCP keepalive: {}",
244                    e
245                ))
246            })?;
247
248            // Prevent socket from being dropped
249            std::mem::forget(socket);
250        }
251
252        #[cfg(windows)]
253        if options.tcp_keepalive {
254            use socket2::{
255                Socket,
256                TcpKeepalive,
257            };
258            use std::os::windows::io::{
259                AsRawSocket,
260                FromRawSocket,
261            };
262
263            let socket =
264                unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
265
266            let keepalive = TcpKeepalive::new()
267                .with_time(options.tcp_keepalive_idle)
268                .with_interval(options.tcp_keepalive_interval);
269
270            socket.set_tcp_keepalive(&keepalive).map_err(|e| {
271                Error::Connection(format!(
272                    "Failed to set TCP keepalive: {}",
273                    e
274                ))
275            })?;
276
277            // Prevent socket from being dropped
278            std::mem::forget(socket);
279        }
280
281        Ok(Self::new(stream))
282    }
283
284    /// Connect to a ClickHouse server with TLS
285    #[cfg(feature = "tls")]
286    pub async fn connect_with_tls(
287        host: &str,
288        port: u16,
289        options: &ConnectionOptions,
290        ssl_config: Arc<rustls::ClientConfig>,
291        server_name: Option<&str>,
292    ) -> Result<Self> {
293        let addr = format!("{}:{}", host, port);
294
295        // Establish TCP connection first
296        let stream = if options.connect_timeout > Duration::ZERO {
297            tokio::time::timeout(
298                options.connect_timeout,
299                TcpStream::connect(&addr),
300            )
301            .await
302            .map_err(|_| {
303                Error::Connection(format!(
304                    "Connection timeout after {:?} to {}",
305                    options.connect_timeout, addr
306                ))
307            })?
308            .map_err(|e| {
309                Error::Connection(format!(
310                    "Failed to connect to {}: {}",
311                    addr, e
312                ))
313            })?
314        } else {
315            TcpStream::connect(&addr).await.map_err(|e| {
316                Error::Connection(format!(
317                    "Failed to connect to {}: {}",
318                    addr, e
319                ))
320            })?
321        };
322
323        // Apply TCP_NODELAY
324        if options.tcp_nodelay {
325            stream.set_nodelay(true).map_err(|e| {
326                Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
327            })?;
328        }
329
330        // Apply TCP keepalive (same as non-TLS connection)
331        #[cfg(unix)]
332        if options.tcp_keepalive {
333            use socket2::{
334                Socket,
335                TcpKeepalive,
336            };
337            use std::os::unix::io::{
338                AsRawFd,
339                FromRawFd,
340            };
341
342            let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
343
344            let mut keepalive =
345                TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
346
347            #[cfg(any(target_os = "linux", target_os = "macos"))]
348            {
349                keepalive =
350                    keepalive.with_interval(options.tcp_keepalive_interval);
351            }
352
353            // Note: with_retries is not available in socket2 0.5.x
354            // TCP_KEEPCNT can be set via raw socket options if needed
355            // For now, we rely on system defaults for keepalive retry count
356
357            socket.set_tcp_keepalive(&keepalive).map_err(|e| {
358                Error::Connection(format!(
359                    "Failed to set TCP keepalive: {}",
360                    e
361                ))
362            })?;
363
364            // Prevent socket from being dropped
365            std::mem::forget(socket);
366        }
367
368        #[cfg(windows)]
369        if options.tcp_keepalive {
370            use socket2::{
371                Socket,
372                TcpKeepalive,
373            };
374            use std::os::windows::io::{
375                AsRawSocket,
376                FromRawSocket,
377            };
378
379            let socket =
380                unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
381
382            let keepalive = TcpKeepalive::new()
383                .with_time(options.tcp_keepalive_idle)
384                .with_interval(options.tcp_keepalive_interval);
385
386            socket.set_tcp_keepalive(&keepalive).map_err(|e| {
387                Error::Connection(format!(
388                    "Failed to set TCP keepalive: {}",
389                    e
390                ))
391            })?;
392
393            // Prevent socket from being dropped
394            std::mem::forget(socket);
395        }
396
397        // Perform TLS handshake
398        let connector = TlsConnector::from(ssl_config);
399        let server_name_to_use = server_name.unwrap_or(host);
400
401        let domain =
402            ServerName::try_from(server_name_to_use).map_err(|e| {
403                Error::Connection(format!(
404                    "Invalid server name '{}': {}",
405                    server_name_to_use, e
406                ))
407            })?;
408
409        let tls_stream =
410            connector.connect(domain, stream).await.map_err(|e| {
411                Error::Connection(format!("TLS handshake failed: {}", e))
412            })?;
413
414        Ok(Self::new_tls(tls_stream))
415    }
416
417    /// Read a varint-encoded u64
418    pub async fn read_varint(&mut self) -> Result<u64> {
419        WireFormat::read_varint64(&mut self.reader).await
420    }
421
422    /// Write a varint-encoded u64
423    pub async fn write_varint(&mut self, value: u64) -> Result<()> {
424        WireFormat::write_varint64(&mut self.writer, value).await
425    }
426
427    /// Read a fixed-size value
428    pub async fn read_u8(&mut self) -> Result<u8> {
429        Ok(self.reader.read_u8().await?)
430    }
431
432    /// Read a little-endian u16
433    pub async fn read_u16(&mut self) -> Result<u16> {
434        Ok(self.reader.read_u16_le().await?)
435    }
436
437    /// Read a little-endian u32
438    pub async fn read_u32(&mut self) -> Result<u32> {
439        Ok(self.reader.read_u32_le().await?)
440    }
441
442    /// Read a little-endian u64
443    pub async fn read_u64(&mut self) -> Result<u64> {
444        Ok(self.reader.read_u64_le().await?)
445    }
446
447    /// Read a signed i8
448    pub async fn read_i8(&mut self) -> Result<i8> {
449        Ok(self.reader.read_i8().await?)
450    }
451
452    /// Read a little-endian i16
453    pub async fn read_i16(&mut self) -> Result<i16> {
454        Ok(self.reader.read_i16_le().await?)
455    }
456
457    /// Read a little-endian i32
458    pub async fn read_i32(&mut self) -> Result<i32> {
459        Ok(self.reader.read_i32_le().await?)
460    }
461
462    /// Read a little-endian i64
463    pub async fn read_i64(&mut self) -> Result<i64> {
464        Ok(self.reader.read_i64_le().await?)
465    }
466
467    /// Write fixed-size values
468    pub async fn write_u8(&mut self, value: u8) -> Result<()> {
469        Ok(self.writer.write_u8(value).await?)
470    }
471
472    /// Write a little-endian u16
473    pub async fn write_u16(&mut self, value: u16) -> Result<()> {
474        Ok(self.writer.write_u16_le(value).await?)
475    }
476
477    /// Write a little-endian u32
478    pub async fn write_u32(&mut self, value: u32) -> Result<()> {
479        Ok(self.writer.write_u32_le(value).await?)
480    }
481
482    /// Write a little-endian u64
483    pub async fn write_u64(&mut self, value: u64) -> Result<()> {
484        Ok(self.writer.write_u64_le(value).await?)
485    }
486
487    /// Write a little-endian u128
488    pub async fn write_u128(&mut self, value: u128) -> Result<()> {
489        Ok(self.writer.write_u128_le(value).await?)
490    }
491
492    /// Write a signed i8
493    pub async fn write_i8(&mut self, value: i8) -> Result<()> {
494        Ok(self.writer.write_i8(value).await?)
495    }
496
497    /// Write a little-endian i16
498    pub async fn write_i16(&mut self, value: i16) -> Result<()> {
499        Ok(self.writer.write_i16_le(value).await?)
500    }
501
502    /// Write a little-endian i32
503    pub async fn write_i32(&mut self, value: i32) -> Result<()> {
504        Ok(self.writer.write_i32_le(value).await?)
505    }
506
507    /// Write a little-endian i64
508    pub async fn write_i64(&mut self, value: i64) -> Result<()> {
509        Ok(self.writer.write_i64_le(value).await?)
510    }
511
512    /// Read a length-prefixed string
513    pub async fn read_string(&mut self) -> Result<String> {
514        WireFormat::read_string(&mut self.reader).await
515    }
516
517    /// Write a length-prefixed string
518    pub async fn write_string(&mut self, s: &str) -> Result<()> {
519        WireFormat::write_string(&mut self.writer, s).await
520    }
521
522    /// Write a quoted string for query parameters
523    pub async fn write_quoted_string(&mut self, s: &str) -> Result<()> {
524        WireFormat::write_quoted_string(&mut self.writer, s).await
525    }
526
527    /// Read exact number of bytes into a buffer
528    pub async fn read_bytes(&mut self, len: usize) -> Result<Bytes> {
529        let mut buf = vec![0u8; len];
530        self.reader.read_exact(&mut buf).await?;
531        Ok(Bytes::from(buf))
532    }
533
534    /// Read bytes into an existing buffer
535    pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
536        self.reader.read_exact(buf).await?;
537        Ok(())
538    }
539
540    /// Write bytes
541    pub async fn write_bytes(&mut self, data: &[u8]) -> Result<()> {
542        Ok(self.writer.write_all(data).await?)
543    }
544
545    /// Flush the write buffer
546    pub async fn flush(&mut self) -> Result<()> {
547        Ok(self.writer.flush().await?)
548    }
549
550    /// Read a complete packet (length-prefixed data)
551    /// Returns the packet data without the length prefix
552    pub async fn read_packet(&mut self) -> Result<Bytes> {
553        let len = self.read_varint().await? as usize;
554
555        if len == 0 {
556            return Ok(Bytes::new());
557        }
558
559        if len > 0x40000000 {
560            // 1GB limit
561            return Err(Error::Protocol(format!("Packet too large: {}", len)));
562        }
563
564        self.read_bytes(len).await
565    }
566
567    /// Write a packet with length prefix
568    pub async fn write_packet(&mut self, data: &[u8]) -> Result<()> {
569        self.write_varint(data.len() as u64).await?;
570        self.write_bytes(data).await?;
571        Ok(())
572    }
573}
574
575#[cfg(test)]
576#[cfg_attr(coverage_nightly, coverage(off))]
577mod tests {
578    use super::*;
579
580    // Note: These tests would require a running ClickHouse server or mock
581    // For now, we'll just test constants and basic structure
582
583    #[test]
584    fn test_buffer_sizes() {
585        assert_eq!(DEFAULT_READ_BUFFER_SIZE, 8192);
586        assert_eq!(DEFAULT_WRITE_BUFFER_SIZE, 8192);
587    }
588
589    // Integration tests with actual server would go in tests/ directory
590}