datafusion_table_providers/sql/db_connection_pool/dbconnection/
clickhouseconn.rs

1use std::io::Cursor;
2use std::{any::Any, sync::Arc};
3
4use arrow::array::RecordBatch;
5use arrow_ipc::reader::{StreamDecoder, StreamReader};
6use async_trait::async_trait;
7use clickhouse::{Client, Row};
8use datafusion::arrow::datatypes::{Schema, SchemaRef};
9use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
10use datafusion::physical_plan::EmptyRecordBatchStream;
11use datafusion::{execution::SendableRecordBatchStream, sql::TableReference};
12use regex::Regex;
13use serde::Deserialize;
14use snafu::ResultExt;
15
16use super::{AsyncDbConnection, DbConnection, Error, SyncDbConnection};
17
18impl DbConnection<Client, ()> for Client {
19    fn as_any(&self) -> &dyn Any {
20        self
21    }
22
23    fn as_any_mut(&mut self) -> &mut dyn Any {
24        self
25    }
26
27    fn as_sync(&self) -> Option<&dyn SyncDbConnection<Client, ()>> {
28        None
29    }
30
31    fn as_async(&self) -> Option<&dyn AsyncDbConnection<Client, ()>> {
32        Some(self)
33    }
34}
35
36#[async_trait]
37impl AsyncDbConnection<Client, ()> for Client {
38    fn new(conn: Client) -> Self
39    where
40        Self: Sized,
41    {
42        conn
43    }
44
45    async fn tables(&self, schema: &str) -> Result<Vec<String>, Error> {
46        #[derive(Row, Deserialize)]
47        struct Row {
48            name: String,
49        }
50
51        let tables: Vec<Row> = self
52            .query("SHOW TABLES FROM ?")
53            .bind(schema)
54            .fetch_all()
55            .await
56            .boxed()
57            .context(super::UnableToGetTablesSnafu)?;
58
59        Ok(tables.into_iter().map(|x| x.name).collect())
60    }
61
62    async fn schemas(&self) -> Result<Vec<String>, Error> {
63        #[derive(Row, Deserialize)]
64        struct Row {
65            name: String,
66        }
67        let tables: Vec<Row> = self
68            .query("SHOW DATABASES")
69            .fetch_all()
70            .await
71            .boxed()
72            .context(super::UnableToGetSchemasSnafu)?;
73
74        Ok(tables.into_iter().map(|x| x.name).collect())
75    }
76
77    /// Get the schema for a table reference.
78    ///
79    /// # Arguments
80    ///
81    /// * `table_reference` - The table reference.
82    async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error> {
83        #[derive(Row, Deserialize)]
84        struct CatalogRow {
85            db: String,
86        }
87
88        let database = match table_reference.schema() {
89            Some(db) => db.to_string(),
90            None => {
91                let row: CatalogRow = self
92                    .query("SELECT currentDatabase() AS db")
93                    .fetch_one()
94                    .await
95                    .boxed()
96                    .context(super::UnableToGetSchemaSnafu)?;
97                row.db
98            }
99        };
100
101        #[derive(Row, Deserialize)]
102        struct TableInfoRow {
103            engine: String,
104            as_select: String,
105        }
106
107        let table_info: TableInfoRow = self
108            .query("SELECT engine, as_select FROM system.tables WHERE database = ? AND name = ?")
109            .bind(&database)
110            .bind(table_reference.table())
111            .fetch_one()
112            .await
113            .boxed()
114            .context(super::UnableToGetSchemaSnafu)?;
115
116        let is_view = matches!(
117            table_info.engine.to_uppercase().as_str(),
118            "VIEW" | "MATERIALIZEDVIEW"
119        );
120
121        let statement = if is_view {
122            let view_query = table_info.as_select;
123            format!(
124                "SELECT * FROM ({}) LIMIT 0",
125                replace_clickhouse_ddl_parameters(&view_query)
126            )
127        } else {
128            let table_ref = TableReference::partial(database, table_reference.table());
129            format!("SELECT * FROM {} LIMIT 0", table_ref.to_quoted_string())
130        };
131
132        let mut bytes = self
133            .query(&statement)
134            .fetch_bytes("ArrowStream")
135            .boxed()
136            .context(super::UnableToGetSchemaSnafu)?;
137
138        let reader = bytes
139            .collect()
140            .await
141            .boxed()
142            .and_then(|bytes| StreamReader::try_new(Cursor::new(bytes), None).boxed())
143            .context(super::UnableToGetSchemaSnafu)?;
144
145        return Ok(reader.schema());
146    }
147
148    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
149    ///
150    /// # Arguments
151    ///
152    /// * `sql` - The SQL statement.
153    /// * `params` - The parameters for the SQL statement.
154    /// * `projected_schema` - The Projected schema for the query.
155    ///
156    /// # Errors
157    ///
158    /// Returns an error if the query fails.
159    async fn query_arrow(
160        &self,
161        sql: &str,
162        _params: &[()],
163        projected_schema: Option<SchemaRef>,
164    ) -> super::Result<SendableRecordBatchStream> {
165        let query = self.query(sql);
166
167        let mut bytes_stream = query
168            .fetch_bytes("ArrowStream")
169            .boxed()
170            .context(super::UnableToQueryArrowSnafu)?;
171
172        let mut first_batch: Option<RecordBatch> = None;
173        let mut decoder = StreamDecoder::new();
174
175        // fetch till first set of records
176        while let Some(buf) = bytes_stream.next().await? {
177            if let Some(batch) = decoder.decode(&mut buf.into())? {
178                first_batch = Some(batch);
179                break;
180            }
181        }
182
183        if let Some(first_batch) = first_batch {
184            let schema = first_batch.schema();
185            let stream = async_stream::stream! {
186                yield Ok(first_batch);
187                while let Some(buf) = bytes_stream
188                    .next()
189                    .await
190                    .map_err(|er| arrow::error::ArrowError::ExternalError(Box::new(er)))?
191                {
192                    if let Some(batch) = decoder.decode(&mut buf.into())? {
193                        yield Ok(batch);
194                    }
195                }
196            };
197            Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
198        } else if let Some(schema) = projected_schema {
199            Ok(Box::pin(RecordBatchStreamAdapter::new(
200                schema.clone(),
201                EmptyRecordBatchStream::new(schema),
202            )))
203        } else {
204            let schema: Arc<Schema> = Schema::empty().into();
205            Ok(Box::pin(RecordBatchStreamAdapter::new(
206                schema.clone(),
207                EmptyRecordBatchStream::new(schema),
208            )))
209        }
210    }
211
212    /// Execute the given SQL statement with parameters, returning the number of affected rows.
213    ///
214    /// # Arguments
215    ///
216    /// * `sql` - The SQL statement.
217    /// * `params` - The parameters for the SQL statement.
218    async fn execute(&self, sql: &str, params: &[()]) -> super::Result<u64> {
219        let mut query = self.query(sql);
220
221        for param in params {
222            query = query.bind(param);
223        }
224
225        query
226            .execute()
227            .await
228            .boxed()
229            .context(super::UnableToQueryArrowSnafu)?;
230
231        Ok(0)
232    }
233}
234
235pub fn replace_clickhouse_ddl_parameters(ddl_query: &str) -> String {
236    // Regex to find parameters in the format {parameter_name:DataType}
237    let param_pattern = Regex::new(r"\{(\w+?):(\w+?)\}").unwrap();
238
239    let modified_query = param_pattern.replace_all(ddl_query, |caps: &regex::Captures| {
240        // match against the datatype
241        let data_type = caps.get(2).map_or("", |m| m.as_str());
242        match data_type.to_lowercase().as_str() {
243            "string" => "''".to_string(),
244            "uint8" | "uint16" | "uint32" | "uint64" | "int8" | "int16" | "int32" | "int64" => {
245                "0".to_string()
246            }
247            "float32" | "float64" => "0.0".to_string(),
248            "date" => "'2000-01-01'".to_string(),
249            "datetime" => "'2000-01-01 00:00:00'".to_string(),
250            "bool" => "false".to_string(),
251            _ => "''".to_string(),
252        }
253    });
254
255    modified_query.into_owned()
256}