datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
29use datafusion::common::DataFusionError;
30use datafusion::execution::SendableRecordBatchStream;
31use datafusion::physical_plan::common::collect;
32use datafusion::sql::unparser::Unparser;
33use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
34use std::fmt::Debug;
35use std::sync::Arc;
36
37#[cfg(feature = "dm")]
38pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
39
40#[async_trait::async_trait]
41pub trait Pool: Debug + Send + Sync {
42    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
43}
44
45#[async_trait::async_trait]
46pub trait Connection: Debug + Send + Sync {
47    fn as_any(&self) -> &dyn Any;
48
49    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
50
51    async fn query(
52        &self,
53        conn_options: &ConnectionOptions,
54        source: &RemoteSource,
55        table_schema: SchemaRef,
56        projection: Option<&Vec<usize>>,
57        unparsed_filters: &[String],
58        limit: Option<usize>,
59    ) -> DFResult<SendableRecordBatchStream>;
60
61    async fn insert(
62        &self,
63        conn_options: &ConnectionOptions,
64        literalizer: Arc<dyn Literalize>,
65        table: &[String],
66        remote_schema: RemoteSchemaRef,
67        input: SendableRecordBatchStream,
68    ) -> DFResult<usize>;
69}
70
71#[allow(unused_variables)]
72pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
73    match options {
74        ConnectionOptions::Postgres(options) => {
75            #[cfg(feature = "postgres")]
76            {
77                let pool = connect_postgres(options).await?;
78                Ok(Arc::new(pool))
79            }
80            #[cfg(not(feature = "postgres"))]
81            {
82                Err(DataFusionError::Internal(
83                    "Please enable the postgres feature".to_string(),
84                ))
85            }
86        }
87        ConnectionOptions::Mysql(options) => {
88            #[cfg(feature = "mysql")]
89            {
90                let pool = connect_mysql(options)?;
91                Ok(Arc::new(pool))
92            }
93            #[cfg(not(feature = "mysql"))]
94            {
95                Err(DataFusionError::Internal(
96                    "Please enable the mysql feature".to_string(),
97                ))
98            }
99        }
100        ConnectionOptions::Oracle(options) => {
101            #[cfg(feature = "oracle")]
102            {
103                let pool = connect_oracle(options).await?;
104                Ok(Arc::new(pool))
105            }
106            #[cfg(not(feature = "oracle"))]
107            {
108                Err(DataFusionError::Internal(
109                    "Please enable the oracle feature".to_string(),
110                ))
111            }
112        }
113        ConnectionOptions::Sqlite(options) => {
114            #[cfg(feature = "sqlite")]
115            {
116                let pool = connect_sqlite(options).await?;
117                Ok(Arc::new(pool))
118            }
119            #[cfg(not(feature = "sqlite"))]
120            {
121                Err(DataFusionError::Internal(
122                    "Please enable the sqlite feature".to_string(),
123                ))
124            }
125        }
126        ConnectionOptions::Dm(options) => {
127            #[cfg(feature = "dm")]
128            {
129                let pool = connect_dm(options)?;
130                Ok(Arc::new(pool))
131            }
132            #[cfg(not(feature = "dm"))]
133            {
134                Err(DataFusionError::Internal(
135                    "Please enable the dm feature".to_string(),
136                ))
137            }
138        }
139    }
140}
141
142#[derive(Debug, Clone, Copy)]
143pub enum RemoteDbType {
144    Postgres,
145    Mysql,
146    Oracle,
147    Sqlite,
148    Dm,
149}
150
151impl RemoteDbType {
152    pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
153        match source {
154            RemoteSource::Table(_) => true,
155            RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
156        }
157    }
158
159    pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
160        match self {
161            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
162            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
163            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
164            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
165                "Oracle unparser not implemented".to_string(),
166            )),
167            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
168                "Dm unparser not implemented".to_string(),
169            )),
170        }
171    }
172
173    pub(crate) fn rewrite_query(
174        &self,
175        source: &RemoteSource,
176        unparsed_filters: &[String],
177        limit: Option<usize>,
178    ) -> String {
179        match source {
180            RemoteSource::Table(table) => match self {
181                RemoteDbType::Postgres
182                | RemoteDbType::Mysql
183                | RemoteDbType::Sqlite
184                | RemoteDbType::Dm => {
185                    let where_clause = if unparsed_filters.is_empty() {
186                        "".to_string()
187                    } else {
188                        format!(" WHERE {}", unparsed_filters.join(" AND "))
189                    };
190                    let limit_clause = if let Some(limit) = limit {
191                        format!(" LIMIT {limit}")
192                    } else {
193                        "".to_string()
194                    };
195
196                    format!(
197                        "{}{where_clause}{limit_clause}",
198                        self.select_all_query(table)
199                    )
200                }
201                RemoteDbType::Oracle => {
202                    let mut all_filters: Vec<String> = vec![];
203                    all_filters.extend_from_slice(unparsed_filters);
204                    if let Some(limit) = limit {
205                        all_filters.push(format!("ROWNUM <= {limit}"))
206                    }
207
208                    let where_clause = if all_filters.is_empty() {
209                        "".to_string()
210                    } else {
211                        format!(" WHERE {}", all_filters.join(" AND "))
212                    };
213                    format!("{}{where_clause}", self.select_all_query(table))
214                }
215            },
216            RemoteSource::Query(query) => match self {
217                RemoteDbType::Postgres
218                | RemoteDbType::Mysql
219                | RemoteDbType::Sqlite
220                | RemoteDbType::Dm => {
221                    let where_clause = if unparsed_filters.is_empty() {
222                        "".to_string()
223                    } else {
224                        format!(" WHERE {}", unparsed_filters.join(" AND "))
225                    };
226                    let limit_clause = if let Some(limit) = limit {
227                        format!(" LIMIT {limit}")
228                    } else {
229                        "".to_string()
230                    };
231
232                    if where_clause.is_empty() && limit_clause.is_empty() {
233                        query.clone()
234                    } else {
235                        format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
236                    }
237                }
238                RemoteDbType::Oracle => {
239                    let mut all_filters: Vec<String> = vec![];
240                    all_filters.extend_from_slice(unparsed_filters);
241                    if let Some(limit) = limit {
242                        all_filters.push(format!("ROWNUM <= {limit}"))
243                    }
244
245                    let where_clause = if all_filters.is_empty() {
246                        "".to_string()
247                    } else {
248                        format!(" WHERE {}", all_filters.join(" AND "))
249                    };
250                    if where_clause.is_empty() {
251                        query.clone()
252                    } else {
253                        format!("SELECT * FROM ({query}){where_clause}")
254                    }
255                }
256            },
257        }
258    }
259
260    pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
261        match self {
262            RemoteDbType::Postgres
263            | RemoteDbType::Oracle
264            | RemoteDbType::Sqlite
265            | RemoteDbType::Dm => {
266                format!("\"{identifier}\"")
267            }
268            RemoteDbType::Mysql => {
269                format!("`{identifier}`")
270            }
271        }
272    }
273
274    pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
275        indentifiers
276            .iter()
277            .map(|identifier| self.sql_identifier(identifier))
278            .collect::<Vec<String>>()
279            .join(".")
280    }
281
282    pub(crate) fn sql_string_literal(&self, value: &str) -> String {
283        let value = value.replace("'", "''");
284        format!("'{value}'")
285    }
286
287    pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
288        match self {
289            RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
290            RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
291            RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
292        }
293    }
294
295    pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
296        match self {
297            RemoteDbType::Postgres
298            | RemoteDbType::Mysql
299            | RemoteDbType::Oracle
300            | RemoteDbType::Sqlite
301            | RemoteDbType::Dm => {
302                format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
303            }
304        }
305    }
306
307    pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
308        if !self.support_rewrite_with_filters_limit(source) {
309            return source.query(*self);
310        }
311        self.rewrite_query(source, &[], Some(1))
312    }
313
314    pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
315        if !self.support_rewrite_with_filters_limit(source) {
316            return None;
317        }
318        match source {
319            RemoteSource::Table(table) => Some(self.select_all_query(table)),
320            RemoteSource::Query(query) => match self {
321                RemoteDbType::Postgres
322                | RemoteDbType::Mysql
323                | RemoteDbType::Sqlite
324                | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
325                RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
326            },
327        }
328    }
329
330    pub(crate) async fn fetch_count(
331        &self,
332        conn: Arc<dyn Connection>,
333        conn_options: &ConnectionOptions,
334        count1_query: &str,
335    ) -> DFResult<usize> {
336        let count1_schema = Arc::new(Schema::new(vec![Field::new(
337            "count(1)",
338            DataType::Int64,
339            false,
340        )]));
341        let stream = conn
342            .query(
343                conn_options,
344                &RemoteSource::Query(count1_query.to_string()),
345                count1_schema,
346                None,
347                &[],
348                None,
349            )
350            .await?;
351        let batches = collect(stream).await?;
352        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
353        if count_vec.len() != 1 {
354            return Err(DataFusionError::Execution(format!(
355                "Count query did not return exactly one row: {count_vec:?}",
356            )));
357        }
358        count_vec[0]
359            .map(|count| count as usize)
360            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
361    }
362}
363
364pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
365    match projection {
366        Some(p) => p.contains(&col_idx),
367        None => true,
368    }
369}
370
371#[allow(unused)]
372fn just_return<T>(v: T) -> DFResult<T> {
373    Ok(v)
374}
375
376#[allow(unused)]
377fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
378    Ok(*t)
379}
380
381#[test]
382fn tst_f32() {
383    println!("{}", 10f32.powi(40));
384}