clickhouse_datafusion/
connection.rs

1#![cfg_attr(feature = "mocks", expect(clippy::unused_async))]
2#![cfg_attr(feature = "mocks", expect(dead_code))]
3
4#[cfg(feature = "mocks")]
5mod mock;
6
7use clickhouse_arrow::ArrowConnectionPoolBuilder;
8#[cfg(not(feature = "mocks"))]
9use clickhouse_arrow::{
10    ArrowConnectionManager, ArrowFormat, ClickHouseResponse, ConnectionManager,
11    Error as ClickhouseNativeError, bb8,
12};
13#[cfg(not(feature = "mocks"))]
14use datafusion::arrow::array::RecordBatch;
15use datafusion::arrow::datatypes::SchemaRef;
16use datafusion::common::DataFusionError;
17use datafusion::common::error::GenericError;
18use datafusion::error::Result;
19use datafusion::physical_plan::SendableRecordBatchStream;
20#[cfg(not(feature = "mocks"))]
21use datafusion::sql::TableReference;
22use futures_util::TryStreamExt;
23use tracing::{debug, error};
24
25use crate::sql::JoinPushDown;
26use crate::stream::{RecordBatchStreamWrapper, record_batch_stream_from_stream};
27
28/// Type alias for a pooled connection to a `ClickHouse` database.
29#[cfg(not(feature = "mocks"))]
30pub type ArrowPoolConnection<'a> = bb8::PooledConnection<'a, ConnectionManager<ArrowFormat>>;
31#[cfg(not(feature = "mocks"))]
32pub type ArrowPool = bb8::Pool<ArrowConnectionManager>;
33/// Type alias for a pooled connection as mocks.
34#[cfg(feature = "mocks")]
35pub type ArrowPoolConnection<'a> = &'a ();
36#[cfg(feature = "mocks")]
37pub type ArrowPool = ();
38
39/// A wrapper around a [`clickhouse_arrow::ConnectionPool<ArrowFormat>`]
40#[derive(Debug, Clone)]
41pub struct ClickHouseConnectionPool {
42    // "mocks" feature affects mainly this property and other properties related to connecting.
43    pool:           ArrowPool,
44    join_push_down: JoinPushDown,
45}
46
47impl ClickHouseConnectionPool {
48    /// Create a new `ClickHouse` connection pool for use in `DataFusion`. The identifier is used in
49    /// the case of federation to determine if queries can be pushed down across two pools
50    pub fn new(identifier: impl Into<String>, pool: ArrowPool) -> Self {
51        debug!("Creating new ClickHouse connection pool");
52        let join_push_down = JoinPushDown::AllowedFor(identifier.into());
53        Self { pool, join_push_down }
54    }
55
56    /// Create a new `ClickHouse` connection pool from a builder.
57    ///
58    /// # Errors
59    /// - Returns an error if the connection pool cannot be created.
60    pub async fn from_pool_builder(builder: ArrowConnectionPoolBuilder) -> Result<Self> {
61        let identifer = builder.connection_identifier();
62
63        // Since this pool will be used for ddl, it's essential it connects to the "default" db
64        #[cfg(not(feature = "mocks"))]
65        let pool = builder
66            .configure_client(|c| c.with_database("default"))
67            .build()
68            .await
69            .inspect_err(|error| error!(?error, "Error building ClickHouse connection pool"))
70            .map_err(crate::utils::map_clickhouse_err)?;
71
72        #[cfg(feature = "mocks")]
73        let pool = ();
74
75        Ok(Self::new(identifer, pool))
76    }
77
78    /// Access the underlying connection pool
79    pub fn pool(&self) -> &ArrowPool { &self.pool }
80
81    pub fn join_push_down(&self) -> JoinPushDown { self.join_push_down.clone() }
82}
83
84impl ClickHouseConnectionPool {
85    /// Get a managed [`ArrowPoolConnection`] wrapped in a [`ClickHouseConnection`]
86    ///
87    /// # Errors
88    /// - Returns an error if the connection cannot be established.
89    pub async fn connect(&self) -> Result<ClickHouseConnection<'_>> {
90        #[cfg(not(feature = "mocks"))]
91        let conn = self
92            .pool
93            .get()
94            .await
95            .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
96            .map_err(crate::utils::map_external_err)?;
97        #[cfg(feature = "mocks")]
98        let conn = &();
99        Ok(ClickHouseConnection::new(conn))
100    }
101
102    /// Get a managed static [`ArrowPoolConnection`] wrapped in a [`ClickHouseConnection`]
103    ///
104    /// # Errors
105    /// - Returns an error if the connection cannot be established.
106    pub async fn connect_static(&self) -> Result<ClickHouseConnection<'static>> {
107        #[cfg(not(feature = "mocks"))]
108        let conn = self
109            .pool
110            .get_owned()
111            .await
112            .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
113            .map_err(crate::utils::map_external_err)?;
114        #[cfg(feature = "mocks")]
115        let conn = &();
116        Ok(ClickHouseConnection::new_static(conn))
117    }
118}
119
120/// A wrapper around [`ArrowPoolConnection`] that provides additional functionality relevant for
121/// `DataFusion`.
122///
123/// The methods [`ClickHouseConnection::tables`], [`ClickHouseConnection::get_schema`], and
124/// [`ClickHouseConnection::query_arrow`] will all be run against the `ClickHouse` instance.
125#[derive(Debug)]
126pub struct ClickHouseConnection<'a> {
127    conn: ArrowPoolConnection<'a>,
128}
129
130impl<'a> ClickHouseConnection<'a> {
131    pub fn new(conn: ArrowPoolConnection<'a>) -> Self { ClickHouseConnection { conn } }
132
133    // TODO: Use to provide interop with datafusion-table-providers
134    pub fn new_static(conn: ArrowPoolConnection<'static>) -> Self { ClickHouseConnection { conn } }
135
136    /// Issues a query against `ClickHouse` and returns the result as an arrow
137    /// [`SendableRecordBatchStream`] using the provided schema.
138    ///
139    /// The argument `coerce_schema` will be passed to `RecordBatchStream` only if
140    /// `projected_schema` is also provided. Otherwise coercion won't be necessary as the streamed
141    /// `RecordBatch`es will determine the schema.
142    ///
143    /// # Errors
144    /// - Returns an error if the query fails.
145    pub async fn query_arrow_with_schema(
146        &self,
147        sql: &str,
148        params: &[()],
149        schema: SchemaRef,
150        coerce_schema: bool,
151    ) -> Result<RecordBatchStreamWrapper, DataFusionError> {
152        debug!(sql, "Running query");
153        let batches = Box::pin(
154            self.query_arrow_raw(sql, params)
155                .await?
156                // Map the stream's clickhouse-arrow error to DataFusionError
157                .map_err(|e| DataFusionError::External(Box::new(e))),
158        );
159        Ok(RecordBatchStreamWrapper::new_from_stream(batches, schema).with_coercion(coerce_schema))
160    }
161
162    /// Issues a query against `ClickHouse` and returns the result as an arrow
163    /// [`SendableRecordBatchStream`] using the provided schema.
164    ///
165    /// This method allows interop with `datafusion-table-providers` if desired. Otherwise, the
166    /// method `Self::query_arrow_raw` can be used to prevent additional wrapping, or
167    /// `Self::query_arrow_with_schema` if schema coercion is desired.
168    ///
169    /// # Errors
170    /// - Returns an error if the query fails.
171    pub async fn query_arrow(
172        &self,
173        sql: &str,
174        params: &[()],
175        projected_schema: Option<SchemaRef>,
176    ) -> Result<SendableRecordBatchStream, GenericError> {
177        if let Some(schema) = projected_schema {
178            return Ok(Box::pin(self.query_arrow_with_schema(sql, params, schema, false).await?));
179        }
180
181        let batches = Box::pin(
182            self.query_arrow_raw(sql, params)
183                .await?
184                // Map the stream's clickhouse-arrow error to DataFusionError
185                .map_err(|e| DataFusionError::External(Box::new(e))),
186        );
187
188        Ok(Box::pin(
189            record_batch_stream_from_stream(batches)
190                .await
191                .inspect_err(|error| error!(?error, "Failed converting batches to stream"))
192                .map_err(Box::new)?,
193        ))
194    }
195}
196
197#[cfg(not(feature = "mocks"))]
198impl ClickHouseConnection<'_> {
199    /// Fetch the names of the tables in a schema (database).
200    ///
201    /// # Errors
202    /// - Returns an error if the tables cannot be fetched.
203    pub async fn tables(&self, schema: &str) -> Result<Vec<String>> {
204        debug!(schema, "Fetching tables");
205        self.conn
206            .fetch_tables(Some(schema), None)
207            .await
208            .inspect_err(|error| error!(?error, "Fetching tables failed"))
209            .map_err(crate::utils::map_clickhouse_err)
210    }
211
212    /// Fetch the names of the schemas (databases).
213    ///
214    /// # Errors
215    /// - Returns an error if the schemas cannot be fetched.
216    pub async fn schemas(&self) -> Result<Vec<String>> {
217        debug!("Fetching databases");
218        self.conn
219            .fetch_schemas(None)
220            .await
221            .inspect_err(|error| error!(?error, "Fetching databases failed"))
222            .map_err(crate::utils::map_clickhouse_err)
223    }
224
225    /// Fetch the schema for a table
226    ///
227    /// # Errors
228    /// - Returns an error if the schema cannot be fetched.
229    pub async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef> {
230        debug!(%table_reference, "Fetching schema for table");
231        let db = table_reference.schema();
232        let table = table_reference.table();
233        let mut schemas =
234            self.conn.fetch_schema(db, &[table][..], None).await.map_err(|error| {
235                if let ClickhouseNativeError::UndefinedTables { .. } = error {
236                    error!(?error, ?db, ?table, "Tables undefined");
237                } else {
238                    error!(?error, ?db, ?table, "Unknown error occurred while fetching schema");
239                }
240                crate::utils::map_clickhouse_err(error)
241            })?;
242
243        schemas
244            .remove(table)
245            .ok_or(DataFusionError::External("Schema not found for table".into()))
246    }
247
248    /// Issues a query against `ClickHouse` and returns the raw `ClickHouseResponse<RecordBatch>`
249    ///
250    /// # Errors
251    /// - Returns an error if the query fails
252    pub async fn query_arrow_raw(
253        &self,
254        sql: &str,
255        _params: &[()],
256    ) -> Result<ClickHouseResponse<RecordBatch>> {
257        self.conn
258            .query(sql, None)
259            .await
260            .inspect(|_| tracing::trace!("Query executed successfully"))
261            .inspect_err(|error| error!(?error, "Failed running query"))
262            // Convert the clickhouse-arrow error to a DataFusionError
263            .map_err(|e| DataFusionError::External(Box::new(e)))
264    }
265
266    /// Executes a statement against `ClickHouse` and returns the number of affected rows.
267    ///
268    /// # Errors
269    /// - Returns an error if the query fails.
270    pub async fn execute(&self, sql: &str, _params: &[()]) -> Result<u64, GenericError> {
271        debug!(sql, "Executing query");
272        self.conn
273            .execute(sql, None)
274            .await
275            .inspect_err(|error| error!(?error, "Failed executing query"))
276            .map_err(Box::new)?;
277        Ok(0)
278    }
279}
280
281// TODO: Provide compat with datafusion-table-providers DbConnection, AsyncDbConnection
282
283#[cfg(test)]
284mod tests {
285    use datafusion::sql::TableReference;
286
287    use super::*;
288
289    #[test]
290    fn test_table_reference_schema_extraction() {
291        // Test the logic used in get_schema method
292        let table_ref = TableReference::full("catalog", "schema", "table");
293        assert_eq!(table_ref.schema(), Some("schema"));
294        assert_eq!(table_ref.table(), "table");
295
296        let partial_ref = TableReference::partial("schema", "table");
297        assert_eq!(partial_ref.schema(), Some("schema"));
298        assert_eq!(partial_ref.table(), "table");
299
300        let bare_ref = TableReference::bare("table");
301        assert_eq!(bare_ref.schema(), None);
302        assert_eq!(bare_ref.table(), "table");
303    }
304
305    #[test]
306    fn test_error_handling_patterns() {
307        use clickhouse_arrow::Error as ClickhouseNativeError;
308
309        use crate::utils::map_clickhouse_err;
310
311        // Test the error patterns used in connection methods
312        let undefined_tables_error = ClickhouseNativeError::UndefinedTables {
313            db:     "test_db".to_string(),
314            tables: vec!["test_table".to_string()],
315        };
316
317        let mapped_error = map_clickhouse_err(undefined_tables_error);
318        match mapped_error {
319            DataFusionError::External(boxed_error) => {
320                let error_str = boxed_error.to_string();
321                assert!(error_str.contains("Tables undefined"));
322                assert!(error_str.contains("test_db"));
323                assert!(error_str.contains("test_table"));
324            }
325            _ => panic!("Expected External error"),
326        }
327    }
328
329    #[test]
330    fn test_join_push_down_creation() {
331        use crate::sql::JoinPushDown;
332
333        // Test the join push down logic used in connection pool creation
334        let identifier = "test_pool";
335        let join_push_down = JoinPushDown::AllowedFor(identifier.to_string());
336
337        match join_push_down {
338            JoinPushDown::AllowedFor(id) => assert_eq!(id, "test_pool"),
339            JoinPushDown::Disallow => panic!("Expected AllowedFor variant"),
340        }
341    }
342}