datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22use std::any::Any;
23
24use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
25use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
26use datafusion::common::DataFusionError;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::common::collect;
29use datafusion::sql::unparser::Unparser;
30use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
31use std::fmt::Debug;
32use std::sync::Arc;
33
34#[cfg(feature = "dm")]
35pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
36
37#[async_trait::async_trait]
38pub trait Pool: Debug + Send + Sync {
39    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
40}
41
42#[async_trait::async_trait]
43pub trait Connection: Debug + Send + Sync {
44    fn as_any(&self) -> &dyn Any;
45
46    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
47
48    async fn query(
49        &self,
50        conn_options: &ConnectionOptions,
51        source: &RemoteSource,
52        table_schema: SchemaRef,
53        projection: Option<&Vec<usize>>,
54        unparsed_filters: &[String],
55        limit: Option<usize>,
56    ) -> DFResult<SendableRecordBatchStream>;
57
58    async fn insert(
59        &self,
60        conn_options: &ConnectionOptions,
61        literalizer: Arc<dyn Literalize>,
62        table: &[String],
63        remote_schema: RemoteSchemaRef,
64        input: SendableRecordBatchStream,
65    ) -> DFResult<usize>;
66}
67
68pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
69    match options {
70        #[cfg(feature = "postgres")]
71        ConnectionOptions::Postgres(options) => {
72            let pool = connect_postgres(options).await?;
73            Ok(Arc::new(pool))
74        }
75        #[cfg(feature = "mysql")]
76        ConnectionOptions::Mysql(options) => {
77            let pool = connect_mysql(options)?;
78            Ok(Arc::new(pool))
79        }
80        #[cfg(feature = "oracle")]
81        ConnectionOptions::Oracle(options) => {
82            let pool = connect_oracle(options).await?;
83            Ok(Arc::new(pool))
84        }
85        #[cfg(feature = "sqlite")]
86        ConnectionOptions::Sqlite(options) => {
87            let pool = connect_sqlite(options).await?;
88            Ok(Arc::new(pool))
89        }
90        #[cfg(feature = "dm")]
91        ConnectionOptions::Dm(options) => {
92            let pool = connect_dm(options)?;
93            Ok(Arc::new(pool))
94        }
95    }
96}
97
98#[derive(Debug, Clone)]
99pub enum ConnectionOptions {
100    #[cfg(feature = "postgres")]
101    Postgres(PostgresConnectionOptions),
102    #[cfg(feature = "oracle")]
103    Oracle(OracleConnectionOptions),
104    #[cfg(feature = "mysql")]
105    Mysql(MysqlConnectionOptions),
106    #[cfg(feature = "sqlite")]
107    Sqlite(SqliteConnectionOptions),
108    #[cfg(feature = "dm")]
109    Dm(DmConnectionOptions),
110}
111
112impl ConnectionOptions {
113    pub(crate) fn stream_chunk_size(&self) -> usize {
114        match self {
115            #[cfg(feature = "postgres")]
116            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
117            #[cfg(feature = "oracle")]
118            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
119            #[cfg(feature = "mysql")]
120            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
121            #[cfg(feature = "sqlite")]
122            ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
123            #[cfg(feature = "dm")]
124            ConnectionOptions::Dm(options) => options.stream_chunk_size,
125        }
126    }
127
128    pub(crate) fn db_type(&self) -> RemoteDbType {
129        match self {
130            #[cfg(feature = "postgres")]
131            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
132            #[cfg(feature = "oracle")]
133            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
134            #[cfg(feature = "mysql")]
135            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
136            #[cfg(feature = "sqlite")]
137            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
138            #[cfg(feature = "dm")]
139            ConnectionOptions::Dm(_) => RemoteDbType::Dm,
140        }
141    }
142
143    pub fn with_pool_max_size(self, pool_max_size: usize) -> Self {
144        match self {
145            #[cfg(feature = "postgres")]
146            ConnectionOptions::Postgres(options) => {
147                ConnectionOptions::Postgres(options.with_pool_max_size(pool_max_size))
148            }
149            #[cfg(feature = "oracle")]
150            ConnectionOptions::Oracle(options) => {
151                ConnectionOptions::Oracle(options.with_pool_max_size(pool_max_size))
152            }
153            #[cfg(feature = "mysql")]
154            ConnectionOptions::Mysql(options) => {
155                ConnectionOptions::Mysql(options.with_pool_max_size(pool_max_size))
156            }
157            #[cfg(feature = "sqlite")]
158            ConnectionOptions::Sqlite(options) => ConnectionOptions::Sqlite(options),
159            #[cfg(feature = "dm")]
160            ConnectionOptions::Dm(options) => ConnectionOptions::Dm(options),
161        }
162    }
163}
164
165#[derive(Debug, Clone, Copy)]
166pub enum RemoteDbType {
167    Postgres,
168    Mysql,
169    Oracle,
170    Sqlite,
171    Dm,
172}
173
174impl RemoteDbType {
175    pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
176        match source {
177            RemoteSource::Table(_) => true,
178            RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
179        }
180    }
181
182    pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
183        match self {
184            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
185            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
186            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
187            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
188                "Oracle unparser not implemented".to_string(),
189            )),
190            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
191                "Dm unparser not implemented".to_string(),
192            )),
193        }
194    }
195
196    pub(crate) fn rewrite_query(
197        &self,
198        source: &RemoteSource,
199        unparsed_filters: &[String],
200        limit: Option<usize>,
201    ) -> String {
202        match source {
203            RemoteSource::Table(table) => match self {
204                RemoteDbType::Postgres
205                | RemoteDbType::Mysql
206                | RemoteDbType::Sqlite
207                | RemoteDbType::Dm => {
208                    let where_clause = if unparsed_filters.is_empty() {
209                        "".to_string()
210                    } else {
211                        format!(" WHERE {}", unparsed_filters.join(" AND "))
212                    };
213                    let limit_clause = if let Some(limit) = limit {
214                        format!(" LIMIT {limit}")
215                    } else {
216                        "".to_string()
217                    };
218
219                    format!(
220                        "{}{where_clause}{limit_clause}",
221                        self.select_all_query(table)
222                    )
223                }
224                RemoteDbType::Oracle => {
225                    let mut all_filters: Vec<String> = vec![];
226                    all_filters.extend_from_slice(unparsed_filters);
227                    if let Some(limit) = limit {
228                        all_filters.push(format!("ROWNUM <= {limit}"))
229                    }
230
231                    let where_clause = if all_filters.is_empty() {
232                        "".to_string()
233                    } else {
234                        format!(" WHERE {}", all_filters.join(" AND "))
235                    };
236                    format!("{}{where_clause}", self.select_all_query(table))
237                }
238            },
239            RemoteSource::Query(query) => match self {
240                RemoteDbType::Postgres
241                | RemoteDbType::Mysql
242                | RemoteDbType::Sqlite
243                | RemoteDbType::Dm => {
244                    let where_clause = if unparsed_filters.is_empty() {
245                        "".to_string()
246                    } else {
247                        format!(" WHERE {}", unparsed_filters.join(" AND "))
248                    };
249                    let limit_clause = if let Some(limit) = limit {
250                        format!(" LIMIT {limit}")
251                    } else {
252                        "".to_string()
253                    };
254
255                    if where_clause.is_empty() && limit_clause.is_empty() {
256                        query.clone()
257                    } else {
258                        format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
259                    }
260                }
261                RemoteDbType::Oracle => {
262                    let mut all_filters: Vec<String> = vec![];
263                    all_filters.extend_from_slice(unparsed_filters);
264                    if let Some(limit) = limit {
265                        all_filters.push(format!("ROWNUM <= {limit}"))
266                    }
267
268                    let where_clause = if all_filters.is_empty() {
269                        "".to_string()
270                    } else {
271                        format!(" WHERE {}", all_filters.join(" AND "))
272                    };
273                    if where_clause.is_empty() {
274                        query.clone()
275                    } else {
276                        format!("SELECT * FROM ({query}){where_clause}")
277                    }
278                }
279            },
280        }
281    }
282
283    pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
284        match self {
285            RemoteDbType::Postgres
286            | RemoteDbType::Oracle
287            | RemoteDbType::Sqlite
288            | RemoteDbType::Dm => {
289                format!("\"{identifier}\"")
290            }
291            RemoteDbType::Mysql => {
292                format!("`{identifier}`")
293            }
294        }
295    }
296
297    pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
298        indentifiers
299            .iter()
300            .map(|identifier| self.sql_identifier(identifier))
301            .collect::<Vec<String>>()
302            .join(".")
303    }
304
305    pub(crate) fn sql_string_literal(&self, value: &str) -> String {
306        let value = value.replace("'", "''");
307        format!("'{value}'")
308    }
309
310    pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
311        match self {
312            RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
313            RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
314            RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
315        }
316    }
317
318    pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
319        match self {
320            RemoteDbType::Postgres
321            | RemoteDbType::Mysql
322            | RemoteDbType::Oracle
323            | RemoteDbType::Sqlite
324            | RemoteDbType::Dm => {
325                format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
326            }
327        }
328    }
329
330    pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
331        if !self.support_rewrite_with_filters_limit(source) {
332            return source.query(*self);
333        }
334        self.rewrite_query(source, &[], Some(1))
335    }
336
337    pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
338        if !self.support_rewrite_with_filters_limit(source) {
339            return None;
340        }
341        match source {
342            RemoteSource::Table(table) => Some(self.select_all_query(table)),
343            RemoteSource::Query(query) => match self {
344                RemoteDbType::Postgres
345                | RemoteDbType::Mysql
346                | RemoteDbType::Sqlite
347                | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
348                RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
349            },
350        }
351    }
352
353    pub(crate) async fn fetch_count(
354        &self,
355        conn: Arc<dyn Connection>,
356        conn_options: &ConnectionOptions,
357        count1_query: &str,
358    ) -> DFResult<usize> {
359        let count1_schema = Arc::new(Schema::new(vec![Field::new(
360            "count(1)",
361            DataType::Int64,
362            false,
363        )]));
364        let stream = conn
365            .query(
366                conn_options,
367                &RemoteSource::Query(count1_query.to_string()),
368                count1_schema,
369                None,
370                &[],
371                None,
372            )
373            .await?;
374        let batches = collect(stream).await?;
375        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
376        if count_vec.len() != 1 {
377            return Err(DataFusionError::Execution(format!(
378                "Count query did not return exactly one row: {count_vec:?}",
379            )));
380        }
381        count_vec[0]
382            .map(|count| count as usize)
383            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
384    }
385}
386
387pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
388    match projection {
389        Some(p) => p.contains(&col_idx),
390        None => true,
391    }
392}
393
394#[allow(unused)]
395fn just_return<T>(v: T) -> DFResult<T> {
396    Ok(v)
397}
398
399#[allow(unused)]
400fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
401    Ok(*t)
402}
403
404#[test]
405fn tst_f32() {
406    println!("{}", 10f32.powi(40));
407}