clickhouse_arrow/client/
connection.rs

1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6#[cfg(feature = "inner_pool")]
7use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use strum::Display;
10use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
11use tokio::sync::{broadcast, mpsc};
12use tokio::task::{AbortHandle, JoinSet};
13use tokio_rustls::rustls;
14
15use super::internal::{InternalConn, PendingQuery};
16use super::{ArrowOptions, CompressionMethod, Event};
17use crate::client::chunk::{ChunkReader, ChunkWriter};
18use crate::flags::{conn_read_buffer_size, conn_write_buffer_size};
19use crate::io::{ClickHouseRead, ClickHouseWrite};
20use crate::native::protocol::{
21    ClientHello, DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM, DBMS_TCP_PROTOCOL_VERSION, ServerHello,
22};
23use crate::prelude::*;
24use crate::{ClientOptions, Message, Operation};
25
26// Type alias for the JoinSet used to spawn inner connections
27type IoHandle<T> = JoinSet<VecDeque<PendingQuery<T>>>;
28
29/// The status of the underlying connection to `ClickHouse`
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
31pub enum ConnectionStatus {
32    Open,
33    Closed,
34    Error,
35}
36
37impl From<u8> for ConnectionStatus {
38    fn from(value: u8) -> Self {
39        match value {
40            0 => Self::Open,
41            1 => Self::Closed,
42            _ => Self::Error,
43        }
44    }
45}
46
47impl From<ConnectionStatus> for u8 {
48    fn from(value: ConnectionStatus) -> u8 { value as u8 }
49}
50
51/// Client metadata passed around the internal client
52#[derive(Debug, Clone, Copy)]
53pub(crate) struct ClientMetadata {
54    pub(crate) client_id:     u16,
55    pub(crate) compression:   CompressionMethod,
56    pub(crate) arrow_options: ArrowOptions,
57}
58
59impl ClientMetadata {
60    /// Helper function to disable compression on the metadata.
61    pub(crate) fn disable_compression(self) -> Self {
62        Self {
63            client_id:     self.client_id,
64            compression:   CompressionMethod::None,
65            arrow_options: self.arrow_options,
66        }
67    }
68
69    /// Helper function to provide settings for compression
70    pub(crate) fn compression_settings(self) -> Settings {
71        match self.compression {
72            CompressionMethod::None | CompressionMethod::LZ4 => Settings::default(),
73            CompressionMethod::ZSTD => vec![
74                ("network_compression_method", "zstd"),
75                ("network_zstd_compression_level", "1"),
76            ]
77            .into(),
78        }
79    }
80}
81
82/// A struct defining the information needed to connect over TCP.
83#[derive(Debug)]
84struct ConnectState<T: Send + Sync + 'static> {
85    status:  Arc<AtomicU8>,
86    channel: mpsc::Sender<Message<T>>,
87    #[expect(unused)]
88    handle:  AbortHandle,
89}
90
91// NOTE: ArcSwaps are used to support reconnects in the future.
92#[derive(Debug)]
93pub(super) struct Connection<T: ClientFormat> {
94    #[expect(unused)]
95    addrs:         Arc<[SocketAddr]>,
96    options:       Arc<ClientOptions>,
97    io_task:       Arc<Mutex<IoHandle<T::Data>>>,
98    metadata:      ClientMetadata,
99    #[cfg(not(feature = "inner_pool"))]
100    state:         Arc<ConnectState<T::Data>>,
101    /// NOTE: Max connections must remain at 4, unless algorithm changes
102    #[cfg(feature = "inner_pool")]
103    state:         Vec<ArcSwap<ConnectState<T::Data>>>,
104    #[cfg(feature = "inner_pool")]
105    load_balancer: Arc<load::AtomicLoad>,
106}
107
108impl<T: ClientFormat> Connection<T> {
109    #[instrument(
110        level = "trace",
111        name = "clickhouse.connection.create",
112        skip_all,
113        fields(
114            clickhouse.client.id = client_id,
115            db.system = "clickhouse",
116            db.operation = "connect",
117            network.transport = ?if options.use_tls { "tls" } else { "tcp" }
118        ),
119        err
120    )]
121    pub(crate) async fn connect(
122        client_id: u16,
123        addrs: Vec<SocketAddr>,
124        options: ClientOptions,
125        events: Arc<broadcast::Sender<Event>>,
126        trace_ctx: TraceContext,
127    ) -> Result<Self> {
128        let span = Span::current();
129        span.in_scope(|| trace!({ {ATT_CID} = client_id }, "connecting stream"));
130        let _ = trace_ctx.link(&span);
131
132        // Create joinset
133        let mut io_task = JoinSet::new();
134
135        // Construct connection metadata
136        let metadata = ClientMetadata {
137            client_id,
138            compression: options.compression,
139            arrow_options: options.ext.arrow.unwrap_or_default(),
140        };
141
142        // Install rustls provider if using tls
143        if options.use_tls {
144            drop(rustls::crypto::aws_lc_rs::default_provider().install_default());
145        }
146
147        // Establish tcp connection, perform handshake, and spawn io task
148        let state = Arc::new(
149            Self::connect_inner(&addrs, &mut io_task, Arc::clone(&events), &options, metadata)
150                .await?,
151        );
152
153        #[cfg(feature = "inner_pool")]
154        let mut state = vec![ArcSwap::from(state)];
155
156        // Currently "inner_pool" = 2 connections. But this can support up to 4 (possibly more with
157        // u64 load_counter)
158        #[cfg(feature = "inner_pool")]
159        for _ in 0..options.ext.fast_mode_size.map_or(2, |s| s.clamp(2, 4)) {
160            let events = Arc::clone(&events);
161            state.push(ArcSwap::from(Arc::new(
162                Self::connect_inner(&addrs, &mut io_task, events, &options, metadata).await?,
163            )));
164        }
165
166        Ok(Self {
167            addrs: Arc::from(addrs.as_slice()),
168            io_task: Arc::new(Mutex::new(io_task)),
169            options: Arc::new(options),
170            metadata,
171            state,
172            // Currently only using 2 connections
173            // TODO: Provide inner pool configuration option
174            #[cfg(feature = "inner_pool")]
175            load_balancer: Arc::new(load::AtomicLoad::new(2)),
176        })
177    }
178
179    async fn connect_inner(
180        addrs: &[SocketAddr],
181        io_task: &mut IoHandle<T::Data>,
182        events: Arc<broadcast::Sender<Event>>,
183        options: &ClientOptions,
184        metadata: ClientMetadata,
185    ) -> Result<ConnectState<T::Data>> {
186        if options.use_tls {
187            let tls_stream = super::tcp::connect_tls(addrs, options.domain.as_deref()).await?;
188            Self::establish_connection(tls_stream, io_task, events, options, metadata).await
189        } else {
190            let tcp_stream = super::tcp::connect_socket(addrs).await?;
191            Self::establish_connection(tcp_stream, io_task, events, options, metadata).await
192        }
193    }
194
195    async fn establish_connection<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
196        mut stream: RW,
197        io_task: &mut IoHandle<T::Data>,
198        events: Arc<broadcast::Sender<Event>>,
199        options: &ClientOptions,
200        metadata: ClientMetadata,
201    ) -> Result<ConnectState<T::Data>> {
202        let cid = metadata.client_id;
203
204        // Initialize the status to allow the io loop to signal broken/closed connections
205        let status = Arc::new(AtomicU8::new(ConnectionStatus::Open.into()));
206        let internal_status = Arc::clone(&status);
207
208        // Perform connection handshake
209        let server_hello = Arc::new(Self::perform_handshake(&mut stream, cid, options).await?);
210
211        // Create operation channel
212        let (operations, op_rx) = mpsc::channel(InternalConn::<T>::CAPACITY);
213
214        // Split stream
215        let (reader, writer) = tokio::io::split(stream);
216
217        // Spawn read loop
218        let handle = io_task.spawn(
219            async move {
220                let chunk_send = server_hello.supports_chunked_send();
221                let chunk_recv = server_hello.supports_chunked_recv();
222
223                // Create and run internal client
224                let mut internal = InternalConn::<T>::new(metadata, events, server_hello);
225
226                let reader = BufReader::with_capacity(conn_read_buffer_size(), reader);
227                let writer = BufWriter::with_capacity(conn_write_buffer_size(), writer);
228
229                let result = match (chunk_send, chunk_recv) {
230                    (true, true) => {
231                        // let reader = ChunkReader::new(reader);
232                        let reader = ChunkReader::new(reader);
233                        let writer = ChunkWriter::new(writer);
234                        internal.run_chunked(reader, writer, op_rx).await
235                    }
236                    (true, false) => {
237                        let writer = ChunkWriter::new(writer);
238                        internal.run_chunked(reader, writer, op_rx).await
239                    }
240                    (false, true) => {
241                        // let reader = ChunkReader::new(reader);
242                        let reader = ChunkReader::new(reader);
243                        internal.run(reader, writer, op_rx).await
244                    }
245                    (false, false) => internal.run(reader, writer, op_rx).await,
246                };
247
248                if let Err(error) = result {
249                    error!(?error, "Internal connection lost");
250                    internal_status.store(ConnectionStatus::Error.into(), Ordering::Release);
251                } else {
252                    info!("Internal connection closed");
253                    internal_status.store(ConnectionStatus::Closed.into(), Ordering::Release);
254                }
255                trace!("Exiting inner connection");
256                // TODO: Drain inner of pending queries
257                VecDeque::new()
258            }
259            .instrument(trace_span!(
260                "clickhouse.connection.io",
261                { ATT_CID } = cid,
262                otel.kind = "server",
263                peer.service = "clickhouse",
264            )),
265        );
266
267        trace!({ ATT_CID } = cid, "spawned connection loop");
268        Ok(ConnectState { status, channel: operations, handle })
269    }
270
271    #[instrument(
272        level = "trace",
273        skip_all,
274        fields(
275            db.system = "clickhouse",
276            db.operation = op.as_ref(),
277            clickhouse.client.id = self.metadata.client_id,
278            clickhouse.query.id = %qid,
279        )
280    )]
281    pub(crate) async fn send_operation(
282        &self,
283        op: Operation<T::Data>,
284        qid: Qid,
285        finished: bool,
286    ) -> Result<usize> {
287        #[cfg(not(feature = "inner_pool"))]
288        let conn_idx = 0; // Dummy for non-fast mode
289        #[cfg(feature = "inner_pool")]
290        let conn_idx = {
291            let key = (matches!(op, Operation::Query { .. } if !finished)
292                || matches!(op, Operation::Insert { .. } | Operation::InsertMany { .. }))
293            .then(|| qid.key());
294            self.load_balancer.assign(key, op.weight(finished) as usize)
295        };
296
297        let span = trace_span!(
298            "clickhouse.connection.send_operation",
299            { ATT_CID } = self.metadata.client_id,
300            { ATT_QID } = %qid,
301            db.system = "clickhouse",
302            db.operation = op.as_ref(),
303            finished
304        );
305
306        // Get the current state
307        #[cfg(not(feature = "inner_pool"))]
308        let state = &self.state;
309        #[cfg(feature = "inner_pool")]
310        let state = self.state[conn_idx].load();
311
312        // Get the current status
313        #[cfg(not(feature = "inner_pool"))]
314        let status = self.state.status.load(Ordering::Acquire);
315        #[cfg(feature = "inner_pool")]
316        let status = state.status.load(Ordering::Acquire);
317
318        // First check if the underlying connection is ok (until re-connects are impelemented)
319        if status > 0 {
320            return Err(Error::Client("No active connection".into()));
321        }
322
323        let result = state.channel.send(Message::Operation { qid, op }).instrument(span).await;
324        if result.is_err() {
325            error!({ ATT_QID } = %qid, "failed to send message");
326            self.update_status(conn_idx, ConnectionStatus::Closed);
327            return Err(Error::ChannelClosed);
328        }
329
330        Ok(conn_idx)
331    }
332
333    #[instrument(
334        level = "trace",
335        skip_all,
336        fields(db.system = "clickhouse", clickhouse.client.id = self.metadata.client_id)
337    )]
338    pub(crate) async fn shutdown(&self) -> Result<()> {
339        trace!({ ATT_CID } = self.metadata.client_id, "Shutting down connections");
340        #[cfg(not(feature = "inner_pool"))]
341        {
342            if self.state.channel.send(Message::Shutdown).await.is_err() {
343                error!("Failed to shutdown connection");
344            }
345        }
346        #[cfg(feature = "inner_pool")]
347        {
348            for (i, conn_state) in self.state.iter().enumerate() {
349                let state = conn_state.load();
350                debug!("Shutting down connection {i}");
351                // Send the message again to shutdown the next internal connection
352                if state.channel.send(Message::Shutdown).await.is_err() {
353                    error!("Failed to shutdown connection {i}");
354                }
355            }
356        }
357        self.io_task.lock().abort_all();
358        Ok(())
359    }
360
361    pub(crate) async fn check_connection(&self, ping: bool) -> Result<()> {
362        // First check that internal channels are ok
363        self.check_channel()?;
364
365        if !ping {
366            return Ok(());
367        }
368
369        // Then ping
370        let (response, rx) = tokio::sync::oneshot::channel();
371        let cid = self.metadata.client_id;
372        let qid = Qid::default();
373        let idx = self
374            .send_operation(Operation::Ping { response }, qid, true)
375            .instrument(trace_span!(
376                "clickhouse.connection.ping",
377                { ATT_CID } = cid,
378                { ATT_QID } = %qid,
379                db.system = "clickhouse",
380            ))
381            .await?;
382
383        rx.await
384            .map_err(|_| {
385                self.update_status(idx, ConnectionStatus::Closed);
386                Error::ChannelClosed
387            })?
388            .inspect_err(|error| {
389                self.update_status(idx, ConnectionStatus::Error);
390                error!(?error, { ATT_CID } = cid, "Ping failed");
391            })?;
392
393        Ok(())
394    }
395
396    fn update_status(&self, idx: usize, status: ConnectionStatus) {
397        trace!({ ATT_CID } = self.metadata.client_id, ?status, "Updating status conn {idx}");
398
399        #[cfg(not(feature = "inner_pool"))]
400        let state = &self.state;
401        #[cfg(feature = "inner_pool")]
402        let state = self.state[idx].load();
403
404        state.status.store(status.into(), Ordering::Release);
405    }
406
407    async fn perform_handshake<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
408        stream: &mut RW,
409        client_id: u16,
410        options: &ClientOptions,
411    ) -> Result<ServerHello> {
412        use crate::client::reader::Reader;
413        use crate::client::writer::Writer;
414
415        let client_hello = ClientHello {
416            default_database: options.default_database.clone(),
417            username:         options.username.clone(),
418            password:         options.password.get().to_string(),
419        };
420
421        // Send client hello
422        Writer::send_hello(stream, client_hello)
423            .await
424            .inspect_err(|error| error!(?error, { ATT_CID } = client_id, "Failed to send hello"))?;
425
426        // Receive server hello
427        let chunked_modes = (options.ext.chunked_send, options.ext.chunked_recv);
428        let server_hello =
429            Reader::receive_hello(stream, DBMS_TCP_PROTOCOL_VERSION, chunked_modes, client_id)
430                .await?;
431        trace!({ ATT_CID } = client_id, ?server_hello, "Finished handshake");
432
433        if server_hello.revision_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
434            Writer::send_addendum(stream, server_hello.revision_version, &server_hello).await?;
435            stream.flush().await.inspect_err(|error| error!(?error, "Error writing addendum"))?;
436        }
437
438        Ok(server_hello)
439    }
440}
441
442impl<T: ClientFormat> Connection<T> {
443    pub(crate) fn metadata(&self) -> ClientMetadata { self.metadata }
444
445    pub(crate) fn database(&self) -> &str { &self.options.default_database }
446
447    #[cfg(feature = "inner_pool")]
448    pub(crate) fn finish(&self, conn_idx: usize, weight: u8) {
449        self.load_balancer.finish(usize::from(weight), conn_idx);
450    }
451
452    pub(crate) fn status(&self) -> ConnectionStatus {
453        #[cfg(not(feature = "inner_pool"))]
454        let status = ConnectionStatus::from(self.state.status.load(Ordering::Acquire));
455
456        // TODO: Status is strange if we have an internal pool. Figure this out.
457        // Just use the first channel for now
458        #[cfg(feature = "inner_pool")]
459        let status = ConnectionStatus::from(self.state[0].load().status.load(Ordering::Acquire));
460
461        status
462    }
463
464    fn check_channel(&self) -> Result<()> {
465        #[cfg(not(feature = "inner_pool"))]
466        {
467            if self.state.channel.is_closed() {
468                self.update_status(0, ConnectionStatus::Closed);
469                Err(Error::ChannelClosed)
470            } else {
471                Ok(())
472            }
473        }
474
475        // TODO: Checking channel is strange if we have an internal pool. Figure this out.
476        // Just return status of first connection for now
477        #[cfg(feature = "inner_pool")]
478        if self.state[0].load().channel.is_closed() {
479            self.update_status(0, ConnectionStatus::Closed);
480            Err(Error::ChannelClosed)
481        } else {
482            Ok(())
483        }
484    }
485}
486
487impl<T: ClientFormat> Drop for Connection<T> {
488    fn drop(&mut self) {
489        trace!({ ATT_CID } = self.metadata.client_id, "Connection dropped");
490        self.io_task.lock().abort_all();
491    }
492}
493
494#[cfg(feature = "inner_pool")]
495mod load {
496    use std::sync::atomic::{AtomicUsize, Ordering};
497
498    #[derive(Debug)]
499    pub(super) struct AtomicLoad {
500        load_counter:    AtomicUsize,
501        max_connections: u8,
502    }
503
504    impl AtomicLoad {
505        /// Try and create the load balancer.
506        ///
507        /// # Panics
508        /// - Currently only 4 connections are supported. Errors if max > 4.
509        pub(super) fn new(max_connections: u8) -> Self {
510            assert!(max_connections <= 4, "Max 4 connections supported");
511            assert!(max_connections > 0, "At leat 1 connection required");
512            Self { load_counter: AtomicUsize::new(0), max_connections }
513        }
514
515        /// Assign a connection index, incrementing load by weight
516        /// If key is Some, use key % `max_connections` (deterministic)
517        /// If key is None, use least-loaded connection
518        /// Returns connection index
519        pub(super) fn assign(&self, key: Option<usize>, weight: usize) -> usize {
520            let idx = if let Some(k) = key {
521                k % usize::from(self.max_connections) // Deterministic assignment
522            } else {
523                // Select least-loaded connection
524                let load = self.load_counter.load(Ordering::Acquire);
525                usize::from(
526                    (0..self.max_connections)
527                        .min_by_key(|&i| (load >> (i * 8)) & 0xFF)
528                        .unwrap_or(0),
529                )
530            };
531            if weight == 0 {
532                return idx;
533            }
534
535            // Increment load
536            let _ = self.load_counter.fetch_add(weight << (idx * 8), Ordering::SeqCst);
537            idx
538        }
539
540        /// Finish an operation, decrementing load by weight for index
541        pub(crate) fn finish(&self, weight: usize, idx: usize) {
542            if weight == 0 {
543                return;
544            }
545
546            let _ = self.load_counter.fetch_sub(weight << (idx * 8), Ordering::SeqCst);
547        }
548    }
549}