use anyhow::{anyhow, Context, Result};
use clickhouse::Client;
use tokio::sync::OnceCell;
pub struct ClickhouseBackend {
dsn_env: String,
client: OnceCell<Client>,
}
impl ClickhouseBackend {
pub fn new(dsn_env: String) -> Self {
Self {
dsn_env,
client: OnceCell::new(),
}
}
pub async fn client(&self) -> Result<Client> {
let c = self
.client
.get_or_try_init(|| async {
let dsn = std::env::var(&self.dsn_env)
.with_context(|| format!("DSN env var {} not set", self.dsn_env))?;
build_client_from_dsn(&dsn)
})
.await?;
Ok(c.clone())
}
pub async fn connect(&self) -> Result<()> {
let _ = self.client().await?;
Ok(())
}
pub async fn table_exists(&self, client: &Client, db: &str, table: &str) -> Result<bool> {
#[derive(clickhouse::Row, serde::Deserialize)]
struct Count {
c: u64,
}
let mut cur = client
.query("SELECT count() AS c FROM system.tables WHERE database = ? AND name = ?")
.bind(db)
.bind(table)
.fetch::<Count>()?;
let row = cur
.next()
.await?
.ok_or_else(|| anyhow!("system.tables count() returned no rows"))?;
Ok(row.c > 0)
}
pub async fn embedding_dim(
&self,
client: &Client,
db: &str,
table: &str,
) -> Result<Option<usize>> {
#[derive(clickhouse::Row, serde::Deserialize)]
struct DimRow {
d: u64,
}
let fq = self.fq_table(db, table);
let q = format!("SELECT length(embedding) AS d FROM {fq} LIMIT 1");
let mut cur = match client.query(&q).fetch::<DimRow>() {
Ok(c) => c,
Err(_) => return Ok(None),
};
match cur.next().await {
Ok(Some(r)) => Ok(Some(r.d as usize)),
Ok(None) => Ok(None),
Err(_) => Ok(None),
}
}
pub async fn with_create_lock(&self, _client: &Client, _key: &str) -> Result<()> {
Ok(())
}
}
fn build_client_from_dsn(dsn: &str) -> Result<Client> {
let parsed = url::Url::parse(dsn).with_context(|| format!("parsing CH DSN {dsn:?}"))?;
let scheme = parsed.scheme();
let secure = matches!(scheme, "https" | "clickhouse+https");
if !matches!(
scheme,
"clickhouse" | "http" | "https" | "clickhouse+http" | "clickhouse+https"
) {
return Err(anyhow!(
"expected clickhouse:// or http(s):// DSN for ClickHouse, got {scheme:?}"
));
}
let host = parsed
.host_str()
.ok_or_else(|| anyhow!("DSN missing host: {dsn:?}"))?;
let port = parsed.port().unwrap_or(if secure { 8443 } else { 8123 });
let url = format!(
"{}://{}:{}",
if secure { "https" } else { "http" },
host,
port
);
let user = match parsed.username() {
"" => "default".to_string(),
u => urlencoding::decode(u)
.map(|c| c.into_owned())
.unwrap_or_else(|_| u.to_string()),
};
let password = parsed
.password()
.map(|p| {
urlencoding::decode(p)
.map(|c| c.into_owned())
.unwrap_or_else(|_| p.to_string())
})
.unwrap_or_default();
let database = match parsed.path().trim_start_matches('/') {
"" => "default".to_string(),
d => d.to_string(),
};
Ok(Client::default()
.with_url(url)
.with_user(user)
.with_password(password)
.with_database(database))
}
use crate::backends::base::{BackendDialect, ColSpec};
impl BackendDialect for ClickhouseBackend {
const NAME: &'static str = "clickhouse";
const SUPPORTS_UPSERT: bool = false;
fn quote_ident(&self, name: &str) -> String {
format!("`{}`", name.replace('`', "``"))
}
fn fq_table(&self, db: &str, table: &str) -> String {
format!("{}.{}", self.quote_ident(db), self.quote_ident(table))
}
fn vector_type_ddl(&self, _dim: usize) -> String {
"Array(Float32)".to_string()
}
fn json_type_ddl(&self) -> String {
"String".to_string()
}
fn tags_array_type_ddl(&self) -> String {
"Array(String)".to_string()
}
fn text_pk_type_ddl(&self) -> String {
"String".to_string()
}
fn timestamp_now_default_ddl(&self) -> String {
"DateTime64(6)".to_string()
}
fn vector_literal(&self, arr: &[f32]) -> String {
let parts: Vec<String> = arr.iter().map(|x| format!("{x:.6}")).collect();
format!("[{}]", parts.join(","))
}
fn json_literal(&self, obj: &serde_json::Value) -> String {
serde_json::to_string(obj).unwrap_or_else(|_| "null".to_string())
}
fn json_path_sql(&self, col_expr: &str, dotted_path: &str) -> String {
let segs: Vec<String> = dotted_path.split('.').map(|s| format!("'{s}'")).collect();
format!("JSONExtractString({col_expr}, {})", segs.join(", "))
}
fn upsert_clause(&self, _key_cols: &[&str], _update_cols: &[&str]) -> String {
String::new()
}
fn create_database_sql(&self, name: &str) -> String {
format!("CREATE DATABASE IF NOT EXISTS {}", self.quote_ident(name))
}
fn add_column_if_not_exists_sql(&self, fq: &str, col: &str, type_ddl: &str) -> String {
format!(
"ALTER TABLE {fq} ADD COLUMN IF NOT EXISTS {} {type_ddl}",
self.quote_ident(col)
)
}
fn drop_table_sql(&self, fq: &str) -> String {
format!("DROP TABLE IF EXISTS {fq} SYNC")
}
fn emit_chunks_table_ddl(
&self,
fq: &str,
cols: &[ColSpec],
hnsw: bool,
_dim: usize,
engine: Option<&str>,
_vector_metric: Option<&str>,
) -> Vec<String> {
let mut col_lines: Vec<String> = Vec::with_capacity(cols.len());
let mut order_by_cols: Vec<&str> = Vec::new();
for c in cols {
let mut line = format!(" {} {}", self.quote_ident(c.name), c.type_ddl);
if let Some(default) = c.default {
line.push_str(&format!(" DEFAULT {default}"));
}
col_lines.push(line);
if c.is_primary_key {
order_by_cols.push(c.name);
}
}
if hnsw {
col_lines.push(
" INDEX vec_idx embedding TYPE vector_similarity('hnsw', 'cosineDistance') GRANULARITY 1"
.to_string(),
);
}
let body = col_lines.join(",\n");
let engine_clause = match engine {
Some(e) => e.to_string(),
None => {
let order_by = if order_by_cols.is_empty() {
"tuple()".to_string()
} else {
order_by_cols
.iter()
.map(|c| self.quote_ident(c))
.collect::<Vec<_>>()
.join(", ")
};
format!("MergeTree() ORDER BY ({order_by})")
}
};
vec![format!(
"CREATE TABLE IF NOT EXISTS {fq} (\n{body}\n) ENGINE = {engine_clause}"
)]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dsn_parses_clickhouse_scheme_with_credentials() {
let _client =
build_client_from_dsn("clickhouse://default:chpw@localhost:8124/chunkshop_test")
.expect("parse");
}
#[test]
fn dsn_parses_http_alias() {
let _client = build_client_from_dsn("http://localhost:8123/test").expect("parse");
}
#[test]
fn dsn_rejects_unknown_scheme() {
let err = match build_client_from_dsn("postgres://x/y") {
Ok(_) => panic!("expected error for postgres scheme"),
Err(e) => e,
};
assert!(format!("{err:#}").contains("expected clickhouse://"));
}
}