1use std::any::Any;
2use std::collections::HashSet;
3use std::sync::Arc;
4
5use arrow::array::RecordBatch;
6use arrow_schema::{DataType, Field};
7use async_stream::stream;
8use datafusion::arrow::datatypes::SchemaRef;
9use datafusion::error::DataFusionError;
10use datafusion::execution::SendableRecordBatchStream;
11use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
12use datafusion::sql::sqlparser::ast::TableFactor;
13use datafusion::sql::sqlparser::parser::Parser;
14use datafusion::sql::sqlparser::{dialect::DuckDbDialect, tokenizer::Tokenizer};
15use datafusion::sql::TableReference;
16use duckdb::vtab::to_duckdb_type_id;
17use duckdb::ToSql;
18use duckdb::{Connection, DuckdbConnectionManager};
19use dyn_clone::DynClone;
20use rand::distr::{Alphanumeric, SampleString};
21use snafu::{prelude::*, ResultExt};
22use tokio::sync::mpsc::Sender;
23
24use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
25use crate::util::schema::SchemaValidator;
26use crate::UnsupportedTypeAction;
27
28use super::DbConnection;
29use super::Result;
30use super::SyncDbConnection;
31
32#[derive(Debug, Snafu)]
33pub enum Error {
34 #[snafu(display("DuckDB connection failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
35 DuckDBConnectionError { source: duckdb::Error },
36
37 #[snafu(display("Query execution failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
38 DuckDBQueryError { source: duckdb::Error },
39
40 #[snafu(display(
41 "An unexpected error occurred.\n{message}\nVerify the configuration and try again."
42 ))]
43 ChannelError { message: String },
44
45 #[snafu(display(
46 "Unable to attach DuckDB database {path}.\n{source}\nEnsure the DuckDB file path is valid."
47 ))]
48 UnableToAttachDatabase {
49 path: Arc<str>,
50 source: std::io::Error,
51 },
52}
53
54pub trait DuckDBSyncParameter: ToSql + Sync + Send + DynClone {
55 fn as_input_parameter(&self) -> &dyn ToSql;
56}
57
58impl<T: ToSql + Sync + Send + DynClone> DuckDBSyncParameter for T {
59 fn as_input_parameter(&self) -> &dyn ToSql {
60 self
61 }
62}
63dyn_clone::clone_trait_object!(DuckDBSyncParameter);
64pub type DuckDBParameter = Box<dyn DuckDBSyncParameter>;
65
66#[derive(Debug)]
67pub struct DuckDBAttachments {
68 attachments: HashSet<Arc<str>>,
69 random_id: String,
70 main_db: String,
71}
72
73impl DuckDBAttachments {
74 #[must_use]
76 pub fn new(main_db: &str, attachments: &[Arc<str>]) -> Self {
77 let random_id = Alphanumeric.sample_string(&mut rand::rng(), 8);
78 let attachments: HashSet<Arc<str>> = attachments.iter().cloned().collect();
79 Self {
80 attachments,
81 random_id,
82 main_db: main_db.to_string(),
83 }
84 }
85
86 #[must_use]
90 fn get_search_path<'a>(id: &str, attachments: impl IntoIterator<Item = &'a str>) -> Arc<str> {
91 let mut path = String::from(id);
92
93 for attachment in attachments {
94 path.push(',');
95 path.push_str(attachment);
96 }
97
98 Arc::from(path)
99 }
100
101 pub fn set_search_path<'a>(
109 &self,
110 conn: &Connection,
111 attachments: impl IntoIterator<Item = &'a str>,
112 ) -> Result<Arc<str>> {
113 let search_path = Self::get_search_path(&self.main_db, attachments);
114
115 tracing::trace!("Setting search_path to {search_path}");
116
117 conn.execute(&format!("SET search_path ='{}'", search_path), [])
118 .context(DuckDBConnectionSnafu)?;
119 Ok(search_path)
120 }
121
122 pub fn reset_search_path(&self, conn: &Connection) -> Result<()> {
128 conn.execute("RESET search_path", [])
129 .context(DuckDBConnectionSnafu)?;
130 Ok(())
131 }
132
133 pub fn attach(&self, conn: &Connection) -> Result<Arc<str>> {
141 let mut stmt = conn
143 .prepare("PRAGMA database_list;")
144 .context(DuckDBConnectionSnafu)?;
145 let mut rows = stmt.query([]).context(DuckDBConnectionSnafu)?;
146
147 let mut existing_attachments = std::collections::HashMap::new();
148 while let Some(row) = rows.next()? {
149 let db_name: String = row.get(1)?;
150 let db_path: Option<String> = row.get(2)?;
151 if db_name.starts_with("attachment_") {
152 existing_attachments.insert(db_path.unwrap_or_default(), db_name);
154 }
155 }
156
157 if !existing_attachments.is_empty() {
159 tracing::trace!(
160 "Attachments {:?} creation skipped as connection contains existing attachments: {existing_attachments:?}",
161 self.attachments
162 );
163 for db in &self.attachments {
164 if !existing_attachments.contains_key(db.as_ref()) {
165 tracing::warn!("{db} not found among existing attachments");
166 }
167 }
168 return self.set_search_path(conn, existing_attachments.values().map(|s| s.as_str()));
170 }
171
172 let mut created_attachments = Vec::new();
173
174 for (i, db) in self.attachments.iter().enumerate() {
175 std::fs::metadata(db.as_ref()).context(UnableToAttachDatabaseSnafu {
177 path: Arc::clone(db),
178 })?;
179 let attachment_name = Self::get_attachment_name(&self.random_id, i);
180 let sql = format!("ATTACH IF NOT EXISTS '{db}' AS {attachment_name} (READ_ONLY)");
181 tracing::trace!("Attaching {db} using: {sql}");
182 conn.execute(&sql, []).context(DuckDBConnectionSnafu)?;
183 created_attachments.push(attachment_name);
184 }
185
186 self.set_search_path(conn, created_attachments.iter().map(|s| s.as_str()))
187 }
188
189 pub fn detach(&self, conn: &Connection) -> Result<()> {
195 for (i, _) in self.attachments.iter().enumerate() {
196 conn.execute(
197 &format!(
198 "DETACH DATABASE IF EXISTS {}",
199 Self::get_attachment_name(&self.random_id, i)
200 ),
201 [],
202 )
203 .context(DuckDBConnectionSnafu)?;
204 }
205
206 self.reset_search_path(conn)?;
207 Ok(())
208 }
209
210 #[must_use]
211 fn get_attachment_name(random_id: &str, index: usize) -> String {
212 format!("attachment_{random_id}_{index}")
213 }
214}
215
216pub struct DuckDbConnection {
217 pub conn: r2d2::PooledConnection<DuckdbConnectionManager>,
218 attachments: Option<Arc<DuckDBAttachments>>,
219 unsupported_type_action: UnsupportedTypeAction,
220 connection_setup_queries: Vec<Arc<str>>,
221}
222
223impl SchemaValidator for DuckDbConnection {
224 type Error = super::Error;
225
226 fn is_data_type_supported(data_type: &DataType) -> bool {
227 match data_type {
228 DataType::List(inner_field)
229 | DataType::FixedSizeList(inner_field, _)
230 | DataType::LargeList(inner_field) => {
231 match inner_field.data_type() {
232 dt if dt.is_primitive() => true,
233 DataType::Utf8
234 | DataType::Binary
235 | DataType::Utf8View
236 | DataType::BinaryView
237 | DataType::Boolean => true,
238 _ => false, }
240 }
241 DataType::Struct(inner_fields) => inner_fields
242 .iter()
243 .all(|field| Self::is_data_type_supported(field.data_type())),
244 _ => true,
245 }
246 }
247
248 fn is_field_supported(field: &Arc<Field>) -> bool {
249 let duckdb_type_id = to_duckdb_type_id(field.data_type());
250 Self::is_data_type_supported(field.data_type()) && duckdb_type_id.is_ok()
251 }
252
253 fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
254 super::Error::UnsupportedDataType {
255 data_type: data_type.to_string(),
256 field_name: field_name.to_string(),
257 }
258 }
259}
260
261impl DuckDbConnection {
262 pub fn get_underlying_conn_mut(
263 &mut self,
264 ) -> &mut r2d2::PooledConnection<DuckdbConnectionManager> {
265 &mut self.conn
266 }
267
268 #[must_use]
269 pub fn with_unsupported_type_action(
270 mut self,
271 unsupported_type_action: UnsupportedTypeAction,
272 ) -> Self {
273 self.unsupported_type_action = unsupported_type_action;
274 self
275 }
276
277 #[must_use]
278 pub fn with_attachments(mut self, attachments: Option<Arc<DuckDBAttachments>>) -> Self {
279 self.attachments = attachments;
280 self
281 }
282
283 #[must_use]
284 pub fn with_connection_setup_queries(mut self, queries: Vec<Arc<str>>) -> Self {
285 self.connection_setup_queries = queries;
286 self
287 }
288
289 pub fn attach(conn: &Connection, attachments: &Option<Arc<DuckDBAttachments>>) -> Result<()> {
295 if let Some(attachments) = attachments {
296 attachments.attach(conn)?;
297 }
298 Ok(())
299 }
300
301 pub fn detach(conn: &Connection, attachments: &Option<Arc<DuckDBAttachments>>) -> Result<()> {
307 if let Some(attachments) = attachments {
308 attachments.detach(conn)?;
309 }
310 Ok(())
311 }
312
313 fn apply_connection_setup_queries(&self, conn: &Connection) -> Result<()> {
314 for query in self.connection_setup_queries.iter() {
315 conn.execute(query, []).context(DuckDBConnectionSnafu)?;
316 }
317 Ok(())
318 }
319}
320
321impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
322 for DuckDbConnection
323{
324 fn as_any(&self) -> &dyn Any {
325 self
326 }
327
328 fn as_any_mut(&mut self) -> &mut dyn Any {
329 self
330 }
331
332 fn as_sync(
333 &self,
334 ) -> Option<
335 &dyn SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
336 > {
337 Some(self)
338 }
339}
340
341impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
342 for DuckDbConnection
343{
344 fn new(conn: r2d2::PooledConnection<DuckdbConnectionManager>) -> Self {
345 DuckDbConnection {
346 conn,
347 attachments: None,
348 unsupported_type_action: UnsupportedTypeAction::default(),
349 connection_setup_queries: Vec::new(),
350 }
351 }
352
353 fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
354 let sql = "SELECT table_name FROM information_schema.tables \
355 WHERE table_schema = ? AND table_type = 'BASE TABLE'";
356
357 let mut stmt = self
358 .conn
359 .prepare(sql)
360 .boxed()
361 .context(super::UnableToGetTablesSnafu)?;
362 let mut rows = stmt
363 .query([schema])
364 .boxed()
365 .context(super::UnableToGetTablesSnafu)?;
366 let mut tables = vec![];
367
368 while let Some(row) = rows.next().boxed().context(super::UnableToGetTablesSnafu)? {
369 tables.push(row.get(0).boxed().context(super::UnableToGetTablesSnafu)?);
370 }
371
372 Ok(tables)
373 }
374
375 fn schemas(&self) -> Result<Vec<String>, super::Error> {
376 let sql = "SELECT DISTINCT schema_name FROM information_schema.schemata \
377 WHERE schema_name NOT IN ('information_schema', 'pg_catalog')";
378
379 let mut stmt = self
380 .conn
381 .prepare(sql)
382 .boxed()
383 .context(super::UnableToGetSchemasSnafu)?;
384 let mut rows = stmt
385 .query([])
386 .boxed()
387 .context(super::UnableToGetSchemasSnafu)?;
388 let mut schemas = vec![];
389
390 while let Some(row) = rows
391 .next()
392 .boxed()
393 .context(super::UnableToGetSchemasSnafu)?
394 {
395 schemas.push(row.get(0).boxed().context(super::UnableToGetSchemasSnafu)?);
396 }
397
398 Ok(schemas)
399 }
400
401 fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, super::Error> {
402 let table_str = if is_table_function(table_reference) {
403 table_reference.to_string()
404 } else {
405 table_reference.to_quoted_string()
406 };
407 let mut stmt = self
408 .conn
409 .prepare(&format!("SELECT * FROM {table_str} LIMIT 0"))
410 .boxed()
411 .context(super::UnableToGetSchemaSnafu)?;
412
413 let result: duckdb::Arrow<'_> = stmt
414 .query_arrow([])
415 .boxed()
416 .context(super::UnableToGetSchemaSnafu)?;
417
418 Self::handle_unsupported_schema(&result.get_schema(), self.unsupported_type_action)
419 }
420
421 fn query_arrow(
422 &self,
423 sql: &str,
424 params: &[DuckDBParameter],
425 _projected_schema: Option<SchemaRef>,
426 ) -> Result<SendableRecordBatchStream> {
427 let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
428
429 let conn = self.conn.try_clone()?;
430 Self::attach(&conn, &self.attachments)?;
431 self.apply_connection_setup_queries(&conn)?;
432
433 let fetch_schema_sql =
434 format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0");
435 let mut stmt = conn
436 .prepare(&fetch_schema_sql)
437 .boxed()
438 .context(super::UnableToGetSchemaSnafu)?;
439
440 let result: duckdb::Arrow<'_> = stmt
441 .query_arrow([])
442 .boxed()
443 .context(super::UnableToGetSchemaSnafu)?;
444
445 let schema = result.get_schema();
446
447 let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
448
449 let sql = sql.to_string();
450
451 let cloned_schema = schema.clone();
452
453 let create_stream = || -> Result<SendableRecordBatchStream> {
454 let join_handle = tokio::task::spawn_blocking(move || {
455 let mut stmt = conn.prepare(&sql).context(DuckDBQuerySnafu)?;
456 let params: &[&dyn ToSql] = ¶ms
457 .iter()
458 .map(|f| f.as_input_parameter())
459 .collect::<Vec<_>>();
460 let result: duckdb::ArrowStream<'_> = stmt
461 .stream_arrow(params, cloned_schema)
462 .context(DuckDBQuerySnafu)?;
463 for i in result {
464 blocking_channel_send(&batch_tx, i)?;
465 }
466 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
467 });
468
469 let output_stream = stream! {
470 while let Some(batch) = batch_rx.recv().await {
471 yield Ok(batch);
472 }
473
474 match join_handle.await {
475 Ok(Err(task_error)) => {
476 yield Err(DataFusionError::Execution(format!(
477 "Failed to execute DuckDB query: {task_error}"
478 )))
479 },
480 Err(join_error) => {
481 yield Err(DataFusionError::Execution(format!(
482 "Failed to execute DuckDB query: {join_error}"
483 )))
484 },
485 _ => {}
486 }
487 };
488
489 Ok(Box::pin(RecordBatchStreamAdapter::new(
490 schema,
491 output_stream,
492 )))
493 };
494
495 run_sync_with_tokio(create_stream)
496 }
497
498 fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result<u64> {
499 let params: &[&dyn ToSql] = ¶ms
500 .iter()
501 .map(|f| f.as_input_parameter())
502 .collect::<Vec<_>>();
503
504 let rows_modified = self.conn.execute(sql, params).context(DuckDBQuerySnafu)?;
505 Ok(rows_modified as u64)
506 }
507}
508
509fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
510 match channel.blocking_send(item) {
511 Ok(()) => Ok(()),
512 Err(e) => Err(Error::ChannelError {
513 message: format!("{e}"),
514 }
515 .into()),
516 }
517}
518
519#[must_use]
520pub fn flatten_table_function_name(table_reference: &TableReference) -> String {
521 let table_name = table_reference.table();
522 let filtered_name: String = table_name
523 .chars()
524 .filter(|c| c.is_alphanumeric() || *c == '(')
525 .collect();
526 let result = filtered_name.replace('(', "_");
527
528 format!("{result}__view")
529}
530
531#[must_use]
532pub fn is_table_function(table_reference: &TableReference) -> bool {
533 let table_name = match table_reference {
534 TableReference::Full { .. } | TableReference::Partial { .. } => return false,
535 TableReference::Bare { table } => table,
536 };
537
538 let dialect = DuckDbDialect {};
539 let mut tokenizer = Tokenizer::new(&dialect, table_name);
540 let Ok(tokens) = tokenizer.tokenize() else {
541 return false;
542 };
543 let Ok(tf) = Parser::new(&dialect)
544 .with_tokens(tokens)
545 .parse_table_factor()
546 else {
547 return false;
548 };
549
550 let TableFactor::Table { args, .. } = tf else {
551 return false;
552 };
553
554 args.is_some()
555}
556
557#[cfg(test)]
558mod tests {
559 use arrow_schema::{DataType, Field, Fields, SchemaBuilder};
560 use tempfile::tempdir;
561
562 use super::*;
563
564 #[test]
565 fn test_is_table_function() {
566 let tests = vec![
567 ("table_name", false),
568 ("table_name()", true),
569 ("table_name(arg1, arg2)", true),
570 ("read_parquet", false),
571 ("read_parquet()", true),
572 ("read_parquet('my_parquet_file.parquet')", true),
573 ("read_csv_auto('my_csv_file.csv')", true),
574 ];
575
576 for (table_name, expected) in tests {
577 let table_reference = TableReference::bare(table_name.to_string());
578 assert_eq!(is_table_function(&table_reference), expected);
579 }
580 }
581
582 #[test]
583 fn test_field_is_unsupported() {
584 let field = Field::new(
586 "list_struct",
587 DataType::List(Arc::new(Field::new(
588 "struct",
589 DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
590 false,
591 ))),
592 false,
593 );
594
595 assert!(
596 !DuckDbConnection::is_data_type_supported(field.data_type()),
597 "list with struct should be unsupported"
598 );
599 }
600
601 #[test]
602 fn test_fields_are_supported() {
603 let fields = vec![
605 Field::new("string", DataType::Utf8, false),
606 Field::new("int", DataType::Int64, false),
607 Field::new("float", DataType::Float64, false),
608 Field::new("bool", DataType::Boolean, false),
609 Field::new("binary", DataType::Binary, false),
610 ];
611
612 for field in fields {
613 assert!(
614 DuckDbConnection::is_data_type_supported(field.data_type()),
615 "field should be supported"
616 );
617 }
618 }
619
620 #[test]
621 fn test_schema_rebuild_with_supported_fields() {
622 let fields = vec![
623 Field::new("string", DataType::Utf8, false),
624 Field::new("int", DataType::Int64, false),
625 Field::new("float", DataType::Float64, false),
626 Field::new("bool", DataType::Boolean, false),
627 Field::new("binary", DataType::Binary, false),
628 ];
629
630 let schema = Arc::new(SchemaBuilder::from(Fields::from(fields)).finish());
631
632 let rebuilt_schema =
633 DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
634 .expect("should rebuild schema successfully");
635
636 assert_eq!(schema, rebuilt_schema);
637 }
638
639 #[test]
640 fn test_schema_rebuild_with_unsupported_fields() {
641 let fields = vec![
642 Field::new("string", DataType::Utf8, false),
643 Field::new("int", DataType::Int64, false),
644 Field::new("float", DataType::Float64, false),
645 Field::new("bool", DataType::Boolean, false),
646 Field::new("binary", DataType::Binary, false),
647 Field::new(
648 "list_struct",
649 DataType::List(Arc::new(Field::new(
650 "struct",
651 DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
652 false,
653 ))),
654 false,
655 ),
656 Field::new("another_bool", DataType::Boolean, false),
657 Field::new(
658 "another_list_struct",
659 DataType::List(Arc::new(Field::new(
660 "struct",
661 DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
662 false,
663 ))),
664 false,
665 ),
666 Field::new("another_float", DataType::Float32, false),
667 ];
668
669 let rebuilt_fields = vec![
670 Field::new("string", DataType::Utf8, false),
671 Field::new("int", DataType::Int64, false),
672 Field::new("float", DataType::Float64, false),
673 Field::new("bool", DataType::Boolean, false),
674 Field::new("binary", DataType::Binary, false),
675 Field::new("another_bool", DataType::Boolean, false),
677 Field::new("another_float", DataType::Float32, false),
678 ];
679
680 let schema = Arc::new(SchemaBuilder::from(Fields::from(fields)).finish());
681 let expected_rebuilt_schema =
682 Arc::new(SchemaBuilder::from(Fields::from(rebuilt_fields)).finish());
683
684 assert!(
685 DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
686 .is_err()
687 );
688
689 let rebuilt_schema =
690 DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Warn)
691 .expect("should rebuild schema successfully");
692
693 assert_eq!(rebuilt_schema, expected_rebuilt_schema);
694
695 let rebuilt_schema =
696 DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Ignore)
697 .expect("should rebuild schema successfully");
698
699 assert_eq!(rebuilt_schema, expected_rebuilt_schema);
700 }
701
702 #[test]
703 fn test_duckdb_attachments_deduplication() {
704 let db1 = Arc::from("db1.duckdb");
705 let db2 = Arc::from("db2.duckdb");
706 let db3 = Arc::from("db3.duckdb");
707
708 let attachments = vec![
710 Arc::clone(&db1),
711 Arc::clone(&db2),
712 Arc::clone(&db1), Arc::clone(&db3),
714 Arc::clone(&db2), ];
716
717 let duckdb_attachments = DuckDBAttachments::new("main_db", &attachments);
718
719 assert_eq!(duckdb_attachments.attachments.len(), 3);
721 assert!(duckdb_attachments.attachments.contains(&db1));
722 assert!(duckdb_attachments.attachments.contains(&db2));
723 assert!(duckdb_attachments.attachments.contains(&db3));
724 }
725
726 #[test]
727 fn test_duckdb_attachments_search_path() -> Result<()> {
728 let temp_dir = tempdir()?;
729 let db1: Arc<str> = temp_dir
730 .path()
731 .join("db1.duckdb")
732 .to_str()
733 .expect("to convert path to str")
734 .into();
735 let db2: Arc<str> = temp_dir
736 .path()
737 .join("db2.duckdb")
738 .to_str()
739 .expect("to convert path to str")
740 .into();
741 let db3: Arc<str> = temp_dir
742 .path()
743 .join("db3.duckdb")
744 .to_str()
745 .expect("to convert path to str")
746 .into();
747
748 for db in [&db1, &db2, &db3] {
749 let conn1 = Connection::open(db.as_ref())?;
750 conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
751 }
752
753 let attachments = vec![
755 Arc::clone(&db1),
756 Arc::clone(&db2),
757 Arc::clone(&db1), Arc::clone(&db3),
759 Arc::clone(&db2), ];
761
762 let duckdb_attachments = DuckDBAttachments::new("main", &attachments);
763
764 let conn = Connection::open_in_memory()?;
765
766 let search_path = duckdb_attachments.attach(&conn)?;
767
768 assert!(search_path.starts_with("main"));
771 assert!(search_path.contains("attachment_"));
772 assert_eq!(search_path.split(',').count(), 4); Ok(())
775 }
776
777 #[test]
778 fn test_duckdb_attachments_empty() -> Result<()> {
779 let duckdb_attachments = DuckDBAttachments::new("main", &[]);
780
781 assert!(duckdb_attachments.attachments.is_empty());
783
784 let conn = Connection::open_in_memory()?;
787
788 let search_path = duckdb_attachments.attach(&conn)?;
789 assert_eq!(search_path, "main".into());
790
791 Ok(())
792 }
793
794 #[test]
795 fn test_duckdb_attachments_with_real_files() -> Result<()> {
796 let temp_dir = tempdir()?;
798 let db1_path = temp_dir.path().join("db1.duckdb");
799 let db2_path = temp_dir.path().join("db2.duckdb");
800
801 {
803 let conn1 = Connection::open(&db1_path)?;
804 conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
805 conn1.execute("INSERT INTO test1 VALUES (1, 'test1_1')", [])?;
806
807 let conn2 = Connection::open(&db2_path)?;
808 conn2.execute("CREATE TABLE test2 (id INTEGER, name VARCHAR)", [])?;
809 conn2.execute("INSERT INTO test2 VALUES (2, 'test2_1')", [])?;
810 }
811
812 let attachments = vec![
814 Arc::from(db1_path.to_str().unwrap()),
815 Arc::from(db2_path.to_str().unwrap()),
816 Arc::from(db1_path.to_str().unwrap()), ];
818
819 let conn = Connection::open_in_memory()?;
821
822 let duckdb_attachments = DuckDBAttachments::new("main", &attachments);
824 duckdb_attachments.attach(&conn)?;
825
826 let result1: (i64, String) = conn
828 .query_row("SELECT * FROM test1 LIMIT 1", [], |row| {
829 Ok((
830 row.get::<_, i64>(0).expect("to get i64"),
831 row.get::<_, String>(1).expect("to get string"),
832 ))
833 })
834 .expect("to get result");
835 let result2: (i64, String) = conn
836 .query_row("SELECT * FROM test2 LIMIT 1", [], |row| {
837 Ok((
838 row.get::<_, i64>(0).expect("to get i64"),
839 row.get::<_, String>(1).expect("to get string"),
840 ))
841 })
842 .expect("to get result");
843
844 assert_eq!(result1, (1, "test1_1".to_string()));
845 assert_eq!(result2, (2, "test2_1".to_string()));
846
847 let search_path: String = conn
849 .query_row("SELECT current_setting('search_path');", [], |row| {
850 Ok(row.get::<_, String>(0).expect("to get string"))
851 })
852 .expect("to get search path");
853 assert!(search_path.contains("main"));
854 assert!(search_path.contains("attachment_"));
855
856 duckdb_attachments.detach(&conn)?;
858 Ok(())
859 }
860
861 #[test]
862 fn test_duckdb_attach_multiple_times() -> Result<()> {
863 let temp_dir = tempdir()?;
865 let db1_path = temp_dir.path().join("db1.duckdb");
866 let db2_path = temp_dir.path().join("db2.duckdb");
867
868 {
870 let conn1 = Connection::open(&db1_path)?;
871 conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
872 conn1.execute("INSERT INTO test1 VALUES (1, 'test1_1')", [])?;
873
874 let conn2 = Connection::open(&db2_path)?;
875 conn2.execute("CREATE TABLE test2 (id INTEGER, name VARCHAR)", [])?;
876 conn2.execute("INSERT INTO test2 VALUES (2, 'test2_1')", [])?;
877 }
878
879 let attachments = vec![
880 Arc::from(db1_path.to_str().expect("to convert path top str")),
881 Arc::from(db2_path.to_str().expect("to convert path top str")),
882 ];
883
884 let conn = Connection::open_in_memory()?;
885
886 DuckDBAttachments::new("main", &attachments).attach(&conn)?;
888 DuckDBAttachments::new("main", &attachments).attach(&conn)?;
889 DuckDBAttachments::new("main", &attachments).attach(&conn)?;
890
891 let join_result: (i64, String, i64, String) = conn
892 .query_row(
893 "SELECT t1.id, t1.name, t2.id, t2.name FROM test1 t1, test2 t2",
894 [],
895 |row| {
896 Ok((
897 row.get::<_, i64>(0).expect("to get i64"),
898 row.get::<_, String>(1).expect("to get string"),
899 row.get::<_, i64>(2).expect("to get i64"),
900 row.get::<_, String>(3).expect("to get string"),
901 ))
902 },
903 )
904 .expect("to get join result");
905
906 assert_eq!(
907 join_result,
908 (1, "test1_1".to_string(), 2, "test2_1".to_string())
909 );
910
911 Ok(())
912 }
913}