1#![cfg_attr(feature = "mocks", expect(clippy::unused_async))]
2#![cfg_attr(feature = "mocks", expect(dead_code))]
3
4#[cfg(feature = "mocks")]
5mod mock;
6
7use clickhouse_arrow::ArrowConnectionPoolBuilder;
8#[cfg(not(feature = "mocks"))]
9use clickhouse_arrow::{
10 ArrowConnectionManager, ArrowFormat, ClickHouseResponse, ConnectionManager,
11 Error as ClickhouseNativeError, bb8,
12};
13#[cfg(not(feature = "mocks"))]
14use datafusion::arrow::array::RecordBatch;
15use datafusion::arrow::datatypes::SchemaRef;
16use datafusion::common::DataFusionError;
17use datafusion::common::error::GenericError;
18use datafusion::error::Result;
19use datafusion::physical_plan::SendableRecordBatchStream;
20#[cfg(not(feature = "mocks"))]
21use datafusion::sql::TableReference;
22use futures_util::TryStreamExt;
23use tracing::{debug, error};
24
25use crate::sql::JoinPushDown;
26use crate::stream::{RecordBatchStreamWrapper, record_batch_stream_from_stream};
27
28#[cfg(not(feature = "mocks"))]
30pub type ArrowPoolConnection<'a> = bb8::PooledConnection<'a, ConnectionManager<ArrowFormat>>;
31#[cfg(not(feature = "mocks"))]
32pub type ArrowPool = bb8::Pool<ArrowConnectionManager>;
33#[cfg(feature = "mocks")]
35pub type ArrowPoolConnection<'a> = &'a ();
36#[cfg(feature = "mocks")]
37pub type ArrowPool = ();
38
39#[derive(Debug, Clone)]
41pub struct ClickHouseConnectionPool {
42 pool: ArrowPool,
44 join_push_down: JoinPushDown,
45 write_concurrency: usize,
48}
49
50impl ClickHouseConnectionPool {
51 pub fn new(identifier: impl Into<String>, pool: ArrowPool) -> Self {
54 debug!("Creating new ClickHouse connection pool");
55 let join_push_down = JoinPushDown::AllowedFor(identifier.into());
56 Self { pool, join_push_down, write_concurrency: 4 }
57 }
58
59 #[must_use]
66 pub fn with_write_concurrency(mut self, concurrency: usize) -> Self {
67 self.write_concurrency = concurrency;
68 self
69 }
70
71 pub fn write_concurrency(&self) -> usize { self.write_concurrency }
73
74 pub async fn from_pool_builder(builder: ArrowConnectionPoolBuilder) -> Result<Self> {
79 let identifer = builder.connection_identifier();
80
81 #[cfg(not(feature = "mocks"))]
83 let pool = builder
84 .configure_client(|c| c.with_database("default"))
85 .build()
86 .await
87 .inspect_err(|error| error!(?error, "Error building ClickHouse connection pool"))
88 .map_err(crate::utils::map_clickhouse_err)?;
89
90 #[cfg(feature = "mocks")]
91 let pool = ();
92
93 Ok(Self::new(identifer, pool))
94 }
95
96 pub fn pool(&self) -> &ArrowPool { &self.pool }
98
99 pub fn join_push_down(&self) -> JoinPushDown { self.join_push_down.clone() }
100}
101
102impl ClickHouseConnectionPool {
103 pub async fn connect(&self) -> Result<ClickHouseConnection<'_>> {
108 #[cfg(not(feature = "mocks"))]
109 let conn = self
110 .pool
111 .get()
112 .await
113 .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
114 .map_err(crate::utils::map_external_err)?;
115 #[cfg(feature = "mocks")]
116 let conn = &();
117 Ok(ClickHouseConnection::new(conn))
118 }
119
120 pub async fn connect_static(&self) -> Result<ClickHouseConnection<'static>> {
125 #[cfg(not(feature = "mocks"))]
126 let conn = self
127 .pool
128 .get_owned()
129 .await
130 .inspect_err(|error| error!(?error, "Failed getting connection from pool"))
131 .map_err(crate::utils::map_external_err)?;
132 #[cfg(feature = "mocks")]
133 let conn = &();
134 Ok(ClickHouseConnection::new_static(conn))
135 }
136}
137
138#[derive(Debug)]
144pub struct ClickHouseConnection<'a> {
145 conn: ArrowPoolConnection<'a>,
146}
147
148impl<'a> ClickHouseConnection<'a> {
149 pub fn new(conn: ArrowPoolConnection<'a>) -> Self { ClickHouseConnection { conn } }
150
151 pub fn new_static(conn: ArrowPoolConnection<'static>) -> Self { ClickHouseConnection { conn } }
153
154 pub async fn query_arrow_with_schema(
164 &self,
165 sql: &str,
166 params: &[()],
167 schema: SchemaRef,
168 coerce_schema: bool,
169 ) -> Result<RecordBatchStreamWrapper, DataFusionError> {
170 debug!(sql, "Running query");
171 let batches = Box::pin(
172 self.query_arrow_raw(sql, params)
173 .await?
174 .map_err(|e| DataFusionError::External(Box::new(e))),
176 );
177 Ok(RecordBatchStreamWrapper::new_from_stream(batches, schema).with_coercion(coerce_schema))
178 }
179
180 pub async fn query_arrow(
190 &self,
191 sql: &str,
192 params: &[()],
193 projected_schema: Option<SchemaRef>,
194 ) -> Result<SendableRecordBatchStream, GenericError> {
195 if let Some(schema) = projected_schema {
196 return Ok(Box::pin(self.query_arrow_with_schema(sql, params, schema, false).await?));
197 }
198
199 let batches = Box::pin(
200 self.query_arrow_raw(sql, params)
201 .await?
202 .map_err(|e| DataFusionError::External(Box::new(e))),
204 );
205
206 Ok(Box::pin(
207 record_batch_stream_from_stream(batches)
208 .await
209 .inspect_err(|error| error!(?error, "Failed converting batches to stream"))
210 .map_err(Box::new)?,
211 ))
212 }
213}
214
215#[cfg(not(feature = "mocks"))]
216impl ClickHouseConnection<'_> {
217 pub async fn tables(&self, schema: &str) -> Result<Vec<String>> {
222 debug!(schema, "Fetching tables");
223 self.conn
224 .fetch_tables(Some(schema), None)
225 .await
226 .inspect_err(|error| error!(?error, "Fetching tables failed"))
227 .map_err(crate::utils::map_clickhouse_err)
228 }
229
230 pub async fn schemas(&self) -> Result<Vec<String>> {
235 debug!("Fetching databases");
236 self.conn
237 .fetch_schemas(None)
238 .await
239 .inspect_err(|error| error!(?error, "Fetching databases failed"))
240 .map_err(crate::utils::map_clickhouse_err)
241 }
242
243 pub async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef> {
248 debug!(%table_reference, "Fetching schema for table");
249 let db = table_reference.schema();
250 let table = table_reference.table();
251 let mut schemas =
252 self.conn.fetch_schema(db, &[table][..], None).await.map_err(|error| {
253 if let ClickhouseNativeError::UndefinedTables { .. } = error {
254 error!(?error, ?db, ?table, "Tables undefined");
255 } else {
256 error!(?error, ?db, ?table, "Unknown error occurred while fetching schema");
257 }
258 crate::utils::map_clickhouse_err(error)
259 })?;
260
261 schemas
262 .remove(table)
263 .ok_or(DataFusionError::External("Schema not found for table".into()))
264 }
265
266 pub async fn query_arrow_raw(
271 &self,
272 sql: &str,
273 _params: &[()],
274 ) -> Result<ClickHouseResponse<RecordBatch>> {
275 self.conn
276 .query(sql, None)
277 .await
278 .inspect(|_| tracing::trace!("Query executed successfully"))
279 .inspect_err(|error| error!(?error, "Failed running query"))
280 .map_err(|e| DataFusionError::External(Box::new(e)))
282 }
283
284 pub async fn execute(&self, sql: &str, _params: &[()]) -> Result<u64, GenericError> {
289 debug!(sql, "Executing query");
290 self.conn
291 .execute(sql, None)
292 .await
293 .inspect_err(|error| error!(?error, "Failed executing query"))
294 .map_err(Box::new)?;
295 Ok(0)
296 }
297}
298
299#[cfg(test)]
302mod tests {
303 use datafusion::sql::TableReference;
304
305 use super::*;
306
307 #[test]
308 fn test_table_reference_schema_extraction() {
309 let table_ref = TableReference::full("catalog", "schema", "table");
311 assert_eq!(table_ref.schema(), Some("schema"));
312 assert_eq!(table_ref.table(), "table");
313
314 let partial_ref = TableReference::partial("schema", "table");
315 assert_eq!(partial_ref.schema(), Some("schema"));
316 assert_eq!(partial_ref.table(), "table");
317
318 let bare_ref = TableReference::bare("table");
319 assert_eq!(bare_ref.schema(), None);
320 assert_eq!(bare_ref.table(), "table");
321 }
322
323 #[test]
324 fn test_error_handling_patterns() {
325 use clickhouse_arrow::Error as ClickhouseNativeError;
326
327 use crate::utils::map_clickhouse_err;
328
329 let undefined_tables_error = ClickhouseNativeError::UndefinedTables {
331 db: "test_db".to_string(),
332 tables: vec!["test_table".to_string()],
333 };
334
335 let mapped_error = map_clickhouse_err(undefined_tables_error);
336 match mapped_error {
337 DataFusionError::External(boxed_error) => {
338 let error_str = boxed_error.to_string();
339 assert!(error_str.contains("Tables undefined"));
340 assert!(error_str.contains("test_db"));
341 assert!(error_str.contains("test_table"));
342 }
343 _ => panic!("Expected External error"),
344 }
345 }
346
347 #[test]
348 fn test_join_push_down_creation() {
349 use crate::sql::JoinPushDown;
350
351 let identifier = "test_pool";
353 let join_push_down = JoinPushDown::AllowedFor(identifier.to_string());
354
355 match join_push_down {
356 JoinPushDown::AllowedFor(id) => assert_eq!(id, "test_pool"),
357 JoinPushDown::Disallow => panic!("Expected AllowedFor variant"),
358 }
359 }
360}