datafusion_table_providers/sql/db_connection_pool/
dbconnection.rs

1use std::{any::Any, sync::Arc};
2
3use datafusion::{
4    arrow::datatypes::SchemaRef, execution::SendableRecordBatchStream, sql::TableReference,
5};
6use snafu::prelude::*;
7
8#[cfg(feature = "clickhouse")]
9pub mod clickhouseconn;
10#[cfg(feature = "duckdb")]
11pub mod duckdbconn;
12#[cfg(feature = "mysql")]
13pub mod mysqlconn;
14#[cfg(feature = "odbc")]
15pub mod odbcconn;
16#[cfg(feature = "postgres")]
17pub mod postgresconn;
18#[cfg(feature = "sqlite")]
19pub mod sqliteconn;
20
21pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
22type Result<T, E = GenericError> = std::result::Result<T, E>;
23
24#[derive(Debug, Snafu)]
25pub enum Error {
26    #[snafu(display("Unable to downcast connection.\nReport a bug to request support: https://github.com/datafusion-contrib/datafusion-table-providers/issues"))]
27    UnableToDowncastConnection {},
28
29    #[snafu(display("{source}"))]
30    UnableToGetSchema { source: GenericError },
31
32    #[snafu(display("The field '{field_name}' has an unsupported data type: {data_type}."))]
33    UnsupportedDataType {
34        data_type: String,
35        field_name: String,
36    },
37
38    #[snafu(display("Failed to execute query.\n{source}"))]
39    UnableToQueryArrow { source: GenericError },
40
41    #[snafu(display(
42        "Table '{table_name}' not found. Ensure the table name is correctly spelled."
43    ))]
44    UndefinedTable {
45        table_name: String,
46        source: GenericError,
47    },
48
49    #[snafu(display("Unable to get schemas: {source}"))]
50    UnableToGetSchemas { source: GenericError },
51
52    #[snafu(display("Unable to get tables: {source}"))]
53    UnableToGetTables { source: GenericError },
54}
55
56pub trait SyncDbConnection<T, P>: DbConnection<T, P> {
57    fn new(conn: T) -> Self
58    where
59        Self: Sized;
60
61    fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
62
63    fn schemas(&self) -> Result<Vec<String>, Error>;
64
65    /// Get the schema for a table reference.
66    ///
67    /// # Arguments
68    ///
69    /// * `table_reference` - The table reference.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if the schema cannot be retrieved.
74    fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
75
76    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
77    ///
78    /// # Arguments
79    ///
80    /// * `sql` - The SQL statement.
81    /// * `params` - The parameters for the SQL statement.
82    /// * `projected_schema` - The Projected schema for the query.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the query fails.
87    fn query_arrow(
88        &self,
89        sql: &str,
90        params: &[P],
91        projected_schema: Option<SchemaRef>,
92    ) -> Result<SendableRecordBatchStream>;
93
94    /// Execute the given SQL statement with parameters, returning the number of affected rows.
95    ///
96    /// # Arguments
97    ///
98    /// * `sql` - The SQL statement.
99    /// * `params` - The parameters for the SQL statement.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if the execution fails.
104    fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
105}
106
107#[async_trait::async_trait]
108pub trait AsyncDbConnection<T, P>: DbConnection<T, P> + Sync {
109    fn new(conn: T) -> Self
110    where
111        Self: Sized;
112
113    async fn tables(&self, schema: &str) -> Result<Vec<String>, Error>;
114
115    async fn schemas(&self) -> Result<Vec<String>, Error>;
116
117    /// Get the schema for a table reference.
118    ///
119    /// # Arguments
120    ///
121    /// * `table_reference` - The table reference.
122    async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error>;
123
124    /// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
125    ///
126    /// # Arguments
127    ///
128    /// * `sql` - The SQL statement.
129    /// * `params` - The parameters for the SQL statement.
130    /// * `projected_schema` - The Projected schema for the query.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the query fails.
135    async fn query_arrow(
136        &self,
137        sql: &str,
138        params: &[P],
139        projected_schema: Option<SchemaRef>,
140    ) -> Result<SendableRecordBatchStream>;
141
142    /// Execute the given SQL statement with parameters, returning the number of affected rows.
143    ///
144    /// # Arguments
145    ///
146    /// * `sql` - The SQL statement.
147    /// * `params` - The parameters for the SQL statement.
148    async fn execute(&self, sql: &str, params: &[P]) -> Result<u64>;
149}
150
151pub trait DbConnection<T, P>: Send {
152    fn as_any(&self) -> &dyn Any;
153    fn as_any_mut(&mut self) -> &mut dyn Any;
154
155    fn as_sync(&self) -> Option<&dyn SyncDbConnection<T, P>> {
156        None
157    }
158    fn as_async(&self) -> Option<&dyn AsyncDbConnection<T, P>> {
159        None
160    }
161}
162
163pub async fn get_tables<T, P>(
164    conn: Box<dyn DbConnection<T, P>>,
165    schema: &str,
166) -> Result<Vec<String>, Error> {
167    let schema = if let Some(conn) = conn.as_sync() {
168        conn.tables(schema)?
169    } else if let Some(conn) = conn.as_async() {
170        conn.tables(schema).await?
171    } else {
172        return Err(Error::UnableToDowncastConnection {});
173    };
174    Ok(schema)
175}
176
177/// Get the schemas for the database.
178///
179/// # Errors
180///
181/// Returns an error if the schemas cannot be retrieved.
182pub async fn get_schemas<T, P>(conn: Box<dyn DbConnection<T, P>>) -> Result<Vec<String>, Error> {
183    let schema = if let Some(conn) = conn.as_sync() {
184        conn.schemas()?
185    } else if let Some(conn) = conn.as_async() {
186        conn.schemas().await?
187    } else {
188        return Err(Error::UnableToDowncastConnection {});
189    };
190    Ok(schema)
191}
192
193/// Get the schema for a table reference.
194///
195/// # Arguments
196///
197/// * `conn` - The database connection.
198/// * `table_reference` - The table reference.
199///
200/// # Errors
201///
202/// Returns an error if the schema cannot be retrieved.
203pub async fn get_schema<T, P>(
204    conn: Box<dyn DbConnection<T, P>>,
205    table_reference: &datafusion::sql::TableReference,
206) -> Result<Arc<datafusion::arrow::datatypes::Schema>, Error> {
207    let schema = if let Some(conn) = conn.as_sync() {
208        conn.get_schema(table_reference)?
209    } else if let Some(conn) = conn.as_async() {
210        conn.get_schema(table_reference).await?
211    } else {
212        return Err(Error::UnableToDowncastConnection {});
213    };
214    Ok(schema)
215}
216
217/// Query the database with the given SQL statement and parameters, returning a `Result` of `SendableRecordBatchStream`.
218///
219/// # Arguments
220///
221/// * `conn` - The database connection.
222/// * `sql` - The SQL statement.
223///
224/// # Errors
225///
226/// Returns an error if the query fails.
227pub async fn query_arrow<T, P>(
228    conn: Box<dyn DbConnection<T, P>>,
229    sql: String,
230    projected_schema: Option<SchemaRef>,
231) -> Result<SendableRecordBatchStream, Error> {
232    if let Some(conn) = conn.as_sync() {
233        conn.query_arrow(&sql, &[], projected_schema)
234            .context(UnableToQueryArrowSnafu {})
235    } else if let Some(conn) = conn.as_async() {
236        conn.query_arrow(&sql, &[], projected_schema)
237            .await
238            .context(UnableToQueryArrowSnafu {})
239    } else {
240        return Err(Error::UnableToDowncastConnection {});
241    }
242}