datafusion_table_providers/sql/db_connection_pool/
dbconnection.rs

1use std::{any::Any, sync::Arc};
2
3use datafusion::{
4    arrow::datatypes::SchemaRef, execution::SendableRecordBatchStream, sql::TableReference,
5};
6use snafu::prelude::*;
7
8#[cfg(feature = "duckdb")]
9pub mod duckdbconn;
10#[cfg(feature = "mysql")]
11pub mod mysqlconn;
12#[cfg(feature = "odbc")]
13pub mod odbcconn;
14#[cfg(feature = "postgres")]
15pub mod postgresconn;
16#[cfg(feature = "sqlite")]
17pub mod sqliteconn;
18
19pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
20type Result<T, E = GenericError> = std::result::Result<T, E>;
21
22#[derive(Debug, Snafu)]
23pub enum Error {
24    #[snafu(display("Unable to downcast connection"))]
25    UnableToDowncastConnection {},
26
27    #[snafu(display("Unable to get schema: {source}"))]
28    UnableToGetSchema { source: GenericError },
29
30    #[snafu(display("Unable to query arrow: {source}"))]
31    UnableToQueryArrow { source: GenericError },
32
33    #[snafu(display("Table {table_name} not found. Ensure the table name is correctly spelled."))]
34    UndefinedTable {
35        table_name: String,
36        source: GenericError,
37    },
38
39    #[snafu(display("Unable to get schemas: {source}"))]
40    UnableToGetSchemas { source: GenericError },
41
42    #[snafu(display("Unable to get tables: {source}"))]
43    UnableToGetTables { source: GenericError },
44}
45
46pub trait SyncDbConnection<T, P>: DbConnection<T, P> {
47    fn new(conn: T) -> Self
48    where
49        Self: Sized;
50
51    fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
52
53    fn schemas(&self) -> Result<Vec<String>, Error>;
54
55    /// Get the schema for a table reference.
56    ///
57    /// # Arguments
58    ///
59    /// * `table_reference` - The table reference.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the schema cannot be retrieved.
64    fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
65
66    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
67    ///
68    /// # Arguments
69    ///
70    /// * `sql` - The SQL statement.
71    /// * `params` - The parameters for the SQL statement.
72    /// * `projected_schema` - The Projected schema for the query.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the query fails.
77    fn query_arrow(
78        &self,
79        sql: &str,
80        params: &[P],
81        projected_schema: Option<SchemaRef>,
82    ) -> Result<SendableRecordBatchStream>;
83
84    /// Execute the given SQL statement with parameters, returning the number of affected rows.
85    ///
86    /// # Arguments
87    ///
88    /// * `sql` - The SQL statement.
89    /// * `params` - The parameters for the SQL statement.
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the execution fails.
94    fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
95}
96
97#[async_trait::async_trait]
98pub trait AsyncDbConnection<T, P>: DbConnection<T, P> + Sync {
99    fn new(conn: T) -> Self
100    where
101        Self: Sized;
102
103    async fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
104
105    async fn schemas(&self) -> Result<Vec<String>, Error>;
106
107    /// Get the schema for a table reference.
108    ///
109    /// # Arguments
110    ///
111    /// * `table_reference` - The table reference.
112    async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
113
114    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
115    ///
116    /// # Arguments
117    ///
118    /// * `sql` - The SQL statement.
119    /// * `params` - The parameters for the SQL statement.
120    /// * `projected_schema` - The Projected schema for the query.
121    ///
122    /// # Errors
123    ///
124    /// Returns an error if the query fails.
125    async fn query_arrow(
126        &self,
127        sql: &str,
128        params: &[P],
129        projected_schema: Option<SchemaRef>,
130    ) -> Result<SendableRecordBatchStream>;
131
132    /// Execute the given SQL statement with parameters, returning the number of affected rows.
133    ///
134    /// # Arguments
135    ///
136    /// * `sql` - The SQL statement.
137    /// * `params` - The parameters for the SQL statement.
138    async fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
139}
140
141pub trait DbConnection<T, P>: Send {
142    fn as_any(&self) -> &dyn Any;
143    fn as_any_mut(&mut self) -> &mut dyn Any;
144
145    fn as_sync(&self) -> Option<&dyn SyncDbConnection<T, P>> {
146        None
147    }
148    fn as_async(&self) -> Option<&dyn AsyncDbConnection<T, P>> {
149        None
150    }
151}
152
153pub async fn get_tables<T, P>(
154    conn: Box<dyn DbConnection<T, P>>,
155    schema: &str,
156) -> Result<Vec<String>, Error> {
157    let schema = if let Some(conn) = conn.as_sync() {
158        conn.tables(schema)?
159    } else if let Some(conn) = conn.as_async() {
160        conn.tables(schema).await?
161    } else {
162        return Err(Error::UnableToDowncastConnection {});
163    };
164    Ok(schema)
165}
166
167/// Get the schemas for the database.
168///
169/// # Errors
170///
171/// Returns an error if the schemas cannot be retrieved.
172pub async fn get_schemas<T, P>(conn: Box<dyn DbConnection<T, P>>) -> Result<Vec<String>, Error> {
173    let schema = if let Some(conn) = conn.as_sync() {
174        conn.schemas()?
175    } else if let Some(conn) = conn.as_async() {
176        conn.schemas().await?
177    } else {
178        return Err(Error::UnableToDowncastConnection {});
179    };
180    Ok(schema)
181}
182
183/// Get the schema for a table reference.
184///
185/// # Arguments
186///
187/// * `conn` - The database connection.
188/// * `table_reference` - The table reference.
189///
190/// # Errors
191///
192/// Returns an error if the schema cannot be retrieved.
193pub async fn get_schema<T, P>(
194    conn: Box<dyn DbConnection<T, P>>,
195    table_reference: &datafusion::sql::TableReference,
196) -> Result<Arc<datafusion::arrow::datatypes::Schema>, Error> {
197    let schema = if let Some(conn) = conn.as_sync() {
198        conn.get_schema(table_reference)?
199    } else if let Some(conn) = conn.as_async() {
200        conn.get_schema(table_reference).await?
201    } else {
202        return Err(Error::UnableToDowncastConnection {});
203    };
204    Ok(schema)
205}
206
207/// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
208///
209/// # Arguments
210///
211/// * `conn` - The database connection.
212/// * `sql` - The SQL statement.
213///
214/// # Errors
215///
216/// Returns an error if the query fails.
217pub async fn query_arrow<T, P>(
218    conn: Box<dyn DbConnection<T, P>>,
219    sql: String,
220    projected_schema: Option<SchemaRef>,
221) -> Result<SendableRecordBatchStream, Error> {
222    if let Some(conn) = conn.as_sync() {
223        conn.query_arrow(&sql, &[], projected_schema)
224            .context(UnableToQueryArrowSnafu {})
225    } else if let Some(conn) = conn.as_async() {
226        conn.query_arrow(&sql, &[], projected_schema)
227            .await
228            .context(UnableToQueryArrowSnafu {})
229    } else {
230        return Err(Error::UnableToDowncastConnection {});
231    }
232}