datafusion_table_providers/sql/db_connection_pool/dbconnection/
clickhouseconn.rs1use std::io::Cursor;
2use std::{any::Any, sync::Arc};
3
4use arrow::array::RecordBatch;
5use arrow_ipc::reader::{StreamDecoder, StreamReader};
6use async_trait::async_trait;
7use clickhouse::{Client, Row};
8use datafusion::arrow::datatypes::{Schema, SchemaRef};
9use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
10use datafusion::physical_plan::EmptyRecordBatchStream;
11use datafusion::{execution::SendableRecordBatchStream, sql::TableReference};
12use regex::Regex;
13use serde::Deserialize;
14use snafu::ResultExt;
15
16use super::{AsyncDbConnection, DbConnection, Error, SyncDbConnection};
17
18impl DbConnection<Client, ()> for Client {
19 fn as_any(&self) -> &dyn Any {
20 self
21 }
22
23 fn as_any_mut(&mut self) -> &mut dyn Any {
24 self
25 }
26
27 fn as_sync(&self) -> Option<&dyn SyncDbConnection<Client, ()>> {
28 None
29 }
30
31 fn as_async(&self) -> Option<&dyn AsyncDbConnection<Client, ()>> {
32 Some(self)
33 }
34}
35
36#[async_trait]
37impl AsyncDbConnection<Client, ()> for Client {
38 fn new(conn: Client) -> Self
39 where
40 Self: Sized,
41 {
42 conn
43 }
44
45 async fn tables(&self, schema: &str) -> Result<Vec<String>, Error> {
46 #[derive(Row, Deserialize)]
47 struct Row {
48 name: String,
49 }
50
51 let tables: Vec<Row> = self
52 .query("SHOW TABLES FROM ?")
53 .bind(schema)
54 .fetch_all()
55 .await
56 .boxed()
57 .context(super::UnableToGetTablesSnafu)?;
58
59 Ok(tables.into_iter().map(|x| x.name).collect())
60 }
61
62 async fn schemas(&self) -> Result<Vec<String>, Error> {
63 #[derive(Row, Deserialize)]
64 struct Row {
65 name: String,
66 }
67 let tables: Vec<Row> = self
68 .query("SHOW DATABASES")
69 .fetch_all()
70 .await
71 .boxed()
72 .context(super::UnableToGetSchemasSnafu)?;
73
74 Ok(tables.into_iter().map(|x| x.name).collect())
75 }
76
77 async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error> {
83 #[derive(Row, Deserialize)]
84 struct CatalogRow {
85 db: String,
86 }
87
88 let database = match table_reference.schema() {
89 Some(db) => db.to_string(),
90 None => {
91 let row: CatalogRow = self
92 .query("SELECT currentDatabase() AS db")
93 .fetch_one()
94 .await
95 .boxed()
96 .context(super::UnableToGetSchemaSnafu)?;
97 row.db
98 }
99 };
100
101 #[derive(Row, Deserialize)]
102 struct TableInfoRow {
103 engine: String,
104 as_select: String,
105 }
106
107 let table_info: TableInfoRow = self
108 .query("SELECT engine, as_select FROM system.tables WHERE database = ? AND name = ?")
109 .bind(&database)
110 .bind(table_reference.table())
111 .fetch_one()
112 .await
113 .boxed()
114 .context(super::UnableToGetSchemaSnafu)?;
115
116 let is_view = matches!(
117 table_info.engine.to_uppercase().as_str(),
118 "VIEW" | "MATERIALIZEDVIEW"
119 );
120
121 let statement = if is_view {
122 let view_query = table_info.as_select;
123 format!(
124 "SELECT * FROM ({}) LIMIT 0",
125 replace_clickhouse_ddl_parameters(&view_query)
126 )
127 } else {
128 let table_ref = TableReference::partial(database, table_reference.table());
129 format!("SELECT * FROM {} LIMIT 0", table_ref.to_quoted_string())
130 };
131
132 let mut bytes = self
133 .query(&statement)
134 .fetch_bytes("ArrowStream")
135 .boxed()
136 .context(super::UnableToGetSchemaSnafu)?;
137
138 let reader = bytes
139 .collect()
140 .await
141 .boxed()
142 .and_then(|bytes| StreamReader::try_new(Cursor::new(bytes), None).boxed())
143 .context(super::UnableToGetSchemaSnafu)?;
144
145 return Ok(reader.schema());
146 }
147
148 async fn query_arrow(
160 &self,
161 sql: &str,
162 _params: &[()],
163 projected_schema: Option<SchemaRef>,
164 ) -> super::Result<SendableRecordBatchStream> {
165 let query = self.query(sql);
166
167 let mut bytes_stream = query
168 .fetch_bytes("ArrowStream")
169 .boxed()
170 .context(super::UnableToQueryArrowSnafu)?;
171
172 let mut first_batch: Option<RecordBatch> = None;
173 let mut decoder = StreamDecoder::new();
174
175 while let Some(buf) = bytes_stream.next().await? {
177 if let Some(batch) = decoder.decode(&mut buf.into())? {
178 first_batch = Some(batch);
179 break;
180 }
181 }
182
183 if let Some(first_batch) = first_batch {
184 let schema = first_batch.schema();
185 let stream = async_stream::stream! {
186 yield Ok(first_batch);
187 while let Some(buf) = bytes_stream
188 .next()
189 .await
190 .map_err(|er| arrow::error::ArrowError::ExternalError(Box::new(er)))?
191 {
192 if let Some(batch) = decoder.decode(&mut buf.into())? {
193 yield Ok(batch);
194 }
195 }
196 };
197 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
198 } else if let Some(schema) = projected_schema {
199 Ok(Box::pin(RecordBatchStreamAdapter::new(
200 schema.clone(),
201 EmptyRecordBatchStream::new(schema),
202 )))
203 } else {
204 let schema: Arc<Schema> = Schema::empty().into();
205 Ok(Box::pin(RecordBatchStreamAdapter::new(
206 schema.clone(),
207 EmptyRecordBatchStream::new(schema),
208 )))
209 }
210 }
211
212 async fn execute(&self, sql: &str, params: &[()]) -> super::Result<u64> {
219 let mut query = self.query(sql);
220
221 for param in params {
222 query = query.bind(param);
223 }
224
225 query
226 .execute()
227 .await
228 .boxed()
229 .context(super::UnableToQueryArrowSnafu)?;
230
231 Ok(0)
232 }
233}
234
235pub fn replace_clickhouse_ddl_parameters(ddl_query: &str) -> String {
236 let param_pattern = Regex::new(r"\{(\w+?):(\w+?)\}").unwrap();
238
239 let modified_query = param_pattern.replace_all(ddl_query, |caps: ®ex::Captures| {
240 let data_type = caps.get(2).map_or("", |m| m.as_str());
242 match data_type.to_lowercase().as_str() {
243 "string" => "''".to_string(),
244 "uint8" | "uint16" | "uint32" | "uint64" | "int8" | "int16" | "int32" | "int64" => {
245 "0".to_string()
246 }
247 "float32" | "float64" => "0.0".to_string(),
248 "date" => "'2000-01-01'".to_string(),
249 "datetime" => "'2000-01-01 00:00:00'".to_string(),
250 "bool" => "false".to_string(),
251 _ => "''".to_string(),
252 }
253 });
254
255 modified_query.into_owned()
256}