clickhouse_arrow/client/
connection.rs

1use std::collections::VecDeque;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6#[cfg(feature = "inner_pool")]
7use arc_swap::ArcSwap;
8use parking_lot::Mutex;
9use strum::Display;
10use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
11use tokio::sync::{broadcast, mpsc};
12use tokio::task::{AbortHandle, JoinSet};
13use tokio_rustls::rustls;
14
15use super::internal::{InternalConn, PendingQuery};
16use super::{ArrowOptions, CompressionMethod, Event};
17use crate::client::chunk::{ChunkReader, ChunkWriter};
18use crate::flags::{conn_read_buffer_size, conn_write_buffer_size};
19use crate::io::{ClickHouseRead, ClickHouseWrite};
20use crate::native::protocol::{
21    ClientHello, DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM, DBMS_TCP_PROTOCOL_VERSION, ServerHello,
22};
23use crate::prelude::*;
24use crate::{ClientOptions, Message, Operation};
25
26// Type alias for the JoinSet used to spawn inner connections
27type IoHandle<T> = JoinSet<VecDeque<PendingQuery<T>>>;
28
29/// The status of the underlying connection to `ClickHouse`
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
31pub enum ConnectionStatus {
32    Open,
33    Closed,
34    Error,
35}
36
37impl From<u8> for ConnectionStatus {
38    fn from(value: u8) -> Self {
39        match value {
40            0 => Self::Open,
41            1 => Self::Closed,
42            _ => Self::Error,
43        }
44    }
45}
46
47impl From<ConnectionStatus> for u8 {
48    fn from(value: ConnectionStatus) -> u8 { value as u8 }
49}
50
51/// Client metadata passed around the internal client
52#[derive(Debug, Clone, Copy)]
53pub(crate) struct ClientMetadata {
54    pub(crate) client_id:     u16,
55    pub(crate) compression:   CompressionMethod,
56    pub(crate) arrow_options: ArrowOptions,
57}
58
59impl ClientMetadata {
60    /// Helper function to disable compression on the metadata.
61    pub(crate) fn disable_compression(self) -> Self {
62        Self {
63            client_id:     self.client_id,
64            compression:   CompressionMethod::None,
65            arrow_options: self.arrow_options,
66        }
67    }
68
69    /// Helper function to provide settings for compression
70    pub(crate) fn compression_settings(self) -> Settings {
71        match self.compression {
72            CompressionMethod::None | CompressionMethod::LZ4 => Settings::default(),
73            CompressionMethod::ZSTD => vec![
74                ("network_compression_method", "zstd"),
75                ("network_zstd_compression_level", "1"),
76            ]
77            .into(),
78        }
79    }
80}
81
82/// A struct defining the information needed to connect over TCP.
83#[derive(Debug)]
84struct ConnectState<T: Send + Sync + 'static> {
85    status:  Arc<AtomicU8>,
86    channel: mpsc::Sender<Message<T>>,
87    #[expect(unused)]
88    handle:  AbortHandle,
89}
90
91// NOTE: ArcSwaps are used to support reconnects in the future.
92#[derive(Debug)]
93pub(super) struct Connection<T: ClientFormat> {
94    #[expect(unused)]
95    addrs:         Arc<[SocketAddr]>,
96    options:       Arc<ClientOptions>,
97    io_task:       Arc<Mutex<IoHandle<T::Data>>>,
98    metadata:      ClientMetadata,
99    #[cfg(not(feature = "inner_pool"))]
100    state:         Arc<ConnectState<T::Data>>,
101    /// NOTE: Max connections must remain at 4, unless algorithm changes
102    #[cfg(feature = "inner_pool")]
103    state:         Vec<ArcSwap<ConnectState<T::Data>>>,
104    #[cfg(feature = "inner_pool")]
105    load_balancer: Arc<load::AtomicLoad>,
106}
107
108impl<T: ClientFormat> Connection<T> {
109    #[instrument(
110        level = "trace",
111        name = "clickhouse.connection.create",
112        skip_all,
113        fields(
114            clickhouse.client.id = client_id,
115            db.system = "clickhouse",
116            db.operation = "connect",
117            network.transport = ?if options.use_tls { "tls" } else { "tcp" }
118        ),
119        err
120    )]
121    pub(crate) async fn connect(
122        client_id: u16,
123        addrs: Vec<SocketAddr>,
124        options: ClientOptions,
125        events: Arc<broadcast::Sender<Event>>,
126        trace_ctx: TraceContext,
127    ) -> Result<Self> {
128        let span = Span::current();
129        span.in_scope(|| trace!({ {ATT_CID} = client_id }, "connecting stream"));
130        let _ = trace_ctx.link(&span);
131
132        // Create joinset
133        let mut io_task = JoinSet::new();
134
135        // Construct connection metadata
136        let metadata = ClientMetadata {
137            client_id,
138            compression: options.compression,
139            arrow_options: options.ext.arrow.unwrap_or_default(),
140        };
141
142        // Install rustls provider if using tls
143        if options.use_tls {
144            drop(rustls::crypto::aws_lc_rs::default_provider().install_default());
145        }
146
147        // Establish tcp connection, perform handshake, and spawn io task
148        let state = Arc::new(
149            Self::connect_inner(&addrs, &mut io_task, Arc::clone(&events), &options, metadata)
150                .await?,
151        );
152
153        #[cfg(feature = "inner_pool")]
154        let mut state = vec![ArcSwap::from(state)];
155
156        // Inner pool: Spawn additional connections for improved concurrency.
157        // Default is 4, max is 16. User can configure via fast_mode_size option.
158        #[cfg(feature = "inner_pool")]
159        let inner_pool_size = options
160            .ext
161            .fast_mode_size
162            .map_or(load::DEFAULT_MAX_CONNECTIONS, |s| s.clamp(2, load::ABSOLUTE_MAX_CONNECTIONS));
163
164        #[cfg(feature = "inner_pool")]
165        for _ in 0..inner_pool_size.saturating_sub(1) {
166            let events = Arc::clone(&events);
167            state.push(ArcSwap::from(Arc::new(
168                Self::connect_inner(&addrs, &mut io_task, events, &options, metadata).await?,
169            )));
170        }
171
172        Ok(Self {
173            addrs: Arc::from(addrs.as_slice()),
174            io_task: Arc::new(Mutex::new(io_task)),
175            options: Arc::new(options),
176            metadata,
177            state,
178            #[cfg(feature = "inner_pool")]
179            load_balancer: Arc::new(load::AtomicLoad::new(inner_pool_size)),
180        })
181    }
182
183    async fn connect_inner(
184        addrs: &[SocketAddr],
185        io_task: &mut IoHandle<T::Data>,
186        events: Arc<broadcast::Sender<Event>>,
187        options: &ClientOptions,
188        metadata: ClientMetadata,
189    ) -> Result<ConnectState<T::Data>> {
190        if options.use_tls {
191            let tls_stream = super::tcp::connect_tls(addrs, options.domain.as_deref()).await?;
192            Self::establish_connection(tls_stream, io_task, events, options, metadata).await
193        } else {
194            let tcp_stream = super::tcp::connect_socket(addrs).await?;
195            Self::establish_connection(tcp_stream, io_task, events, options, metadata).await
196        }
197    }
198
199    async fn establish_connection<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
200        mut stream: RW,
201        io_task: &mut IoHandle<T::Data>,
202        events: Arc<broadcast::Sender<Event>>,
203        options: &ClientOptions,
204        metadata: ClientMetadata,
205    ) -> Result<ConnectState<T::Data>> {
206        let cid = metadata.client_id;
207
208        // Initialize the status to allow the io loop to signal broken/closed connections
209        let status = Arc::new(AtomicU8::new(ConnectionStatus::Open.into()));
210        let internal_status = Arc::clone(&status);
211
212        // Perform connection handshake
213        let server_hello = Arc::new(Self::perform_handshake(&mut stream, cid, options).await?);
214
215        // Create operation channel
216        let (operations, op_rx) = mpsc::channel(InternalConn::<T>::CAPACITY);
217
218        // Split stream
219        let (reader, writer) = tokio::io::split(stream);
220
221        // Spawn read loop
222        let handle = io_task.spawn(
223            async move {
224                let chunk_send = server_hello.supports_chunked_send();
225                let chunk_recv = server_hello.supports_chunked_recv();
226
227                // Create and run internal client
228                let mut internal = InternalConn::<T>::new(metadata, events, server_hello);
229
230                let reader = BufReader::with_capacity(conn_read_buffer_size(), reader);
231                let writer = BufWriter::with_capacity(conn_write_buffer_size(), writer);
232
233                let result = match (chunk_send, chunk_recv) {
234                    (true, true) => {
235                        // let reader = ChunkReader::new(reader);
236                        let reader = ChunkReader::new(reader);
237                        let writer = ChunkWriter::new(writer);
238                        internal.run_chunked(reader, writer, op_rx).await
239                    }
240                    (true, false) => {
241                        let writer = ChunkWriter::new(writer);
242                        internal.run_chunked(reader, writer, op_rx).await
243                    }
244                    (false, true) => {
245                        // let reader = ChunkReader::new(reader);
246                        let reader = ChunkReader::new(reader);
247                        internal.run(reader, writer, op_rx).await
248                    }
249                    (false, false) => internal.run(reader, writer, op_rx).await,
250                };
251
252                if let Err(error) = result {
253                    error!(?error, "Internal connection lost");
254                    internal_status.store(ConnectionStatus::Error.into(), Ordering::Release);
255                } else {
256                    info!("Internal connection closed");
257                    internal_status.store(ConnectionStatus::Closed.into(), Ordering::Release);
258                }
259                trace!("Exiting inner connection");
260                // TODO: Drain inner of pending queries
261                VecDeque::new()
262            }
263            .instrument(trace_span!(
264                "clickhouse.connection.io",
265                { ATT_CID } = cid,
266                otel.kind = "server",
267                peer.service = "clickhouse",
268            )),
269        );
270
271        trace!({ ATT_CID } = cid, "spawned connection loop");
272        Ok(ConnectState { status, channel: operations, handle })
273    }
274
275    #[instrument(
276        level = "trace",
277        skip_all,
278        fields(
279            db.system = "clickhouse",
280            db.operation = op.as_ref(),
281            clickhouse.client.id = self.metadata.client_id,
282            clickhouse.query.id = %qid,
283        )
284    )]
285    pub(crate) async fn send_operation(
286        &self,
287        op: Operation<T::Data>,
288        qid: Qid,
289        finished: bool,
290    ) -> Result<usize> {
291        #[cfg(not(feature = "inner_pool"))]
292        let conn_idx = 0; // Dummy for non-fast mode
293        #[cfg(feature = "inner_pool")]
294        let conn_idx = {
295            let key = (matches!(op, Operation::Query { .. } if !finished)
296                || matches!(op, Operation::Insert { .. } | Operation::InsertMany { .. }))
297            .then(|| qid.key());
298            self.load_balancer.assign(key, op.weight(finished) as usize)
299        };
300
301        let span = trace_span!(
302            "clickhouse.connection.send_operation",
303            { ATT_CID } = self.metadata.client_id,
304            { ATT_QID } = %qid,
305            db.system = "clickhouse",
306            db.operation = op.as_ref(),
307            finished
308        );
309
310        // Get the current state
311        #[cfg(not(feature = "inner_pool"))]
312        let state = &self.state;
313        #[cfg(feature = "inner_pool")]
314        let state = self.state[conn_idx].load();
315
316        // Get the current status
317        #[cfg(not(feature = "inner_pool"))]
318        let status = self.state.status.load(Ordering::Acquire);
319        #[cfg(feature = "inner_pool")]
320        let status = state.status.load(Ordering::Acquire);
321
322        // First check if the underlying connection is ok (until re-connects are impelemented)
323        if status > 0 {
324            return Err(Error::Client("No active connection".into()));
325        }
326
327        let result = state.channel.send(Message::Operation { qid, op }).instrument(span).await;
328        if result.is_err() {
329            error!({ ATT_QID } = %qid, "failed to send message");
330            self.update_status(conn_idx, ConnectionStatus::Closed);
331            return Err(Error::ChannelClosed);
332        }
333
334        Ok(conn_idx)
335    }
336
337    #[instrument(
338        level = "trace",
339        skip_all,
340        fields(db.system = "clickhouse", clickhouse.client.id = self.metadata.client_id)
341    )]
342    pub(crate) async fn shutdown(&self) -> Result<()> {
343        trace!({ ATT_CID } = self.metadata.client_id, "Shutting down connections");
344        #[cfg(not(feature = "inner_pool"))]
345        {
346            if self.state.channel.send(Message::Shutdown).await.is_err() {
347                error!("Failed to shutdown connection");
348            }
349        }
350        #[cfg(feature = "inner_pool")]
351        {
352            for (i, conn_state) in self.state.iter().enumerate() {
353                let state = conn_state.load();
354                debug!("Shutting down connection {i}");
355                // Send the message again to shutdown the next internal connection
356                if state.channel.send(Message::Shutdown).await.is_err() {
357                    error!("Failed to shutdown connection {i}");
358                }
359            }
360        }
361        self.io_task.lock().abort_all();
362        Ok(())
363    }
364
365    pub(crate) async fn check_connection(&self, ping: bool) -> Result<()> {
366        // First check that internal channels are ok
367        self.check_channel()?;
368
369        if !ping {
370            return Ok(());
371        }
372
373        // Then ping
374        let (response, rx) = tokio::sync::oneshot::channel();
375        let cid = self.metadata.client_id;
376        let qid = Qid::default();
377        let idx = self
378            .send_operation(Operation::Ping { response }, qid, true)
379            .instrument(trace_span!(
380                "clickhouse.connection.ping",
381                { ATT_CID } = cid,
382                { ATT_QID } = %qid,
383                db.system = "clickhouse",
384            ))
385            .await?;
386
387        rx.await
388            .map_err(|_| {
389                self.update_status(idx, ConnectionStatus::Closed);
390                Error::ChannelClosed
391            })?
392            .inspect_err(|error| {
393                self.update_status(idx, ConnectionStatus::Error);
394                error!(?error, { ATT_CID } = cid, "Ping failed");
395            })?;
396
397        Ok(())
398    }
399
400    fn update_status(&self, idx: usize, status: ConnectionStatus) {
401        trace!({ ATT_CID } = self.metadata.client_id, ?status, "Updating status conn {idx}");
402
403        #[cfg(not(feature = "inner_pool"))]
404        let state = &self.state;
405        #[cfg(feature = "inner_pool")]
406        let state = self.state[idx].load();
407
408        state.status.store(status.into(), Ordering::Release);
409    }
410
411    async fn perform_handshake<RW: ClickHouseRead + ClickHouseWrite + Send + 'static>(
412        stream: &mut RW,
413        client_id: u16,
414        options: &ClientOptions,
415    ) -> Result<ServerHello> {
416        use crate::client::reader::Reader;
417        use crate::client::writer::Writer;
418
419        let client_hello = ClientHello {
420            default_database: options.default_database.clone(),
421            username:         options.username.clone(),
422            password:         options.password.get().to_string(),
423        };
424
425        // Send client hello
426        Writer::send_hello(stream, client_hello)
427            .await
428            .inspect_err(|error| error!(?error, { ATT_CID } = client_id, "Failed to send hello"))?;
429
430        // Receive server hello
431        let chunked_modes = (options.ext.chunked_send, options.ext.chunked_recv);
432        let server_hello =
433            Reader::receive_hello(stream, DBMS_TCP_PROTOCOL_VERSION, chunked_modes, client_id)
434                .await?;
435        trace!({ ATT_CID } = client_id, ?server_hello, "Finished handshake");
436
437        if server_hello.revision_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM {
438            Writer::send_addendum(stream, server_hello.revision_version, &server_hello).await?;
439            stream.flush().await.inspect_err(|error| error!(?error, "Error writing addendum"))?;
440        }
441
442        Ok(server_hello)
443    }
444}
445
446impl<T: ClientFormat> Connection<T> {
447    pub(crate) fn metadata(&self) -> ClientMetadata { self.metadata }
448
449    pub(crate) fn database(&self) -> &str { &self.options.default_database }
450
451    #[cfg(feature = "inner_pool")]
452    pub(crate) fn finish(&self, conn_idx: usize, weight: u8) {
453        self.load_balancer.finish(usize::from(weight), conn_idx);
454    }
455
456    pub(crate) fn status(&self) -> ConnectionStatus {
457        #[cfg(not(feature = "inner_pool"))]
458        let status = ConnectionStatus::from(self.state.status.load(Ordering::Acquire));
459
460        // TODO: Status is strange if we have an internal pool. Figure this out.
461        // Just use the first channel for now
462        #[cfg(feature = "inner_pool")]
463        let status = ConnectionStatus::from(self.state[0].load().status.load(Ordering::Acquire));
464
465        status
466    }
467
468    fn check_channel(&self) -> Result<()> {
469        #[cfg(not(feature = "inner_pool"))]
470        {
471            if self.state.channel.is_closed() {
472                self.update_status(0, ConnectionStatus::Closed);
473                Err(Error::ChannelClosed)
474            } else {
475                Ok(())
476            }
477        }
478
479        // TODO: Checking channel is strange if we have an internal pool. Figure this out.
480        // Just return status of first connection for now
481        #[cfg(feature = "inner_pool")]
482        if self.state[0].load().channel.is_closed() {
483            self.update_status(0, ConnectionStatus::Closed);
484            Err(Error::ChannelClosed)
485        } else {
486            Ok(())
487        }
488    }
489}
490
491impl<T: ClientFormat> Drop for Connection<T> {
492    fn drop(&mut self) {
493        trace!({ ATT_CID } = self.metadata.client_id, "Connection dropped");
494        self.io_task.lock().abort_all();
495    }
496}
497
498#[cfg(feature = "inner_pool")]
499mod load {
500    use std::sync::atomic::{AtomicUsize, Ordering};
501
502    pub(super) const DEFAULT_MAX_CONNECTIONS: u8 = 4;
503    pub(super) const ABSOLUTE_MAX_CONNECTIONS: u8 = 16;
504
505    /// Array-based load balancer for distributing operations across multiple connections.
506    ///
507    /// Each connection has a dedicated 64-bit atomic counter tracking its current load.
508    /// This prevents the overflow issues inherent in bit-packed approaches and allows
509    /// scaling up to 16 concurrent connections.
510    #[derive(Debug)]
511    pub(super) struct AtomicLoad {
512        load_counters:   Box<[AtomicUsize]>,
513        max_connections: u8,
514    }
515
516    impl AtomicLoad {
517        /// Create a new load balancer with the specified maximum connections.
518        ///
519        /// # Panics
520        /// - If `max_connections` is 0
521        /// - If `max_connections` exceeds 16
522        pub(super) fn new(max_connections: u8) -> Self {
523            assert!(max_connections > 0, "At least 1 connection required");
524            assert!(
525                max_connections <= ABSOLUTE_MAX_CONNECTIONS,
526                "Max {ABSOLUTE_MAX_CONNECTIONS} connections supported"
527            );
528
529            let load_counters = (0..max_connections)
530                .map(|_| AtomicUsize::new(0))
531                .collect::<Vec<_>>()
532                .into_boxed_slice();
533
534            Self { load_counters, max_connections }
535        }
536
537        /// Assign a connection index, incrementing its load by the specified weight.
538        ///
539        /// If `key` is Some, uses deterministic assignment (key % `max_connections`).
540        /// If `key` is None, selects the least-loaded connection.
541        ///
542        /// Returns the selected connection index.
543        pub(super) fn assign(&self, key: Option<usize>, weight: usize) -> usize {
544            let idx = if let Some(k) = key {
545                k % usize::from(self.max_connections)
546            } else {
547                // Select least-loaded connection
548                (0..self.max_connections)
549                    .min_by_key(|&i| self.load_counters[usize::from(i)].load(Ordering::Acquire))
550                    .unwrap_or(0)
551                    .into()
552            };
553
554            if weight > 0 {
555                let _ = self.load_counters[idx].fetch_add(weight, Ordering::SeqCst);
556            }
557            idx
558        }
559
560        /// Decrement load by weight for the connection at the specified index.
561        pub(crate) fn finish(&self, weight: usize, idx: usize) {
562            if weight == 0 || idx >= self.load_counters.len() {
563                return;
564            }
565            let _ = self.load_counters[idx].fetch_sub(weight, Ordering::SeqCst);
566        }
567    }
568
569    #[cfg(test)]
570    mod tests {
571        use super::*;
572
573        #[test]
574        fn test_atomic_load_supports_16_connections() {
575            let load = AtomicLoad::new(16);
576
577            // Assign 1000 tasks across 16 connections
578            let assignments: Vec<_> = (0..1000).map(|_| load.assign(None, 1)).collect();
579
580            // Verify reasonable distribution (should be ~62-63 per connection)
581            for i in 0..16 {
582                let count = assignments.iter().filter(|&&idx| idx == i).count();
583                assert!(
584                    (50..=75).contains(&count),
585                    "Connection {i} got {count} assignments (expected ~62)"
586                );
587            }
588        }
589
590        #[test]
591        fn test_no_overflow_with_heavy_inserts() {
592            let load = AtomicLoad::new(4);
593
594            // Simulate 1000 concurrent InsertMany operations (weight=7)
595            for _ in 0..1000 {
596                let idx = load.assign(None, 7);
597                // Immediately finish to prevent unbounded growth
598                load.finish(7, idx);
599            }
600
601            // All counters should be back to 0
602            for i in 0..4 {
603                assert_eq!(load.load_counters[i].load(Ordering::Acquire), 0);
604            }
605        }
606
607        #[test]
608        fn test_deterministic_assignment_by_key() {
609            let load = AtomicLoad::new(8);
610
611            // Same key should always go to same connection
612            let key = 12345;
613            let idx1 = load.assign(Some(key), 1);
614            let idx2 = load.assign(Some(key), 1);
615            let idx3 = load.assign(Some(key), 1);
616
617            assert_eq!(idx1, idx2);
618            assert_eq!(idx2, idx3);
619            assert_eq!(idx1, key % 8);
620        }
621
622        #[test]
623        fn test_least_loaded_selection() {
624            let load = AtomicLoad::new(4);
625
626            // Manually set load: [100, 50, 200, 75]
627            load.load_counters[0].store(100, Ordering::Release);
628            load.load_counters[1].store(50, Ordering::Release);
629            load.load_counters[2].store(200, Ordering::Release);
630            load.load_counters[3].store(75, Ordering::Release);
631
632            // Next assignment should go to connection 1 (load=50)
633            let idx = load.assign(None, 1);
634            assert_eq!(idx, 1);
635        }
636
637        #[test]
638        #[should_panic(expected = "Max 16 connections")]
639        fn test_rejects_too_many_connections() { drop(AtomicLoad::new(17)); }
640
641        #[test]
642        #[should_panic(expected = "At least 1 connection")]
643        fn test_rejects_zero_connections() { drop(AtomicLoad::new(0)); }
644
645        #[test]
646        fn test_zero_weight_returns_index_without_increment() {
647            let load = AtomicLoad::new(4);
648
649            let idx = load.assign(None, 0);
650            assert!(idx < 4);
651
652            // All counters should still be 0
653            for i in 0..4 {
654                assert_eq!(load.load_counters[i].load(Ordering::Acquire), 0);
655            }
656        }
657
658        #[test]
659        fn test_finish_with_invalid_index() {
660            let load = AtomicLoad::new(4);
661
662            // Assign some load
663            let idx = load.assign(None, 10);
664            load.load_counters[idx].store(10, Ordering::Release);
665
666            // Finish with out-of-bounds index should not panic
667            load.finish(5, 999);
668
669            // Original load should be unchanged
670            assert_eq!(load.load_counters[idx].load(Ordering::Acquire), 10);
671        }
672
673        #[test]
674        fn test_finish_with_zero_weight() {
675            let load = AtomicLoad::new(4);
676
677            // Assign some load
678            let idx = load.assign(None, 10);
679            load.load_counters[idx].store(10, Ordering::Release);
680
681            // Finish with zero weight should not modify counters
682            load.finish(0, idx);
683
684            // Load should be unchanged
685            assert_eq!(load.load_counters[idx].load(Ordering::Acquire), 10);
686
687            // Also test zero weight with invalid index (covers both branches)
688            load.finish(0, 999);
689        }
690    }
691}