#![forbid(unsafe_code)]
mod types;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
use narwhal_core::{
CancelHandle, Capabilities, Column, ColumnHeader, Connection, ConnectionConfig,
ConnectionParams, DatabaseDriver, Error, IsolationLevel, QueryResult, Result, Row as CoreRow,
RowStream, Schema, SslMode, Table, TableKind, TableSchema, Value,
};
use parking_lot::Mutex;
use tokio::sync::mpsc;
use tracing::{debug, info};
use url::Url;
use self::types::{parse_tsv_body, parse_tsv_value, value_to_sql_literal};
#[derive(Debug, Default)]
pub struct ClickhouseDriver;
impl ClickhouseDriver {
pub const NAME: &'static str = "clickhouse";
pub const fn new() -> Self {
Self
}
fn capabilities() -> Capabilities {
Capabilities::default()
.with_transactions(false)
.with_cancellation(true)
.with_multiple_schemas(true)
.with_prepared_statements(false)
.with_savepoints(false)
.with_rows_affected(false)
.with_streaming(true)
.with_row_level_dml(false)
}
}
impl DatabaseDriver for ClickhouseDriver {
fn name(&self) -> &'static str {
Self::NAME
}
fn display_name(&self) -> &'static str {
"ClickHouse"
}
fn validate(&self, config: &ConnectionConfig) -> Vec<String> {
let mut problems = Vec::new();
if config.params.host.is_none() {
problems.push("host is required".into());
}
problems
}
async fn connect(
&self,
config: &ConnectionConfig,
password: Option<&str>,
) -> Result<Box<dyn narwhal_core::DynConnection>> {
let base_url = build_base_url(&config.params)?;
let user = config
.params
.username
.as_deref()
.unwrap_or("default")
.to_owned();
let database = config
.params
.database
.as_deref()
.unwrap_or("default")
.to_owned();
let pw = password.map(String::from).unwrap_or_default();
debug!(target: "narwhal::clickhouse", %base_url, %user, %database, "connecting");
const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
let mut client_builder = reqwest::Client::builder().timeout(REQUEST_TIMEOUT);
if config.params.ssl_mode != SslMode::Disable {
let has_root_cert = if let Some(path) = &config.params.ssl_root_cert {
let bytes = std::fs::read(path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_root_cert '{}': {e}",
path.display()
))
})?;
let cert = reqwest::Certificate::from_pem(&bytes)
.map_err(|e| Error::Config(format!("failed to parse ssl_root_cert: {e}")))?;
client_builder = client_builder.add_root_certificate(cert);
true
} else {
false
};
let accept_invalid_certs = false;
let accept_invalid_hostnames = !matches!(config.params.ssl_mode, SslMode::VerifyFull);
if has_root_cert && matches!(config.params.ssl_mode, SslMode::Prefer | SslMode::Require)
{
tracing::info!(
target: "narwhal::clickhouse",
"ssl_root_cert is set with ssl_mode='{}'; \
CA certificate will be validated but hostname will not — \
use ssl_mode='verify-full' to also enforce hostname checking",
match config.params.ssl_mode {
SslMode::Prefer => "prefer",
SslMode::Require => "require",
_ => "unknown",
}
);
}
if !has_root_cert {
tracing::debug!(
target: "narwhal::clickhouse",
ssl_mode = ?config.params.ssl_mode,
"no ssl_root_cert provided; using system CA store for chain verification"
);
}
client_builder = client_builder
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
}
if config.params.ssl_cert.is_some() != config.params.ssl_key.is_some() {
return Err(Error::Config(
"ssl_cert and ssl_key must both be provided or both omitted".into(),
));
}
if let (Some(cert_path), Some(key_path)) = (&config.params.ssl_cert, &config.params.ssl_key)
{
let cert_bytes = std::fs::read(cert_path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_cert '{}': {e}",
cert_path.display()
))
})?;
let key_bytes = std::fs::read(key_path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_key '{}': {e}",
key_path.display()
))
})?;
let mut pem = Vec::with_capacity(cert_bytes.len() + key_bytes.len());
pem.extend_from_slice(&cert_bytes);
pem.extend_from_slice(&key_bytes);
let identity = reqwest::Identity::from_pem(&pem)
.map_err(|e| Error::Config(format!("failed to parse client identity PEM: {e}")))?;
client_builder = client_builder.identity(identity);
}
let client = client_builder
.build()
.map_err(|e| Error::connection_with("failed to build HTTP client", e))?;
let mut url = base_url.clone();
url.query_pairs_mut().append_pair("query", "SELECT 1");
let response = client
.post(url.as_str())
.basic_auth(&user, if pw.is_empty() { None } else { Some(&pw) })
.send()
.await
.map_err(|e| Error::connection_with("ping failed", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
return Err(Error::Connection(format!(
"ClickHouse returned {status}: {body}"
)));
}
info!(target: "narwhal::clickhouse", %base_url, "connected");
Ok(Box::new(ClickhouseConnection {
inner: Arc::new(SharedState {
client,
base_url,
user,
password: pw,
database,
active_queries: Arc::new(Mutex::new(HashSet::new())),
}),
}))
}
}
struct SharedState {
client: reqwest::Client,
base_url: Url,
user: String,
password: String,
database: String,
active_queries: Arc<Mutex<HashSet<String>>>,
}
impl SharedState {
fn build_request(&self, url: &Url, body: String) -> reqwest::RequestBuilder {
self.client
.post(url.as_str())
.basic_auth(
&self.user,
if self.password.is_empty() {
None
} else {
Some(self.password.as_str())
},
)
.body(body)
}
}
pub struct ClickhouseConnection {
inner: Arc<SharedState>,
}
fn statement_returns_rows(sql: &str) -> bool {
let lead = sql
.trim_start()
.split(|c: char| c.is_whitespace() || c == '(')
.next()
.unwrap_or("")
.to_ascii_uppercase();
matches!(
lead.as_str(),
"SELECT" | "WITH" | "SHOW" | "DESCRIBE" | "EXPLAIN" | "EXISTS"
)
}
fn build_base_url(params: &ConnectionParams) -> Result<Url> {
let host = params
.host
.as_deref()
.ok_or_else(|| Error::Config("host is required".into()))?;
let port = params.port.unwrap_or(8123);
let scheme = if params.ssl_mode == SslMode::Disable {
"http"
} else {
"https"
};
let host_part = if host.contains(':') && !host.starts_with('[') {
format!("[{host}]")
} else {
host.to_string()
};
Url::parse(&format!("{scheme}://{host_part}:{port}/"))
.map_err(|e| Error::Config(format!("invalid URL: {e}")))
}
fn quote_ident(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
impl ClickhouseConnection {
async fn http_query(&self, sql: &str, query_id: Option<&str>) -> Result<bytes::Bytes> {
let state = &self.inner;
let mut url = state.base_url.clone();
url.query_pairs_mut()
.append_pair("database", &state.database);
if let Some(qid) = query_id {
url.query_pairs_mut().append_pair("query_id", qid);
}
debug!(target: "narwhal::clickhouse", %sql, "sending HTTP query");
if let Some(qid) = query_id {
state.active_queries.lock().insert(qid.to_owned());
}
let response = match state.build_request(&url, sql.to_owned()).send().await {
Ok(r) => r,
Err(e) => {
if let Some(qid) = query_id {
state.active_queries.lock().remove(qid);
}
return Err(Error::query_with("HTTP request failed", e));
}
};
let status = response.status();
if !status.is_success() {
let body = response
.text()
.await
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
if let Some(qid) = query_id {
state.active_queries.lock().remove(qid);
}
return Err(Error::Query(format!(
"ClickHouse returned {status}: {body}"
)));
}
if let Some(qid) = query_id {
state.active_queries.lock().remove(qid);
}
response
.bytes()
.await
.map_err(|e| Error::query_with("failed to read response body", e))
}
async fn query_tsv(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
let started = Instant::now();
let query_id = Self::new_query_id();
let formatted_sql = if params.is_empty() {
sql.to_owned()
} else {
substitute_params(sql, params)
};
let full_sql = format!("{formatted_sql}\nFORMAT TabSeparatedWithNamesAndTypes");
let body = self.http_query(&full_sql, Some(&query_id)).await?;
let (headers, type_strings, rows) = parse_tsv_body(&body);
let column_headers: Vec<ColumnHeader> = headers
.into_iter()
.zip(type_strings)
.map(|(name, data_type)| ColumnHeader { name, data_type })
.collect();
let core_rows: Vec<CoreRow> = rows.into_iter().map(CoreRow).collect();
Ok(QueryResult {
columns: column_headers,
rows: core_rows,
rows_affected: None,
elapsed_ms: started.elapsed().as_millis() as u64,
})
}
async fn execute_raw(&self, sql: &str, params: &[Value]) -> Result<QueryResult> {
let started = Instant::now();
let query_id = Self::new_query_id();
let formatted_sql = if params.is_empty() {
sql.to_owned()
} else {
substitute_params(sql, params)
};
self.http_query(&formatted_sql, Some(&query_id)).await?;
Ok(QueryResult {
columns: Vec::new(),
rows: Vec::new(),
rows_affected: None,
elapsed_ms: started.elapsed().as_millis() as u64,
})
}
fn new_query_id() -> String {
uuid::Uuid::new_v4().to_string()
}
}
fn substitute_params(sql: &str, params: &[Value]) -> String {
let mut out = String::with_capacity(sql.len());
let mut quote: Option<char> = None;
let mut placeholder_idx = 0usize;
let mut chars = sql.chars().peekable();
while let Some(ch) = chars.next() {
if let Some(q) = quote {
out.push(ch);
if ch == q {
if chars.peek() == Some(&q) {
out.push(q);
chars.next();
} else {
quote = None;
}
}
continue;
}
match ch {
'\'' | '"' => {
quote = Some(ch);
out.push(ch);
}
'?' => {
if let Some(p) = params.get(placeholder_idx) {
out.push_str(&value_to_sql_literal(p));
placeholder_idx += 1;
} else {
out.push('?');
}
}
'$' => {
let mut digits = String::new();
while let Some(&next) = chars.peek() {
if next.is_ascii_digit() {
digits.push(next);
chars.next();
} else {
break;
}
}
if digits.is_empty() {
out.push('$');
continue;
}
match digits.parse::<usize>() {
Ok(n) if n >= 1 => {
if let Some(p) = params.get(n - 1) {
out.push_str(&value_to_sql_literal(p));
} else {
out.push('$');
out.push_str(&digits);
}
}
_ => {
out.push('$');
out.push_str(&digits);
}
}
}
other => out.push(other),
}
}
out
}
fn escape_sql_string(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'\'' => out.push_str("''"),
other => out.push(other),
}
}
out
}
#[doc(hidden)]
pub mod __test_only {
use narwhal_core::Value;
pub fn replace_question_marks(sql: &str, params: &[Value]) -> String {
super::replace_question_marks(sql, params)
}
pub fn substitute_params(sql: &str, params: &[Value]) -> String {
super::substitute_params(sql, params)
}
}
fn replace_question_marks(sql: &str, params: &[Value]) -> String {
let mut out = String::with_capacity(sql.len());
let mut param_iter = params.iter();
let mut quote: Option<char> = None;
let mut chars = sql.chars().peekable();
while let Some(ch) = chars.next() {
if let Some(q) = quote {
out.push(ch);
if ch == q {
if chars.peek() == Some(&q) {
out.push(q);
chars.next();
} else {
quote = None;
}
}
continue;
}
match ch {
'\'' | '"' => {
quote = Some(ch);
out.push(ch);
}
'?' => {
if let Some(p) = param_iter.next() {
out.push_str(&value_to_sql_literal(p));
} else {
out.push('?');
}
}
other => out.push(other),
}
}
out
}
impl Connection for ClickhouseConnection {
async fn execute(&mut self, sql: &str, params: &[Value]) -> Result<QueryResult> {
if statement_returns_rows(sql) {
self.query_tsv(sql, params).await
} else {
self.execute_raw(sql, params).await
}
}
async fn stream(
&mut self,
sql: &str,
params: &[Value],
) -> Result<Box<dyn narwhal_core::DynRowStream>> {
let state = &self.inner;
let formatted_sql = if params.is_empty() {
sql.to_owned()
} else {
substitute_params(sql, params)
};
let query_id = Self::new_query_id();
state.active_queries.lock().insert(query_id.clone());
let _guard = QueryGuard {
active: Arc::clone(&state.active_queries),
qid: query_id.clone(),
};
if !statement_returns_rows(&formatted_sql) {
let mut url = state.base_url.clone();
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("database", &state.database);
pairs.append_pair("query_id", &query_id);
}
let response = state
.build_request(&url, formatted_sql)
.send()
.await
.map_err(|e| Error::query_with("HTTP request failed", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
return Err(Error::Query(format!(
"ClickHouse returned {status}: {body}"
)));
}
let (tx, rx) = mpsc::channel::<Result<CoreRow>>(1);
drop(tx);
return Ok(Box::new(ClickhouseRowStream {
columns: Vec::new(),
rx,
task: None,
}));
}
let full_sql = format!("{formatted_sql}\nFORMAT TabSeparatedWithNamesAndTypes");
let mut url = state.base_url.clone();
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("database", &state.database);
pairs.append_pair("query_id", &query_id);
}
let response = state
.build_request(&url, full_sql)
.send()
.await
.map_err(|e| Error::query_with("HTTP request failed", e))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|e| format!("<failed to read response body: {e}>"));
return Err(Error::Query(format!(
"ClickHouse returned {status}: {body}"
)));
}
let (header_tx, header_rx) = tokio::sync::oneshot::channel::<Result<Vec<ColumnHeader>>>();
let (row_tx, row_rx) = mpsc::channel::<Result<CoreRow>>(64);
let guard = _guard;
let task = tokio::spawn(async move {
stream_tsv_chunks(response.bytes_stream(), header_tx, row_tx).await;
drop(guard);
});
let columns = match header_rx.await {
Ok(Ok(cols)) => cols,
Ok(Err(e)) => {
task.abort();
return Err(e);
}
Err(_) => {
task.abort();
return Err(Error::Other("clickhouse stream cancelled".into()));
}
};
Ok(Box::new(ClickhouseRowStream {
columns,
rx: row_rx,
task: Some(task),
}))
}
async fn begin(&mut self) -> Result<()> {
Err(Error::unsupported("transactions (ClickHouse)"))
}
async fn begin_with(&mut self, _isolation: IsolationLevel) -> Result<()> {
Err(Error::unsupported("transactions (ClickHouse)"))
}
async fn commit(&mut self) -> Result<()> {
Err(Error::unsupported("transactions (ClickHouse)"))
}
async fn rollback(&mut self) -> Result<()> {
Err(Error::unsupported("transactions (ClickHouse)"))
}
async fn list_schemas(&mut self) -> Result<Vec<Schema>> {
const SQL: &str = "SHOW DATABASES";
let result = self.query_tsv(SQL, &[]).await?;
let hidden = ["system", "INFORMATION_SCHEMA", "information_schema"];
let mut out = Vec::with_capacity(result.rows.len());
for row in result.rows {
if let Some(Value::String(name)) = row.0.into_iter().next() {
if !hidden.contains(&name.as_str()) {
out.push(Schema { name });
}
}
}
Ok(out)
}
async fn list_tables(&mut self, schema: &str) -> Result<Vec<Table>> {
let sql = format!(
"SELECT name, engine FROM system.tables WHERE database = '{}' ORDER BY name",
escape_sql_string(schema)
);
let result = self.query_tsv(&sql, &[]).await?;
let mut out = Vec::with_capacity(result.rows.len());
for row in result.rows {
let mut iter = row.0.into_iter();
let name = match iter.next() {
Some(Value::String(s)) => s,
_ => continue,
};
let engine = match iter.next() {
Some(Value::String(s)) => s.to_ascii_lowercase(),
_ => String::new(),
};
let kind = if engine == "view" {
TableKind::View
} else if engine == "materializedview" {
TableKind::MaterializedView
} else {
TableKind::Table
};
out.push(Table {
schema: schema.to_owned(),
name,
kind,
});
}
Ok(out)
}
async fn list_all_tables(&mut self) -> Result<Vec<(Schema, Vec<Table>)>> {
const SQL: &str = "
SELECT database, name, engine
FROM system.tables
WHERE database NOT IN ('system', 'INFORMATION_SCHEMA', 'information_schema')
ORDER BY database, name";
let result = self.query_tsv(SQL, &[]).await?;
let mut map: std::collections::BTreeMap<String, Vec<Table>> =
std::collections::BTreeMap::new();
for row in result.rows {
let mut iter = row.0.into_iter();
let schema = match iter.next() {
Some(Value::String(s)) => s,
_ => continue,
};
let name = match iter.next() {
Some(Value::String(s)) => s,
_ => continue,
};
let engine = match iter.next() {
Some(Value::String(s)) => s.to_ascii_lowercase(),
_ => String::new(),
};
let kind = if engine == "view" {
TableKind::View
} else if engine == "materializedview" {
TableKind::MaterializedView
} else {
TableKind::Table
};
map.entry(schema.clone())
.or_default()
.push(Table { schema, name, kind });
}
let schemas = self.list_schemas().await?;
let mut out = Vec::with_capacity(schemas.len());
for schema in schemas {
let tables = map.remove(&schema.name).unwrap_or_default();
out.push((schema, tables));
}
for (name, tables) in map {
out.push((Schema { name }, tables));
}
Ok(out)
}
async fn describe_table(&mut self, schema: &str, name: &str) -> Result<TableSchema> {
let escaped_schema = quote_ident(schema);
let escaped_name = quote_ident(name);
let sql = format!("DESCRIBE TABLE {escaped_schema}.{escaped_name}");
let result = self.query_tsv(&sql, &[]).await?;
if result.rows.is_empty() {
return Err(Error::Schema(format!("table {schema}.{name} not found")));
}
let kind = self.lookup_table_kind(schema, name).await?;
let columns: Vec<Column> = result
.rows
.into_iter()
.filter_map(|row| {
let mut iter = row.0.into_iter();
let col_name = match iter.next() {
Some(Value::String(s)) => s,
_ => return None,
};
let data_type = match iter.next() {
Some(Value::String(s)) => s,
_ => String::new(),
};
let _default_kind = match iter.next() {
Some(Value::String(s)) => s,
_ => String::new(),
};
let default_expr = match iter.next() {
Some(Value::String(s)) if !s.is_empty() => Some(s),
_ => None,
};
let default = default_expr;
let nullable = data_type.trim().starts_with("Nullable(");
Some(Column {
name: col_name,
data_type,
nullable,
primary_key: false,
default,
})
})
.collect();
let primary_key_columns = match self.lookup_primary_key(schema, name).await {
Ok(v) => v,
Err(error) => {
tracing::warn!(
target: "narwhal::clickhouse",
schema, table = name, error = %error,
"primary-key lookup failed; continuing without"
);
Vec::new()
}
};
let pk_set: std::collections::HashSet<String> = primary_key_columns.into_iter().collect();
let columns: Vec<Column> = columns
.into_iter()
.map(|mut c| {
c.primary_key = pk_set.contains(&c.name);
c
})
.collect();
Ok(TableSchema {
table: Table {
schema: schema.to_owned(),
name: name.to_owned(),
kind,
},
columns,
indexes: Vec::new(),
foreign_keys: Vec::new(),
unique_constraints: Vec::new(),
})
}
async fn fetch_ddl(&mut self, schema: &str, name: &str) -> Result<String> {
let escaped_schema = quote_ident(schema);
let escaped_name = quote_ident(name);
let sql = format!("SHOW CREATE TABLE {escaped_schema}.{escaped_name}");
let result = self.query_tsv(&sql, &[]).await?;
match result.rows.into_iter().next() {
Some(row) => {
let ddl = row.0.last().cloned().or_else(|| row.0.into_iter().nth(1));
match ddl {
Some(Value::String(s)) => Ok(s),
_ => Err(Error::Schema(format!(
"DDL not found for table {schema}.{name}"
))),
}
}
None => Err(Error::Schema(format!(
"DDL not found for table {schema}.{name}"
))),
}
}
async fn ping(&mut self) -> Result<()> {
self.http_query("SELECT 1", None).await.map(|_| ())
}
async fn set_read_only(&mut self, read_only: bool) -> Result<()> {
let sql = if read_only {
"SET readonly = 2"
} else {
"SET readonly = 0"
};
self.http_query(sql, None).await.map(|_| ())
}
fn cancel_handle(&self) -> Option<Box<dyn narwhal_core::DynCancelHandle>> {
Some(Box::new(ClickhouseCancel {
state: Arc::clone(&self.inner),
}))
}
fn capabilities(&self) -> Capabilities {
ClickhouseDriver::capabilities()
}
async fn close(self: Box<Self>) -> Result<()> {
Ok(())
}
}
impl ClickhouseConnection {
async fn lookup_table_kind(&self, schema: &str, name: &str) -> Result<TableKind> {
let sql = format!(
"SELECT engine FROM system.tables WHERE database = '{}' AND name = '{}'",
escape_sql_string(schema),
escape_sql_string(name)
);
let result = self.query_tsv(&sql, &[]).await?;
match result.rows.into_iter().next() {
Some(row) => match row.0.into_iter().next() {
Some(Value::String(engine)) => {
let engine = engine.to_ascii_lowercase();
Ok(if engine == "view" {
TableKind::View
} else if engine == "materializedview" {
TableKind::MaterializedView
} else {
TableKind::Table
})
}
_ => Ok(TableKind::Table),
},
None => Ok(TableKind::Table),
}
}
async fn lookup_primary_key(&mut self, schema: &str, name: &str) -> Result<Vec<String>> {
let sql = format!(
"SELECT primary_key FROM system.tables WHERE database = '{}' AND name = '{}'",
escape_sql_string(schema),
escape_sql_string(name)
);
let result = self.query_tsv(&sql, &[]).await?;
match result.rows.into_iter().next() {
Some(row) => match row.0.into_iter().next() {
Some(Value::String(pk)) if !pk.is_empty() => {
Ok(pk.split(',').map(|s| s.trim().to_owned()).collect())
}
_ => Ok(Vec::new()),
},
None => Ok(Vec::new()),
}
}
}
struct QueryGuard {
active: Arc<Mutex<HashSet<String>>>,
qid: String,
}
impl Drop for QueryGuard {
fn drop(&mut self) {
let active = self.active.clone();
let qid = std::mem::take(&mut self.qid);
tokio::spawn(async move {
active.lock().remove(&qid);
});
}
}
struct ClickhouseCancel {
state: Arc<SharedState>,
}
impl CancelHandle for ClickhouseCancel {
async fn cancel(&self) -> Result<()> {
let query_ids: Vec<String> = self.state.active_queries.lock().iter().cloned().collect();
if query_ids.is_empty() {
return Ok(());
}
let ids: Vec<String> = query_ids.iter().map(|id| format!("'{id}'")).collect();
let kill_sql = format!("KILL QUERY WHERE query_id IN ({})", ids.join(", "));
debug!(target: "narwhal::clickhouse", %kill_sql, "cancelling queries");
let state = &self.state;
let mut url = state.base_url.clone();
url.query_pairs_mut()
.append_pair("database", &state.database);
let result = state.build_request(&url, kill_sql).send().await;
match result {
Ok(response) => {
if !response.status().is_success() {
debug!(
target: "narwhal::clickhouse",
status = %response.status(),
"KILL QUERY returned non-success (best-effort, ignoring)"
);
}
}
Err(e) => {
debug!(
target: "narwhal::clickhouse",
error = %e,
"KILL QUERY request failed (best-effort, ignoring)"
);
}
}
Ok(())
}
}
struct ClickhouseRowStream {
columns: Vec<ColumnHeader>,
rx: mpsc::Receiver<Result<CoreRow>>,
task: Option<tokio::task::JoinHandle<()>>,
}
impl RowStream for ClickhouseRowStream {
fn columns(&self) -> &[ColumnHeader] {
&self.columns
}
async fn next_row(&mut self) -> Result<Option<CoreRow>> {
match self.rx.recv().await {
Some(Ok(row)) => Ok(Some(row)),
Some(Err(error)) => Err(error),
None => Ok(None),
}
}
async fn close(mut self: Box<Self>) -> Result<()> {
if let Some(handle) = self.task.take() {
handle.abort();
}
Ok(())
}
}
impl Drop for ClickhouseRowStream {
fn drop(&mut self) {
if let Some(handle) = self.task.take() {
handle.abort();
}
}
}
async fn stream_tsv_chunks<S>(
stream: S,
header_tx: tokio::sync::oneshot::Sender<Result<Vec<ColumnHeader>>>,
row_tx: mpsc::Sender<Result<CoreRow>>,
) where
S: futures_util::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>> + Unpin,
{
use futures_util::StreamExt;
let mut stream = stream;
let mut buf: Vec<u8> = Vec::new();
let mut header_lines: Vec<String> = Vec::new();
while header_lines.len() < 2 {
match stream.next().await {
Some(Ok(chunk)) => {
buf.extend_from_slice(&chunk);
while header_lines.len() < 2 {
let Some(pos) = buf.iter().position(|&b| b == b'\n') else {
break;
};
let line_bytes: Vec<u8> = buf.drain(..=pos).collect();
if std::str::from_utf8(&line_bytes).is_err() {
tracing::warn!(
target: "narwhal::clickhouse",
"header line contained invalid UTF-8; lossy conversion applied"
);
}
let line = String::from_utf8_lossy(&line_bytes);
let line = line.trim_end_matches('\n').trim_end_matches('\r');
header_lines.push(line.to_owned());
}
}
Some(Err(e)) => {
let _ = header_tx.send(Err(Error::query_with("stream error", e)));
return;
}
None => {
let _ = header_tx.send(Err(Error::Query(
"clickhouse stream ended before headers were complete".into(),
)));
return;
}
}
}
let header_line = header_lines[0].as_str();
let type_line = header_lines[1].as_str();
let headers: Vec<String> = header_line.split('\t').map(String::from).collect();
let type_strings: Vec<String> = type_line.split('\t').map(String::from).collect();
let column_headers: Vec<ColumnHeader> = headers
.iter()
.zip(type_strings.iter())
.map(|(name, data_type)| ColumnHeader {
name: name.clone(),
data_type: data_type.clone(),
})
.collect();
if header_tx.send(Ok(column_headers)).is_err() {
return;
}
loop {
while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
let mut line_bytes: Vec<u8> = buf.drain(..=pos).collect();
if line_bytes.last() == Some(&b'\n') {
line_bytes.pop();
}
if line_bytes.last() == Some(&b'\r') {
line_bytes.pop();
}
if line_bytes.is_empty() {
continue;
}
let fields: Vec<&[u8]> = line_bytes.split(|&b| b == b'\t').collect();
let mut row = Vec::with_capacity(headers.len());
for (i, field) in fields.iter().enumerate() {
let ch_type = type_strings.get(i).map_or("String", String::as_str);
row.push(parse_tsv_value(field, ch_type));
}
while row.len() < headers.len() {
row.push(Value::Null);
}
if row_tx.send(Ok(CoreRow(row))).await.is_err() {
return;
}
}
match stream.next().await {
Some(Ok(chunk)) => {
buf.extend_from_slice(&chunk);
}
Some(Err(e)) => {
let _ = row_tx.send(Err(Error::query_with("stream error", e))).await;
return;
}
None => {
if !buf.is_empty() {
if buf.last() == Some(&b'\r') {
buf.pop();
}
if !buf.is_empty() {
let _ = row_tx
.send(Err(Error::Query(
"clickhouse stream truncated mid-row (no trailing newline)".into(),
)))
.await;
}
}
return;
}
}
}
}
#[cfg(test)]
mod stream_tests {
use super::*;
use bytes::Bytes;
use futures_util::stream;
#[tokio::test]
async fn chunked_tsv_decodes_rows() {
let payload: &[u8] = b"id\tname\nUInt32\tString\n1\talice\n2\tbob\n";
let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> = vec![
Ok(Bytes::copy_from_slice(&payload[..8])),
Ok(Bytes::copy_from_slice(&payload[8..20])),
Ok(Bytes::copy_from_slice(&payload[20..])),
];
let byte_stream = stream::iter(chunks);
let (header_tx, header_rx) = tokio::sync::oneshot::channel::<Result<Vec<ColumnHeader>>>();
let (row_tx, mut row_rx) = mpsc::channel::<Result<CoreRow>>(64);
stream_tsv_chunks(byte_stream, header_tx, row_tx).await;
let columns = header_rx.await.expect("header rx").expect("headers");
assert_eq!(columns.len(), 2);
assert_eq!(columns[0].name, "id");
assert_eq!(columns[0].data_type, "UInt32");
assert_eq!(columns[1].name, "name");
assert_eq!(columns[1].data_type, "String");
let mut rows = Vec::new();
while let Some(result) = row_rx.recv().await {
let row = result.expect("row");
rows.push(row);
}
assert_eq!(rows.len(), 2);
assert!(matches!(rows[0].0.first(), Some(Value::Int(1))));
assert!(matches!(rows[0].0.get(1), Some(Value::String(_))));
assert!(matches!(rows[1].0.first(), Some(Value::Int(2))));
assert!(matches!(rows[1].0.get(1), Some(Value::String(_))));
}
#[tokio::test]
async fn chunked_tsv_preserves_binary_string() {
let mut payload: Vec<u8> = b"col\nString\n".to_vec();
payload.extend_from_slice(&[0xFF, 0xFE, 0x00, 0x01]);
payload.push(b'\n');
let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
vec![Ok(Bytes::copy_from_slice(&payload))];
let byte_stream = stream::iter(chunks);
let (header_tx, header_rx) = tokio::sync::oneshot::channel::<Result<Vec<ColumnHeader>>>();
let (row_tx, mut row_rx) = mpsc::channel::<Result<CoreRow>>(64);
stream_tsv_chunks(byte_stream, header_tx, row_tx).await;
let columns = header_rx.await.expect("header rx").expect("headers");
assert_eq!(columns.len(), 1);
assert_eq!(columns[0].name, "col");
assert_eq!(columns[0].data_type, "String");
let row = row_rx.recv().await.expect("row rx").expect("row");
match row.0.first() {
Some(Value::Bytes(b)) => assert_eq!(b, &vec![0xFF, 0xFE, 0x00, 0x01]),
other => panic!("expected Value::Bytes, got {other:?}"),
}
}
#[tokio::test]
async fn chunked_tsv_truncated_mid_row_errors() {
let payload: &[u8] = b"id\tname\nUInt32\tString\n1\tali";
let chunks: Vec<std::result::Result<Bytes, reqwest::Error>> =
vec![Ok(Bytes::copy_from_slice(payload))];
let byte_stream = stream::iter(chunks);
let (header_tx, header_rx) = tokio::sync::oneshot::channel::<Result<Vec<ColumnHeader>>>();
let (row_tx, mut row_rx) = mpsc::channel::<Result<CoreRow>>(64);
stream_tsv_chunks(byte_stream, header_tx, row_tx).await;
let _columns = header_rx.await.expect("header rx").expect("headers");
match row_rx.recv().await {
Some(Err(Error::Query(msg))) => {
assert!(
msg.contains("truncated"),
"expected truncation error, got: {msg}"
);
}
other => panic!("expected Err(Error::Query(truncated)), got {other:?}"),
}
}
}
#[cfg(test)]
mod cancel_tests {
use super::*;
#[tokio::test]
async fn tracks_active_query_id() {
let active: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
active.lock().insert("test-qid-1".to_owned());
assert!(active.lock().contains("test-qid-1"));
active.lock().remove("test-qid-1");
assert!(!active.lock().contains("test-qid-1"));
assert!(active.lock().is_empty());
}
#[tokio::test]
async fn cancel_reads_cloned_not_drained() {
let active: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
active.lock().insert("qid-1".to_owned());
active.lock().insert("qid-2".to_owned());
let client = reqwest::Client::new();
let base_url = Url::parse("http://127.0.0.1:1/").expect("url");
let state = Arc::new(SharedState {
client,
base_url,
user: "default".to_owned(),
password: String::new(),
database: "default".to_owned(),
active_queries: active.clone(),
});
let cancel = ClickhouseCancel { state };
let _ = cancel.cancel().await;
let remaining = active.lock();
assert!(
remaining.contains("qid-1"),
"qid-1 should still be present after cancel"
);
assert!(
remaining.contains("qid-2"),
"qid-2 should still be present after cancel"
);
}
#[tokio::test]
async fn stream_error_path_clears_active_query() {
let active: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
let qid = "test-qid-guard".to_owned();
active.lock().insert(qid.clone());
assert!(active.lock().contains(&qid));
{
let guard_active = active.clone();
let guard_qid = qid.clone();
let _guard = QueryGuard {
active: guard_active,
qid: guard_qid,
};
}
tokio::task::yield_now().await;
assert!(
!active.lock().contains(&qid),
"query ID should be removed after guard drops"
);
}
#[tokio::test]
async fn cancel_with_no_active_queries_is_noop() {
let active: Arc<Mutex<HashSet<String>>> = Arc::new(Mutex::new(HashSet::new()));
let client = reqwest::Client::new();
let base_url = Url::parse("http://127.0.0.1:1/").expect("url");
let state = Arc::new(SharedState {
client,
base_url,
user: "default".to_owned(),
password: String::new(),
database: "default".to_owned(),
active_queries: active.clone(),
});
let cancel = ClickhouseCancel { state };
let result = cancel.cancel().await;
assert!(result.is_ok());
assert!(active.lock().is_empty());
}
}
#[cfg(test)]
mod build_base_url_tests {
use super::*;
use narwhal_core::SslMode;
fn params(host: &str, port: Option<u16>, ssl_mode: SslMode) -> ConnectionParams {
ConnectionParams::with(|p| {
p.host = Some(host.to_owned());
p.port = port;
p.ssl_mode = ssl_mode;
})
}
#[test]
fn ipv6_loopback() {
let url = build_base_url(¶ms("::1", None, SslMode::Disable)).unwrap();
assert_eq!(url.as_str(), "http://[::1]:8123/");
}
#[test]
fn ipv6_full_address() {
let url = build_base_url(¶ms("2001:db8::1", Some(8443), SslMode::Prefer)).unwrap();
assert_eq!(url.as_str(), "https://[2001:db8::1]:8443/");
}
#[test]
fn pre_bracketed_ipv6() {
let url = build_base_url(¶ms("[::1]", None, SslMode::Disable)).unwrap();
assert_eq!(url.as_str(), "http://[::1]:8123/");
}
#[test]
fn ipv4_unchanged() {
let url = build_base_url(¶ms("192.168.1.1", Some(8123), SslMode::Disable)).unwrap();
assert_eq!(url.as_str(), "http://192.168.1.1:8123/");
}
#[test]
fn dns_hostname_unchanged() {
let url = build_base_url(¶ms("clickhouse.prod", None, SslMode::Require)).unwrap();
assert_eq!(url.as_str(), "https://clickhouse.prod:8123/");
}
#[test]
fn missing_host_returns_config_error() {
let params = ConnectionParams::default();
let err = build_base_url(¶ms).unwrap_err();
assert!(matches!(err, Error::Config(msg) if msg.contains("host is required")));
}
}
#[cfg(test)]
mod mtls_tests {
use super::*;
use narwhal_core::SslMode;
fn make_config(ssl_cert: Option<&str>, ssl_key: Option<&str>) -> ConnectionConfig {
let params = ConnectionParams::with(|p| {
p.host = Some("localhost".to_owned());
p.port = Some(8123);
p.ssl_mode = SslMode::Prefer;
p.ssl_cert = ssl_cert.map(std::path::PathBuf::from);
p.ssl_key = ssl_key.map(std::path::PathBuf::from);
});
ConnectionConfig {
id: uuid::Uuid::new_v4(),
name: "test".into(),
driver: "clickhouse".into(),
params,
}
}
#[tokio::test]
async fn mtls_half_config_rejected() {
let driver = ClickhouseDriver::new();
let config = make_config(Some("/tmp/cert.pem"), None);
let result = driver.connect(&config, None).await;
assert!(result.is_err(), "expected error when only ssl_cert is set");
let err = result.err().unwrap();
assert!(
matches!(err, Error::Config(ref msg) if msg.contains("ssl_cert and ssl_key must both be provided or both omitted")),
"expected Config error about ssl_cert/ssl_key, got: {err:?}"
);
let config = make_config(None, Some("/tmp/key.pem"));
let result = driver.connect(&config, None).await;
assert!(result.is_err(), "expected error when only ssl_key is set");
let err = result.err().unwrap();
assert!(
matches!(err, Error::Config(ref msg) if msg.contains("ssl_cert and ssl_key must both be provided or both omitted")),
"expected Config error about ssl_cert/ssl_key, got: {err:?}"
);
}
}