datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22use std::any::Any;
23
24use crate::{DFResult, RemoteSchemaRef, extract_primitive_array};
25use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
26use datafusion::common::DataFusionError;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::common::collect;
29use datafusion::sql::unparser::Unparser;
30use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
31use std::fmt::Debug;
32use std::sync::Arc;
33
34#[cfg(feature = "dm")]
35pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
36
37#[async_trait::async_trait]
38pub trait Pool: Debug + Send + Sync {
39    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
40}
41
42#[async_trait::async_trait]
43pub trait Connection: Debug + Send + Sync {
44    fn as_any(&self) -> &dyn Any;
45
46    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
47
48    async fn query(
49        &self,
50        conn_options: &ConnectionOptions,
51        sql: &str,
52        table_schema: SchemaRef,
53        projection: Option<&Vec<usize>>,
54        unparsed_filters: &[String],
55        limit: Option<usize>,
56    ) -> DFResult<SendableRecordBatchStream>;
57}
58
59pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
60    match options {
61        #[cfg(feature = "postgres")]
62        ConnectionOptions::Postgres(options) => {
63            let pool = connect_postgres(options).await?;
64            Ok(Arc::new(pool))
65        }
66        #[cfg(feature = "mysql")]
67        ConnectionOptions::Mysql(options) => {
68            let pool = connect_mysql(options)?;
69            Ok(Arc::new(pool))
70        }
71        #[cfg(feature = "oracle")]
72        ConnectionOptions::Oracle(options) => {
73            let pool = connect_oracle(options).await?;
74            Ok(Arc::new(pool))
75        }
76        #[cfg(feature = "sqlite")]
77        ConnectionOptions::Sqlite(options) => {
78            let pool = connect_sqlite(options).await?;
79            Ok(Arc::new(pool))
80        }
81        #[cfg(feature = "dm")]
82        ConnectionOptions::Dm(options) => {
83            let pool = connect_dm(options)?;
84            Ok(Arc::new(pool))
85        }
86    }
87}
88
89#[derive(Debug, Clone)]
90pub enum ConnectionOptions {
91    #[cfg(feature = "postgres")]
92    Postgres(PostgresConnectionOptions),
93    #[cfg(feature = "oracle")]
94    Oracle(OracleConnectionOptions),
95    #[cfg(feature = "mysql")]
96    Mysql(MysqlConnectionOptions),
97    #[cfg(feature = "sqlite")]
98    Sqlite(SqliteConnectionOptions),
99    #[cfg(feature = "dm")]
100    Dm(DmConnectionOptions),
101}
102
103impl ConnectionOptions {
104    pub(crate) fn stream_chunk_size(&self) -> usize {
105        match self {
106            #[cfg(feature = "postgres")]
107            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
108            #[cfg(feature = "oracle")]
109            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
110            #[cfg(feature = "mysql")]
111            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
112            #[cfg(feature = "sqlite")]
113            ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
114            #[cfg(feature = "dm")]
115            ConnectionOptions::Dm(options) => options.stream_chunk_size,
116        }
117    }
118
119    pub(crate) fn db_type(&self) -> RemoteDbType {
120        match self {
121            #[cfg(feature = "postgres")]
122            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
123            #[cfg(feature = "oracle")]
124            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
125            #[cfg(feature = "mysql")]
126            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
127            #[cfg(feature = "sqlite")]
128            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
129            #[cfg(feature = "dm")]
130            ConnectionOptions::Dm(_) => RemoteDbType::Dm,
131        }
132    }
133}
134
135pub enum RemoteDbType {
136    Postgres,
137    Mysql,
138    Oracle,
139    Sqlite,
140    Dm,
141}
142
143impl RemoteDbType {
144    pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
145        sql.trim()[0..6].eq_ignore_ascii_case("select")
146    }
147
148    pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
149        match self {
150            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
151            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
152            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
153            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
154                "Oracle unparser not implemented".to_string(),
155            )),
156            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
157                "Dm unparser not implemented".to_string(),
158            )),
159        }
160    }
161
162    pub(crate) fn rewrite_query(
163        &self,
164        sql: &str,
165        unparsed_filters: &[String],
166        limit: Option<usize>,
167    ) -> String {
168        match self {
169            RemoteDbType::Postgres
170            | RemoteDbType::Mysql
171            | RemoteDbType::Sqlite
172            | RemoteDbType::Dm => {
173                let where_clause = if unparsed_filters.is_empty() {
174                    "".to_string()
175                } else {
176                    format!(" WHERE {}", unparsed_filters.join(" AND "))
177                };
178                let limit_clause = if let Some(limit) = limit {
179                    format!(" LIMIT {limit}")
180                } else {
181                    "".to_string()
182                };
183
184                if where_clause.is_empty() && limit_clause.is_empty() {
185                    sql.to_string()
186                } else {
187                    format!("SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}")
188                }
189            }
190            RemoteDbType::Oracle => {
191                let mut all_filters: Vec<String> = vec![];
192                all_filters.extend_from_slice(unparsed_filters);
193                if let Some(limit) = limit {
194                    all_filters.push(format!("ROWNUM <= {limit}"))
195                }
196
197                let where_clause = if all_filters.is_empty() {
198                    "".to_string()
199                } else {
200                    format!(" WHERE {}", all_filters.join(" AND "))
201                };
202                if where_clause.is_empty() {
203                    sql.to_string()
204                } else {
205                    format!("SELECT * FROM ({sql}){where_clause}")
206                }
207            }
208        }
209    }
210
211    pub(crate) fn query_limit_1(&self, sql: &str) -> String {
212        if !self.support_rewrite_with_filters_limit(sql) {
213            return sql.to_string();
214        }
215        self.rewrite_query(sql, &[], Some(1))
216    }
217
218    pub(crate) fn try_count1_query(&self, sql: &str) -> Option<String> {
219        if !self.support_rewrite_with_filters_limit(sql) {
220            return None;
221        }
222        match self {
223            RemoteDbType::Postgres
224            | RemoteDbType::Mysql
225            | RemoteDbType::Sqlite
226            | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({sql}) AS __subquery")),
227            RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({sql})")),
228        }
229    }
230
231    pub(crate) async fn fetch_count(
232        &self,
233        conn: Arc<dyn Connection>,
234        conn_options: &ConnectionOptions,
235        count1_query: &str,
236    ) -> DFResult<usize> {
237        let count1_schema = Arc::new(Schema::new(vec![Field::new(
238            "count(1)",
239            DataType::Int64,
240            false,
241        )]));
242        let stream = conn
243            .query(conn_options, count1_query, count1_schema, None, &[], None)
244            .await?;
245        let batches = collect(stream).await?;
246        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
247        if count_vec.len() != 1 {
248            return Err(DataFusionError::Execution(format!(
249                "Count query did not return exactly one row: {count_vec:?}",
250            )));
251        }
252        count_vec[0]
253            .map(|count| count as usize)
254            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
255    }
256}
257
258pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
259    match projection {
260        Some(p) => p.contains(&col_idx),
261        None => true,
262    }
263}
264
265#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
266fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
267    use bigdecimal::{FromPrimitive, ToPrimitive};
268    let scale = scale.unwrap_or_else(|| {
269        decimal
270            .fractional_digit_count()
271            .try_into()
272            .unwrap_or_default()
273    });
274    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
275    (decimal * scale_decimal).to_i128()
276}
277
278#[allow(unused)]
279fn just_return<T>(v: T) -> DFResult<T> {
280    Ok(v)
281}
282
283#[allow(unused)]
284fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
285    Ok(*t)
286}