clickhouse_datafusion/
connection.rs1#![cfg_attr(feature = "mocks", expect(clippy::unused_async))]
2#![cfg_attr(feature = "mocks", expect(dead_code))]
3
4#[cfg(feature = "mocks")]
5mod mock;
6
7use clickhouse_arrow::ArrowConnectionPoolBuilder;
8#[cfg(not(feature = "mocks"))]
9use clickhouse_arrow::{
10 ArrowConnectionManager, ArrowFormat, ClickHouseResponse, ConnectionManager,
11 Error as ClickhouseNativeError, bb8,
12};
13#[cfg(not(feature = "mocks"))]
14use datafusion::arrow::array::RecordBatch;
15use datafusion::arrow::datatypes::SchemaRef;
16use datafusion::common::DataFusionError;
17use datafusion::common::error::GenericError;
18use datafusion::error::Result;
19use datafusion::physical_plan::SendableRecordBatchStream;
20#[cfg(not(feature = "mocks"))]
21use datafusion::sql::TableReference;
22use futures_util::TryStreamExt;
23use tracing::{debug, error};
24
25use crate::sql::JoinPushDown;
26use crate::stream::{RecordBatchStreamWrapper, record_batch_stream_from_stream};
27
28#[cfg(not(feature = "mocks"))]
30pub type ArrowPoolConnection<'a> = bb8::PooledConnection<'a, ConnectionManager<ArrowFormat>>;
31#[cfg(not(feature = "mocks"))]
32pub type ArrowPool = bb8::Pool<ArrowConnectionManager>;
33#[cfg(feature = "mocks")]
35pub type ArrowPoolConnection<'a> = &'a ();
36#[cfg(feature = "mocks")]
37pub type ArrowPool = ();
38
39#[derive(Debug, Clone)]
41pub struct ClickHouseConnectionPool {
42 pool: ArrowPool,
44 join_push_down: JoinPushDown,
45}
46
47impl ClickHouseConnectionPool {
48 pub fn new(identifier: impl Into<String>, pool: ArrowPool) -> Self {
51 debug!("Creating new ClickHouse connection pool");
52 let join_push_down = JoinPushDown::AllowedFor(identifier.into());
53 Self { pool, join_push_down }
54 }
55
56 pub async fn from_pool_builder(builder: ArrowConnectionPoolBuilder) -> Result<Self> {
61 let identifer = builder.connection_identifier();
62
63 #[cfg(not(feature = "mocks"))]
65 let pool = builder
66 .configure_client(|c| c.with_database("default"))
67 .build()
68 .await
69 .inspect_err(|error| error!(?error, "Error building ClickHouse connection pool"))
70 .map_err(crate::utils::map_clickhouse_err)?;
71
72 #[cfg(feature = "mocks")]
73 let pool = ();
74
75 Ok(Self::new(identifer, pool))
76 }
77
78 pub fn pool(&self) -> &ArrowPool { &self.pool }
80
81 pub fn join_push_down(&self) -> JoinPushDown { self.join_push_down.clone() }
82}
83
84impl ClickHouseConnectionPool {
85 pub async fn connect(&self) -> Result<ClickHouseConnection<'_>> {
90 #[cfg(not(feature = "mocks"))]
91 let conn = self
92 .pool
93 .get()
94 .await
95 .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
96 .map_err(crate::utils::map_external_err)?;
97 #[cfg(feature = "mocks")]
98 let conn = &();
99 Ok(ClickHouseConnection::new(conn))
100 }
101
102 pub async fn connect_static(&self) -> Result<ClickHouseConnection<'static>> {
107 #[cfg(not(feature = "mocks"))]
108 let conn = self
109 .pool
110 .get_owned()
111 .await
112 .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
113 .map_err(crate::utils::map_external_err)?;
114 #[cfg(feature = "mocks")]
115 let conn = &();
116 Ok(ClickHouseConnection::new_static(conn))
117 }
118}
119
120#[derive(Debug)]
126pub struct ClickHouseConnection<'a> {
127 conn: ArrowPoolConnection<'a>,
128}
129
130impl<'a> ClickHouseConnection<'a> {
131 pub fn new(conn: ArrowPoolConnection<'a>) -> Self { ClickHouseConnection { conn } }
132
133 pub fn new_static(conn: ArrowPoolConnection<'static>) -> Self { ClickHouseConnection { conn } }
135
136 pub async fn query_arrow_with_schema(
146 &self,
147 sql: &str,
148 params: &[()],
149 schema: SchemaRef,
150 coerce_schema: bool,
151 ) -> Result<RecordBatchStreamWrapper, DataFusionError> {
152 debug!(sql, "Running query");
153 let batches = Box::pin(
154 self.query_arrow_raw(sql, params)
155 .await?
156 .map_err(|e| DataFusionError::External(Box::new(e))),
158 );
159 Ok(RecordBatchStreamWrapper::new_from_stream(batches, schema).with_coercion(coerce_schema))
160 }
161
162 pub async fn query_arrow(
172 &self,
173 sql: &str,
174 params: &[()],
175 projected_schema: Option<SchemaRef>,
176 ) -> Result<SendableRecordBatchStream, GenericError> {
177 if let Some(schema) = projected_schema {
178 return Ok(Box::pin(self.query_arrow_with_schema(sql, params, schema, false).await?));
179 }
180
181 let batches = Box::pin(
182 self.query_arrow_raw(sql, params)
183 .await?
184 .map_err(|e| DataFusionError::External(Box::new(e))),
186 );
187
188 Ok(Box::pin(
189 record_batch_stream_from_stream(batches)
190 .await
191 .inspect_err(|error| error!(?error, "Failed converting batches to stream"))
192 .map_err(Box::new)?,
193 ))
194 }
195}
196
197#[cfg(not(feature = "mocks"))]
198impl ClickHouseConnection<'_> {
199 pub async fn tables(&self, schema: &str) -> Result<Vec<String>> {
204 debug!(schema, "Fetching tables");
205 self.conn
206 .fetch_tables(Some(schema), None)
207 .await
208 .inspect_err(|error| error!(?error, "Fetching tables failed"))
209 .map_err(crate::utils::map_clickhouse_err)
210 }
211
212 pub async fn schemas(&self) -> Result<Vec<String>> {
217 debug!("Fetching databases");
218 self.conn
219 .fetch_schemas(None)
220 .await
221 .inspect_err(|error| error!(?error, "Fetching databases failed"))
222 .map_err(crate::utils::map_clickhouse_err)
223 }
224
225 pub async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef> {
230 debug!(%table_reference, "Fetching schema for table");
231 let db = table_reference.schema();
232 let table = table_reference.table();
233 let mut schemas =
234 self.conn.fetch_schema(db, &[table][..], None).await.map_err(|error| {
235 if let ClickhouseNativeError::UndefinedTables { .. } = error {
236 error!(?error, ?db, ?table, "Tables undefined");
237 } else {
238 error!(?error, ?db, ?table, "Unknown error occurred while fetching schema");
239 }
240 crate::utils::map_clickhouse_err(error)
241 })?;
242
243 schemas
244 .remove(table)
245 .ok_or(DataFusionError::External("Schema not found for table".into()))
246 }
247
248 pub async fn query_arrow_raw(
253 &self,
254 sql: &str,
255 _params: &[()],
256 ) -> Result<ClickHouseResponse<RecordBatch>> {
257 self.conn
258 .query(sql, None)
259 .await
260 .inspect(|_| tracing::trace!("Query executed successfully"))
261 .inspect_err(|error| error!(?error, "Failed running query"))
262 .map_err(|e| DataFusionError::External(Box::new(e)))
264 }
265
266 pub async fn execute(&self, sql: &str, _params: &[()]) -> Result<u64, GenericError> {
271 debug!(sql, "Executing query");
272 self.conn
273 .execute(sql, None)
274 .await
275 .inspect_err(|error| error!(?error, "Failed executing query"))
276 .map_err(Box::new)?;
277 Ok(0)
278 }
279}
280
281#[cfg(test)]
284mod tests {
285 use datafusion::sql::TableReference;
286
287 use super::*;
288
289 #[test]
290 fn test_table_reference_schema_extraction() {
291 let table_ref = TableReference::full("catalog", "schema", "table");
293 assert_eq!(table_ref.schema(), Some("schema"));
294 assert_eq!(table_ref.table(), "table");
295
296 let partial_ref = TableReference::partial("schema", "table");
297 assert_eq!(partial_ref.schema(), Some("schema"));
298 assert_eq!(partial_ref.table(), "table");
299
300 let bare_ref = TableReference::bare("table");
301 assert_eq!(bare_ref.schema(), None);
302 assert_eq!(bare_ref.table(), "table");
303 }
304
305 #[test]
306 fn test_error_handling_patterns() {
307 use clickhouse_arrow::Error as ClickhouseNativeError;
308
309 use crate::utils::map_clickhouse_err;
310
311 let undefined_tables_error = ClickhouseNativeError::UndefinedTables {
313 db: "test_db".to_string(),
314 tables: vec!["test_table".to_string()],
315 };
316
317 let mapped_error = map_clickhouse_err(undefined_tables_error);
318 match mapped_error {
319 DataFusionError::External(boxed_error) => {
320 let error_str = boxed_error.to_string();
321 assert!(error_str.contains("Tables undefined"));
322 assert!(error_str.contains("test_db"));
323 assert!(error_str.contains("test_table"));
324 }
325 _ => panic!("Expected External error"),
326 }
327 }
328
329 #[test]
330 fn test_join_push_down_creation() {
331 use crate::sql::JoinPushDown;
332
333 let identifier = "test_pool";
335 let join_push_down = JoinPushDown::AllowedFor(identifier.to_string());
336
337 match join_push_down {
338 JoinPushDown::AllowedFor(id) => assert_eq!(id, "test_pool"),
339 JoinPushDown::Disallow => panic!("Expected AllowedFor variant"),
340 }
341 }
342}