clickhouse_native_client/
connection.rs1use crate::{
2 wire_format::WireFormat,
3 Error,
4 Result,
5};
6use bytes::Bytes;
7use std::time::Duration;
8use tokio::{
9 io::{
10 AsyncRead,
11 AsyncReadExt,
12 AsyncWrite,
13 AsyncWriteExt,
14 BufReader,
15 BufWriter,
16 },
17 net::TcpStream,
18};
19
20#[cfg(feature = "tls")]
21use rustls::ServerName;
22#[cfg(feature = "tls")]
23use std::sync::Arc;
24#[cfg(feature = "tls")]
25use tokio_rustls::TlsConnector;
26
27const DEFAULT_READ_BUFFER_SIZE: usize = 8192;
29const DEFAULT_WRITE_BUFFER_SIZE: usize = 8192;
30
31#[derive(Clone, Debug)]
33pub struct ConnectionOptions {
34 pub connect_timeout: Duration,
36 pub recv_timeout: Duration,
38 pub send_timeout: Duration,
40 pub tcp_keepalive: bool,
42 pub tcp_keepalive_idle: Duration,
44 pub tcp_keepalive_interval: Duration,
46 pub tcp_keepalive_count: u32,
48 pub tcp_nodelay: bool,
50}
51
52impl Default for ConnectionOptions {
53 fn default() -> Self {
54 Self {
55 connect_timeout: Duration::from_secs(5),
56 recv_timeout: Duration::ZERO,
57 send_timeout: Duration::ZERO,
58 tcp_keepalive: false,
59 tcp_keepalive_idle: Duration::from_secs(60),
60 tcp_keepalive_interval: Duration::from_secs(5),
61 tcp_keepalive_count: 3,
62 tcp_nodelay: true,
63 }
64 }
65}
66
67impl ConnectionOptions {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
75 self.connect_timeout = timeout;
76 self
77 }
78
79 pub fn recv_timeout(mut self, timeout: Duration) -> Self {
81 self.recv_timeout = timeout;
82 self
83 }
84
85 pub fn send_timeout(mut self, timeout: Duration) -> Self {
87 self.send_timeout = timeout;
88 self
89 }
90
91 pub fn tcp_keepalive(mut self, enabled: bool) -> Self {
93 self.tcp_keepalive = enabled;
94 self
95 }
96
97 pub fn tcp_keepalive_idle(mut self, duration: Duration) -> Self {
99 self.tcp_keepalive_idle = duration;
100 self
101 }
102
103 pub fn tcp_keepalive_interval(mut self, duration: Duration) -> Self {
105 self.tcp_keepalive_interval = duration;
106 self
107 }
108
109 pub fn tcp_keepalive_count(mut self, count: u32) -> Self {
111 self.tcp_keepalive_count = count;
112 self
113 }
114
115 pub fn tcp_nodelay(mut self, enabled: bool) -> Self {
117 self.tcp_nodelay = enabled;
118 self
119 }
120}
121
122pub struct Connection {
125 reader: BufReader<Box<dyn AsyncRead + Unpin + Send>>,
126 writer: BufWriter<Box<dyn AsyncWrite + Unpin + Send>>,
127}
128
129impl Connection {
130 pub fn new(stream: TcpStream) -> Self {
132 let (read_half, write_half) = tokio::io::split(stream);
133
134 Self {
135 reader: BufReader::with_capacity(
136 DEFAULT_READ_BUFFER_SIZE,
137 Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
138 ),
139 writer: BufWriter::with_capacity(
140 DEFAULT_WRITE_BUFFER_SIZE,
141 Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
142 ),
143 }
144 }
145
146 #[cfg(feature = "tls")]
148 pub fn new_tls(
149 stream: tokio_rustls::client::TlsStream<TcpStream>,
150 ) -> Self {
151 let (read_half, write_half) = tokio::io::split(stream);
152
153 Self {
154 reader: BufReader::with_capacity(
155 DEFAULT_READ_BUFFER_SIZE,
156 Box::new(read_half) as Box<dyn AsyncRead + Unpin + Send>,
157 ),
158 writer: BufWriter::with_capacity(
159 DEFAULT_WRITE_BUFFER_SIZE,
160 Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
161 ),
162 }
163 }
164
165 pub async fn connect(host: &str, port: u16) -> Result<Self> {
167 Self::connect_with_options(host, port, &ConnectionOptions::default())
168 .await
169 }
170
171 pub async fn connect_with_options(
173 host: &str,
174 port: u16,
175 options: &ConnectionOptions,
176 ) -> Result<Self> {
177 let addr = format!("{}:{}", host, port);
178
179 let stream = if options.connect_timeout > Duration::ZERO {
181 tokio::time::timeout(
182 options.connect_timeout,
183 TcpStream::connect(&addr),
184 )
185 .await
186 .map_err(|_| {
187 Error::Connection(format!(
188 "Connection timeout after {:?} to {}",
189 options.connect_timeout, addr
190 ))
191 })?
192 .map_err(|e| {
193 Error::Connection(format!(
194 "Failed to connect to {}: {}",
195 addr, e
196 ))
197 })?
198 } else {
199 TcpStream::connect(&addr).await.map_err(|e| {
200 Error::Connection(format!(
201 "Failed to connect to {}: {}",
202 addr, e
203 ))
204 })?
205 };
206
207 if options.tcp_nodelay {
209 stream.set_nodelay(true).map_err(|e| {
210 Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
211 })?;
212 }
213
214 #[cfg(unix)]
216 if options.tcp_keepalive {
217 use socket2::{
218 Socket,
219 TcpKeepalive,
220 };
221 use std::os::unix::io::{
222 AsRawFd,
223 FromRawFd,
224 };
225
226 let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
227
228 let mut keepalive =
229 TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
230
231 #[cfg(any(target_os = "linux", target_os = "macos"))]
232 {
233 keepalive =
234 keepalive.with_interval(options.tcp_keepalive_interval);
235 }
236
237 socket.set_tcp_keepalive(&keepalive).map_err(|e| {
242 Error::Connection(format!(
243 "Failed to set TCP keepalive: {}",
244 e
245 ))
246 })?;
247
248 std::mem::forget(socket);
250 }
251
252 #[cfg(windows)]
253 if options.tcp_keepalive {
254 use socket2::{
255 Socket,
256 TcpKeepalive,
257 };
258 use std::os::windows::io::{
259 AsRawSocket,
260 FromRawSocket,
261 };
262
263 let socket =
264 unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
265
266 let keepalive = TcpKeepalive::new()
267 .with_time(options.tcp_keepalive_idle)
268 .with_interval(options.tcp_keepalive_interval);
269
270 socket.set_tcp_keepalive(&keepalive).map_err(|e| {
271 Error::Connection(format!(
272 "Failed to set TCP keepalive: {}",
273 e
274 ))
275 })?;
276
277 std::mem::forget(socket);
279 }
280
281 Ok(Self::new(stream))
282 }
283
284 #[cfg(feature = "tls")]
286 pub async fn connect_with_tls(
287 host: &str,
288 port: u16,
289 options: &ConnectionOptions,
290 ssl_config: Arc<rustls::ClientConfig>,
291 server_name: Option<&str>,
292 ) -> Result<Self> {
293 let addr = format!("{}:{}", host, port);
294
295 let stream = if options.connect_timeout > Duration::ZERO {
297 tokio::time::timeout(
298 options.connect_timeout,
299 TcpStream::connect(&addr),
300 )
301 .await
302 .map_err(|_| {
303 Error::Connection(format!(
304 "Connection timeout after {:?} to {}",
305 options.connect_timeout, addr
306 ))
307 })?
308 .map_err(|e| {
309 Error::Connection(format!(
310 "Failed to connect to {}: {}",
311 addr, e
312 ))
313 })?
314 } else {
315 TcpStream::connect(&addr).await.map_err(|e| {
316 Error::Connection(format!(
317 "Failed to connect to {}: {}",
318 addr, e
319 ))
320 })?
321 };
322
323 if options.tcp_nodelay {
325 stream.set_nodelay(true).map_err(|e| {
326 Error::Connection(format!("Failed to set TCP_NODELAY: {}", e))
327 })?;
328 }
329
330 #[cfg(unix)]
332 if options.tcp_keepalive {
333 use socket2::{
334 Socket,
335 TcpKeepalive,
336 };
337 use std::os::unix::io::{
338 AsRawFd,
339 FromRawFd,
340 };
341
342 let socket = unsafe { Socket::from_raw_fd(stream.as_raw_fd()) };
343
344 let mut keepalive =
345 TcpKeepalive::new().with_time(options.tcp_keepalive_idle);
346
347 #[cfg(any(target_os = "linux", target_os = "macos"))]
348 {
349 keepalive =
350 keepalive.with_interval(options.tcp_keepalive_interval);
351 }
352
353 socket.set_tcp_keepalive(&keepalive).map_err(|e| {
358 Error::Connection(format!(
359 "Failed to set TCP keepalive: {}",
360 e
361 ))
362 })?;
363
364 std::mem::forget(socket);
366 }
367
368 #[cfg(windows)]
369 if options.tcp_keepalive {
370 use socket2::{
371 Socket,
372 TcpKeepalive,
373 };
374 use std::os::windows::io::{
375 AsRawSocket,
376 FromRawSocket,
377 };
378
379 let socket =
380 unsafe { Socket::from_raw_socket(stream.as_raw_socket()) };
381
382 let keepalive = TcpKeepalive::new()
383 .with_time(options.tcp_keepalive_idle)
384 .with_interval(options.tcp_keepalive_interval);
385
386 socket.set_tcp_keepalive(&keepalive).map_err(|e| {
387 Error::Connection(format!(
388 "Failed to set TCP keepalive: {}",
389 e
390 ))
391 })?;
392
393 std::mem::forget(socket);
395 }
396
397 let connector = TlsConnector::from(ssl_config);
399 let server_name_to_use = server_name.unwrap_or(host);
400
401 let domain =
402 ServerName::try_from(server_name_to_use).map_err(|e| {
403 Error::Connection(format!(
404 "Invalid server name '{}': {}",
405 server_name_to_use, e
406 ))
407 })?;
408
409 let tls_stream =
410 connector.connect(domain, stream).await.map_err(|e| {
411 Error::Connection(format!("TLS handshake failed: {}", e))
412 })?;
413
414 Ok(Self::new_tls(tls_stream))
415 }
416
417 pub async fn read_varint(&mut self) -> Result<u64> {
419 WireFormat::read_varint64(&mut self.reader).await
420 }
421
422 pub async fn write_varint(&mut self, value: u64) -> Result<()> {
424 WireFormat::write_varint64(&mut self.writer, value).await
425 }
426
427 pub async fn read_u8(&mut self) -> Result<u8> {
429 Ok(self.reader.read_u8().await?)
430 }
431
432 pub async fn read_u16(&mut self) -> Result<u16> {
434 Ok(self.reader.read_u16_le().await?)
435 }
436
437 pub async fn read_u32(&mut self) -> Result<u32> {
439 Ok(self.reader.read_u32_le().await?)
440 }
441
442 pub async fn read_u64(&mut self) -> Result<u64> {
444 Ok(self.reader.read_u64_le().await?)
445 }
446
447 pub async fn read_i8(&mut self) -> Result<i8> {
449 Ok(self.reader.read_i8().await?)
450 }
451
452 pub async fn read_i16(&mut self) -> Result<i16> {
454 Ok(self.reader.read_i16_le().await?)
455 }
456
457 pub async fn read_i32(&mut self) -> Result<i32> {
459 Ok(self.reader.read_i32_le().await?)
460 }
461
462 pub async fn read_i64(&mut self) -> Result<i64> {
464 Ok(self.reader.read_i64_le().await?)
465 }
466
467 pub async fn write_u8(&mut self, value: u8) -> Result<()> {
469 Ok(self.writer.write_u8(value).await?)
470 }
471
472 pub async fn write_u16(&mut self, value: u16) -> Result<()> {
474 Ok(self.writer.write_u16_le(value).await?)
475 }
476
477 pub async fn write_u32(&mut self, value: u32) -> Result<()> {
479 Ok(self.writer.write_u32_le(value).await?)
480 }
481
482 pub async fn write_u64(&mut self, value: u64) -> Result<()> {
484 Ok(self.writer.write_u64_le(value).await?)
485 }
486
487 pub async fn write_u128(&mut self, value: u128) -> Result<()> {
489 Ok(self.writer.write_u128_le(value).await?)
490 }
491
492 pub async fn write_i8(&mut self, value: i8) -> Result<()> {
494 Ok(self.writer.write_i8(value).await?)
495 }
496
497 pub async fn write_i16(&mut self, value: i16) -> Result<()> {
499 Ok(self.writer.write_i16_le(value).await?)
500 }
501
502 pub async fn write_i32(&mut self, value: i32) -> Result<()> {
504 Ok(self.writer.write_i32_le(value).await?)
505 }
506
507 pub async fn write_i64(&mut self, value: i64) -> Result<()> {
509 Ok(self.writer.write_i64_le(value).await?)
510 }
511
512 pub async fn read_string(&mut self) -> Result<String> {
514 WireFormat::read_string(&mut self.reader).await
515 }
516
517 pub async fn write_string(&mut self, s: &str) -> Result<()> {
519 WireFormat::write_string(&mut self.writer, s).await
520 }
521
522 pub async fn write_quoted_string(&mut self, s: &str) -> Result<()> {
524 WireFormat::write_quoted_string(&mut self.writer, s).await
525 }
526
527 pub async fn read_bytes(&mut self, len: usize) -> Result<Bytes> {
529 let mut buf = vec![0u8; len];
530 self.reader.read_exact(&mut buf).await?;
531 Ok(Bytes::from(buf))
532 }
533
534 pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
536 self.reader.read_exact(buf).await?;
537 Ok(())
538 }
539
540 pub async fn write_bytes(&mut self, data: &[u8]) -> Result<()> {
542 Ok(self.writer.write_all(data).await?)
543 }
544
545 pub async fn flush(&mut self) -> Result<()> {
547 Ok(self.writer.flush().await?)
548 }
549
550 pub async fn read_packet(&mut self) -> Result<Bytes> {
553 let len = self.read_varint().await? as usize;
554
555 if len == 0 {
556 return Ok(Bytes::new());
557 }
558
559 if len > 0x40000000 {
560 return Err(Error::Protocol(format!("Packet too large: {}", len)));
562 }
563
564 self.read_bytes(len).await
565 }
566
567 pub async fn write_packet(&mut self, data: &[u8]) -> Result<()> {
569 self.write_varint(data.len() as u64).await?;
570 self.write_bytes(data).await?;
571 Ok(())
572 }
573}
574
575#[cfg(test)]
576#[cfg_attr(coverage_nightly, coverage(off))]
577mod tests {
578 use super::*;
579
580 #[test]
584 fn test_buffer_sizes() {
585 assert_eq!(DEFAULT_READ_BUFFER_SIZE, 8192);
586 assert_eq!(DEFAULT_WRITE_BUFFER_SIZE, 8192);
587 }
588
589 }