1use std::ops::Deref;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::Arc;
12
13use arrow::datatypes::{DataType, SchemaRef};
14use arrow::record_batch::RecordBatch;
15use duckdb::arrow::array::RecordBatch as DuckRecordBatch;
16use tracing::debug;
17
18use crate::error::DuckDbError;
19
20const DEFAULT_READ_POOL_SIZE: usize = 4;
22
23pub struct DuckDbEngine {
58 write_conn: Arc<std::sync::Mutex<duckdb::Connection>>,
60 read_pool: Vec<Arc<std::sync::Mutex<duckdb::Connection>>>,
62 read_idx: AtomicUsize,
64}
65
66unsafe impl Send for DuckDbEngine {}
78unsafe impl Sync for DuckDbEngine {}
85
86impl DuckDbEngine {
87 pub fn in_memory() -> Result<Self, DuckDbError> {
93 Self::in_memory_with_pool(DEFAULT_READ_POOL_SIZE)
94 }
95
96 pub fn in_memory_with_pool(read_pool_size: usize) -> Result<Self, DuckDbError> {
101 let write_conn = duckdb::Connection::open_in_memory()?;
102 Self::from_connection(write_conn, read_pool_size.max(1))
103 }
104
105 pub fn persistent(path: &str) -> Result<Self, DuckDbError> {
110 Self::persistent_with_pool(path, DEFAULT_READ_POOL_SIZE)
111 }
112
113 pub fn persistent_with_pool(path: &str, read_pool_size: usize) -> Result<Self, DuckDbError> {
118 let write_conn = duckdb::Connection::open(path)?;
119 Self::from_connection(write_conn, read_pool_size.max(1))
120 }
121
122 fn from_connection(
128 write_conn: duckdb::Connection,
129 read_pool_size: usize,
130 ) -> Result<Self, DuckDbError> {
131 let mut read_pool = Vec::with_capacity(read_pool_size);
132 for _ in 0..read_pool_size {
133 let reader = write_conn.try_clone()?;
134 read_pool.push(Arc::new(std::sync::Mutex::new(reader)));
135 }
136
137 Ok(Self {
138 write_conn: Arc::new(std::sync::Mutex::new(write_conn)),
139 read_pool,
140 read_idx: AtomicUsize::new(0),
141 })
142 }
143
144 fn next_reader(&self) -> Arc<std::sync::Mutex<duckdb::Connection>> {
150 let idx = self.read_idx.fetch_add(1, Ordering::Relaxed) % self.read_pool.len();
151 Arc::clone(&self.read_pool[idx])
152 }
153
154 pub fn read_pool_size(&self) -> usize {
158 self.read_pool.len()
159 }
160}
161
162fn arrow_type_to_duckdb_sql(dt: &DataType) -> &'static str {
168 match dt {
169 DataType::Boolean => "BOOLEAN",
170 DataType::Int8 | DataType::UInt8 => "TINYINT",
171 DataType::Int16 | DataType::UInt16 => "SMALLINT",
172 DataType::Int32 | DataType::UInt32 => "INTEGER",
173 DataType::Int64 | DataType::UInt64 => "BIGINT",
174 DataType::Float16 | DataType::Float32 => "FLOAT",
175 DataType::Float64 => "DOUBLE",
176 DataType::Utf8 | DataType::LargeUtf8 => "VARCHAR",
177 DataType::Binary | DataType::LargeBinary => "BLOB",
178 DataType::Date32 | DataType::Date64 => "DATE",
179 DataType::Timestamp(_, _) => "TIMESTAMP",
180 _ => "VARCHAR", }
182}
183
184fn convert_duck_batch(b: DuckRecordBatch) -> Result<RecordBatch, DuckDbError> {
191 let schema = Arc::new(arrow::datatypes::Schema::new(
192 b.schema()
193 .fields()
194 .iter()
195 .map(|f| arrow::datatypes::Field::new(f.name(), f.data_type().clone(), f.is_nullable()))
196 .collect::<Vec<_>>(),
197 ));
198 RecordBatch::try_new(schema, b.columns().to_vec()).map_err(DuckDbError::Arrow)
199}
200
201fn convert_to_duck_batch(b: &RecordBatch) -> Result<DuckRecordBatch, DuckDbError> {
207 let duck_schema = Arc::new(duckdb::arrow::datatypes::Schema::new(
208 b.schema()
209 .fields()
210 .iter()
211 .map(|f| {
212 duckdb::arrow::datatypes::Field::new(
213 f.name(),
214 f.data_type().clone(),
215 f.is_nullable(),
216 )
217 })
218 .collect::<Vec<_>>(),
219 ));
220 DuckRecordBatch::try_new(duck_schema, b.columns().to_vec()).map_err(DuckDbError::Arrow)
221}
222
223impl rhei_core::OlapEngine for DuckDbEngine {
224 type Error = DuckDbError;
225
226 async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
234 debug!(sql, "DuckDB query (reader)");
235 let conn = self.next_reader();
236 let sql = sql.to_string();
237 tokio::task::spawn_blocking(move || {
238 let conn = conn.lock().unwrap();
239 let mut stmt = conn.prepare(&sql).map_err(DuckDbError::DuckDb)?;
240 let arrow_result = stmt.query_arrow([]).map_err(DuckDbError::DuckDb)?;
241 let duck_batches: Vec<DuckRecordBatch> = arrow_result.collect();
242 duck_batches
243 .into_iter()
244 .map(convert_duck_batch)
245 .collect::<Result<Vec<_>, _>>()
246 })
247 .await
248 .map_err(DuckDbError::from_join)?
249 }
250
251 async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
256 debug!(sql, "DuckDB execute (writer)");
257 let conn = Arc::clone(&self.write_conn);
258 let sql = sql.to_string();
259 tokio::task::spawn_blocking(move || {
260 let conn = conn.lock().unwrap();
261 let rows = conn.execute(&sql, []).map_err(DuckDbError::DuckDb)?;
262 Ok(rows as u64)
263 })
264 .await
265 .map_err(DuckDbError::from_join)?
266 }
267
268 async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
292 if batches.is_empty() {
293 return Ok(0);
294 }
295
296 debug!(
297 table,
298 batch_count = batches.len(),
299 "DuckDB load_arrow (writer, appender)"
300 );
301
302 rhei_core::validate_identifier(table)?;
304
305 let conn = Arc::clone(&self.write_conn);
306 let table = table.to_string();
307 let batches = batches.to_vec();
308
309 tokio::task::spawn_blocking(move || {
310 let conn = conn.lock().unwrap();
311
312 let mut appender = conn.appender(&table).map_err(DuckDbError::DuckDb)?;
316 let mut total_rows: u64 = 0;
317
318 for batch in &batches {
319 if batch.num_rows() == 0 {
320 continue;
321 }
322 let duck_batch = convert_to_duck_batch(batch)?;
323 appender
324 .append_record_batch(duck_batch)
325 .map_err(DuckDbError::DuckDb)?;
326 total_rows += batch.num_rows() as u64;
327 }
328
329 appender.flush().map_err(DuckDbError::DuckDb)?;
331
332 Ok(total_rows)
333 })
334 .await
335 .map_err(DuckDbError::from_join)?
336 }
337
338 async fn create_table(
348 &self,
349 table_name: &str,
350 schema: &SchemaRef,
351 primary_key: &[String],
352 ) -> Result<(), Self::Error> {
353 rhei_core::validate_identifier(table_name)?;
354 for field in schema.fields() {
355 rhei_core::validate_identifier(field.name())?;
356 }
357 for pk_col in primary_key {
358 rhei_core::validate_identifier(pk_col)?;
359 }
360
361 let mut columns: Vec<String> = schema
362 .fields()
363 .iter()
364 .map(|f| {
365 let nullable = if f.is_nullable() { "" } else { " NOT NULL" };
366 format!(
367 "{} {}{}",
368 f.name(),
369 arrow_type_to_duckdb_sql(f.data_type()),
370 nullable
371 )
372 })
373 .collect();
374
375 if !primary_key.is_empty() {
376 columns.push(format!("PRIMARY KEY ({})", primary_key.join(", ")));
377 }
378
379 let ddl = format!(
380 "CREATE TABLE IF NOT EXISTS {} ({})",
381 table_name,
382 columns.join(", ")
383 );
384
385 debug!(ddl = ddl.as_str(), "DuckDB create_table (writer)");
386 let conn = Arc::clone(&self.write_conn);
387 tokio::task::spawn_blocking(move || {
388 let conn = conn.lock().unwrap();
389 conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
390 Ok(())
391 })
392 .await
393 .map_err(DuckDbError::from_join)?
394 }
395
396 async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
400 let conn = self.next_reader();
401 let table_name = table_name.to_string();
402 tokio::task::spawn_blocking(move || {
403 let conn = conn.lock().unwrap();
404 let mut stmt = conn
405 .prepare("SELECT count(*) FROM information_schema.tables WHERE table_name = ?")
406 .map_err(DuckDbError::DuckDb)?;
407 let mut rows = stmt
408 .query_arrow(duckdb::params![table_name])
409 .map_err(DuckDbError::DuckDb)?;
410
411 if let Some(batch) = rows.next() {
412 if batch.num_rows() > 0 {
413 let col = batch
414 .column(0)
415 .as_any()
416 .downcast_ref::<duckdb::arrow::array::Int64Array>();
417 if let Some(arr) = col {
418 return Ok(arr.value(0) > 0);
419 }
420 }
421 }
422 Ok(false)
423 })
424 .await
425 .map_err(DuckDbError::from_join)?
426 }
427
428 async fn add_column(
435 &self,
436 table_name: &str,
437 column_name: &str,
438 data_type: &DataType,
439 ) -> Result<(), Self::Error> {
440 rhei_core::validate_identifier(table_name)?;
441 rhei_core::validate_identifier(column_name)?;
442
443 let duckdb_type = arrow_type_to_duckdb_sql(data_type);
444 let ddl = format!(
445 "ALTER TABLE {} ADD COLUMN {} {}",
446 table_name, column_name, duckdb_type
447 );
448
449 debug!(ddl = ddl.as_str(), "DuckDB add_column (writer)");
450 let conn = Arc::clone(&self.write_conn);
451 tokio::task::spawn_blocking(move || {
452 let conn = conn.lock().unwrap();
453 conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
454 Ok(())
455 })
456 .await
457 .map_err(DuckDbError::from_join)?
458 }
459
460 fn supports_transactions(&self) -> bool {
466 true
467 }
468
469 async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
474 rhei_core::validate_identifier(table_name)?;
475 rhei_core::validate_identifier(column_name)?;
476
477 let ddl = format!("ALTER TABLE {} DROP COLUMN {}", table_name, column_name);
478
479 debug!(ddl = ddl.as_str(), "DuckDB drop_column (writer)");
480 let conn = Arc::clone(&self.write_conn);
481 tokio::task::spawn_blocking(move || {
482 let conn = conn.lock().unwrap();
483 conn.execute(&ddl, []).map_err(DuckDbError::DuckDb)?;
484 Ok(())
485 })
486 .await
487 .map_err(DuckDbError::from_join)?
488 }
489}
490
491#[derive(Clone)]
502pub struct SharedDuckDbEngine(pub Arc<DuckDbEngine>);
503
504impl SharedDuckDbEngine {
505 pub fn new(engine: DuckDbEngine) -> Self {
507 Self(Arc::new(engine))
508 }
509}
510
511impl Deref for SharedDuckDbEngine {
512 type Target = DuckDbEngine;
513
514 fn deref(&self) -> &Self::Target {
516 &self.0
517 }
518}
519
520impl rhei_core::OlapEngine for SharedDuckDbEngine {
523 type Error = DuckDbError;
524
525 async fn query(&self, sql: &str) -> Result<Vec<RecordBatch>, Self::Error> {
527 self.0.query(sql).await
528 }
529
530 async fn execute(&self, sql: &str) -> Result<u64, Self::Error> {
532 self.0.execute(sql).await
533 }
534
535 async fn load_arrow(&self, table: &str, batches: &[RecordBatch]) -> Result<u64, Self::Error> {
537 self.0.load_arrow(table, batches).await
538 }
539
540 async fn create_table(
542 &self,
543 table_name: &str,
544 schema: &SchemaRef,
545 primary_key: &[String],
546 ) -> Result<(), Self::Error> {
547 self.0.create_table(table_name, schema, primary_key).await
548 }
549
550 async fn table_exists(&self, table_name: &str) -> Result<bool, Self::Error> {
552 self.0.table_exists(table_name).await
553 }
554
555 async fn add_column(
557 &self,
558 table_name: &str,
559 column_name: &str,
560 data_type: &DataType,
561 ) -> Result<(), Self::Error> {
562 self.0.add_column(table_name, column_name, data_type).await
563 }
564
565 async fn drop_column(&self, table_name: &str, column_name: &str) -> Result<(), Self::Error> {
567 self.0.drop_column(table_name, column_name).await
568 }
569
570 fn supports_transactions(&self) -> bool {
572 self.0.supports_transactions()
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use arrow::datatypes::{Field, Schema};
580 use rhei_core::OlapEngine;
581
582 #[tokio::test]
583 async fn test_in_memory_basic() {
584 let engine = DuckDbEngine::in_memory().unwrap();
585 let schema = Arc::new(Schema::new(vec![
586 Field::new("id", DataType::Int64, false),
587 Field::new("name", DataType::Utf8, true),
588 ]));
589 engine
590 .create_table("test_table", &schema, &[])
591 .await
592 .unwrap();
593 assert!(engine.table_exists("test_table").await.unwrap());
594 assert!(!engine.table_exists("nonexistent").await.unwrap());
595 }
596
597 #[tokio::test]
598 async fn test_read_pool_round_robin() {
599 let engine = DuckDbEngine::in_memory_with_pool(2).unwrap();
600 assert_eq!(engine.read_pool_size(), 2);
601
602 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
603 engine.create_table("t", &schema, &[]).await.unwrap();
604 engine.execute("INSERT INTO t VALUES (1)").await.unwrap();
605
606 for _ in 0..4 {
608 let batches = engine.query("SELECT * FROM t").await.unwrap();
609 assert_eq!(batches.len(), 1);
610 assert_eq!(batches[0].num_rows(), 1);
611 }
612 }
613
614 #[tokio::test]
615 async fn test_shared_engine() {
616 let engine = SharedDuckDbEngine::new(DuckDbEngine::in_memory().unwrap());
617 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
618 engine
619 .create_table("shared_test", &schema, &[])
620 .await
621 .unwrap();
622 engine
623 .execute("INSERT INTO shared_test VALUES (42)")
624 .await
625 .unwrap();
626 let batches = engine.query("SELECT * FROM shared_test").await.unwrap();
627 assert_eq!(batches[0].num_rows(), 1);
628 }
629
630 #[tokio::test]
631 async fn test_pool_size_clamped_to_one() {
632 let engine = DuckDbEngine::in_memory_with_pool(0).unwrap();
633 assert_eq!(engine.read_pool_size(), 1);
634 }
635
636 #[tokio::test]
637 async fn test_load_arrow_basic_types() {
638 let engine = DuckDbEngine::in_memory().unwrap();
639 let schema = Arc::new(Schema::new(vec![
640 Field::new("id", DataType::Int64, false),
641 Field::new("name", DataType::Utf8, true),
642 Field::new("score", DataType::Float64, true),
643 Field::new("active", DataType::Boolean, true),
644 ]));
645 engine
646 .create_table("load_test", &schema, &[])
647 .await
648 .unwrap();
649
650 let batch = RecordBatch::try_new(
651 schema.clone(),
652 vec![
653 Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
654 Arc::new(arrow::array::StringArray::from(vec![
655 Some("alice"),
656 None,
657 Some("charlie"),
658 ])),
659 Arc::new(arrow::array::Float64Array::from(vec![
660 Some(9.5),
661 Some(8.0),
662 None,
663 ])),
664 Arc::new(arrow::array::BooleanArray::from(vec![
665 Some(true),
666 Some(false),
667 None,
668 ])),
669 ],
670 )
671 .unwrap();
672
673 let rows = engine.load_arrow("load_test", &[batch]).await.unwrap();
674 assert_eq!(rows, 3);
675
676 let result = engine
677 .query("SELECT * FROM load_test ORDER BY id")
678 .await
679 .unwrap();
680 assert_eq!(result[0].num_rows(), 3);
681 }
682
683 #[tokio::test]
684 async fn test_load_arrow_date_and_timestamp() {
685 use arrow::datatypes::TimeUnit;
686
687 let engine = DuckDbEngine::in_memory().unwrap();
688 let schema = Arc::new(Schema::new(vec![
689 Field::new("id", DataType::Int32, false),
690 Field::new("created_date", DataType::Date32, true),
691 Field::new(
692 "created_ts",
693 DataType::Timestamp(TimeUnit::Microsecond, None),
694 true,
695 ),
696 ]));
697 engine
698 .create_table("dates_test", &schema, &[])
699 .await
700 .unwrap();
701
702 let batch = RecordBatch::try_new(
705 schema.clone(),
706 vec![
707 Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
708 Arc::new(arrow::array::Date32Array::from(vec![Some(19737), Some(0)])),
709 Arc::new(arrow::array::TimestampMicrosecondArray::from(vec![
710 Some(1_705_276_800_000_000), None,
712 ])),
713 ],
714 )
715 .unwrap();
716
717 let rows = engine.load_arrow("dates_test", &[batch]).await.unwrap();
718 assert_eq!(rows, 2);
719
720 let result = engine
721 .query("SELECT * FROM dates_test ORDER BY id")
722 .await
723 .unwrap();
724 assert_eq!(result[0].num_rows(), 2);
725 }
726
727 #[tokio::test]
728 async fn test_create_table_with_composite_pk_enforced() {
729 let engine = DuckDbEngine::in_memory().unwrap();
730 let schema = Arc::new(Schema::new(vec![
731 Field::new("tenant_id", DataType::Int64, false),
732 Field::new("order_id", DataType::Int64, false),
733 Field::new("amount", DataType::Float64, true),
734 ]));
735 let pk = vec!["tenant_id".to_string(), "order_id".to_string()];
736 engine
737 .create_table("orders_pk_test", &schema, &pk)
738 .await
739 .unwrap();
740
741 engine
743 .execute("INSERT INTO orders_pk_test VALUES (1, 100, 9.99)")
744 .await
745 .unwrap();
746
747 let err = engine
749 .execute("INSERT INTO orders_pk_test VALUES (1, 100, 19.99)")
750 .await
751 .unwrap_err();
752 let msg = err.to_string().to_ascii_lowercase();
753 assert!(
754 msg.contains("constraint") || msg.contains("primary key") || msg.contains("unique"),
755 "expected a PK constraint error, got: {err}"
756 );
757
758 engine
760 .execute("INSERT INTO orders_pk_test VALUES (1, 101, 5.00)")
761 .await
762 .unwrap();
763 engine
764 .execute("INSERT INTO orders_pk_test VALUES (2, 100, 7.50)")
765 .await
766 .unwrap();
767
768 let batches = engine
769 .query("SELECT COUNT(*) FROM orders_pk_test")
770 .await
771 .unwrap();
772 let count = batches[0]
773 .column(0)
774 .as_any()
775 .downcast_ref::<duckdb::arrow::array::Int64Array>()
776 .unwrap()
777 .value(0);
778 assert_eq!(count, 3);
779 }
780
781 #[tokio::test]
782 async fn test_load_arrow_binary() {
783 let engine = DuckDbEngine::in_memory().unwrap();
784 let schema = Arc::new(Schema::new(vec![
785 Field::new("id", DataType::Int32, false),
786 Field::new("data", DataType::Binary, true),
787 ]));
788 engine
789 .create_table("binary_test", &schema, &[])
790 .await
791 .unwrap();
792
793 let batch = RecordBatch::try_new(
794 schema.clone(),
795 vec![
796 Arc::new(arrow::array::Int32Array::from(vec![1, 2])),
797 Arc::new(arrow::array::BinaryArray::from(vec![
798 Some(b"\x00\x01\x02\xff" as &[u8]),
799 None,
800 ])),
801 ],
802 )
803 .unwrap();
804
805 let rows = engine.load_arrow("binary_test", &[batch]).await.unwrap();
806 assert_eq!(rows, 2);
807
808 let result = engine
809 .query("SELECT * FROM binary_test ORDER BY id")
810 .await
811 .unwrap();
812 assert_eq!(result[0].num_rows(), 2);
813 }
814}