1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6#[cfg(feature = "inner_pool")]
7use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use strum::Display;
10use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
11use tokio::sync::{broadcast, mpsc};
12use tokio::task::{AbortHandle, JoinSet};
13use tokio_rustls::rustls;
14
15use super::internal::{InternalConn, PendingQuery};
16use super::{ArrowOptions, CompressionMethod, Event};
17use crate::client::chunk::{ChunkReader, ChunkWriter};
18use crate::flags::{conn_read_buffer_size, conn_write_buffer_size};
19use crate::io::{ClickHouseRead, ClickHouseWrite};
20use crate::native::protocol::{
21 ClientHello, DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM, DBMS_TCP_PROTOCOL_VERSION, ServerHello,
22};
23use crate::prelude::*;
24use crate::{ClientOptions, Message, Operation};
25
26type IoHandle<T> = JoinSet<VecDeque<PendingQuery<T>>>;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
31pub enum ConnectionStatus {
32 Open,
33 Closed,
34 Error,
35}
36
37impl From<u8> for ConnectionStatus {
38 fn from(value: u8) -> Self {
39 match value {
40 0 => Self::Open,
41 1 => Self::Closed,
42 _ => Self::Error,
43 }
44 }
45}
46
47impl From<ConnectionStatus> for u8 {
48 fn from(value: ConnectionStatus) -> u8 { value as u8 }
49}
50
51#[derive(Debug, Clone, Copy)]
53pub(crate) struct ClientMetadata {
54 pub(crate) client_id: u16,
55 pub(crate) compression: CompressionMethod,
56 pub(crate) arrow_options: ArrowOptions,
57}
58
59impl ClientMetadata {
60 pub(crate) fn disable_compression(self) -> Self {
62 Self {
63 client_id: self.client_id,
64 compression: CompressionMethod::None,
65 arrow_options: self.arrow_options,
66 }
67 }
68
69 pub(crate) fn compression_settings(self) -> Settings {
71 match self.compression {
72 CompressionMethod::None | CompressionMethod::LZ4 => Settings::default(),
73 CompressionMethod::ZSTD => vec![
74 ("network_compression_method", "zstd"),
75 ("network_zstd_compression_level", "1"),
76 ]
77 .into(),
78 }
79 }
80}
81
82#[derive(Debug)]
84struct ConnectState<T: Send + Sync + 'static> {
85 status: Arc<AtomicU8>,
86 channel: mpsc::Sender<Message<T>>,
87 #[expect(unused)]
88 handle: AbortHandle,
89}
90
91#[derive(Debug)]
93pub(super) struct Connection<T: ClientFormat> {
94 #[expect(unused)]
95 addrs: Arc<[SocketAddr]>,
96 options: Arc<ClientOptions>,
97 io_task: Arc<Mutex<IoHandle<T::Data>>>,
98 metadata: ClientMetadata,
99 #[cfg(not(feature = "inner_pool"))]
100 state: Arc<ConnectState<T::Data>>,
101 #[cfg(feature = "inner_pool")]
103 state: Vec<ArcSwap<ConnectState<T::Data>>>,
104 #[cfg(feature = "inner_pool")]
105 load_balancer: Arc<load::AtomicLoad>,
106}
107
108impl<T: ClientFormat> Connection<T> {
109 #[instrument(
110 level = "trace",
111 name = "clickhouse.connection.create",
112 skip_all,
113 fields(
114 clickhouse.client.id = client_id,
115 db.system = "clickhouse",
116 db.operation = "connect",
117 network.transport = ?if options.use_tls { "tls" } else { "tcp" }
118 ),
119 err
120 )]
121 pub(crate) async fn connect(
122 client_id: u16,
123 addrs: Vec<SocketAddr>,
124 options: ClientOptions,
125 events: Arc<broadcast::Sender<Event>>,
126 trace_ctx: TraceContext,
127 ) -> Result<Self> {
128 let span = Span::current();
129 span.in_scope(|| trace!({ {ATT_CID} = client_id }, "connecting stream"));
130 let _ = trace_ctx.link(&span);
131
132 let mut io_task = JoinSet::new();
134
135 let metadata = ClientMetadata {
137 client_id,
138 compression: options.compression,
139 arrow_options: options.ext.arrow.unwrap_or_default(),
140 };
141
142 if options.use_tls {
144 drop(rustls::crypto::aws_lc_rs::default_provider().install_default());
145 }
146
147 let state = Arc::new(
149 Self::connect_inner(&addrs, &mut io_task, Arc::clone(&events), &options, metadata)
150 .await?,
151 );
152
153 #[cfg(feature = "inner_pool")]
154 let mut state = vec![ArcSwap::from(state)];
155
156 #[cfg(feature = "inner_pool")]
159 for _ in 0..options.ext.fast_mode_size.map_or(2, |s| s.clamp(2, 4)) {
160 let events = Arc::clone(&events);
161 state.push(ArcSwap::from(Arc::new(
162 Self::connect_inner(&addrs, &mut io_task, events, &options, metadata).await?,
163 )));
164 }
165
166 Ok(Self {
167 addrs: Arc::from(addrs.as_slice()),
168 io_task: Arc::new(Mutex::new(io_task)),
169 options: Arc::new(options),
170 metadata,
171 state,
172 #[cfg(feature = "inner_pool")]
175 load_balancer: Arc::new(load::AtomicLoad::new(2)),
176 })
177 }
178
179 async fn connect_inner(
180 addrs: &[SocketAddr],
181 io_task: &mut IoHandle<T::Data>,
182 events: Arc<broadcast::Sender<Event>>,
183 options: &ClientOptions,
184 metadata: ClientMetadata,
185 ) -> Result<ConnectState<T::Data>> {
186 if options.use_tls {
187 let tls_stream = super::tcp::connect_tls(addrs, options.domain.as_deref()).await?;
188 Self::establish_connection(tls_stream, io_task, events, options, metadata).await
189 } else {
190 let tcp_stream = super::tcp::connect_socket(addrs).await?;
191 Self::establish_connection(tcp_stream, io_task, events, options, metadata).await
192 }
193 }
194
195 async fn establish_connection<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
196 mut stream: RW,
197 io_task: &mut IoHandle<T::Data>,
198 events: Arc<broadcast::Sender<Event>>,
199 options: &ClientOptions,
200 metadata: ClientMetadata,
201 ) -> Result<ConnectState<T::Data>> {
202 let cid = metadata.client_id;
203
204 let status = Arc::new(AtomicU8::new(ConnectionStatus::Open.into()));
206 let internal_status = Arc::clone(&status);
207
208 let server_hello = Arc::new(Self::perform_handshake(&mut stream, cid, options).await?);
210
211 let (operations, op_rx) = mpsc::channel(InternalConn::<T>::CAPACITY);
213
214 let (reader, writer) = tokio::io::split(stream);
216
217 let handle = io_task.spawn(
219 async move {
220 let chunk_send = server_hello.supports_chunked_send();
221 let chunk_recv = server_hello.supports_chunked_recv();
222
223 let mut internal = InternalConn::<T>::new(metadata, events, server_hello);
225
226 let reader = BufReader::with_capacity(conn_read_buffer_size(), reader);
227 let writer = BufWriter::with_capacity(conn_write_buffer_size(), writer);
228
229 let result = match (chunk_send, chunk_recv) {
230 (true, true) => {
231 let reader = ChunkReader::new(reader);
233 let writer = ChunkWriter::new(writer);
234 internal.run_chunked(reader, writer, op_rx).await
235 }
236 (true, false) => {
237 let writer = ChunkWriter::new(writer);
238 internal.run_chunked(reader, writer, op_rx).await
239 }
240 (false, true) => {
241 let reader = ChunkReader::new(reader);
243 internal.run(reader, writer, op_rx).await
244 }
245 (false, false) => internal.run(reader, writer, op_rx).await,
246 };
247
248 if let Err(error) = result {
249 error!(?error, "Internal connection lost");
250 internal_status.store(ConnectionStatus::Error.into(), Ordering::Release);
251 } else {
252 info!("Internal connection closed");
253 internal_status.store(ConnectionStatus::Closed.into(), Ordering::Release);
254 }
255 trace!("Exiting inner connection");
256 VecDeque::new()
258 }
259 .instrument(trace_span!(
260 "clickhouse.connection.io",
261 { ATT_CID } = cid,
262 otel.kind = "server",
263 peer.service = "clickhouse",
264 )),
265 );
266
267 trace!({ ATT_CID } = cid, "spawned connection loop");
268 Ok(ConnectState { status, channel: operations, handle })
269 }
270
271 #[instrument(
272 level = "trace",
273 skip_all,
274 fields(
275 db.system = "clickhouse",
276 db.operation = op.as_ref(),
277 clickhouse.client.id = self.metadata.client_id,
278 clickhouse.query.id = %qid,
279 )
280 )]
281 pub(crate) async fn send_operation(
282 &self,
283 op: Operation<T::Data>,
284 qid: Qid,
285 finished: bool,
286 ) -> Result<usize> {
287 #[cfg(not(feature = "inner_pool"))]
288 let conn_idx = 0; #[cfg(feature = "inner_pool")]
290 let conn_idx = {
291 let key = (matches!(op, Operation::Query { .. } if !finished)
292 || matches!(op, Operation::Insert { .. } | Operation::InsertMany { .. }))
293 .then(|| qid.key());
294 self.load_balancer.assign(key, op.weight(finished) as usize)
295 };
296
297 let span = trace_span!(
298 "clickhouse.connection.send_operation",
299 { ATT_CID } = self.metadata.client_id,
300 { ATT_QID } = %qid,
301 db.system = "clickhouse",
302 db.operation = op.as_ref(),
303 finished
304 );
305
306 #[cfg(not(feature = "inner_pool"))]
308 let state = &self.state;
309 #[cfg(feature = "inner_pool")]
310 let state = self.state[conn_idx].load();
311
312 #[cfg(not(feature = "inner_pool"))]
314 let status = self.state.status.load(Ordering::Acquire);
315 #[cfg(feature = "inner_pool")]
316 let status = state.status.load(Ordering::Acquire);
317
318 if status > 0 {
320 return Err(Error::Client("No active connection".into()));
321 }
322
323 let result = state.channel.send(Message::Operation { qid, op }).instrument(span).await;
324 if result.is_err() {
325 error!({ ATT_QID } = %qid, "failed to send message");
326 self.update_status(conn_idx, ConnectionStatus::Closed);
327 return Err(Error::ChannelClosed);
328 }
329
330 Ok(conn_idx)
331 }
332
333 #[instrument(
334 level = "trace",
335 skip_all,
336 fields(db.system = "clickhouse", clickhouse.client.id = self.metadata.client_id)
337 )]
338 pub(crate) async fn shutdown(&self) -> Result<()> {
339 trace!({ ATT_CID } = self.metadata.client_id, "Shutting down connections");
340 #[cfg(not(feature = "inner_pool"))]
341 {
342 if self.state.channel.send(Message::Shutdown).await.is_err() {
343 error!("Failed to shutdown connection");
344 }
345 }
346 #[cfg(feature = "inner_pool")]
347 {
348 for (i, conn_state) in self.state.iter().enumerate() {
349 let state = conn_state.load();
350 debug!("Shutting down connection {i}");
351 if state.channel.send(Message::Shutdown).await.is_err() {
353 error!("Failed to shutdown connection {i}");
354 }
355 }
356 }
357 self.io_task.lock().abort_all();
358 Ok(())
359 }
360
361 pub(crate) async fn check_connection(&self, ping: bool) -> Result<()> {
362 self.check_channel()?;
364
365 if !ping {
366 return Ok(());
367 }
368
369 let (response, rx) = tokio::sync::oneshot::channel();
371 let cid = self.metadata.client_id;
372 let qid = Qid::default();
373 let idx = self
374 .send_operation(Operation::Ping { response }, qid, true)
375 .instrument(trace_span!(
376 "clickhouse.connection.ping",
377 { ATT_CID } = cid,
378 { ATT_QID } = %qid,
379 db.system = "clickhouse",
380 ))
381 .await?;
382
383 rx.await
384 .map_err(|_| {
385 self.update_status(idx, ConnectionStatus::Closed);
386 Error::ChannelClosed
387 })?
388 .inspect_err(|error| {
389 self.update_status(idx, ConnectionStatus::Error);
390 error!(?error, { ATT_CID } = cid, "Ping failed");
391 })?;
392
393 Ok(())
394 }
395
396 fn update_status(&self, idx: usize, status: ConnectionStatus) {
397 trace!({ ATT_CID } = self.metadata.client_id, ?status, "Updating status conn {idx}");
398
399 #[cfg(not(feature = "inner_pool"))]
400 let state = &self.state;
401 #[cfg(feature = "inner_pool")]
402 let state = self.state[idx].load();
403
404 state.status.store(status.into(), Ordering::Release);
405 }
406
407 async fn perform_handshake<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
408 stream: &mut RW,
409 client_id: u16,
410 options: &ClientOptions,
411 ) -> Result<ServerHello> {
412 use crate::client::reader::Reader;
413 use crate::client::writer::Writer;
414
415 let client_hello = ClientHello {
416 default_database: options.default_database.clone(),
417 username: options.username.clone(),
418 password: options.password.get().to_string(),
419 };
420
421 Writer::send_hello(stream, client_hello)
423 .await
424 .inspect_err(|error| error!(?error, { ATT_CID } = client_id, "Failed to send hello"))?;
425
426 let chunked_modes = (options.ext.chunked_send, options.ext.chunked_recv);
428 let server_hello =
429 Reader::receive_hello(stream, DBMS_TCP_PROTOCOL_VERSION, chunked_modes, client_id)
430 .await?;
431 trace!({ ATT_CID } = client_id, ?server_hello, "Finished handshake");
432
433 if server_hello.revision_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
434 Writer::send_addendum(stream, server_hello.revision_version, &server_hello).await?;
435 stream.flush().await.inspect_err(|error| error!(?error, "Error writing addendum"))?;
436 }
437
438 Ok(server_hello)
439 }
440}
441
442impl<T: ClientFormat> Connection<T> {
443 pub(crate) fn metadata(&self) -> ClientMetadata { self.metadata }
444
445 pub(crate) fn database(&self) -> &str { &self.options.default_database }
446
447 #[cfg(feature = "inner_pool")]
448 pub(crate) fn finish(&self, conn_idx: usize, weight: u8) {
449 self.load_balancer.finish(usize::from(weight), conn_idx);
450 }
451
452 pub(crate) fn status(&self) -> ConnectionStatus {
453 #[cfg(not(feature = "inner_pool"))]
454 let status = ConnectionStatus::from(self.state.status.load(Ordering::Acquire));
455
456 #[cfg(feature = "inner_pool")]
459 let status = ConnectionStatus::from(self.state[0].load().status.load(Ordering::Acquire));
460
461 status
462 }
463
464 fn check_channel(&self) -> Result<()> {
465 #[cfg(not(feature = "inner_pool"))]
466 {
467 if self.state.channel.is_closed() {
468 self.update_status(0, ConnectionStatus::Closed);
469 Err(Error::ChannelClosed)
470 } else {
471 Ok(())
472 }
473 }
474
475 #[cfg(feature = "inner_pool")]
478 if self.state[0].load().channel.is_closed() {
479 self.update_status(0, ConnectionStatus::Closed);
480 Err(Error::ChannelClosed)
481 } else {
482 Ok(())
483 }
484 }
485}
486
487impl<T: ClientFormat> Drop for Connection<T> {
488 fn drop(&mut self) {
489 trace!({ ATT_CID } = self.metadata.client_id, "Connection dropped");
490 self.io_task.lock().abort_all();
491 }
492}
493
494#[cfg(feature = "inner_pool")]
495mod load {
496 use std::sync::atomic::{AtomicUsize, Ordering};
497
498 #[derive(Debug)]
499 pub(super) struct AtomicLoad {
500 load_counter: AtomicUsize,
501 max_connections: u8,
502 }
503
504 impl AtomicLoad {
505 pub(super) fn new(max_connections: u8) -> Self {
510 assert!(max_connections <= 4, "Max 4 connections supported");
511 assert!(max_connections > 0, "At leat 1 connection required");
512 Self { load_counter: AtomicUsize::new(0), max_connections }
513 }
514
515 pub(super) fn assign(&self, key: Option<usize>, weight: usize) -> usize {
520 let idx = if let Some(k) = key {
521 k % usize::from(self.max_connections) } else {
523 let load = self.load_counter.load(Ordering::Acquire);
525 usize::from(
526 (0..self.max_connections)
527 .min_by_key(|&i| (load >> (i * 8)) & 0xFF)
528 .unwrap_or(0),
529 )
530 };
531 if weight == 0 {
532 return idx;
533 }
534
535 let _ = self.load_counter.fetch_add(weight << (idx * 8), Ordering::SeqCst);
537 idx
538 }
539
540 pub(crate) fn finish(&self, weight: usize, idx: usize) {
542 if weight == 0 {
543 return;
544 }
545
546 let _ = self.load_counter.fetch_sub(weight << (idx * 8), Ordering::SeqCst);
547 }
548 }
549}