datafusion_table_providers/sql/db_connection_pool/
dbconnection.rs

1use std::{any::Any, sync::Arc};
2
3use datafusion::{
4    arrow::datatypes::SchemaRef, execution::SendableRecordBatchStream, sql::TableReference,
5};
6use snafu::prelude::*;
7
8#[cfg(feature = "duckdb")]
9pub mod duckdbconn;
10#[cfg(feature = "mysql")]
11pub mod mysqlconn;
12#[cfg(feature = "odbc")]
13pub mod odbcconn;
14#[cfg(feature = "postgres")]
15pub mod postgresconn;
16#[cfg(feature = "sqlite")]
17pub mod sqliteconn;
18
19pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
20type Result<T, E = GenericError> = std::result::Result<T, E>;
21
22#[derive(Debug, Snafu)]
23pub enum Error {
24    #[snafu(display("Unable to downcast connection.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
25    UnableToDowncastConnection {},
26
27    #[snafu(display("{source}"))]
28    UnableToGetSchema { source: GenericError },
29
30    #[snafu(display("The field '{field_name}' has an unsupported data type: {data_type}."))]
31    UnsupportedDataType {
32        data_type: String,
33        field_name: String,
34    },
35
36    #[snafu(display("Failed to execute query.\n{source}"))]
37    UnableToQueryArrow { source: GenericError },
38
39    #[snafu(display(
40        "Table '{table_name}' not found. Ensure the table name is correctly spelled."
41    ))]
42    UndefinedTable {
43        table_name: String,
44        source: GenericError,
45    },
46
47    #[snafu(display("Unable to get schemas: {source}"))]
48    UnableToGetSchemas { source: GenericError },
49
50    #[snafu(display("Unable to get tables: {source}"))]
51    UnableToGetTables { source: GenericError },
52}
53
54pub trait SyncDbConnection<T, P>: DbConnection<T, P> {
55    fn new(conn: T) -> Self
56    where
57        Self: Sized;
58
59    fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
60
61    fn schemas(&self) -> Result<Vec<String>, Error>;
62
63    /// Get the schema for a table reference.
64    ///
65    /// # Arguments
66    ///
67    /// * `table_reference` - The table reference.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the schema cannot be retrieved.
72    fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
73
74    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
75    ///
76    /// # Arguments
77    ///
78    /// * `sql` - The SQL statement.
79    /// * `params` - The parameters for the SQL statement.
80    /// * `projected_schema` - The Projected schema for the query.
81    ///
82    /// # Errors
83    ///
84    /// Returns an error if the query fails.
85    fn query_arrow(
86        &self,
87        sql: &str,
88        params: &[P],
89        projected_schema: Option<SchemaRef>,
90    ) -> Result<SendableRecordBatchStream>;
91
92    /// Execute the given SQL statement with parameters, returning the number of affected rows.
93    ///
94    /// # Arguments
95    ///
96    /// * `sql` - The SQL statement.
97    /// * `params` - The parameters for the SQL statement.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the execution fails.
102    fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
103}
104
105#[async_trait::async_trait]
106pub trait AsyncDbConnection<T, P>: DbConnection<T, P> + Sync {
107    fn new(conn: T) -> Self
108    where
109        Self: Sized;
110
111    async fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
112
113    async fn schemas(&self) -> Result<Vec<String>, Error>;
114
115    /// Get the schema for a table reference.
116    ///
117    /// # Arguments
118    ///
119    /// * `table_reference` - The table reference.
120    async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
121
122    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
123    ///
124    /// # Arguments
125    ///
126    /// * `sql` - The SQL statement.
127    /// * `params` - The parameters for the SQL statement.
128    /// * `projected_schema` - The Projected schema for the query.
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if the query fails.
133    async fn query_arrow(
134        &self,
135        sql: &str,
136        params: &[P],
137        projected_schema: Option<SchemaRef>,
138    ) -> Result<SendableRecordBatchStream>;
139
140    /// Execute the given SQL statement with parameters, returning the number of affected rows.
141    ///
142    /// # Arguments
143    ///
144    /// * `sql` - The SQL statement.
145    /// * `params` - The parameters for the SQL statement.
146    async fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
147}
148
149pub trait DbConnection<T, P>: Send {
150    fn as_any(&self) -> &dyn Any;
151    fn as_any_mut(&mut self) -> &mut dyn Any;
152
153    fn as_sync(&self) -> Option<&dyn SyncDbConnection<T, P>> {
154        None
155    }
156    fn as_async(&self) -> Option<&dyn AsyncDbConnection<T, P>> {
157        None
158    }
159}
160
161pub async fn get_tables<T, P>(
162    conn: Box<dyn DbConnection<T, P>>,
163    schema: &str,
164) -> Result<Vec<String>, Error> {
165    let schema = if let Some(conn) = conn.as_sync() {
166        conn.tables(schema)?
167    } else if let Some(conn) = conn.as_async() {
168        conn.tables(schema).await?
169    } else {
170        return Err(Error::UnableToDowncastConnection {});
171    };
172    Ok(schema)
173}
174
175/// Get the schemas for the database.
176///
177/// # Errors
178///
179/// Returns an error if the schemas cannot be retrieved.
180pub async fn get_schemas<T, P>(conn: Box<dyn DbConnection<T, P>>) -> Result<Vec<String>, Error> {
181    let schema = if let Some(conn) = conn.as_sync() {
182        conn.schemas()?
183    } else if let Some(conn) = conn.as_async() {
184        conn.schemas().await?
185    } else {
186        return Err(Error::UnableToDowncastConnection {});
187    };
188    Ok(schema)
189}
190
191/// Get the schema for a table reference.
192///
193/// # Arguments
194///
195/// * `conn` - The database connection.
196/// * `table_reference` - The table reference.
197///
198/// # Errors
199///
200/// Returns an error if the schema cannot be retrieved.
201pub async fn get_schema<T, P>(
202    conn: Box<dyn DbConnection<T, P>>,
203    table_reference: &datafusion::sql::TableReference,
204) -> Result<Arc<datafusion::arrow::datatypes::Schema>, Error> {
205    let schema = if let Some(conn) = conn.as_sync() {
206        conn.get_schema(table_reference)?
207    } else if let Some(conn) = conn.as_async() {
208        conn.get_schema(table_reference).await?
209    } else {
210        return Err(Error::UnableToDowncastConnection {});
211    };
212    Ok(schema)
213}
214
215/// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
216///
217/// # Arguments
218///
219/// * `conn` - The database connection.
220/// * `sql` - The SQL statement.
221///
222/// # Errors
223///
224/// Returns an error if the query fails.
225pub async fn query_arrow<T, P>(
226    conn: Box<dyn DbConnection<T, P>>,
227    sql: String,
228    projected_schema: Option<SchemaRef>,
229) -> Result<SendableRecordBatchStream, Error> {
230    if let Some(conn) = conn.as_sync() {
231        conn.query_arrow(&sql, &[], projected_schema)
232            .context(UnableToQueryArrowSnafu {})
233    } else if let Some(conn) = conn.as_async() {
234        conn.query_arrow(&sql, &[], projected_schema)
235            .await
236            .context(UnableToQueryArrowSnafu {})
237    } else {
238        return Err(Error::UnableToDowncastConnection {});
239    }
240}