datafusion_table_providers/sql/db_connection_pool/dbconnection/
duckdbconn.rs

1use std::any::Any;
2use std::collections::HashSet;
3use std::sync::Arc;
4
5use arrow::array::RecordBatch;
6use arrow_schema::{DataType, Field};
7use async_stream::stream;
8use datafusion::arrow::datatypes::SchemaRef;
9use datafusion::error::DataFusionError;
10use datafusion::execution::SendableRecordBatchStream;
11use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
12use datafusion::sql::sqlparser::ast::TableFactor;
13use datafusion::sql::sqlparser::parser::Parser;
14use datafusion::sql::sqlparser::{dialect::DuckDbDialect, tokenizer::Tokenizer};
15use datafusion::sql::TableReference;
16use duckdb::vtab::to_duckdb_type_id;
17use duckdb::ToSql;
18use duckdb::{Connection, DuckdbConnectionManager};
19use dyn_clone::DynClone;
20use rand::distr::{Alphanumeric, SampleString};
21use snafu::{prelude::*, ResultExt};
22use tokio::sync::mpsc::Sender;
23
24use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
25use crate::util::schema::SchemaValidator;
26use crate::UnsupportedTypeAction;
27
28use super::DbConnection;
29use super::Result;
30use super::SyncDbConnection;
31
32#[derive(Debug, Snafu)]
33pub enum Error {
34    #[snafu(display("DuckDB connection failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
35    DuckDBConnectionError { source: duckdb::Error },
36
37    #[snafu(display("Query execution failed.\n{source}\nFor details, refer to the DuckDB manual: https://duckdb.org/docs/"))]
38    DuckDBQueryError { source: duckdb::Error },
39
40    #[snafu(display(
41        "An unexpected error occurred.\n{message}\nVerify the configuration and try again."
42    ))]
43    ChannelError { message: String },
44
45    #[snafu(display(
46        "Unable to attach DuckDB database {path}.\n{source}\nEnsure the DuckDB file path is valid."
47    ))]
48    UnableToAttachDatabase {
49        path: Arc<str>,
50        source: std::io::Error,
51    },
52}
53
54pub trait DuckDBSyncParameter: ToSql + Sync + Send + DynClone {
55    fn as_input_parameter(&self) -> &dyn ToSql;
56}
57
58impl<T: ToSql + Sync + Send + DynClone> DuckDBSyncParameter for T {
59    fn as_input_parameter(&self) -> &dyn ToSql {
60        self
61    }
62}
63dyn_clone::clone_trait_object!(DuckDBSyncParameter);
64pub type DuckDBParameter = Box<dyn DuckDBSyncParameter>;
65
66#[derive(Debug)]
67pub struct DuckDBAttachments {
68    attachments: HashSet<Arc<str>>,
69    random_id: String,
70    main_db: String,
71}
72
73impl DuckDBAttachments {
74    /// Creates a new instance of a `DuckDBAttachments`, which instructs DuckDB connections to attach other DuckDB databases for queries.
75    #[must_use]
76    pub fn new(main_db: &str, attachments: &[Arc<str>]) -> Self {
77        let random_id = Alphanumeric.sample_string(&mut rand::rng(), 8);
78        let attachments: HashSet<Arc<str>> = attachments.iter().cloned().collect();
79        Self {
80            attachments,
81            random_id,
82            main_db: main_db.to_string(),
83        }
84    }
85
86    /// Returns the search path for the given database and attachments.
87    /// The given database needs to be included separately, as search path by default do not include the main database.
88    /// The `attachments` parameter represents full attachment names, e.g., ["attachment_zCVN0zYJ_0", ...]
89    #[must_use]
90    fn get_search_path<'a>(id: &str, attachments: impl IntoIterator<Item = &'a str>) -> Arc<str> {
91        let mut path = String::from(id);
92
93        for attachment in attachments {
94            path.push(',');
95            path.push_str(attachment);
96        }
97
98        Arc::from(path)
99    }
100
101    /// Sets the search path for the given connection.
102    ///
103    /// The `attachments` parameter represents full attachment names, e.g., ["attachment_zCVN0zYJ_0", ...]
104    /// # Errors
105    ///
106    /// Returns an error if the search path cannot be set or the connection fails.
107    /// Returns search path if successful.
108    pub fn set_search_path<'a>(
109        &self,
110        conn: &Connection,
111        attachments: impl IntoIterator<Item = &'a str>,
112    ) -> Result<Arc<str>> {
113        let search_path = Self::get_search_path(&self.main_db, attachments);
114
115        tracing::trace!("Setting search_path to {search_path}");
116
117        conn.execute(&format!("SET search_path ='{}'", search_path), [])
118            .context(DuckDBConnectionSnafu)?;
119        Ok(search_path)
120    }
121
122    /// Resets the search path for the given connection to default.
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the search path cannot be set or the connection fails.
127    pub fn reset_search_path(&self, conn: &Connection) -> Result<()> {
128        conn.execute("RESET search_path", [])
129            .context(DuckDBConnectionSnafu)?;
130        Ok(())
131    }
132
133    /// Attaches the databases to the given connection and sets the search path for the newly attached databases.
134    /// If connection already contains attachments, it will skip the attachments override (including search_path).
135    ///
136    /// # Errors
137    ///
138    /// Returns an error if a specific attachment is missing, cannot be attached, search path cannot be set or the connection fails.
139    /// Returns search path if successful.
140    pub fn attach(&self, conn: &Connection) -> Result<Arc<str>> {
141        // Check if attachments already exist; skip attachments override in this case as it requires changing the search_path
142        let mut stmt = conn
143            .prepare("PRAGMA database_list;")
144            .context(DuckDBConnectionSnafu)?;
145        let mut rows = stmt.query([]).context(DuckDBConnectionSnafu)?;
146
147        let mut existing_attachments = std::collections::HashMap::new();
148        while let Some(row) = rows.next()? {
149            let db_name: String = row.get(1)?;
150            let db_path: Option<String> = row.get(2)?;
151            if db_name.starts_with("attachment_") {
152                // attachment always has a path so it is safe to use unwrap_or_default
153                existing_attachments.insert(db_path.unwrap_or_default(), db_name);
154            }
155        }
156
157        // Check if the connection already contains the desired attachments
158        if !existing_attachments.is_empty() {
159            tracing::trace!(
160                "Attachments {:?} creation skipped as connection contains existing attachments: {existing_attachments:?}",
161                self.attachments
162            );
163            for db in &self.attachments {
164                if !existing_attachments.contains_key(db.as_ref()) {
165                    tracing::warn!("{db} not found among existing attachments");
166                }
167            }
168            // The connection can have attachments but not the search_path, so we must set it based on the existing attachment names
169            return self.set_search_path(conn, existing_attachments.values().map(|s| s.as_str()));
170        }
171
172        let mut created_attachments = Vec::new();
173
174        for (i, db) in self.attachments.iter().enumerate() {
175            // check the db file exists
176            std::fs::metadata(db.as_ref()).context(UnableToAttachDatabaseSnafu {
177                path: Arc::clone(db),
178            })?;
179            let attachment_name = Self::get_attachment_name(&self.random_id, i);
180            let sql = format!("ATTACH IF NOT EXISTS '{db}' AS {attachment_name} (READ_ONLY)");
181            tracing::trace!("Attaching {db} using: {sql}");
182            conn.execute(&sql, []).context(DuckDBConnectionSnafu)?;
183            created_attachments.push(attachment_name);
184        }
185
186        self.set_search_path(conn, created_attachments.iter().map(|s| s.as_str()))
187    }
188
189    /// Detaches the databases from the given connection and resets the search path to default.
190    ///
191    /// # Errors
192    ///
193    /// Returns an error if an attachment cannot be detached, search path cannot be set or the connection fails.
194    pub fn detach(&self, conn: &Connection) -> Result<()> {
195        for (i, _) in self.attachments.iter().enumerate() {
196            conn.execute(
197                &format!(
198                    "DETACH DATABASE IF EXISTS {}",
199                    Self::get_attachment_name(&self.random_id, i)
200                ),
201                [],
202            )
203            .context(DuckDBConnectionSnafu)?;
204        }
205
206        self.reset_search_path(conn)?;
207        Ok(())
208    }
209
210    #[must_use]
211    fn get_attachment_name(random_id: &str, index: usize) -> String {
212        format!("attachment_{random_id}_{index}")
213    }
214}
215
216pub struct DuckDbConnection {
217    pub conn: r2d2::PooledConnection<DuckdbConnectionManager>,
218    attachments: Option<Arc<DuckDBAttachments>>,
219    unsupported_type_action: UnsupportedTypeAction,
220    connection_setup_queries: Vec<Arc<str>>,
221}
222
223impl SchemaValidator for DuckDbConnection {
224    type Error = super::Error;
225
226    fn is_data_type_supported(data_type: &DataType) -> bool {
227        match data_type {
228            DataType::List(inner_field)
229            | DataType::FixedSizeList(inner_field, _)
230            | DataType::LargeList(inner_field) => {
231                match inner_field.data_type() {
232                    dt if dt.is_primitive() => true,
233                    DataType::Utf8
234                    | DataType::Binary
235                    | DataType::Utf8View
236                    | DataType::BinaryView
237                    | DataType::Boolean => true,
238                    _ => false, // nested lists don't support anything else yet
239                }
240            }
241            DataType::Struct(inner_fields) => inner_fields
242                .iter()
243                .all(|field| Self::is_data_type_supported(field.data_type())),
244            _ => true,
245        }
246    }
247
248    fn is_field_supported(field: &Arc<Field>) -> bool {
249        let duckdb_type_id = to_duckdb_type_id(field.data_type());
250        Self::is_data_type_supported(field.data_type()) && duckdb_type_id.is_ok()
251    }
252
253    fn unsupported_type_error(data_type: &DataType, field_name: &str) -> Self::Error {
254        super::Error::UnsupportedDataType {
255            data_type: data_type.to_string(),
256            field_name: field_name.to_string(),
257        }
258    }
259}
260
261impl DuckDbConnection {
262    pub fn get_underlying_conn_mut(
263        &mut self,
264    ) -> &mut r2d2::PooledConnection<DuckdbConnectionManager> {
265        &mut self.conn
266    }
267
268    #[must_use]
269    pub fn with_unsupported_type_action(
270        mut self,
271        unsupported_type_action: UnsupportedTypeAction,
272    ) -> Self {
273        self.unsupported_type_action = unsupported_type_action;
274        self
275    }
276
277    #[must_use]
278    pub fn with_attachments(mut self, attachments: Option<Arc<DuckDBAttachments>>) -> Self {
279        self.attachments = attachments;
280        self
281    }
282
283    #[must_use]
284    pub fn with_connection_setup_queries(mut self, queries: Vec<Arc<str>>) -> Self {
285        self.connection_setup_queries = queries;
286        self
287    }
288
289    /// Passthrough if Option is Some for `DuckDBAttachments::attach`
290    ///
291    /// # Errors
292    ///
293    /// See `DuckDBAttachments::attach` for more information.
294    pub fn attach(conn: &Connection, attachments: &Option<Arc<DuckDBAttachments>>) -> Result<()> {
295        if let Some(attachments) = attachments {
296            attachments.attach(conn)?;
297        }
298        Ok(())
299    }
300
301    /// Passthrough if Option is Some for `DuckDBAttachments::detach`
302    ///
303    /// # Errors
304    ///
305    /// See `DuckDBAttachments::detach` for more information.
306    pub fn detach(conn: &Connection, attachments: &Option<Arc<DuckDBAttachments>>) -> Result<()> {
307        if let Some(attachments) = attachments {
308            attachments.detach(conn)?;
309        }
310        Ok(())
311    }
312
313    fn apply_connection_setup_queries(&self, conn: &Connection) -> Result<()> {
314        for query in self.connection_setup_queries.iter() {
315            conn.execute(query, []).context(DuckDBConnectionSnafu)?;
316        }
317        Ok(())
318    }
319}
320
321impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
322    for DuckDbConnection
323{
324    fn as_any(&self) -> &dyn Any {
325        self
326    }
327
328    fn as_any_mut(&mut self) -> &mut dyn Any {
329        self
330    }
331
332    fn as_sync(
333        &self,
334    ) -> Option<
335        &dyn SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
336    > {
337        Some(self)
338    }
339}
340
341impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
342    for DuckDbConnection
343{
344    fn new(conn: r2d2::PooledConnection<DuckdbConnectionManager>) -> Self {
345        DuckDbConnection {
346            conn,
347            attachments: None,
348            unsupported_type_action: UnsupportedTypeAction::default(),
349            connection_setup_queries: Vec::new(),
350        }
351    }
352
353    fn tables(&self, schema: &str) -> Result<Vec<String>, super::Error> {
354        let sql = "SELECT table_name FROM information_schema.tables \
355                  WHERE table_schema = ? AND table_type = 'BASE TABLE'";
356
357        let mut stmt = self
358            .conn
359            .prepare(sql)
360            .boxed()
361            .context(super::UnableToGetTablesSnafu)?;
362        let mut rows = stmt
363            .query([schema])
364            .boxed()
365            .context(super::UnableToGetTablesSnafu)?;
366        let mut tables = vec![];
367
368        while let Some(row) = rows.next().boxed().context(super::UnableToGetTablesSnafu)? {
369            tables.push(row.get(0).boxed().context(super::UnableToGetTablesSnafu)?);
370        }
371
372        Ok(tables)
373    }
374
375    fn schemas(&self) -> Result<Vec<String>, super::Error> {
376        let sql = "SELECT DISTINCT schema_name FROM information_schema.schemata \
377                  WHERE schema_name NOT IN ('information_schema', 'pg_catalog')";
378
379        let mut stmt = self
380            .conn
381            .prepare(sql)
382            .boxed()
383            .context(super::UnableToGetSchemasSnafu)?;
384        let mut rows = stmt
385            .query([])
386            .boxed()
387            .context(super::UnableToGetSchemasSnafu)?;
388        let mut schemas = vec![];
389
390        while let Some(row) = rows
391            .next()
392            .boxed()
393            .context(super::UnableToGetSchemasSnafu)?
394        {
395            schemas.push(row.get(0).boxed().context(super::UnableToGetSchemasSnafu)?);
396        }
397
398        Ok(schemas)
399    }
400
401    fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, super::Error> {
402        let table_str = if is_table_function(table_reference) {
403            table_reference.to_string()
404        } else {
405            table_reference.to_quoted_string()
406        };
407        let mut stmt = self
408            .conn
409            .prepare(&format!("SELECT * FROM {table_str} LIMIT 0"))
410            .boxed()
411            .context(super::UnableToGetSchemaSnafu)?;
412
413        let result: duckdb::Arrow<'_> = stmt
414            .query_arrow([])
415            .boxed()
416            .context(super::UnableToGetSchemaSnafu)?;
417
418        Self::handle_unsupported_schema(&result.get_schema(), self.unsupported_type_action)
419    }
420
421    fn query_arrow(
422        &self,
423        sql: &str,
424        params: &[DuckDBParameter],
425        _projected_schema: Option<SchemaRef>,
426    ) -> Result<SendableRecordBatchStream> {
427        let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
428
429        let conn = self.conn.try_clone()?;
430        Self::attach(&conn, &self.attachments)?;
431        self.apply_connection_setup_queries(&conn)?;
432
433        let fetch_schema_sql =
434            format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0");
435        let mut stmt = conn
436            .prepare(&fetch_schema_sql)
437            .boxed()
438            .context(super::UnableToGetSchemaSnafu)?;
439
440        let result: duckdb::Arrow<'_> = stmt
441            .query_arrow([])
442            .boxed()
443            .context(super::UnableToGetSchemaSnafu)?;
444
445        let schema = result.get_schema();
446
447        let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
448
449        let sql = sql.to_string();
450
451        let cloned_schema = schema.clone();
452
453        let create_stream = || -> Result<SendableRecordBatchStream> {
454            let join_handle = tokio::task::spawn_blocking(move || {
455                let mut stmt = conn.prepare(&sql).context(DuckDBQuerySnafu)?;
456                let params: &[&dyn ToSql] = &params
457                    .iter()
458                    .map(|f| f.as_input_parameter())
459                    .collect::<Vec<_>>();
460                let result: duckdb::ArrowStream<'_> = stmt
461                    .stream_arrow(params, cloned_schema)
462                    .context(DuckDBQuerySnafu)?;
463                for i in result {
464                    blocking_channel_send(&batch_tx, i)?;
465                }
466                Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
467            });
468
469            let output_stream = stream! {
470                while let Some(batch) = batch_rx.recv().await {
471                    yield Ok(batch);
472                }
473
474                match join_handle.await {
475                    Ok(Err(task_error)) => {
476                        yield Err(DataFusionError::Execution(format!(
477                            "Failed to execute DuckDB query: {task_error}"
478                        )))
479                    },
480                    Err(join_error) => {
481                        yield Err(DataFusionError::Execution(format!(
482                            "Failed to execute DuckDB query: {join_error}"
483                        )))
484                    },
485                    _ => {}
486                }
487            };
488
489            Ok(Box::pin(RecordBatchStreamAdapter::new(
490                schema,
491                output_stream,
492            )))
493        };
494
495        run_sync_with_tokio(create_stream)
496    }
497
498    fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result<u64> {
499        let params: &[&dyn ToSql] = &params
500            .iter()
501            .map(|f| f.as_input_parameter())
502            .collect::<Vec<_>>();
503
504        let rows_modified = self.conn.execute(sql, params).context(DuckDBQuerySnafu)?;
505        Ok(rows_modified as u64)
506    }
507}
508
509fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
510    match channel.blocking_send(item) {
511        Ok(()) => Ok(()),
512        Err(e) => Err(Error::ChannelError {
513            message: format!("{e}"),
514        }
515        .into()),
516    }
517}
518
519#[must_use]
520pub fn flatten_table_function_name(table_reference: &TableReference) -> String {
521    let table_name = table_reference.table();
522    let filtered_name: String = table_name
523        .chars()
524        .filter(|c| c.is_alphanumeric() || *c == '(')
525        .collect();
526    let result = filtered_name.replace('(', "_");
527
528    format!("{result}__view")
529}
530
531#[must_use]
532pub fn is_table_function(table_reference: &TableReference) -> bool {
533    let table_name = match table_reference {
534        TableReference::Full { .. } | TableReference::Partial { .. } => return false,
535        TableReference::Bare { table } => table,
536    };
537
538    let dialect = DuckDbDialect {};
539    let mut tokenizer = Tokenizer::new(&dialect, table_name);
540    let Ok(tokens) = tokenizer.tokenize() else {
541        return false;
542    };
543    let Ok(tf) = Parser::new(&dialect)
544        .with_tokens(tokens)
545        .parse_table_factor()
546    else {
547        return false;
548    };
549
550    let TableFactor::Table { args, .. } = tf else {
551        return false;
552    };
553
554    args.is_some()
555}
556
557#[cfg(test)]
558mod tests {
559    use arrow_schema::{DataType, Field, Fields, SchemaBuilder};
560    use tempfile::tempdir;
561
562    use super::*;
563
564    #[test]
565    fn test_is_table_function() {
566        let tests = vec![
567            ("table_name", false),
568            ("table_name()", true),
569            ("table_name(arg1, arg2)", true),
570            ("read_parquet", false),
571            ("read_parquet()", true),
572            ("read_parquet('my_parquet_file.parquet')", true),
573            ("read_csv_auto('my_csv_file.csv')", true),
574        ];
575
576        for (table_name, expected) in tests {
577            let table_reference = TableReference::bare(table_name.to_string());
578            assert_eq!(is_table_function(&table_reference), expected);
579        }
580    }
581
582    #[test]
583    fn test_field_is_unsupported() {
584        // A list with a struct is not supported
585        let field = Field::new(
586            "list_struct",
587            DataType::List(Arc::new(Field::new(
588                "struct",
589                DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
590                false,
591            ))),
592            false,
593        );
594
595        assert!(
596            !DuckDbConnection::is_data_type_supported(field.data_type()),
597            "list with struct should be unsupported"
598        );
599    }
600
601    #[test]
602    fn test_fields_are_supported() {
603        // test that the usual field types are supported, string, numbers, etc
604        let fields = vec![
605            Field::new("string", DataType::Utf8, false),
606            Field::new("int", DataType::Int64, false),
607            Field::new("float", DataType::Float64, false),
608            Field::new("bool", DataType::Boolean, false),
609            Field::new("binary", DataType::Binary, false),
610        ];
611
612        for field in fields {
613            assert!(
614                DuckDbConnection::is_data_type_supported(field.data_type()),
615                "field should be supported"
616            );
617        }
618    }
619
620    #[test]
621    fn test_schema_rebuild_with_supported_fields() {
622        let fields = vec![
623            Field::new("string", DataType::Utf8, false),
624            Field::new("int", DataType::Int64, false),
625            Field::new("float", DataType::Float64, false),
626            Field::new("bool", DataType::Boolean, false),
627            Field::new("binary", DataType::Binary, false),
628        ];
629
630        let schema = Arc::new(SchemaBuilder::from(Fields::from(fields)).finish());
631
632        let rebuilt_schema =
633            DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
634                .expect("should rebuild schema successfully");
635
636        assert_eq!(schema, rebuilt_schema);
637    }
638
639    #[test]
640    fn test_schema_rebuild_with_unsupported_fields() {
641        let fields = vec![
642            Field::new("string", DataType::Utf8, false),
643            Field::new("int", DataType::Int64, false),
644            Field::new("float", DataType::Float64, false),
645            Field::new("bool", DataType::Boolean, false),
646            Field::new("binary", DataType::Binary, false),
647            Field::new(
648                "list_struct",
649                DataType::List(Arc::new(Field::new(
650                    "struct",
651                    DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
652                    false,
653                ))),
654                false,
655            ),
656            Field::new("another_bool", DataType::Boolean, false),
657            Field::new(
658                "another_list_struct",
659                DataType::List(Arc::new(Field::new(
660                    "struct",
661                    DataType::Struct(vec![Field::new("field", DataType::Int64, false)].into()),
662                    false,
663                ))),
664                false,
665            ),
666            Field::new("another_float", DataType::Float32, false),
667        ];
668
669        let rebuilt_fields = vec![
670            Field::new("string", DataType::Utf8, false),
671            Field::new("int", DataType::Int64, false),
672            Field::new("float", DataType::Float64, false),
673            Field::new("bool", DataType::Boolean, false),
674            Field::new("binary", DataType::Binary, false),
675            // this also tests that ordering is preserved when rebuilding the schema with removed fields
676            Field::new("another_bool", DataType::Boolean, false),
677            Field::new("another_float", DataType::Float32, false),
678        ];
679
680        let schema = Arc::new(SchemaBuilder::from(Fields::from(fields)).finish());
681        let expected_rebuilt_schema =
682            Arc::new(SchemaBuilder::from(Fields::from(rebuilt_fields)).finish());
683
684        assert!(
685            DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Error)
686                .is_err()
687        );
688
689        let rebuilt_schema =
690            DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Warn)
691                .expect("should rebuild schema successfully");
692
693        assert_eq!(rebuilt_schema, expected_rebuilt_schema);
694
695        let rebuilt_schema =
696            DuckDbConnection::handle_unsupported_schema(&schema, UnsupportedTypeAction::Ignore)
697                .expect("should rebuild schema successfully");
698
699        assert_eq!(rebuilt_schema, expected_rebuilt_schema);
700    }
701
702    #[test]
703    fn test_duckdb_attachments_deduplication() {
704        let db1 = Arc::from("db1.duckdb");
705        let db2 = Arc::from("db2.duckdb");
706        let db3 = Arc::from("db3.duckdb");
707
708        // Create attachments with duplicates
709        let attachments = vec![
710            Arc::clone(&db1),
711            Arc::clone(&db2),
712            Arc::clone(&db1), // duplicate of db1
713            Arc::clone(&db3),
714            Arc::clone(&db2), // duplicate of db2
715        ];
716
717        let duckdb_attachments = DuckDBAttachments::new("main_db", &attachments);
718
719        // Verify that duplicates are removed
720        assert_eq!(duckdb_attachments.attachments.len(), 3);
721        assert!(duckdb_attachments.attachments.contains(&db1));
722        assert!(duckdb_attachments.attachments.contains(&db2));
723        assert!(duckdb_attachments.attachments.contains(&db3));
724    }
725
726    #[test]
727    fn test_duckdb_attachments_search_path() -> Result<()> {
728        let temp_dir = tempdir()?;
729        let db1: Arc<str> = temp_dir
730            .path()
731            .join("db1.duckdb")
732            .to_str()
733            .expect("to convert path to str")
734            .into();
735        let db2: Arc<str> = temp_dir
736            .path()
737            .join("db2.duckdb")
738            .to_str()
739            .expect("to convert path to str")
740            .into();
741        let db3: Arc<str> = temp_dir
742            .path()
743            .join("db3.duckdb")
744            .to_str()
745            .expect("to convert path to str")
746            .into();
747
748        for db in [&db1, &db2, &db3] {
749            let conn1 = Connection::open(db.as_ref())?;
750            conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
751        }
752
753        // Create attachments with duplicates
754        let attachments = vec![
755            Arc::clone(&db1),
756            Arc::clone(&db2),
757            Arc::clone(&db1), // duplicate of db1
758            Arc::clone(&db3),
759            Arc::clone(&db2), // duplicate of db2
760        ];
761
762        let duckdb_attachments = DuckDBAttachments::new("main", &attachments);
763
764        let conn = Connection::open_in_memory()?;
765
766        let search_path = duckdb_attachments.attach(&conn)?;
767
768        // Verify that the search path contains the main database and unique attachments
769
770        assert!(search_path.starts_with("main"));
771        assert!(search_path.contains("attachment_"));
772        assert_eq!(search_path.split(',').count(), 4); // main + 3 unique attachments
773
774        Ok(())
775    }
776
777    #[test]
778    fn test_duckdb_attachments_empty() -> Result<()> {
779        let duckdb_attachments = DuckDBAttachments::new("main", &[]);
780
781        // Verify empty attachments
782        assert!(duckdb_attachments.attachments.is_empty());
783
784        // Verify search path only contains main database
785
786        let conn = Connection::open_in_memory()?;
787
788        let search_path = duckdb_attachments.attach(&conn)?;
789        assert_eq!(search_path, "main".into());
790
791        Ok(())
792    }
793
794    #[test]
795    fn test_duckdb_attachments_with_real_files() -> Result<()> {
796        // Create a temporary directory for our test files
797        let temp_dir = tempdir()?;
798        let db1_path = temp_dir.path().join("db1.duckdb");
799        let db2_path = temp_dir.path().join("db2.duckdb");
800
801        // Create two test databases with some data
802        {
803            let conn1 = Connection::open(&db1_path)?;
804            conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
805            conn1.execute("INSERT INTO test1 VALUES (1, 'test1_1')", [])?;
806
807            let conn2 = Connection::open(&db2_path)?;
808            conn2.execute("CREATE TABLE test2 (id INTEGER, name VARCHAR)", [])?;
809            conn2.execute("INSERT INTO test2 VALUES (2, 'test2_1')", [])?;
810        }
811
812        // Create attachments with duplicates
813        let attachments = vec![
814            Arc::from(db1_path.to_str().unwrap()),
815            Arc::from(db2_path.to_str().unwrap()),
816            Arc::from(db1_path.to_str().unwrap()), // duplicate of db1
817        ];
818
819        // Create a new in-memory DuckDB connection
820        let conn = Connection::open_in_memory()?;
821
822        // Create DuckDBAttachments and attach the databases
823        let duckdb_attachments = DuckDBAttachments::new("main", &attachments);
824        duckdb_attachments.attach(&conn)?;
825
826        // Verify we can query data from both databases
827        let result1: (i64, String) = conn
828            .query_row("SELECT * FROM test1 LIMIT 1", [], |row| {
829                Ok((
830                    row.get::<_, i64>(0).expect("to get i64"),
831                    row.get::<_, String>(1).expect("to get string"),
832                ))
833            })
834            .expect("to get result");
835        let result2: (i64, String) = conn
836            .query_row("SELECT * FROM test2 LIMIT 1", [], |row| {
837                Ok((
838                    row.get::<_, i64>(0).expect("to get i64"),
839                    row.get::<_, String>(1).expect("to get string"),
840                ))
841            })
842            .expect("to get result");
843
844        assert_eq!(result1, (1, "test1_1".to_string()));
845        assert_eq!(result2, (2, "test2_1".to_string()));
846
847        // Verify the search path
848        let search_path: String = conn
849            .query_row("SELECT current_setting('search_path');", [], |row| {
850                Ok(row.get::<_, String>(0).expect("to get string"))
851            })
852            .expect("to get search path");
853        assert!(search_path.contains("main"));
854        assert!(search_path.contains("attachment_"));
855
856        // Clean up
857        duckdb_attachments.detach(&conn)?;
858        Ok(())
859    }
860
861    #[test]
862    fn test_duckdb_attach_multiple_times() -> Result<()> {
863        // Create a temporary directory for our test files
864        let temp_dir = tempdir()?;
865        let db1_path = temp_dir.path().join("db1.duckdb");
866        let db2_path = temp_dir.path().join("db2.duckdb");
867
868        // Create two test databases with some data
869        {
870            let conn1 = Connection::open(&db1_path)?;
871            conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
872            conn1.execute("INSERT INTO test1 VALUES (1, 'test1_1')", [])?;
873
874            let conn2 = Connection::open(&db2_path)?;
875            conn2.execute("CREATE TABLE test2 (id INTEGER, name VARCHAR)", [])?;
876            conn2.execute("INSERT INTO test2 VALUES (2, 'test2_1')", [])?;
877        }
878
879        let attachments = vec![
880            Arc::from(db1_path.to_str().expect("to convert path top str")),
881            Arc::from(db2_path.to_str().expect("to convert path top str")),
882        ];
883
884        let conn = Connection::open_in_memory()?;
885
886        // Simulate attaching to the same connection multiple times
887        DuckDBAttachments::new("main", &attachments).attach(&conn)?;
888        DuckDBAttachments::new("main", &attachments).attach(&conn)?;
889        DuckDBAttachments::new("main", &attachments).attach(&conn)?;
890
891        let join_result: (i64, String, i64, String) = conn
892            .query_row(
893                "SELECT t1.id, t1.name, t2.id, t2.name FROM test1 t1, test2 t2",
894                [],
895                |row| {
896                    Ok((
897                        row.get::<_, i64>(0).expect("to get i64"),
898                        row.get::<_, String>(1).expect("to get string"),
899                        row.get::<_, i64>(2).expect("to get i64"),
900                        row.get::<_, String>(3).expect("to get string"),
901                    ))
902                },
903            )
904            .expect("to get join result");
905
906        assert_eq!(
907            join_result,
908            (1, "test1_1".to_string(), 2, "test2_1".to_string())
909        );
910
911        Ok(())
912    }
913}