use sqlparser::ast as sqlast;
use std::{borrow::Cow, fmt};
use clickhouse_rs::{types::SqlType, Block, ClientHandle, Pool};
use std::{collections::HashMap, sync::Arc};
use crate::{
ast::Ident,
compile::{sql::create_table, ConnectionString},
runtime::{
error::rt_unimplemented, normalize::Normalizer, resolve_params, LazySQLParam, Result,
SQLEngine, SQLEnginePool, SQLEngineType, SQLParam,
},
types::{
arrow::{ArrowRecordBatchRelation, EmptyRelation},
try_fields_to_arrow_fields,
value::ArrowRecordBatch,
ArrowSchema, Field, Relation, Type, Value,
},
};
use super::value;
pub struct ClickHouseNormalizer();
impl ClickHouseNormalizer {
pub fn new() -> ClickHouseNormalizer {
ClickHouseNormalizer()
}
}
impl Normalizer for ClickHouseNormalizer {
fn quote_style(&self) -> Option<char> {
Some('"')
}
fn param(&self, _key: &str) -> Option<&str> {
None
}
fn preprocess<'a>(&self, stmt: &'a sqlast::Statement) -> Cow<'a, sqlast::Statement> {
match stmt {
sqlast::Statement::CreateTable { .. } => {
let mut stmt = stmt.clone();
match &mut stmt {
sqlast::Statement::CreateTable {
engine,
order_by,
temporary,
or_replace,
if_not_exists,
..
} => {
if *temporary {
if *or_replace {
*or_replace = false;
*if_not_exists = true;
}
*engine = Some("Memory".to_string());
} else if matches!(engine, None) {
*engine = Some("MergeTree".to_string());
*order_by = Some(vec![]);
}
}
_ => unreachable!(),
}
Cow::Owned(stmt)
}
_ => Cow::Borrowed(stmt),
}
}
}
pub struct ClickHouseEngine {
conn: ClientHandle,
}
impl fmt::Debug for ClickHouseEngine {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ClickhouseEngine")
}
}
#[async_trait::async_trait]
impl SQLEnginePool for ClickHouseEngine {
async fn new(url: Arc<ConnectionString>) -> Result<Box<dyn SQLEngine>> {
let mut url = url.get_url().clone();
url.set_scheme("tcp").unwrap();
let mut conn = Pool::new(url.as_str()).get_handle().await?;
conn.ping().await?;
Ok(Box::new(ClickHouseEngine { conn }))
}
async fn create(url: Arc<ConnectionString>) -> Result<()> {
let mut url = url.get_url().clone();
url.set_scheme("tcp").unwrap();
let db_name = url.path().to_string();
let db_name = db_name.strip_prefix("/").unwrap();
url.set_path("");
let mut conn = Pool::new(url.as_str()).get_handle().await?;
conn.ping().await?;
conn.execute(format!("DROP DATABASE IF EXISTS \"{}\"", db_name))
.await?;
conn.execute(format!("CREATE DATABASE \"{}\"", db_name))
.await?;
Ok(())
}
}
#[async_trait::async_trait]
impl SQLEngine for ClickHouseEngine {
async fn query(
&mut self,
query: &sqlast::Statement,
params: HashMap<Ident, LazySQLParam>,
) -> Result<Arc<dyn Relation>> {
self.check_params(&resolve_params(params).await?)?;
let query = ClickHouseNormalizer::new().normalize(query).as_result()?;
let query_string = format!("{}", query);
let result = self.conn.query(query_string).fetch_all().await?;
let mut schema = Vec::new();
let mut arrays = Vec::new();
for column in result.columns() {
let field = Field {
name: column.name().to_string().into(),
type_: (&column.sql_type()).try_into()?,
nullable: matches!(column.sql_type(), SqlType::Nullable(_)),
};
schema.push(field);
arrays.push(value::column_to_arrow(column, column.sql_type(), false)?);
}
let schema = Arc::new(ArrowSchema::new(try_fields_to_arrow_fields(&schema)?));
let record_batch = ArrowRecordBatch::try_new(schema.clone(), arrays)?;
let relation = ArrowRecordBatchRelation::new(schema, Arc::new(vec![record_batch]));
Ok(relation)
}
async fn exec(
&mut self,
stmt: &sqlast::Statement,
params: HashMap<Ident, LazySQLParam>,
) -> Result<()> {
self.check_params(&resolve_params(params).await?)?;
let stmt = ClickHouseNormalizer::new().normalize(stmt).as_result()?;
let query_string = format!("{}", stmt);
self.conn.execute(query_string).await?;
Ok(())
}
async fn load(
&mut self,
table: &sqlast::ObjectName,
value: Value,
type_: Type,
temporary: bool,
) -> Result<()> {
let fields = match type_ {
Type::List(r) => match r.as_ref() {
Type::Record(fields) => fields.clone(),
_ => {
return rt_unimplemented!("Loading non-record lists into ClickHouse");
}
},
_ => {
return rt_unimplemented!("Loading non-record lists into ClickHouse");
}
};
let create_table_stmt = create_table(table.clone(), &fields, temporary)?;
self.exec(&create_table_stmt, HashMap::new()).await?;
let relation = match value {
Value::Relation(relation) => relation,
_ => {
return rt_unimplemented!("Loading non-relation values into ClickHouse");
}
};
let table_name = format!("{}", table);
for batch_idx in 0..relation.num_batches() {
let batch = relation.batch(batch_idx).as_arrow_recordbatch();
let mut block = Block::new();
for (i, field) in fields.iter().enumerate() {
let column = batch.column(i);
block = value::arrow_to_column(block, field.name.as_str(), column.as_ref())?;
}
self.conn.insert(&table_name, block).await?;
}
Ok(())
}
async fn table_exists(&mut self, name: &sqlast::ObjectName) -> Result<bool> {
let ident = if name.0.len() == 1 {
name.0[0].get().value.clone()
} else {
return rt_unimplemented!("Multi-part table names in clickhouse: {}", name);
};
let escaped_name = ident.replace("'", "\\'");
let query = format!(
"SELECT name FROM system.tables WHERE name = '{escaped_name}' AND (database=currentDatabase() OR database='')"
);
Ok(self.conn.query(query).fetch_all().await?.row_count() > 0)
}
fn engine_type(&self) -> SQLEngineType {
SQLEngineType::ClickHouse
}
}
impl ClickHouseEngine {
fn check_params(&self, params: &HashMap<Ident, SQLParam>) -> Result<()> {
for (name, param) in params.iter() {
match ¶m.value {
Value::Relation(r) if r.as_any().downcast_ref::<EmptyRelation>().is_some() => {
continue
}
Value::Relation(_) => {
return rt_unimplemented!("Relation parameters in ClickHouse ({:?})", name,);
}
Value::Fn(_) => {
return rt_unimplemented!("Function parameters in ClickHouse ({:?})", name);
}
_ => {
return rt_unimplemented!("Scalar parameters in ClickHouse ({:?})", name);
}
}
}
Ok(())
}
}