datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22use std::any::Any;
23
24use crate::{DFResult, RemoteSchemaRef, extract_primitive_array};
25use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
26use datafusion::common::DataFusionError;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::common::collect;
29use datafusion::sql::unparser::Unparser;
30use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
31use std::fmt::Debug;
32use std::sync::Arc;
33
34#[cfg(feature = "dm")]
35pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
36
37#[async_trait::async_trait]
38pub trait Pool: Debug + Send + Sync {
39    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
40}
41
42#[async_trait::async_trait]
43pub trait Connection: Debug + Send + Sync {
44    fn as_any(&self) -> &dyn Any;
45
46    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
47
48    async fn query(
49        &self,
50        conn_options: &ConnectionOptions,
51        sql: &str,
52        table_schema: SchemaRef,
53        projection: Option<&Vec<usize>>,
54        unparsed_filters: &[String],
55        limit: Option<usize>,
56    ) -> DFResult<SendableRecordBatchStream>;
57}
58
59pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
60    match options {
61        #[cfg(feature = "postgres")]
62        ConnectionOptions::Postgres(options) => {
63            let pool = connect_postgres(options).await?;
64            Ok(Arc::new(pool))
65        }
66        #[cfg(feature = "mysql")]
67        ConnectionOptions::Mysql(options) => {
68            let pool = connect_mysql(options)?;
69            Ok(Arc::new(pool))
70        }
71        #[cfg(feature = "oracle")]
72        ConnectionOptions::Oracle(options) => {
73            let pool = connect_oracle(options).await?;
74            Ok(Arc::new(pool))
75        }
76        #[cfg(feature = "sqlite")]
77        ConnectionOptions::Sqlite(options) => {
78            let pool = connect_sqlite(options).await?;
79            Ok(Arc::new(pool))
80        }
81        #[cfg(feature = "dm")]
82        ConnectionOptions::Dm(options) => {
83            let pool = connect_dm(options)?;
84            Ok(Arc::new(pool))
85        }
86    }
87}
88
89#[derive(Debug, Clone)]
90pub enum ConnectionOptions {
91    #[cfg(feature = "postgres")]
92    Postgres(PostgresConnectionOptions),
93    #[cfg(feature = "oracle")]
94    Oracle(OracleConnectionOptions),
95    #[cfg(feature = "mysql")]
96    Mysql(MysqlConnectionOptions),
97    #[cfg(feature = "sqlite")]
98    Sqlite(SqliteConnectionOptions),
99    #[cfg(feature = "dm")]
100    Dm(DmConnectionOptions),
101}
102
103impl ConnectionOptions {
104    pub(crate) fn stream_chunk_size(&self) -> usize {
105        match self {
106            #[cfg(feature = "postgres")]
107            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
108            #[cfg(feature = "oracle")]
109            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
110            #[cfg(feature = "mysql")]
111            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
112            #[cfg(feature = "sqlite")]
113            ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
114            #[cfg(feature = "dm")]
115            ConnectionOptions::Dm(options) => options.stream_chunk_size,
116        }
117    }
118
119    pub(crate) fn db_type(&self) -> RemoteDbType {
120        match self {
121            #[cfg(feature = "postgres")]
122            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
123            #[cfg(feature = "oracle")]
124            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
125            #[cfg(feature = "mysql")]
126            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
127            #[cfg(feature = "sqlite")]
128            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
129            #[cfg(feature = "dm")]
130            ConnectionOptions::Dm(_) => RemoteDbType::Dm,
131        }
132    }
133
134    pub fn with_pool_max_size(self, pool_max_size: usize) -> Self {
135        match self {
136            #[cfg(feature = "postgres")]
137            ConnectionOptions::Postgres(options) => {
138                ConnectionOptions::Postgres(options.with_pool_max_size(pool_max_size))
139            }
140            #[cfg(feature = "oracle")]
141            ConnectionOptions::Oracle(options) => {
142                ConnectionOptions::Oracle(options.with_pool_max_size(pool_max_size))
143            }
144            #[cfg(feature = "mysql")]
145            ConnectionOptions::Mysql(options) => {
146                ConnectionOptions::Mysql(options.with_pool_max_size(pool_max_size))
147            }
148            #[cfg(feature = "sqlite")]
149            ConnectionOptions::Sqlite(options) => ConnectionOptions::Sqlite(options),
150            #[cfg(feature = "dm")]
151            ConnectionOptions::Dm(options) => ConnectionOptions::Dm(options),
152        }
153    }
154}
155
156pub enum RemoteDbType {
157    Postgres,
158    Mysql,
159    Oracle,
160    Sqlite,
161    Dm,
162}
163
164impl RemoteDbType {
165    pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
166        sql.trim()[0..6].eq_ignore_ascii_case("select")
167    }
168
169    pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
170        match self {
171            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
172            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
173            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
174            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
175                "Oracle unparser not implemented".to_string(),
176            )),
177            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
178                "Dm unparser not implemented".to_string(),
179            )),
180        }
181    }
182
183    pub(crate) fn rewrite_query(
184        &self,
185        sql: &str,
186        unparsed_filters: &[String],
187        limit: Option<usize>,
188    ) -> String {
189        match self {
190            RemoteDbType::Postgres
191            | RemoteDbType::Mysql
192            | RemoteDbType::Sqlite
193            | RemoteDbType::Dm => {
194                let where_clause = if unparsed_filters.is_empty() {
195                    "".to_string()
196                } else {
197                    format!(" WHERE {}", unparsed_filters.join(" AND "))
198                };
199                let limit_clause = if let Some(limit) = limit {
200                    format!(" LIMIT {limit}")
201                } else {
202                    "".to_string()
203                };
204
205                if where_clause.is_empty() && limit_clause.is_empty() {
206                    sql.to_string()
207                } else {
208                    format!("SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}")
209                }
210            }
211            RemoteDbType::Oracle => {
212                let mut all_filters: Vec<String> = vec![];
213                all_filters.extend_from_slice(unparsed_filters);
214                if let Some(limit) = limit {
215                    all_filters.push(format!("ROWNUM <= {limit}"))
216                }
217
218                let where_clause = if all_filters.is_empty() {
219                    "".to_string()
220                } else {
221                    format!(" WHERE {}", all_filters.join(" AND "))
222                };
223                if where_clause.is_empty() {
224                    sql.to_string()
225                } else {
226                    format!("SELECT * FROM ({sql}){where_clause}")
227                }
228            }
229        }
230    }
231
232    pub(crate) fn query_limit_1(&self, sql: &str) -> String {
233        if !self.support_rewrite_with_filters_limit(sql) {
234            return sql.to_string();
235        }
236        self.rewrite_query(sql, &[], Some(1))
237    }
238
239    pub(crate) fn try_count1_query(&self, sql: &str) -> Option<String> {
240        if !self.support_rewrite_with_filters_limit(sql) {
241            return None;
242        }
243        match self {
244            RemoteDbType::Postgres
245            | RemoteDbType::Mysql
246            | RemoteDbType::Sqlite
247            | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({sql}) AS __subquery")),
248            RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({sql})")),
249        }
250    }
251
252    pub(crate) async fn fetch_count(
253        &self,
254        conn: Arc<dyn Connection>,
255        conn_options: &ConnectionOptions,
256        count1_query: &str,
257    ) -> DFResult<usize> {
258        let count1_schema = Arc::new(Schema::new(vec![Field::new(
259            "count(1)",
260            DataType::Int64,
261            false,
262        )]));
263        let stream = conn
264            .query(conn_options, count1_query, count1_schema, None, &[], None)
265            .await?;
266        let batches = collect(stream).await?;
267        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
268        if count_vec.len() != 1 {
269            return Err(DataFusionError::Execution(format!(
270                "Count query did not return exactly one row: {count_vec:?}",
271            )));
272        }
273        count_vec[0]
274            .map(|count| count as usize)
275            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
276    }
277}
278
279pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
280    match projection {
281        Some(p) => p.contains(&col_idx),
282        None => true,
283    }
284}
285
286#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
287fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
288    use bigdecimal::{FromPrimitive, ToPrimitive};
289    let scale = scale.unwrap_or_else(|| {
290        decimal
291            .fractional_digit_count()
292            .try_into()
293            .unwrap_or_default()
294    });
295    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
296    (decimal * scale_decimal).to_i128()
297}
298
299#[allow(unused)]
300fn just_return<T>(v: T) -> DFResult<T> {
301    Ok(v)
302}
303
304#[allow(unused)]
305fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
306    Ok(*t)
307}