use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use duckdb::arrow::array::RecordBatch as DuckRecordBatch;
use tracing::debug;
use crate::error::DuckDbError;
const DEFAULT_READ_POOL_SIZE: usize = 4;
pub struct DuckDbEngine {
write_conn: Arc<std::sync::Mutex<duckdb::Connection>>,
read_pool: Vec<Arc<std::sync::Mutex<duckdb::Connection>>>,
read_idx: AtomicUsize,
}
unsafe impl Send for DuckDbEngine {}
unsafe impl Sync for DuckDbEngine {}
impl DuckDbEngine {
pub fn in_memory() -> Result<Self, DuckDbError> {
Self::in_memory_with_pool(DEFAULT_READ_POOL_SIZE)
}
pub fn in_memory_with_pool(read_pool_size: usize) -> Result<Self, DuckDbError> {
let write_conn = duckdb::Connection::open_in_memory()?;
Self::from_connection(write_conn, read_pool_size.max(1))
}
pub fn persistent(path: &str) -> Result<Self, DuckDbError> {
Self::persistent_with_pool(path, DEFAULT_READ_POOL_SIZE)
}
pub fn persistent_with_pool(path: &str, read_pool_size: usize) -> Result<Self, DuckDbError> {
let write_conn = duckdb::Connection::open(path)?;
Self::from_connection(write_conn, read_pool_size.max(1))
}
fn from_connection(
write_conn: duckdb::Connection,
read_pool_size: usize,
) -> Result<Self, DuckDbError> {
let mut read_pool = Vec::with_capacity(read_pool_size);
for _ in 0..read_pool_size {
let reader = write_conn.try_clone()?;
read_pool.push(Arc::new(std::sync::Mutex::new(reader)));
}
Ok(Self {
write_conn: Arc::new(std::sync::Mutex::new(write_conn)),
read_pool,
read_idx: AtomicUsize::new(0),
})
}
fn next_reader(&self) -> Arc<std::sync::Mutex<duckdb::Connection>> {
let idx = self.read_idx.fetch_add(1, Ordering::Relaxed) % self.read_pool.len();
Arc::clone(&self.read_pool[idx])
}
pub fn read_pool_size(&self) -> usize {
self.read_pool.len()
}
}
fn arrow_type_to_duckdb_sql(dt: &DataType) -> &'static str {
match dt {
DataType::Boolean => "BOOLEAN",
DataType::Int8 | DataType::UInt8 => "TINYINT",
DataType::Int16 | DataType::UInt16 => "SMALLINT",
DataType::Int32 | DataType::UInt32 => "INTEGER",
DataType::Int64 | DataType::UInt64 => "BIGINT",
DataType::Float16 | DataType::Float32 => "FLOAT",
DataType::Float64 => "DOUBLE",
DataType::Utf8 | DataType::LargeUtf8 => "VARCHAR",
DataType::Binary | DataType::LargeBinary => "BLOB",
DataType::Date32 | DataType::Date64 => "DATE",
DataType::Timestamp(_, _) => "TIMESTAMP",
_ => "VARCHAR", }
}
fn convert_duck_batch(b: DuckRecordBatch) -> Result<RecordBatch, DuckDbError> {
let schema = Arc::new(arrow::datatypes::Schema::new(
b.schema()
.fields()
.iter()
.map(|f| arrow::datatypes::Field::new(f.name(), f.data_type().clone(), f.is_nullable()))
.collect::<Vec<_>>(),
));
RecordBatch::try_new(schema, b.columns().to_vec()).map_err(DuckDbError::Arrow)
}
fn convert_to_duck_batch(b: &RecordBatch) -> Result<DuckRecordBatch, DuckDbError> {
let duck_schema = Arc::new(duckdb::arrow::datatypes::Schema::new(
b.schema()
.fields()
.iter()
.map(|f| {
duckdb::arrow::datatypes::Field::new(
f.name(),
f.data_type().clone(),
f.is_nullable(),
)
})
.collect::<Vec<_>>(),
));
DuckRecordBatch::try_new(duck_schema, b.columns().to_vec()).map_err(DuckDbError::Arrow)
}
impl rhei_core::OlapEngine for DuckDbEngine {
type Error = DuckDbError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
debug!(sql, "DuckDB query (reader)");
let conn = self.next_reader();
let sql = sql.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
let mut stmt = conn.prepare(&sql).map_err(DuckDbError::DuckDb)?;
let arrow_result = stmt.query_arrow([]).map_err(DuckDbError::DuckDb)?;
let duck_batches: Vec<DuckRecordBatch> = arrow_result.collect();
duck_batches
.into_iter()
.map(convert_duck_batch)
.collect::<Result<Vec<_>, _>>()
})
.await
.map_err(DuckDbError::from_join)?
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
debug!(sql, "DuckDB execute (writer)");
let conn = Arc::clone(&self.write_conn);
let sql = sql.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
let rows = conn.execute(&sql, []).map_err(DuckDbError::DuckDb)?;
Ok(rows as u64)
})
.await
.map_err(DuckDbError::from_join)?
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
if batches.is_empty() {
return Ok(0);
}
debug!(
table,
batch_count = batches.len(),
"DuckDB load_arrow (writer, appender)"
);
rhei_core::validate_identifier(table)?;
let conn = Arc::clone(&self.write_conn);
let table = table.to_string();
let batches = batches.to_vec();
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
let mut appender = conn.appender(&table).map_err(DuckDbError::DuckDb)?;
let mut total_rows: u64 = 0;
for batch in &batches {
if batch.num_rows() == 0 {
continue;
}
let duck_batch = convert_to_duck_batch(batch)?;
appender
.append_record_batch(duck_batch)
.map_err(DuckDbError::DuckDb)?;
total_rows += batch.num_rows() as u64;
}
appender.flush().map_err(DuckDbError::DuckDb)?;
Ok(total_rows)
})
.await
.map_err(DuckDbError::from_join)?
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
primary_key: &[String],
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)?;
for field in schema.fields() {
rhei_core::validate_identifier(field.name())?;
}
for pk_col in primary_key {
rhei_core::validate_identifier(pk_col)?;
}
let mut columns: Vec<String> = schema
.fields()
.iter()
.map(|f| {
let nullable = if f.is_nullable() { "" } else { " NOT NULL" };
format!(
"{} {}{}",
f.name(),
arrow_type_to_duckdb_sql(f.data_type()),
nullable
)
})
.collect();
if !primary_key.is_empty() {
columns.push(format!("PRIMARY KEY ({})", primary_key.join(", ")));
}
let ddl = format!(
"CREATE TABLE IF NOT EXISTS {} ({})",
table_name,
columns.join(", ")
);
debug!(ddl = ddl.as_str(), "DuckDB create_table (writer)");
let conn = Arc::clone(&self.write_conn);
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
Ok(())
})
.await
.map_err(DuckDbError::from_join)?
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
let conn = self.next_reader();
let table_name = table_name.to_string();
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
let mut stmt = conn
.prepare("SELECT count(*) FROM information_schema.tables WHERE table_name = ?")
.map_err(DuckDbError::DuckDb)?;
let mut rows = stmt
.query_arrow(duckdb::params![table_name])
.map_err(DuckDbError::DuckDb)?;
if let Some(batch) = rows.next() {
if batch.num_rows() > 0 {
let col = batch
.column(0)
.as_any()
.downcast_ref::<duckdb::arrow::array::Int64Array>();
if let Some(arr) = col {
return Ok(arr.value(0) > 0);
}
}
}
Ok(false)
})
.await
.map_err(DuckDbError::from_join)?
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)?;
rhei_core::validate_identifier(column_name)?;
let duckdb_type = arrow_type_to_duckdb_sql(data_type);
let ddl = format!(
"ALTER TABLE {} ADD COLUMN {} {}",
table_name, column_name, duckdb_type
);
debug!(ddl = ddl.as_str(), "DuckDB add_column (writer)");
let conn = Arc::clone(&self.write_conn);
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
Ok(())
})
.await
.map_err(DuckDbError::from_join)?
}
fn supports_transactions(&self) -> bool {
true
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
rhei_core::validate_identifier(table_name)?;
rhei_core::validate_identifier(column_name)?;
let ddl = format!("ALTER TABLE {} DROP COLUMN {}", table_name, column_name);
debug!(ddl = ddl.as_str(), "DuckDB drop_column (writer)");
let conn = Arc::clone(&self.write_conn);
tokio::task::spawn_blocking(move || {
let conn = conn.lock().unwrap();
conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
Ok(())
})
.await
.map_err(DuckDbError::from_join)?
}
}
#[derive(Clone)]
pub struct SharedDuckDbEngine(pub Arc<DuckDbEngine>);
impl SharedDuckDbEngine {
pub fn new(engine: DuckDbEngine) -> Self {
Self(Arc::new(engine))
}
}
impl Deref for SharedDuckDbEngine {
type Target = DuckDbEngine;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl rhei_core::OlapEngine for SharedDuckDbEngine {
type Error = DuckDbError;
async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
self.0.query(sql).await
}
async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
self.0.execute(sql).await
}
async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
self.0.load_arrow(table, batches).await
}
async fn create_table(
&self,
table_name: &str,
schema: &SchemaRef,
primary_key: &[String],
) -> Result<(), Self::Error> {
self.0.create_table(table_name, schema, primary_key).await
}
async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
self.0.table_exists(table_name).await
}
async fn add_column(
&self,
table_name: &str,
column_name: &str,
data_type: &DataType,
) -> Result<(), Self::Error> {
self.0.add_column(table_name, column_name, data_type).await
}
async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
self.0.drop_column(table_name, column_name).await
}
fn supports_transactions(&self) -> bool {
self.0.supports_transactions()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{Field, Schema};
use rhei_core::OlapEngine;
#[tokio::test]
async fn test_in_memory_basic() {
let engine = DuckDbEngine::in_memory().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]));
engine
.create_table("test_table", &schema, &[])
.await
.unwrap();
assert!(engine.table_exists("test_table").await.unwrap());
assert!(!engine.table_exists("nonexistent").await.unwrap());
}
#[tokio::test]
async fn test_read_pool_round_robin() {
let engine = DuckDbEngine::in_memory_with_pool(2).unwrap();
assert_eq!(engine.read_pool_size(), 2);
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
engine.create_table("t", &schema, &[]).await.unwrap();
engine.execute("INSERT INTO t VALUES (1)").await.unwrap();
for _ in 0..4 {
let batches = engine.query("SELECT * FROM t").await.unwrap();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].num_rows(), 1);
}
}
#[tokio::test]
async fn test_shared_engine() {
let engine = SharedDuckDbEngine::new(DuckDbEngine::in_memory().unwrap());
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
engine
.create_table("shared_test", &schema, &[])
.await
.unwrap();
engine
.execute("INSERT INTO shared_test VALUES (42)")
.await
.unwrap();
let batches = engine.query("SELECT * FROM shared_test").await.unwrap();
assert_eq!(batches[0].num_rows(), 1);
}
#[tokio::test]
async fn test_pool_size_clamped_to_one() {
let engine = DuckDbEngine::in_memory_with_pool(0).unwrap();
assert_eq!(engine.read_pool_size(), 1);
}
#[tokio::test]
async fn test_load_arrow_basic_types() {
let engine = DuckDbEngine::in_memory().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
Field::new("score", DataType::Float64, true),
Field::new("active", DataType::Boolean, true),
]));
engine
.create_table("load_test", &schema, &[])
.await
.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
Arc::new(arrow::array::StringArray::from(vec![
Some("alice"),
None,
Some("charlie"),
])),
Arc::new(arrow::array::Float64Array::from(vec![
Some(9.5),
Some(8.0),
None,
])),
Arc::new(arrow::array::BooleanArray::from(vec![
Some(true),
Some(false),
None,
])),
],
)
.unwrap();
let rows = engine.load_arrow("load_test", &[batch]).await.unwrap();
assert_eq!(rows, 3);
let result = engine
.query("SELECT * FROM load_test ORDER BY id")
.await
.unwrap();
assert_eq!(result[0].num_rows(), 3);
}
#[tokio::test]
async fn test_load_arrow_date_and_timestamp() {
use arrow::datatypes::TimeUnit;
let engine = DuckDbEngine::in_memory().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("created_date", DataType::Date32, true),
Field::new(
"created_ts",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
),
]));
engine
.create_table("dates_test", &schema, &[])
.await
.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
Arc::new(arrow::array::Date32Array::from(vec![Some(19737), Some(0)])),
Arc::new(arrow::array::TimestampMicrosecondArray::from(vec![
Some(1_705_276_800_000_000), None,
])),
],
)
.unwrap();
let rows = engine.load_arrow("dates_test", &[batch]).await.unwrap();
assert_eq!(rows, 2);
let result = engine
.query("SELECT * FROM dates_test ORDER BY id")
.await
.unwrap();
assert_eq!(result[0].num_rows(), 2);
}
#[tokio::test]
async fn test_create_table_with_composite_pk_enforced() {
let engine = DuckDbEngine::in_memory().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("tenant_id", DataType::Int64, false),
Field::new("order_id", DataType::Int64, false),
Field::new("amount", DataType::Float64, true),
]));
let pk = vec!["tenant_id".to_string(), "order_id".to_string()];
engine
.create_table("orders_pk_test", &schema, &pk)
.await
.unwrap();
engine
.execute("INSERT INTO orders_pk_test VALUES (1, 100, 9.99)")
.await
.unwrap();
let err = engine
.execute("INSERT INTO orders_pk_test VALUES (1, 100, 19.99)")
.await
.unwrap_err();
let msg = err.to_string().to_ascii_lowercase();
assert!(
msg.contains("constraint") || msg.contains("primary key") || msg.contains("unique"),
"expected a PK constraint error, got: {err}"
);
engine
.execute("INSERT INTO orders_pk_test VALUES (1, 101, 5.00)")
.await
.unwrap();
engine
.execute("INSERT INTO orders_pk_test VALUES (2, 100, 7.50)")
.await
.unwrap();
let batches = engine
.query("SELECT COUNT(*) FROM orders_pk_test")
.await
.unwrap();
let count = batches[0]
.column(0)
.as_any()
.downcast_ref::<duckdb::arrow::array::Int64Array>()
.unwrap()
.value(0);
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_load_arrow_binary() {
let engine = DuckDbEngine::in_memory().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("data", DataType::Binary, true),
]));
engine
.create_table("binary_test", &schema, &[])
.await
.unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
Arc::new(arrow::array::BinaryArray::from(vec![
Some(b"\x00\x01\x02\xff" as &[u8]),
None,
])),
],
)
.unwrap();
let rows = engine.load_arrow("binary_test", &[batch]).await.unwrap();
assert_eq!(rows, 2);
let result = engine
.query("SELECT * FROM binary_test ORDER BY id")
.await
.unwrap();
assert_eq!(result[0].num_rows(), 2);
}
}