datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22
23use crate::{DFResult, RemoteSchemaRef, extract_primitive_array};
24use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
25use datafusion::common::DataFusionError;
26use datafusion::execution::SendableRecordBatchStream;
27use datafusion::physical_plan::common::collect;
28use datafusion::sql::unparser::Unparser;
29use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
30use std::fmt::Debug;
31use std::sync::Arc;
32
33#[cfg(feature = "dm")]
34pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
35
36#[async_trait::async_trait]
37pub trait Pool: Debug + Send + Sync {
38    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
39}
40
41#[async_trait::async_trait]
42pub trait Connection: Debug + Send + Sync {
43    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
44
45    async fn query(
46        &self,
47        conn_options: &ConnectionOptions,
48        sql: &str,
49        table_schema: SchemaRef,
50        projection: Option<&Vec<usize>>,
51        unparsed_filters: &[String],
52        limit: Option<usize>,
53    ) -> DFResult<SendableRecordBatchStream>;
54}
55
56pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
57    match options {
58        #[cfg(feature = "postgres")]
59        ConnectionOptions::Postgres(options) => {
60            let pool = connect_postgres(options).await?;
61            Ok(Arc::new(pool))
62        }
63        #[cfg(feature = "mysql")]
64        ConnectionOptions::Mysql(options) => {
65            let pool = connect_mysql(options)?;
66            Ok(Arc::new(pool))
67        }
68        #[cfg(feature = "oracle")]
69        ConnectionOptions::Oracle(options) => {
70            let pool = connect_oracle(options).await?;
71            Ok(Arc::new(pool))
72        }
73        #[cfg(feature = "sqlite")]
74        ConnectionOptions::Sqlite(options) => {
75            let pool = connect_sqlite(options).await?;
76            Ok(Arc::new(pool))
77        }
78        #[cfg(feature = "dm")]
79        ConnectionOptions::Dm(options) => {
80            let pool = connect_dm(options)?;
81            Ok(Arc::new(pool))
82        }
83    }
84}
85
86#[derive(Debug, Clone)]
87pub enum ConnectionOptions {
88    #[cfg(feature = "postgres")]
89    Postgres(PostgresConnectionOptions),
90    #[cfg(feature = "oracle")]
91    Oracle(OracleConnectionOptions),
92    #[cfg(feature = "mysql")]
93    Mysql(MysqlConnectionOptions),
94    #[cfg(feature = "sqlite")]
95    Sqlite(SqliteConnectionOptions),
96    #[cfg(feature = "dm")]
97    Dm(DmConnectionOptions),
98}
99
100impl ConnectionOptions {
101    pub(crate) fn stream_chunk_size(&self) -> usize {
102        match self {
103            #[cfg(feature = "postgres")]
104            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
105            #[cfg(feature = "oracle")]
106            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
107            #[cfg(feature = "mysql")]
108            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
109            #[cfg(feature = "sqlite")]
110            ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
111            #[cfg(feature = "dm")]
112            ConnectionOptions::Dm(options) => options.stream_chunk_size,
113        }
114    }
115
116    pub(crate) fn db_type(&self) -> RemoteDbType {
117        match self {
118            #[cfg(feature = "postgres")]
119            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
120            #[cfg(feature = "oracle")]
121            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
122            #[cfg(feature = "mysql")]
123            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
124            #[cfg(feature = "sqlite")]
125            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
126            #[cfg(feature = "dm")]
127            ConnectionOptions::Dm(_) => RemoteDbType::Dm,
128        }
129    }
130}
131
132pub enum RemoteDbType {
133    Postgres,
134    Mysql,
135    Oracle,
136    Sqlite,
137    Dm,
138}
139
140impl RemoteDbType {
141    pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
142        sql.trim()[0..6].eq_ignore_ascii_case("select")
143    }
144
145    pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
146        match self {
147            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
148            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
149            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
150            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
151                "Oracle unparser not implemented".to_string(),
152            )),
153            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
154                "Dm unparser not implemented".to_string(),
155            )),
156        }
157    }
158
159    pub(crate) fn rewrite_query(
160        &self,
161        sql: &str,
162        unparsed_filters: &[String],
163        limit: Option<usize>,
164    ) -> String {
165        match self {
166            RemoteDbType::Postgres
167            | RemoteDbType::Mysql
168            | RemoteDbType::Sqlite
169            | RemoteDbType::Dm => {
170                let where_clause = if unparsed_filters.is_empty() {
171                    "".to_string()
172                } else {
173                    format!(" WHERE {}", unparsed_filters.join(" AND "))
174                };
175                let limit_clause = if let Some(limit) = limit {
176                    format!(" LIMIT {limit}")
177                } else {
178                    "".to_string()
179                };
180
181                if where_clause.is_empty() && limit_clause.is_empty() {
182                    sql.to_string()
183                } else {
184                    format!("SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}")
185                }
186            }
187            RemoteDbType::Oracle => {
188                let mut all_filters: Vec<String> = vec![];
189                all_filters.extend_from_slice(unparsed_filters);
190                if let Some(limit) = limit {
191                    all_filters.push(format!("ROWNUM <= {limit}"))
192                }
193
194                let where_clause = if all_filters.is_empty() {
195                    "".to_string()
196                } else {
197                    format!(" WHERE {}", all_filters.join(" AND "))
198                };
199                if where_clause.is_empty() {
200                    sql.to_string()
201                } else {
202                    format!("SELECT * FROM ({sql}){where_clause}")
203                }
204            }
205        }
206    }
207
208    pub(crate) fn query_limit_1(&self, sql: &str) -> String {
209        if !self.support_rewrite_with_filters_limit(sql) {
210            return sql.to_string();
211        }
212        self.rewrite_query(sql, &[], Some(1))
213    }
214
215    pub(crate) fn try_count1_query(&self, sql: &str) -> Option<String> {
216        if !self.support_rewrite_with_filters_limit(sql) {
217            return None;
218        }
219        match self {
220            RemoteDbType::Postgres
221            | RemoteDbType::Mysql
222            | RemoteDbType::Sqlite
223            | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({sql}) AS __subquery")),
224            RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({sql})")),
225        }
226    }
227
228    pub(crate) async fn fetch_count(
229        &self,
230        conn: Arc<dyn Connection>,
231        conn_options: &ConnectionOptions,
232        count1_query: &str,
233    ) -> DFResult<usize> {
234        let count1_schema = Arc::new(Schema::new(vec![Field::new(
235            "count(1)",
236            DataType::Int64,
237            false,
238        )]));
239        let stream = conn
240            .query(conn_options, count1_query, count1_schema, None, &[], None)
241            .await?;
242        let batches = collect(stream).await?;
243        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
244        if count_vec.len() != 1 {
245            return Err(DataFusionError::Execution(format!(
246                "Count query did not return exactly one row: {count_vec:?}",
247            )));
248        }
249        count_vec[0]
250            .map(|count| count as usize)
251            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
252    }
253}
254
255pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
256    match projection {
257        Some(p) => p.contains(&col_idx),
258        None => true,
259    }
260}
261
262#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
263fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
264    use bigdecimal::{FromPrimitive, ToPrimitive};
265    let scale = scale.unwrap_or_else(|| {
266        decimal
267            .fractional_digit_count()
268            .try_into()
269            .unwrap_or_default()
270    });
271    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
272    (decimal * scale_decimal).to_i128()
273}
274
275#[allow(unused)]
276fn just_return<T>(v: T) -> DFResult<T> {
277    Ok(v)
278}
279
280#[allow(unused)]
281fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
282    Ok(*t)
283}