clickhouse_datafusion/
connection.rs

1#![cfg_attr(feature = "mocks", expect(clippy::unused_async))]
2#![cfg_attr(feature = "mocks", expect(dead_code))]
3
4#[cfg(feature = "mocks")]
5mod mock;
6
7use clickhouse_arrow::ArrowConnectionPoolBuilder;
8#[cfg(not(feature = "mocks"))]
9use clickhouse_arrow::{
10    ArrowConnectionManager, ArrowFormat, ClickHouseResponse, ConnectionManager,
11    Error as ClickhouseNativeError, bb8,
12};
13#[cfg(not(feature = "mocks"))]
14use datafusion::arrow::array::RecordBatch;
15use datafusion::arrow::datatypes::SchemaRef;
16use datafusion::common::DataFusionError;
17use datafusion::common::error::GenericError;
18use datafusion::error::Result;
19use datafusion::physical_plan::SendableRecordBatchStream;
20#[cfg(not(feature = "mocks"))]
21use datafusion::sql::TableReference;
22use futures_util::TryStreamExt;
23use tracing::{debug, error};
24
25use crate::sql::JoinPushDown;
26use crate::stream::{RecordBatchStreamWrapper, record_batch_stream_from_stream};
27
28/// Type alias for a pooled connection to a `ClickHouse` database.
29#[cfg(not(feature = "mocks"))]
30pub type ArrowPoolConnection<'a> = bb8::PooledConnection<'a, ConnectionManager<ArrowFormat>>;
31#[cfg(not(feature = "mocks"))]
32pub type ArrowPool = bb8::Pool<ArrowConnectionManager>;
33/// Type alias for a pooled connection as mocks.
34#[cfg(feature = "mocks")]
35pub type ArrowPoolConnection<'a> = &'a ();
36#[cfg(feature = "mocks")]
37pub type ArrowPool = ();
38
39/// A wrapper around a [`clickhouse_arrow::ConnectionPool<ArrowFormat>`]
40#[derive(Debug, Clone)]
41pub struct ClickHouseConnectionPool {
42    // "mocks" feature affects mainly this property and other properties related to connecting.
43    pool:              ArrowPool,
44    join_push_down:    JoinPushDown,
45    /// Number of concurrent write operations allowed when inserting data.
46    /// Defaults to 4 (matching clickhouse-arrow's current connection pool limit).
47    write_concurrency: usize,
48}
49
50impl ClickHouseConnectionPool {
51    /// Create a new `ClickHouse` connection pool for use in `DataFusion`. The identifier is used in
52    /// the case of federation to determine if queries can be pushed down across two pools
53    pub fn new(identifier: impl Into<String>, pool: ArrowPool) -> Self {
54        debug!("Creating new ClickHouse connection pool");
55        let join_push_down = JoinPushDown::AllowedFor(identifier.into());
56        Self { pool, join_push_down, write_concurrency: 4 }
57    }
58
59    /// Set the write concurrency level for INSERT operations.
60    ///
61    /// This controls how many record batches can be written to `ClickHouse` concurrently.
62    /// Higher values may improve throughput but increase memory and connection usage.
63    ///
64    /// Default: 4 (matching clickhouse-arrow's current connection pool limit)
65    #[must_use]
66    pub fn with_write_concurrency(mut self, concurrency: usize) -> Self {
67        self.write_concurrency = concurrency;
68        self
69    }
70
71    /// Get the configured write concurrency level
72    pub fn write_concurrency(&self) -> usize { self.write_concurrency }
73
74    /// Create a new `ClickHouse` connection pool from a builder.
75    ///
76    /// # Errors
77    /// - Returns an error if the connection pool cannot be created.
78    pub async fn from_pool_builder(builder: ArrowConnectionPoolBuilder) -> Result<Self> {
79        let identifer = builder.connection_identifier();
80
81        // Since this pool will be used for ddl, it's essential it connects to the "default" db
82        #[cfg(not(feature = "mocks"))]
83        let pool = builder
84            .configure_client(|c| c.with_database("default"))
85            .build()
86            .await
87            .inspect_err(|error| error!(?error, "Error building ClickHouse connection pool"))
88            .map_err(crate::utils::map_clickhouse_err)?;
89
90        #[cfg(feature = "mocks")]
91        let pool = ();
92
93        Ok(Self::new(identifer, pool))
94    }
95
96    /// Access the underlying connection pool
97    pub fn pool(&self) -> &ArrowPool { &self.pool }
98
99    pub fn join_push_down(&self) -> JoinPushDown { self.join_push_down.clone() }
100}
101
102impl ClickHouseConnectionPool {
103    /// Get a managed [`ArrowPoolConnection`] wrapped in a [`ClickHouseConnection`]
104    ///
105    /// # Errors
106    /// - Returns an error if the connection cannot be established.
107    pub async fn connect(&self) -> Result<ClickHouseConnection<'_>> {
108        #[cfg(not(feature = "mocks"))]
109        let conn = self
110            .pool
111            .get()
112            .await
113            .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
114            .map_err(crate::utils::map_external_err)?;
115        #[cfg(feature = "mocks")]
116        let conn = &();
117        Ok(ClickHouseConnection::new(conn))
118    }
119
120    /// Get a managed static [`ArrowPoolConnection`] wrapped in a [`ClickHouseConnection`]
121    ///
122    /// # Errors
123    /// - Returns an error if the connection cannot be established.
124    pub async fn connect_static(&self) -> Result<ClickHouseConnection<'static>> {
125        #[cfg(not(feature = "mocks"))]
126        let conn = self
127            .pool
128            .get_owned()
129            .await
130            .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
131            .map_err(crate::utils::map_external_err)?;
132        #[cfg(feature = "mocks")]
133        let conn = &();
134        Ok(ClickHouseConnection::new_static(conn))
135    }
136}
137
138/// A wrapper around [`ArrowPoolConnection`] that provides additional functionality relevant for
139/// `DataFusion`.
140///
141/// The methods [`ClickHouseConnection::tables`], [`ClickHouseConnection::get_schema`], and
142/// [`ClickHouseConnection::query_arrow`] will all be run against the `ClickHouse` instance.
143#[derive(Debug)]
144pub struct ClickHouseConnection<'a> {
145    conn: ArrowPoolConnection<'a>,
146}
147
148impl<'a> ClickHouseConnection<'a> {
149    pub fn new(conn: ArrowPoolConnection<'a>) -> Self { ClickHouseConnection { conn } }
150
151    // TODO: Use to provide interop with datafusion-table-providers
152    pub fn new_static(conn: ArrowPoolConnection<'static>) -> Self { ClickHouseConnection { conn } }
153
154    /// Issues a query against `ClickHouse` and returns the result as an arrow
155    /// [`SendableRecordBatchStream`] using the provided schema.
156    ///
157    /// The argument `coerce_schema` will be passed to `RecordBatchStream` only if
158    /// `projected_schema` is also provided. Otherwise coercion won't be necessary as the streamed
159    /// `RecordBatch`es will determine the schema.
160    ///
161    /// # Errors
162    /// - Returns an error if the query fails.
163    pub async fn query_arrow_with_schema(
164        &self,
165        sql: &str,
166        params: &[()],
167        schema: SchemaRef,
168        coerce_schema: bool,
169    ) -> Result<RecordBatchStreamWrapper, DataFusionError> {
170        debug!(sql, "Running query");
171        let batches = Box::pin(
172            self.query_arrow_raw(sql, params)
173                .await?
174                // Map the stream's clickhouse-arrow error to DataFusionError
175                .map_err(|e| DataFusionError::External(Box::new(e))),
176        );
177        Ok(RecordBatchStreamWrapper::new_from_stream(batches, schema).with_coercion(coerce_schema))
178    }
179
180    /// Issues a query against `ClickHouse` and returns the result as an arrow
181    /// [`SendableRecordBatchStream`] using the provided schema.
182    ///
183    /// This method allows interop with `datafusion-table-providers` if desired. Otherwise, the
184    /// method `Self::query_arrow_raw` can be used to prevent additional wrapping, or
185    /// `Self::query_arrow_with_schema` if schema coercion is desired.
186    ///
187    /// # Errors
188    /// - Returns an error if the query fails.
189    pub async fn query_arrow(
190        &self,
191        sql: &str,
192        params: &[()],
193        projected_schema: Option<SchemaRef>,
194    ) -> Result<SendableRecordBatchStream, GenericError> {
195        if let Some(schema) = projected_schema {
196            return Ok(Box::pin(self.query_arrow_with_schema(sql, params, schema, false).await?));
197        }
198
199        let batches = Box::pin(
200            self.query_arrow_raw(sql, params)
201                .await?
202                // Map the stream's clickhouse-arrow error to DataFusionError
203                .map_err(|e| DataFusionError::External(Box::new(e))),
204        );
205
206        Ok(Box::pin(
207            record_batch_stream_from_stream(batches)
208                .await
209                .inspect_err(|error| error!(?error, "Failed converting batches to stream"))
210                .map_err(Box::new)?,
211        ))
212    }
213}
214
215#[cfg(not(feature = "mocks"))]
216impl ClickHouseConnection<'_> {
217    /// Fetch the names of the tables in a schema (database).
218    ///
219    /// # Errors
220    /// - Returns an error if the tables cannot be fetched.
221    pub async fn tables(&self, schema: &str) -> Result<Vec<String>> {
222        debug!(schema, "Fetching tables");
223        self.conn
224            .fetch_tables(Some(schema), None)
225            .await
226            .inspect_err(|error| error!(?error, "Fetching tables failed"))
227            .map_err(crate::utils::map_clickhouse_err)
228    }
229
230    /// Fetch the names of the schemas (databases).
231    ///
232    /// # Errors
233    /// - Returns an error if the schemas cannot be fetched.
234    pub async fn schemas(&self) -> Result<Vec<String>> {
235        debug!("Fetching databases");
236        self.conn
237            .fetch_schemas(None)
238            .await
239            .inspect_err(|error| error!(?error, "Fetching databases failed"))
240            .map_err(crate::utils::map_clickhouse_err)
241    }
242
243    /// Fetch the schema for a table
244    ///
245    /// # Errors
246    /// - Returns an error if the schema cannot be fetched.
247    pub async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef> {
248        debug!(%table_reference, "Fetching schema for table");
249        let db = table_reference.schema();
250        let table = table_reference.table();
251        let mut schemas =
252            self.conn.fetch_schema(db, &[table][..], None).await.map_err(|error| {
253                if let ClickhouseNativeError::UndefinedTables { .. } = error {
254                    error!(?error, ?db, ?table, "Tables undefined");
255                } else {
256                    error!(?error, ?db, ?table, "Unknown error occurred while fetching schema");
257                }
258                crate::utils::map_clickhouse_err(error)
259            })?;
260
261        schemas
262            .remove(table)
263            .ok_or(DataFusionError::External("Schema not found for table".into()))
264    }
265
266    /// Issues a query against `ClickHouse` and returns the raw `ClickHouseResponse<RecordBatch>`
267    ///
268    /// # Errors
269    /// - Returns an error if the query fails
270    pub async fn query_arrow_raw(
271        &self,
272        sql: &str,
273        _params: &[()],
274    ) -> Result<ClickHouseResponse<RecordBatch>> {
275        self.conn
276            .query(sql, None)
277            .await
278            .inspect(|_| tracing::trace!("Query executed successfully"))
279            .inspect_err(|error| error!(?error, "Failed running query"))
280            // Convert the clickhouse-arrow error to a DataFusionError
281            .map_err(|e| DataFusionError::External(Box::new(e)))
282    }
283
284    /// Executes a statement against `ClickHouse` and returns the number of affected rows.
285    ///
286    /// # Errors
287    /// - Returns an error if the query fails.
288    pub async fn execute(&self, sql: &str, _params: &[()]) -> Result<u64, GenericError> {
289        debug!(sql, "Executing query");
290        self.conn
291            .execute(sql, None)
292            .await
293            .inspect_err(|error| error!(?error, "Failed executing query"))
294            .map_err(Box::new)?;
295        Ok(0)
296    }
297}
298
299// TODO: Provide compat with datafusion-table-providers DbConnection, AsyncDbConnection
300
301#[cfg(test)]
302mod tests {
303    use datafusion::sql::TableReference;
304
305    use super::*;
306
307    #[test]
308    fn test_table_reference_schema_extraction() {
309        // Test the logic used in get_schema method
310        let table_ref = TableReference::full("catalog", "schema", "table");
311        assert_eq!(table_ref.schema(), Some("schema"));
312        assert_eq!(table_ref.table(), "table");
313
314        let partial_ref = TableReference::partial("schema", "table");
315        assert_eq!(partial_ref.schema(), Some("schema"));
316        assert_eq!(partial_ref.table(), "table");
317
318        let bare_ref = TableReference::bare("table");
319        assert_eq!(bare_ref.schema(), None);
320        assert_eq!(bare_ref.table(), "table");
321    }
322
323    #[test]
324    fn test_error_handling_patterns() {
325        use clickhouse_arrow::Error as ClickhouseNativeError;
326
327        use crate::utils::map_clickhouse_err;
328
329        // Test the error patterns used in connection methods
330        let undefined_tables_error = ClickhouseNativeError::UndefinedTables {
331            db:     "test_db".to_string(),
332            tables: vec!["test_table".to_string()],
333        };
334
335        let mapped_error = map_clickhouse_err(undefined_tables_error);
336        match mapped_error {
337            DataFusionError::External(boxed_error) => {
338                let error_str = boxed_error.to_string();
339                assert!(error_str.contains("Tables undefined"));
340                assert!(error_str.contains("test_db"));
341                assert!(error_str.contains("test_table"));
342            }
343            _ => panic!("Expected External error"),
344        }
345    }
346
347    #[test]
348    fn test_join_push_down_creation() {
349        use crate::sql::JoinPushDown;
350
351        // Test the join push down logic used in connection pool creation
352        let identifier = "test_pool";
353        let join_push_down = JoinPushDown::AllowedFor(identifier.to_string());
354
355        match join_push_down {
356            JoinPushDown::AllowedFor(id) => assert_eq!(id, "test_pool"),
357            JoinPushDown::Disallow => panic!("Expected AllowedFor variant"),
358        }
359    }
360}