use std::future::Future;
use std::sync::OnceLock;
use anyhow::{anyhow, Context, Result};
use clickhouse::{Client, Row};
use serde::Serialize;
use tracing::warn;
use crate::backends::base::{BackendDialect, ColSpec};
use crate::backends::clickhouse::ClickhouseBackend;
use crate::chunker::Chunk;
use crate::config::ClickhouseTargetConfig;
use crate::sinks::base::Sink;
static DELETE_ORPHANS_WARNED: OnceLock<()> = OnceLock::new();
static APPEND_WITHOUT_REPLACING_WARNED: OnceLock<()> = OnceLock::new();
const ORPHAN_WARN_MSG: &str =
"target.delete_orphans=true on ClickHouse is a no-op — CH mutations are async \
background ops that don't fit chunkshop's per-document atomic write contract. \
Use ReplacingMergeTree(created_at) for lazy dedup or run manual ALTER TABLE … DELETE WHERE.";
const APPEND_WITHOUT_REPLACING_MSG: &str =
"ClickHouse mode='append' on the default MergeTree engine accumulates duplicate \
rows when the same (doc_id, seq_num) is re-ingested. ClickHouse has no per-row \
UPSERT. Set target.engine: 'ReplacingMergeTree(created_at) ORDER BY (id)' for \
lazy dedup at merge time (run OPTIMIZE TABLE … FINAL to force a merge), or use \
mode='overwrite' for fresh ingests.";
pub struct ClickhouseSink {
cfg: ClickhouseTargetConfig,
backend: ClickhouseBackend,
embed_dim: usize,
}
impl ClickhouseSink {
pub fn new(cfg: ClickhouseTargetConfig, backend: ClickhouseBackend, embed_dim: usize) -> Self {
if cfg.delete_orphans {
DELETE_ORPHANS_WARNED.get_or_init(|| {
warn!("{ORPHAN_WARN_MSG}");
});
}
if cfg.mode == "append"
&& cfg
.engine
.as_deref()
.map(|e| !e.contains("ReplacingMergeTree"))
.unwrap_or(true)
{
APPEND_WITHOUT_REPLACING_WARNED.get_or_init(|| {
warn!("{APPEND_WITHOUT_REPLACING_MSG}");
});
}
Self {
cfg,
backend,
embed_dim,
}
}
fn fq(&self) -> String {
self.backend
.fq_table(&self.cfg.database_name, &self.cfg.table)
}
}
impl ClickhouseSink {
async fn ensure_promote_columns(&self, client: &Client) -> Result<()> {
for pc in &self.cfg.promote_metadata {
let ch_type = pg_type_to_ch(&pc.type_);
let stmt =
self.backend
.add_column_if_not_exists_sql(&self.fq(), &pc.column_name(), &ch_type);
client
.query(&stmt)
.execute()
.await
.context("ADD COLUMN promote_metadata")?;
}
Ok(())
}
async fn create_base_ddl(&self, client: &Client) -> Result<()> {
let cols = canonical_cols(self.embed_dim);
let engine = self.cfg.engine.as_deref();
for stmt in self.backend.emit_chunks_table_ddl(
&self.fq(),
&cols,
self.cfg.hnsw,
self.embed_dim,
engine,
None,
) {
client
.query(&stmt)
.execute()
.await
.context("emit_chunks_table_ddl statement")?;
}
self.ensure_promote_columns(client).await
}
async fn overwrite_create(&self, client: &Client) -> Result<()> {
let exists = self
.backend
.table_exists(client, &self.cfg.database_name, &self.cfg.table)
.await?;
if exists && !self.cfg.force_overwrite {
#[derive(Row, serde::Deserialize)]
struct SourceRow {
source: String,
}
let q = format!(
"SELECT DISTINCT source FROM {} WHERE source != '' LIMIT 10",
self.fq()
);
let mut cur = client.query(&q).fetch::<SourceRow>()?;
let mut existing = std::collections::BTreeSet::new();
while let Some(r) = cur.next().await? {
existing.insert(r.source);
}
let my_tag = self.cfg.source_tag.clone();
let foreign: Vec<&String> = existing
.iter()
.filter(|t| my_tag.as_deref() != Some(t.as_str()))
.collect();
if !foreign.is_empty() {
return Err(anyhow!(
"overwrite refuses to drop {db}.{tbl}: table holds rows with source_tag \
values {foreign:?} that differ from this cell's source_tag {my_tag:?}. \
Set target.force_overwrite: true in YAML to bypass.",
db = self.cfg.database_name,
tbl = self.cfg.table,
foreign = foreign,
my_tag = my_tag,
));
}
}
if exists {
client
.query(&self.backend.drop_table_sql(&self.fq()))
.execute()
.await
.context("DROP TABLE")?;
}
self.create_base_ddl(client).await
}
async fn create_if_missing(&self, client: &Client) -> Result<()> {
if !self
.backend
.table_exists(client, &self.cfg.database_name, &self.cfg.table)
.await?
{
return self.create_base_ddl(client).await;
}
let stmt = self
.backend
.add_column_if_not_exists_sql(&self.fq(), "source", "String");
client
.query(&stmt)
.execute()
.await
.context("ADD COLUMN source")?;
self.ensure_promote_columns(client).await
}
async fn append_preflight(&self, client: &Client) -> Result<()> {
if !self
.backend
.table_exists(client, &self.cfg.database_name, &self.cfg.table)
.await?
{
return Err(anyhow!(
"append mode: table {}.{} does not exist. Use mode='create_if_missing' on the first cell.",
self.cfg.database_name,
self.cfg.table
));
}
let current_dim = self
.backend
.embedding_dim(client, &self.cfg.database_name, &self.cfg.table)
.await?;
match current_dim {
None => {
warn!(
"append mode on empty CH table — cannot verify embedding dim matches. \
Continuing on faith; subsequent reads with mismatched dim will produce \
garbage cosine distances."
);
}
Some(d) if d != self.embed_dim => {
return Err(anyhow!(
"append mode: target embedding dim is {d}, cell's embedder dim is {own}. \
Vectors are not comparable.",
own = self.embed_dim
));
}
_ => {}
}
let stmt = self
.backend
.add_column_if_not_exists_sql(&self.fq(), "source", "String");
client
.query(&stmt)
.execute()
.await
.context("ADD COLUMN source")?;
self.ensure_promote_columns(client).await
}
pub async fn create_table_impl(&self) -> Result<()> {
let client = self.backend.client().await?;
self.backend
.with_create_lock(&client, &self.cfg.database_name)
.await?;
client
.query(&self.backend.create_database_sql(&self.cfg.database_name))
.execute()
.await
.context("CREATE DATABASE")?;
match self.cfg.mode.as_str() {
"overwrite" => self.overwrite_create(&client).await,
"create_if_missing" => self.create_if_missing(&client).await,
"append" => self.append_preflight(&client).await,
other => Err(anyhow!("unknown target.mode: {other:?}")),
}
}
pub async fn write_document_impl(
&self,
_doc_id: &str,
chunks: &[Chunk],
embeddings: &[Vec<f32>],
tags_per_chunk: &[Vec<String>],
) -> Result<()> {
if chunks.len() != embeddings.len() {
return Err(anyhow!(
"chunks ({}) and embeddings ({}) length mismatch",
chunks.len(),
embeddings.len()
));
}
if chunks.len() != tags_per_chunk.len() {
return Err(anyhow!(
"chunks ({}) and tags_per_chunk ({}) length mismatch",
chunks.len(),
tags_per_chunk.len()
));
}
if chunks.is_empty() {
return Ok(());
}
let promote = &self.cfg.promote_metadata;
let client = self.backend.client().await?;
if promote.is_empty() {
let mut insert = client.insert_unescaped::<ChunkRow>(&self.fq()).await?;
for ((c, emb), tags) in chunks
.iter()
.zip(embeddings.iter())
.zip(tags_per_chunk.iter())
{
let row = ChunkRow {
id: format!("{}::{}", c.doc_id, c.seq_num),
doc_id: c.doc_id.clone(),
seq_num: c.seq_num as i32,
original_content: c.original_content.clone(),
embedded_content: c.embedded_content.clone(),
tags: tags.clone(),
metadata: serde_json::to_string(&c.metadata)?,
embedding: emb.clone(),
source: self.cfg.source_tag.clone().unwrap_or_default(),
};
insert.write(&row).await?;
}
insert.end().await?;
} else {
let mut col_names: Vec<String> = vec![
"id",
"doc_id",
"seq_num",
"original_content",
"embedded_content",
"tags",
"metadata",
"embedding",
"source",
]
.into_iter()
.map(|s| self.backend.quote_ident(s))
.collect();
for pc in promote {
col_names.push(self.backend.quote_ident(&pc.column_name()));
}
let cols_sql = col_names.join(", ");
for ((c, emb), tags) in chunks
.iter()
.zip(embeddings.iter())
.zip(tags_per_chunk.iter())
{
let id = format!("{}::{}", c.doc_id, c.seq_num);
let metadata = serde_json::to_string(&c.metadata)?;
let mut q_str = format!(
"INSERT INTO {} ({}) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?",
self.fq(),
cols_sql
);
for _ in promote {
q_str.push_str(", ?");
}
q_str.push(')');
let mut q = client
.query(&q_str)
.bind(id)
.bind(c.doc_id.clone())
.bind(c.seq_num as i32)
.bind(c.original_content.clone())
.bind(c.embedded_content.clone())
.bind(tags.clone())
.bind(metadata)
.bind(emb.clone())
.bind(self.cfg.source_tag.clone().unwrap_or_default());
for pc in promote {
let v = jsonb_path_get(&c.metadata, &pc.path);
let cell = match v {
Some(serde_json::Value::String(s)) => s.clone(),
Some(other) => serde_json::to_string(other).unwrap_or_default(),
None => String::new(),
};
q = q.bind(cell);
}
q.execute()
.await
.context("INSERT chunk row (promoted path)")?;
}
}
Ok(())
}
pub async fn delete_document_impl(&self, doc_id: &str) -> Result<i64> {
let client = self.backend.client().await?;
#[derive(Row, serde::Deserialize)]
struct C {
c: u64,
}
let count_n = if self.cfg.source_tag.is_some() {
let q = format!(
"SELECT count() AS c FROM {} WHERE doc_id = ? AND source = ?",
self.fq()
);
let mut cur = client
.query(&q)
.bind(doc_id)
.bind(self.cfg.source_tag.as_deref().unwrap())
.fetch::<C>()?;
let r = cur.next().await?.unwrap_or(C { c: 0 });
r.c
} else {
let q = format!("SELECT count() AS c FROM {} WHERE doc_id = ?", self.fq());
let mut cur = client.query(&q).bind(doc_id).fetch::<C>()?;
let r = cur.next().await?.unwrap_or(C { c: 0 });
r.c
};
if count_n == 0 {
return Ok(0);
}
if let Some(tag) = &self.cfg.source_tag {
let stmt = format!(
"ALTER TABLE {} DELETE WHERE doc_id = ? AND source = ?",
self.fq()
);
client
.query(&stmt)
.bind(doc_id)
.bind(tag.clone())
.execute()
.await?;
} else {
let stmt = format!("ALTER TABLE {} DELETE WHERE doc_id = ?", self.fq());
client.query(&stmt).bind(doc_id).execute().await?;
}
Ok(count_n as i64)
}
pub async fn count_docs_impl(&self) -> Result<i64> {
#[derive(Row, serde::Deserialize)]
struct C {
c: u64,
}
let client = self.backend.client().await?;
let q = format!("SELECT uniqExact(doc_id) AS c FROM {}", self.fq());
let mut cur = client.query(&q).fetch::<C>()?;
let r = cur.next().await?.unwrap_or(C { c: 0 });
Ok(r.c as i64)
}
pub async fn query_top_k_impl(
&self,
query_vec: &[f32],
k: usize,
) -> Result<Vec<(String, i32, f64)>> {
#[derive(Row, serde::Deserialize)]
struct Hit {
doc_id: String,
seq_num: i32,
dist: f64,
}
let client = self.backend.client().await?;
let vec_lit = self.backend.vector_literal(query_vec);
let q = format!(
"SELECT doc_id, seq_num, cosineDistance(embedding, {vec_lit}) AS dist \
FROM {} ORDER BY dist LIMIT ?",
self.fq()
);
let mut cur = client.query(&q).bind(k as u32).fetch::<Hit>()?;
let mut out = Vec::with_capacity(k);
while let Some(h) = cur.next().await? {
out.push((h.doc_id, h.seq_num, h.dist));
}
Ok(out)
}
}
fn canonical_cols(_dim: usize) -> Vec<ColSpec> {
vec![
ColSpec {
name: "id",
type_ddl: "String".into(),
nullable: false,
default: None,
is_primary_key: true,
},
ColSpec {
name: "doc_id",
type_ddl: "String".into(),
nullable: false,
default: None,
is_primary_key: false,
},
ColSpec {
name: "seq_num",
type_ddl: "Int32".into(),
nullable: false,
default: None,
is_primary_key: false,
},
ColSpec {
name: "original_content",
type_ddl: "String".into(),
nullable: false,
default: None,
is_primary_key: false,
},
ColSpec {
name: "embedded_content",
type_ddl: "String".into(),
nullable: false,
default: None,
is_primary_key: false,
},
ColSpec {
name: "tags",
type_ddl: "Array(String)".into(),
nullable: false,
default: Some("[]"),
is_primary_key: false,
},
ColSpec {
name: "metadata",
type_ddl: "String".into(),
nullable: false,
default: Some("'{}'"),
is_primary_key: false,
},
ColSpec {
name: "embedding",
type_ddl: "Array(Float32)".into(),
nullable: false,
default: None,
is_primary_key: false,
},
ColSpec {
name: "source",
type_ddl: "String".into(),
nullable: true,
default: None,
is_primary_key: false,
},
ColSpec {
name: "created_at",
type_ddl: "DateTime64(6)".into(),
nullable: false,
default: Some("now64()"),
is_primary_key: false,
},
]
}
fn pg_type_to_ch(pg_type: &str) -> String {
match pg_type {
"text" => "String".into(),
"text[]" => "Array(String)".into(),
"int" => "Int32".into(),
"bigint" => "Int64".into(),
"boolean" => "UInt8".into(),
"jsonb" => "String".into(),
"timestamptz" => "DateTime64(6)".into(),
"date" => "Date".into(),
other => other.to_string(),
}
}
fn jsonb_path_get<'a>(meta: &'a serde_json::Value, path: &str) -> Option<&'a serde_json::Value> {
let mut cur = meta;
for seg in path.split('.') {
cur = cur.as_object()?.get(seg)?;
}
Some(cur)
}
#[derive(Row, Serialize)]
pub(crate) struct ChunkRow {
pub id: String,
pub doc_id: String,
pub seq_num: i32,
pub original_content: String,
pub embedded_content: String,
pub tags: Vec<String>,
pub metadata: String,
pub embedding: Vec<f32>,
pub source: String,
}
impl Sink for ClickhouseSink {
fn create_table(&self) -> impl Future<Output = Result<()>> + Send {
async move { self.create_table_impl().await }
}
fn write_document(
&self,
doc_id: &str,
chunks: &[Chunk],
embeddings: &[Vec<f32>],
tags_per_chunk: &[Vec<String>],
) -> impl Future<Output = Result<()>> + Send {
async move {
self.write_document_impl(doc_id, chunks, embeddings, tags_per_chunk)
.await
}
}
fn delete_document(&self, doc_id: &str) -> impl Future<Output = Result<i64>> + Send {
async move { self.delete_document_impl(doc_id).await }
}
fn count_docs(&self) -> impl Future<Output = Result<i64>> + Send {
async move { self.count_docs_impl().await }
}
fn query_top_k(
&self,
query_vec: &[f32],
k: usize,
) -> impl Future<Output = Result<Vec<(String, i32, f64)>>> + Send {
async move { self.query_top_k_impl(query_vec, k).await }
}
}