1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6#[cfg(feature = "inner_pool")]
7use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use strum::Display;
10use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
11use tokio::sync::{broadcast, mpsc};
12use tokio::task::{AbortHandle, JoinSet};
13use tokio_rustls::rustls;
14
15use super::internal::{InternalConn, PendingQuery};
16use super::{ArrowOptions, CompressionMethod, Event};
17use crate::client::chunk::{ChunkReader, ChunkWriter};
18use crate::flags::{conn_read_buffer_size, conn_write_buffer_size};
19use crate::io::{ClickHouseRead, ClickHouseWrite};
20use crate::native::protocol::{
21 ClientHello, DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM, DBMS_TCP_PROTOCOL_VERSION, ServerHello,
22};
23use crate::prelude::*;
24use crate::{ClientOptions, Message, Operation};
25
26type IoHandle<T> = JoinSet<VecDeque<PendingQuery<T>>>;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
31pub enum ConnectionStatus {
32 Open,
33 Closed,
34 Error,
35}
36
37impl From<u8> for ConnectionStatus {
38 fn from(value: u8) -> Self {
39 match value {
40 0 => Self::Open,
41 1 => Self::Closed,
42 _ => Self::Error,
43 }
44 }
45}
46
47impl From<ConnectionStatus> for u8 {
48 fn from(value: ConnectionStatus) -> u8 { value as u8 }
49}
50
51#[derive(Debug, Clone, Copy)]
53pub(crate) struct ClientMetadata {
54 pub(crate) client_id: u16,
55 pub(crate) compression: CompressionMethod,
56 pub(crate) arrow_options: ArrowOptions,
57}
58
59impl ClientMetadata {
60 pub(crate) fn disable_compression(self) -> Self {
62 Self {
63 client_id: self.client_id,
64 compression: CompressionMethod::None,
65 arrow_options: self.arrow_options,
66 }
67 }
68
69 pub(crate) fn compression_settings(self) -> Settings {
71 match self.compression {
72 CompressionMethod::None | CompressionMethod::LZ4 => Settings::default(),
73 CompressionMethod::ZSTD => vec![
74 ("network_compression_method", "zstd"),
75 ("network_zstd_compression_level", "1"),
76 ]
77 .into(),
78 }
79 }
80}
81
82#[derive(Debug)]
84struct ConnectState<T: Send + Sync + 'static> {
85 status: Arc<AtomicU8>,
86 channel: mpsc::Sender<Message<T>>,
87 #[expect(unused)]
88 handle: AbortHandle,
89}
90
91#[derive(Debug)]
93pub(super) struct Connection<T: ClientFormat> {
94 #[expect(unused)]
95 addrs: Arc<[SocketAddr]>,
96 options: Arc<ClientOptions>,
97 io_task: Arc<Mutex<IoHandle<T::Data>>>,
98 metadata: ClientMetadata,
99 #[cfg(not(feature = "inner_pool"))]
100 state: Arc<ConnectState<T::Data>>,
101 #[cfg(feature = "inner_pool")]
103 state: Vec<ArcSwap<ConnectState<T::Data>>>,
104 #[cfg(feature = "inner_pool")]
105 load_balancer: Arc<load::AtomicLoad>,
106}
107
108impl<T: ClientFormat> Connection<T> {
109 #[instrument(
110 level = "trace",
111 name = "clickhouse.connection.create",
112 skip_all,
113 fields(
114 clickhouse.client.id = client_id,
115 db.system = "clickhouse",
116 db.operation = "connect",
117 network.transport = ?if options.use_tls { "tls" } else { "tcp" }
118 ),
119 err
120 )]
121 pub(crate) async fn connect(
122 client_id: u16,
123 addrs: Vec<SocketAddr>,
124 options: ClientOptions,
125 events: Arc<broadcast::Sender<Event>>,
126 trace_ctx: TraceContext,
127 ) -> Result<Self> {
128 let span = Span::current();
129 span.in_scope(|| trace!({ {ATT_CID} = client_id }, "connecting stream"));
130 let _ = trace_ctx.link(&span);
131
132 let mut io_task = JoinSet::new();
134
135 let metadata = ClientMetadata {
137 client_id,
138 compression: options.compression,
139 arrow_options: options.ext.arrow.unwrap_or_default(),
140 };
141
142 if options.use_tls {
144 drop(rustls::crypto::aws_lc_rs::default_provider().install_default());
145 }
146
147 let state = Arc::new(
149 Self::connect_inner(&addrs, &mut io_task, Arc::clone(&events), &options, metadata)
150 .await?,
151 );
152
153 #[cfg(feature = "inner_pool")]
154 let mut state = vec![ArcSwap::from(state)];
155
156 #[cfg(feature = "inner_pool")]
159 let inner_pool_size = options
160 .ext
161 .fast_mode_size
162 .map_or(load::DEFAULT_MAX_CONNECTIONS, |s| s.clamp(2, load::ABSOLUTE_MAX_CONNECTIONS));
163
164 #[cfg(feature = "inner_pool")]
165 for _ in 0..inner_pool_size.saturating_sub(1) {
166 let events = Arc::clone(&events);
167 state.push(ArcSwap::from(Arc::new(
168 Self::connect_inner(&addrs, &mut io_task, events, &options, metadata).await?,
169 )));
170 }
171
172 Ok(Self {
173 addrs: Arc::from(addrs.as_slice()),
174 io_task: Arc::new(Mutex::new(io_task)),
175 options: Arc::new(options),
176 metadata,
177 state,
178 #[cfg(feature = "inner_pool")]
179 load_balancer: Arc::new(load::AtomicLoad::new(inner_pool_size)),
180 })
181 }
182
183 async fn connect_inner(
184 addrs: &[SocketAddr],
185 io_task: &mut IoHandle<T::Data>,
186 events: Arc<broadcast::Sender<Event>>,
187 options: &ClientOptions,
188 metadata: ClientMetadata,
189 ) -> Result<ConnectState<T::Data>> {
190 if options.use_tls {
191 let tls_stream = super::tcp::connect_tls(addrs, options.domain.as_deref()).await?;
192 Self::establish_connection(tls_stream, io_task, events, options, metadata).await
193 } else {
194 let tcp_stream = super::tcp::connect_socket(addrs).await?;
195 Self::establish_connection(tcp_stream, io_task, events, options, metadata).await
196 }
197 }
198
199 async fn establish_connection<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
200 mut stream: RW,
201 io_task: &mut IoHandle<T::Data>,
202 events: Arc<broadcast::Sender<Event>>,
203 options: &ClientOptions,
204 metadata: ClientMetadata,
205 ) -> Result<ConnectState<T::Data>> {
206 let cid = metadata.client_id;
207
208 let status = Arc::new(AtomicU8::new(ConnectionStatus::Open.into()));
210 let internal_status = Arc::clone(&status);
211
212 let server_hello = Arc::new(Self::perform_handshake(&mut stream, cid, options).await?);
214
215 let (operations, op_rx) = mpsc::channel(InternalConn::<T>::CAPACITY);
217
218 let (reader, writer) = tokio::io::split(stream);
220
221 let handle = io_task.spawn(
223 async move {
224 let chunk_send = server_hello.supports_chunked_send();
225 let chunk_recv = server_hello.supports_chunked_recv();
226
227 let mut internal = InternalConn::<T>::new(metadata, events, server_hello);
229
230 let reader = BufReader::with_capacity(conn_read_buffer_size(), reader);
231 let writer = BufWriter::with_capacity(conn_write_buffer_size(), writer);
232
233 let result = match (chunk_send, chunk_recv) {
234 (true, true) => {
235 let reader = ChunkReader::new(reader);
237 let writer = ChunkWriter::new(writer);
238 internal.run_chunked(reader, writer, op_rx).await
239 }
240 (true, false) => {
241 let writer = ChunkWriter::new(writer);
242 internal.run_chunked(reader, writer, op_rx).await
243 }
244 (false, true) => {
245 let reader = ChunkReader::new(reader);
247 internal.run(reader, writer, op_rx).await
248 }
249 (false, false) => internal.run(reader, writer, op_rx).await,
250 };
251
252 if let Err(error) = result {
253 error!(?error, "Internal connection lost");
254 internal_status.store(ConnectionStatus::Error.into(), Ordering::Release);
255 } else {
256 info!("Internal connection closed");
257 internal_status.store(ConnectionStatus::Closed.into(), Ordering::Release);
258 }
259 trace!("Exiting inner connection");
260 VecDeque::new()
262 }
263 .instrument(trace_span!(
264 "clickhouse.connection.io",
265 { ATT_CID } = cid,
266 otel.kind = "server",
267 peer.service = "clickhouse",
268 )),
269 );
270
271 trace!({ ATT_CID } = cid, "spawned connection loop");
272 Ok(ConnectState { status, channel: operations, handle })
273 }
274
275 #[instrument(
276 level = "trace",
277 skip_all,
278 fields(
279 db.system = "clickhouse",
280 db.operation = op.as_ref(),
281 clickhouse.client.id = self.metadata.client_id,
282 clickhouse.query.id = %qid,
283 )
284 )]
285 pub(crate) async fn send_operation(
286 &self,
287 op: Operation<T::Data>,
288 qid: Qid,
289 finished: bool,
290 ) -> Result<usize> {
291 #[cfg(not(feature = "inner_pool"))]
292 let conn_idx = 0; #[cfg(feature = "inner_pool")]
294 let conn_idx = {
295 let key = (matches!(op, Operation::Query { .. } if !finished)
296 || matches!(op, Operation::Insert { .. } | Operation::InsertMany { .. }))
297 .then(|| qid.key());
298 self.load_balancer.assign(key, op.weight(finished) as usize)
299 };
300
301 let span = trace_span!(
302 "clickhouse.connection.send_operation",
303 { ATT_CID } = self.metadata.client_id,
304 { ATT_QID } = %qid,
305 db.system = "clickhouse",
306 db.operation = op.as_ref(),
307 finished
308 );
309
310 #[cfg(not(feature = "inner_pool"))]
312 let state = &self.state;
313 #[cfg(feature = "inner_pool")]
314 let state = self.state[conn_idx].load();
315
316 #[cfg(not(feature = "inner_pool"))]
318 let status = self.state.status.load(Ordering::Acquire);
319 #[cfg(feature = "inner_pool")]
320 let status = state.status.load(Ordering::Acquire);
321
322 if status > 0 {
324 return Err(Error::Client("No active connection".into()));
325 }
326
327 let result = state.channel.send(Message::Operation { qid, op }).instrument(span).await;
328 if result.is_err() {
329 error!({ ATT_QID } = %qid, "failed to send message");
330 self.update_status(conn_idx, ConnectionStatus::Closed);
331 return Err(Error::ChannelClosed);
332 }
333
334 Ok(conn_idx)
335 }
336
337 #[instrument(
338 level = "trace",
339 skip_all,
340 fields(db.system = "clickhouse", clickhouse.client.id = self.metadata.client_id)
341 )]
342 pub(crate) async fn shutdown(&self) -> Result<()> {
343 trace!({ ATT_CID } = self.metadata.client_id, "Shutting down connections");
344 #[cfg(not(feature = "inner_pool"))]
345 {
346 if self.state.channel.send(Message::Shutdown).await.is_err() {
347 error!("Failed to shutdown connection");
348 }
349 }
350 #[cfg(feature = "inner_pool")]
351 {
352 for (i, conn_state) in self.state.iter().enumerate() {
353 let state = conn_state.load();
354 debug!("Shutting down connection {i}");
355 if state.channel.send(Message::Shutdown).await.is_err() {
357 error!("Failed to shutdown connection {i}");
358 }
359 }
360 }
361 self.io_task.lock().abort_all();
362 Ok(())
363 }
364
365 pub(crate) async fn check_connection(&self, ping: bool) -> Result<()> {
366 self.check_channel()?;
368
369 if !ping {
370 return Ok(());
371 }
372
373 let (response, rx) = tokio::sync::oneshot::channel();
375 let cid = self.metadata.client_id;
376 let qid = Qid::default();
377 let idx = self
378 .send_operation(Operation::Ping { response }, qid, true)
379 .instrument(trace_span!(
380 "clickhouse.connection.ping",
381 { ATT_CID } = cid,
382 { ATT_QID } = %qid,
383 db.system = "clickhouse",
384 ))
385 .await?;
386
387 rx.await
388 .map_err(|_| {
389 self.update_status(idx, ConnectionStatus::Closed);
390 Error::ChannelClosed
391 })?
392 .inspect_err(|error| {
393 self.update_status(idx, ConnectionStatus::Error);
394 error!(?error, { ATT_CID } = cid, "Ping failed");
395 })?;
396
397 Ok(())
398 }
399
400 fn update_status(&self, idx: usize, status: ConnectionStatus) {
401 trace!({ ATT_CID } = self.metadata.client_id, ?status, "Updating status conn {idx}");
402
403 #[cfg(not(feature = "inner_pool"))]
404 let state = &self.state;
405 #[cfg(feature = "inner_pool")]
406 let state = self.state[idx].load();
407
408 state.status.store(status.into(), Ordering::Release);
409 }
410
411 async fn perform_handshake<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
412 stream: &mut RW,
413 client_id: u16,
414 options: &ClientOptions,
415 ) -> Result<ServerHello> {
416 use crate::client::reader::Reader;
417 use crate::client::writer::Writer;
418
419 let client_hello = ClientHello {
420 default_database: options.default_database.clone(),
421 username: options.username.clone(),
422 password: options.password.get().to_string(),
423 };
424
425 Writer::send_hello(stream, client_hello)
427 .await
428 .inspect_err(|error| error!(?error, { ATT_CID } = client_id, "Failed to send hello"))?;
429
430 let chunked_modes = (options.ext.chunked_send, options.ext.chunked_recv);
432 let server_hello =
433 Reader::receive_hello(stream, DBMS_TCP_PROTOCOL_VERSION, chunked_modes, client_id)
434 .await?;
435 trace!({ ATT_CID } = client_id, ?server_hello, "Finished handshake");
436
437 if server_hello.revision_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
438 Writer::send_addendum(stream, server_hello.revision_version, &server_hello).await?;
439 stream.flush().await.inspect_err(|error| error!(?error, "Error writing addendum"))?;
440 }
441
442 Ok(server_hello)
443 }
444}
445
446impl<T: ClientFormat> Connection<T> {
447 pub(crate) fn metadata(&self) -> ClientMetadata { self.metadata }
448
449 pub(crate) fn database(&self) -> &str { &self.options.default_database }
450
451 #[cfg(feature = "inner_pool")]
452 pub(crate) fn finish(&self, conn_idx: usize, weight: u8) {
453 self.load_balancer.finish(usize::from(weight), conn_idx);
454 }
455
456 pub(crate) fn status(&self) -> ConnectionStatus {
457 #[cfg(not(feature = "inner_pool"))]
458 let status = ConnectionStatus::from(self.state.status.load(Ordering::Acquire));
459
460 #[cfg(feature = "inner_pool")]
463 let status = ConnectionStatus::from(self.state[0].load().status.load(Ordering::Acquire));
464
465 status
466 }
467
468 fn check_channel(&self) -> Result<()> {
469 #[cfg(not(feature = "inner_pool"))]
470 {
471 if self.state.channel.is_closed() {
472 self.update_status(0, ConnectionStatus::Closed);
473 Err(Error::ChannelClosed)
474 } else {
475 Ok(())
476 }
477 }
478
479 #[cfg(feature = "inner_pool")]
482 if self.state[0].load().channel.is_closed() {
483 self.update_status(0, ConnectionStatus::Closed);
484 Err(Error::ChannelClosed)
485 } else {
486 Ok(())
487 }
488 }
489}
490
491impl<T: ClientFormat> Drop for Connection<T> {
492 fn drop(&mut self) {
493 trace!({ ATT_CID } = self.metadata.client_id, "Connection dropped");
494 self.io_task.lock().abort_all();
495 }
496}
497
498#[cfg(feature = "inner_pool")]
499mod load {
500 use std::sync::atomic::{AtomicUsize, Ordering};
501
502 pub(super) const DEFAULT_MAX_CONNECTIONS: u8 = 4;
503 pub(super) const ABSOLUTE_MAX_CONNECTIONS: u8 = 16;
504
505 #[derive(Debug)]
511 pub(super) struct AtomicLoad {
512 load_counters: Box<[AtomicUsize]>,
513 max_connections: u8,
514 }
515
516 impl AtomicLoad {
517 pub(super) fn new(max_connections: u8) -> Self {
523 assert!(max_connections > 0, "At least 1 connection required");
524 assert!(
525 max_connections <= ABSOLUTE_MAX_CONNECTIONS,
526 "Max {ABSOLUTE_MAX_CONNECTIONS} connections supported"
527 );
528
529 let load_counters = (0..max_connections)
530 .map(|_| AtomicUsize::new(0))
531 .collect::<Vec<_>>()
532 .into_boxed_slice();
533
534 Self { load_counters, max_connections }
535 }
536
537 pub(super) fn assign(&self, key: Option<usize>, weight: usize) -> usize {
544 let idx = if let Some(k) = key {
545 k % usize::from(self.max_connections)
546 } else {
547 (0..self.max_connections)
549 .min_by_key(|&i| self.load_counters[usize::from(i)].load(Ordering::Acquire))
550 .unwrap_or(0)
551 .into()
552 };
553
554 if weight > 0 {
555 let _ = self.load_counters[idx].fetch_add(weight, Ordering::SeqCst);
556 }
557 idx
558 }
559
560 pub(crate) fn finish(&self, weight: usize, idx: usize) {
562 if weight == 0 || idx >= self.load_counters.len() {
563 return;
564 }
565 let _ = self.load_counters[idx].fetch_sub(weight, Ordering::SeqCst);
566 }
567 }
568
569 #[cfg(test)]
570 mod tests {
571 use super::*;
572
573 #[test]
574 fn test_atomic_load_supports_16_connections() {
575 let load = AtomicLoad::new(16);
576
577 let assignments: Vec<_> = (0..1000).map(|_| load.assign(None, 1)).collect();
579
580 for i in 0..16 {
582 let count = assignments.iter().filter(|&&idx| idx == i).count();
583 assert!(
584 (50..=75).contains(&count),
585 "Connection {i} got {count} assignments (expected ~62)"
586 );
587 }
588 }
589
590 #[test]
591 fn test_no_overflow_with_heavy_inserts() {
592 let load = AtomicLoad::new(4);
593
594 for _ in 0..1000 {
596 let idx = load.assign(None, 7);
597 load.finish(7, idx);
599 }
600
601 for i in 0..4 {
603 assert_eq!(load.load_counters[i].load(Ordering::Acquire), 0);
604 }
605 }
606
607 #[test]
608 fn test_deterministic_assignment_by_key() {
609 let load = AtomicLoad::new(8);
610
611 let key = 12345;
613 let idx1 = load.assign(Some(key), 1);
614 let idx2 = load.assign(Some(key), 1);
615 let idx3 = load.assign(Some(key), 1);
616
617 assert_eq!(idx1, idx2);
618 assert_eq!(idx2, idx3);
619 assert_eq!(idx1, key % 8);
620 }
621
622 #[test]
623 fn test_least_loaded_selection() {
624 let load = AtomicLoad::new(4);
625
626 load.load_counters[0].store(100, Ordering::Release);
628 load.load_counters[1].store(50, Ordering::Release);
629 load.load_counters[2].store(200, Ordering::Release);
630 load.load_counters[3].store(75, Ordering::Release);
631
632 let idx = load.assign(None, 1);
634 assert_eq!(idx, 1);
635 }
636
637 #[test]
638 #[should_panic(expected = "Max 16 connections")]
639 fn test_rejects_too_many_connections() { drop(AtomicLoad::new(17)); }
640
641 #[test]
642 #[should_panic(expected = "At least 1 connection")]
643 fn test_rejects_zero_connections() { drop(AtomicLoad::new(0)); }
644
645 #[test]
646 fn test_zero_weight_returns_index_without_increment() {
647 let load = AtomicLoad::new(4);
648
649 let idx = load.assign(None, 0);
650 assert!(idx < 4);
651
652 for i in 0..4 {
654 assert_eq!(load.load_counters[i].load(Ordering::Acquire), 0);
655 }
656 }
657
658 #[test]
659 fn test_finish_with_invalid_index() {
660 let load = AtomicLoad::new(4);
661
662 let idx = load.assign(None, 10);
664 load.load_counters[idx].store(10, Ordering::Release);
665
666 load.finish(5, 999);
668
669 assert_eq!(load.load_counters[idx].load(Ordering::Acquire), 10);
671 }
672
673 #[test]
674 fn test_finish_with_zero_weight() {
675 let load = AtomicLoad::new(4);
676
677 let idx = load.assign(None, 10);
679 load.load_counters[idx].store(10, Ordering::Release);
680
681 load.finish(0, idx);
683
684 assert_eq!(load.load_counters[idx].load(Ordering::Acquire), 10);
686
687 load.finish(0, 999);
689 }
690 }
691}