datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22
23use crate::{DFResult, RemoteSchemaRef};
24use datafusion::arrow::datatypes::SchemaRef;
25use datafusion::common::DataFusionError;
26use datafusion::execution::SendableRecordBatchStream;
27use datafusion::sql::unparser::Unparser;
28use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
29use std::fmt::Debug;
30use std::sync::Arc;
31
32#[cfg(feature = "dm")]
33pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
34
35#[async_trait::async_trait]
36pub trait Pool: Debug + Send + Sync {
37    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
38}
39
40#[async_trait::async_trait]
41pub trait Connection: Debug + Send + Sync {
42    async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
43
44    async fn query(
45        &self,
46        conn_options: &ConnectionOptions,
47        sql: &str,
48        table_schema: SchemaRef,
49        projection: Option<&Vec<usize>>,
50        unparsed_filters: &[String],
51        limit: Option<usize>,
52    ) -> DFResult<SendableRecordBatchStream>;
53}
54
55pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
56    match options {
57        #[cfg(feature = "postgres")]
58        ConnectionOptions::Postgres(options) => {
59            let pool = connect_postgres(options).await?;
60            Ok(Arc::new(pool))
61        }
62        #[cfg(feature = "mysql")]
63        ConnectionOptions::Mysql(options) => {
64            let pool = connect_mysql(options)?;
65            Ok(Arc::new(pool))
66        }
67        #[cfg(feature = "oracle")]
68        ConnectionOptions::Oracle(options) => {
69            let pool = connect_oracle(options).await?;
70            Ok(Arc::new(pool))
71        }
72        #[cfg(feature = "sqlite")]
73        ConnectionOptions::Sqlite(options) => {
74            let pool = connect_sqlite(options).await?;
75            Ok(Arc::new(pool))
76        }
77        #[cfg(feature = "dm")]
78        ConnectionOptions::Dm(options) => {
79            let pool = connect_dm(options)?;
80            Ok(Arc::new(pool))
81        }
82    }
83}
84
85#[derive(Debug, Clone)]
86pub enum ConnectionOptions {
87    #[cfg(feature = "postgres")]
88    Postgres(PostgresConnectionOptions),
89    #[cfg(feature = "oracle")]
90    Oracle(OracleConnectionOptions),
91    #[cfg(feature = "mysql")]
92    Mysql(MysqlConnectionOptions),
93    #[cfg(feature = "sqlite")]
94    Sqlite(SqliteConnectionOptions),
95    #[cfg(feature = "dm")]
96    Dm(DmConnectionOptions),
97}
98
99impl ConnectionOptions {
100    pub(crate) fn stream_chunk_size(&self) -> usize {
101        match self {
102            #[cfg(feature = "postgres")]
103            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
104            #[cfg(feature = "oracle")]
105            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
106            #[cfg(feature = "mysql")]
107            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
108            #[cfg(feature = "sqlite")]
109            ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
110            #[cfg(feature = "dm")]
111            ConnectionOptions::Dm(options) => options.stream_chunk_size,
112        }
113    }
114
115    pub(crate) fn db_type(&self) -> RemoteDbType {
116        match self {
117            #[cfg(feature = "postgres")]
118            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
119            #[cfg(feature = "oracle")]
120            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
121            #[cfg(feature = "mysql")]
122            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
123            #[cfg(feature = "sqlite")]
124            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
125            #[cfg(feature = "dm")]
126            ConnectionOptions::Dm(_) => RemoteDbType::Dm,
127        }
128    }
129}
130
131pub enum RemoteDbType {
132    Postgres,
133    Mysql,
134    Oracle,
135    Sqlite,
136    Dm,
137}
138
139impl RemoteDbType {
140    pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
141        sql.trim()[0..6].eq_ignore_ascii_case("select")
142    }
143
144    pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
145        match self {
146            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
147            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
148            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
149            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
150                "Oracle unparser not implemented".to_string(),
151            )),
152            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
153                "Dm unparser not implemented".to_string(),
154            )),
155        }
156    }
157
158    pub(crate) fn try_rewrite_query(
159        &self,
160        sql: &str,
161        unparsed_filters: &[String],
162        limit: Option<usize>,
163    ) -> DFResult<String> {
164        match self {
165            RemoteDbType::Postgres | RemoteDbType::Mysql | RemoteDbType::Sqlite => {
166                let where_clause = if unparsed_filters.is_empty() {
167                    "".to_string()
168                } else {
169                    format!(" WHERE {}", unparsed_filters.join(" AND "))
170                };
171                let limit_clause = if let Some(limit) = limit {
172                    format!(" LIMIT {limit}")
173                } else {
174                    "".to_string()
175                };
176
177                if where_clause.is_empty() && limit_clause.is_empty() {
178                    Ok(sql.to_string())
179                } else {
180                    Ok(format!(
181                        "SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}"
182                    ))
183                }
184            }
185            RemoteDbType::Oracle => Ok(limit
186                .map(|l| format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {l}"))
187                .unwrap_or_else(|| sql.to_string())),
188            RemoteDbType::Dm => Ok(limit
189                .map(|l| format!("SELECT * FROM ({sql}) LIMIT {l}"))
190                .unwrap_or_else(|| sql.to_string())),
191        }
192    }
193
194    pub(crate) fn query_limit_1(&self, sql: &str) -> DFResult<String> {
195        if !self.support_rewrite_with_filters_limit(sql) {
196            return Ok(sql.to_string());
197        }
198        self.try_rewrite_query(sql, &[], Some(1))
199    }
200}
201
202pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
203    match projection {
204        Some(p) => p.contains(&col_idx),
205        None => true,
206    }
207}
208
209#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
210fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
211    use bigdecimal::{FromPrimitive, ToPrimitive};
212    let scale = scale.unwrap_or_else(|| {
213        decimal
214            .fractional_digit_count()
215            .try_into()
216            .unwrap_or_default()
217    });
218    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
219    (decimal * scale_decimal).to_i128()
220}
221
222#[allow(unused)]
223fn just_return<T>(v: T) -> DFResult<T> {
224    Ok(v)
225}
226
227#[allow(unused)]
228fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
229    Ok(*t)
230}