datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
29use datafusion::common::DataFusionError;
30use datafusion::execution::SendableRecordBatchStream;
31use datafusion::physical_plan::common::collect;
32use datafusion::sql::unparser::Unparser;
33use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
34use std::fmt::Debug;
35use std::sync::Arc;
36
37#[cfg(feature = "dm")]
38pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
39
40#[async_trait::async_trait]
41pub trait Pool: Debug + Send + Sync {
42    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
43    async fn state(&self) -> DFResult<PoolState>;
44}
45
46#[derive(Debug, Clone)]
47pub struct PoolState {
48    pub connections: usize,
49    pub idle_connections: usize,
50}
51
52#[async_trait::async_trait]
53pub trait Connection: Debug + Send + Sync {
54    fn as_any(&self) -> &dyn Any;
55
56    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
57
58    async fn query(
59        &self,
60        conn_options: &ConnectionOptions,
61        source: &RemoteSource,
62        table_schema: SchemaRef,
63        projection: Option<&Vec<usize>>,
64        unparsed_filters: &[String],
65        limit: Option<usize>,
66    ) -> DFResult<SendableRecordBatchStream>;
67
68    async fn insert(
69        &self,
70        conn_options: &ConnectionOptions,
71        literalizer: Arc<dyn Literalize>,
72        table: &[String],
73        remote_schema: RemoteSchemaRef,
74        input: SendableRecordBatchStream,
75    ) -> DFResult<usize>;
76}
77
78#[allow(unused_variables)]
79pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
80    match options {
81        ConnectionOptions::Postgres(options) => {
82            #[cfg(feature = "postgres")]
83            {
84                let pool = connect_postgres(options).await?;
85                Ok(Arc::new(pool))
86            }
87            #[cfg(not(feature = "postgres"))]
88            {
89                Err(DataFusionError::Internal(
90                    "Please enable the postgres feature".to_string(),
91                ))
92            }
93        }
94        ConnectionOptions::Mysql(options) => {
95            #[cfg(feature = "mysql")]
96            {
97                let pool = connect_mysql(options)?;
98                Ok(Arc::new(pool))
99            }
100            #[cfg(not(feature = "mysql"))]
101            {
102                Err(DataFusionError::Internal(
103                    "Please enable the mysql feature".to_string(),
104                ))
105            }
106        }
107        ConnectionOptions::Oracle(options) => {
108            #[cfg(feature = "oracle")]
109            {
110                let pool = connect_oracle(options).await?;
111                Ok(Arc::new(pool))
112            }
113            #[cfg(not(feature = "oracle"))]
114            {
115                Err(DataFusionError::Internal(
116                    "Please enable the oracle feature".to_string(),
117                ))
118            }
119        }
120        ConnectionOptions::Sqlite(options) => {
121            #[cfg(feature = "sqlite")]
122            {
123                let pool = connect_sqlite(options).await?;
124                Ok(Arc::new(pool))
125            }
126            #[cfg(not(feature = "sqlite"))]
127            {
128                Err(DataFusionError::Internal(
129                    "Please enable the sqlite feature".to_string(),
130                ))
131            }
132        }
133        ConnectionOptions::Dm(options) => {
134            #[cfg(feature = "dm")]
135            {
136                let pool = connect_dm(options)?;
137                Ok(Arc::new(pool))
138            }
139            #[cfg(not(feature = "dm"))]
140            {
141                Err(DataFusionError::Internal(
142                    "Please enable the dm feature".to_string(),
143                ))
144            }
145        }
146    }
147}
148
149#[derive(Debug, Clone, Copy)]
150pub enum RemoteDbType {
151    Postgres,
152    Mysql,
153    Oracle,
154    Sqlite,
155    Dm,
156}
157
158impl RemoteDbType {
159    pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
160        match source {
161            RemoteSource::Table(_) => true,
162            RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
163        }
164    }
165
166    pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
167        match self {
168            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
169            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
170            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
171            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
172                "Oracle unparser not implemented".to_string(),
173            )),
174            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
175                "Dm unparser not implemented".to_string(),
176            )),
177        }
178    }
179
180    pub(crate) fn rewrite_query(
181        &self,
182        source: &RemoteSource,
183        unparsed_filters: &[String],
184        limit: Option<usize>,
185    ) -> String {
186        match source {
187            RemoteSource::Table(table) => match self {
188                RemoteDbType::Postgres
189                | RemoteDbType::Mysql
190                | RemoteDbType::Sqlite
191                | RemoteDbType::Dm => {
192                    let where_clause = if unparsed_filters.is_empty() {
193                        "".to_string()
194                    } else {
195                        format!(" WHERE {}", unparsed_filters.join(" AND "))
196                    };
197                    let limit_clause = if let Some(limit) = limit {
198                        format!(" LIMIT {limit}")
199                    } else {
200                        "".to_string()
201                    };
202
203                    format!(
204                        "{}{where_clause}{limit_clause}",
205                        self.select_all_query(table)
206                    )
207                }
208                RemoteDbType::Oracle => {
209                    let mut all_filters: Vec<String> = vec![];
210                    all_filters.extend_from_slice(unparsed_filters);
211                    if let Some(limit) = limit {
212                        all_filters.push(format!("ROWNUM <= {limit}"))
213                    }
214
215                    let where_clause = if all_filters.is_empty() {
216                        "".to_string()
217                    } else {
218                        format!(" WHERE {}", all_filters.join(" AND "))
219                    };
220                    format!("{}{where_clause}", self.select_all_query(table))
221                }
222            },
223            RemoteSource::Query(query) => match self {
224                RemoteDbType::Postgres
225                | RemoteDbType::Mysql
226                | RemoteDbType::Sqlite
227                | RemoteDbType::Dm => {
228                    let where_clause = if unparsed_filters.is_empty() {
229                        "".to_string()
230                    } else {
231                        format!(" WHERE {}", unparsed_filters.join(" AND "))
232                    };
233                    let limit_clause = if let Some(limit) = limit {
234                        format!(" LIMIT {limit}")
235                    } else {
236                        "".to_string()
237                    };
238
239                    if where_clause.is_empty() && limit_clause.is_empty() {
240                        query.clone()
241                    } else {
242                        format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
243                    }
244                }
245                RemoteDbType::Oracle => {
246                    let mut all_filters: Vec<String> = vec![];
247                    all_filters.extend_from_slice(unparsed_filters);
248                    if let Some(limit) = limit {
249                        all_filters.push(format!("ROWNUM <= {limit}"))
250                    }
251
252                    let where_clause = if all_filters.is_empty() {
253                        "".to_string()
254                    } else {
255                        format!(" WHERE {}", all_filters.join(" AND "))
256                    };
257                    if where_clause.is_empty() {
258                        query.clone()
259                    } else {
260                        format!("SELECT * FROM ({query}){where_clause}")
261                    }
262                }
263            },
264        }
265    }
266
267    pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
268        match self {
269            RemoteDbType::Postgres
270            | RemoteDbType::Oracle
271            | RemoteDbType::Sqlite
272            | RemoteDbType::Dm => {
273                format!("\"{identifier}\"")
274            }
275            RemoteDbType::Mysql => {
276                format!("`{identifier}`")
277            }
278        }
279    }
280
281    pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
282        indentifiers
283            .iter()
284            .map(|identifier| self.sql_identifier(identifier))
285            .collect::<Vec<String>>()
286            .join(".")
287    }
288
289    pub(crate) fn sql_string_literal(&self, value: &str) -> String {
290        let value = value.replace("'", "''");
291        format!("'{value}'")
292    }
293
294    pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
295        match self {
296            RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
297            RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
298            RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
299        }
300    }
301
302    pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
303        match self {
304            RemoteDbType::Postgres
305            | RemoteDbType::Mysql
306            | RemoteDbType::Oracle
307            | RemoteDbType::Sqlite
308            | RemoteDbType::Dm => {
309                format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
310            }
311        }
312    }
313
314    pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
315        if !self.support_rewrite_with_filters_limit(source) {
316            return source.query(*self);
317        }
318        self.rewrite_query(source, &[], Some(1))
319    }
320
321    pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
322        if !self.support_rewrite_with_filters_limit(source) {
323            return None;
324        }
325        match source {
326            RemoteSource::Table(table) => Some(self.select_all_query(table)),
327            RemoteSource::Query(query) => match self {
328                RemoteDbType::Postgres
329                | RemoteDbType::Mysql
330                | RemoteDbType::Sqlite
331                | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
332                RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
333            },
334        }
335    }
336
337    pub(crate) async fn fetch_count(
338        &self,
339        conn: Arc<dyn Connection>,
340        conn_options: &ConnectionOptions,
341        count1_query: &str,
342    ) -> DFResult<usize> {
343        let count1_schema = Arc::new(Schema::new(vec![Field::new(
344            "count(1)",
345            DataType::Int64,
346            false,
347        )]));
348        let stream = conn
349            .query(
350                conn_options,
351                &RemoteSource::Query(count1_query.to_string()),
352                count1_schema,
353                None,
354                &[],
355                None,
356            )
357            .await?;
358        let batches = collect(stream).await?;
359        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
360        if count_vec.len() != 1 {
361            return Err(DataFusionError::Execution(format!(
362                "Count query did not return exactly one row: {count_vec:?}",
363            )));
364        }
365        count_vec[0]
366            .map(|count| count as usize)
367            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
368    }
369}
370
371pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
372    match projection {
373        Some(p) => p.contains(&col_idx),
374        None => true,
375    }
376}
377
378#[allow(unused)]
379fn just_return<T>(v: T) -> DFResult<T> {
380    Ok(v)
381}
382
383#[allow(unused)]
384fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
385    Ok(*t)
386}
387
388#[test]
389fn tst_f32() {
390    println!("{}", 10f32.powi(40));
391}