datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use arrow::array::RecordBatch;
29use arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
30use datafusion_common::DataFusionError;
31use datafusion_execution::SendableRecordBatchStream;
32use datafusion_physical_plan::common::collect;
33use datafusion_sql::unparser::Unparser;
34use datafusion_sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
35use std::fmt::Debug;
36use std::sync::Arc;
37
38#[cfg(feature = "dm")]
39pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
40
41#[async_trait::async_trait]
42pub trait Pool: Debug + Send + Sync {
43    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
44    async fn state(&self) -> DFResult<PoolState>;
45}
46
47#[derive(Debug, Clone)]
48pub struct PoolState {
49    pub connections: usize,
50    pub idle_connections: usize,
51}
52
53#[async_trait::async_trait]
54pub trait Connection: Debug + Send + Sync {
55    fn as_any(&self) -> &dyn Any;
56
57    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
58
59    async fn query(
60        &self,
61        conn_options: &ConnectionOptions,
62        source: &RemoteSource,
63        table_schema: SchemaRef,
64        projection: Option<&Vec<usize>>,
65        unparsed_filters: &[String],
66        limit: Option<usize>,
67    ) -> DFResult<SendableRecordBatchStream>;
68
69    async fn insert(
70        &self,
71        conn_options: &ConnectionOptions,
72        literalizer: Arc<dyn Literalize>,
73        table: &[String],
74        remote_schema: RemoteSchemaRef,
75        batch: RecordBatch,
76    ) -> DFResult<usize>;
77}
78
79#[allow(unused_variables)]
80pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
81    match options {
82        ConnectionOptions::Postgres(options) => {
83            #[cfg(feature = "postgres")]
84            {
85                let pool = connect_postgres(options).await?;
86                Ok(Arc::new(pool))
87            }
88            #[cfg(not(feature = "postgres"))]
89            {
90                Err(DataFusionError::Internal(
91                    "Please enable the postgres feature".to_string(),
92                ))
93            }
94        }
95        ConnectionOptions::Mysql(options) => {
96            #[cfg(feature = "mysql")]
97            {
98                let pool = connect_mysql(options)?;
99                Ok(Arc::new(pool))
100            }
101            #[cfg(not(feature = "mysql"))]
102            {
103                Err(DataFusionError::Internal(
104                    "Please enable the mysql feature".to_string(),
105                ))
106            }
107        }
108        ConnectionOptions::Oracle(options) => {
109            #[cfg(feature = "oracle")]
110            {
111                let pool = connect_oracle(options).await?;
112                Ok(Arc::new(pool))
113            }
114            #[cfg(not(feature = "oracle"))]
115            {
116                Err(DataFusionError::Internal(
117                    "Please enable the oracle feature".to_string(),
118                ))
119            }
120        }
121        ConnectionOptions::Sqlite(options) => {
122            #[cfg(feature = "sqlite")]
123            {
124                let pool = connect_sqlite(options).await?;
125                Ok(Arc::new(pool))
126            }
127            #[cfg(not(feature = "sqlite"))]
128            {
129                Err(DataFusionError::Internal(
130                    "Please enable the sqlite feature".to_string(),
131                ))
132            }
133        }
134        ConnectionOptions::Dm(options) => {
135            #[cfg(feature = "dm")]
136            {
137                let pool = connect_dm(options)?;
138                Ok(Arc::new(pool))
139            }
140            #[cfg(not(feature = "dm"))]
141            {
142                Err(DataFusionError::Internal(
143                    "Please enable the dm feature".to_string(),
144                ))
145            }
146        }
147    }
148}
149
150#[derive(Debug, Clone, Copy)]
151pub enum RemoteDbType {
152    Postgres,
153    Mysql,
154    Oracle,
155    Sqlite,
156    Dm,
157}
158
159impl RemoteDbType {
160    pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
161        match source {
162            RemoteSource::Table(_) => true,
163            RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
164        }
165    }
166
167    pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
168        match self {
169            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
170            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
171            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
172            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
173                "Oracle unparser not implemented".to_string(),
174            )),
175            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
176                "Dm unparser not implemented".to_string(),
177            )),
178        }
179    }
180
181    pub(crate) fn rewrite_query(
182        &self,
183        source: &RemoteSource,
184        unparsed_filters: &[String],
185        limit: Option<usize>,
186    ) -> String {
187        match source {
188            RemoteSource::Table(table) => match self {
189                RemoteDbType::Postgres
190                | RemoteDbType::Mysql
191                | RemoteDbType::Sqlite
192                | RemoteDbType::Dm => {
193                    let where_clause = if unparsed_filters.is_empty() {
194                        "".to_string()
195                    } else {
196                        format!(" WHERE {}", unparsed_filters.join(" AND "))
197                    };
198                    let limit_clause = if let Some(limit) = limit {
199                        format!(" LIMIT {limit}")
200                    } else {
201                        "".to_string()
202                    };
203
204                    format!(
205                        "{}{where_clause}{limit_clause}",
206                        self.select_all_query(table)
207                    )
208                }
209                RemoteDbType::Oracle => {
210                    let mut all_filters: Vec<String> = vec![];
211                    all_filters.extend_from_slice(unparsed_filters);
212                    if let Some(limit) = limit {
213                        all_filters.push(format!("ROWNUM <= {limit}"))
214                    }
215
216                    let where_clause = if all_filters.is_empty() {
217                        "".to_string()
218                    } else {
219                        format!(" WHERE {}", all_filters.join(" AND "))
220                    };
221                    format!("{}{where_clause}", self.select_all_query(table))
222                }
223            },
224            RemoteSource::Query(query) => match self {
225                RemoteDbType::Postgres
226                | RemoteDbType::Mysql
227                | RemoteDbType::Sqlite
228                | RemoteDbType::Dm => {
229                    let where_clause = if unparsed_filters.is_empty() {
230                        "".to_string()
231                    } else {
232                        format!(" WHERE {}", unparsed_filters.join(" AND "))
233                    };
234                    let limit_clause = if let Some(limit) = limit {
235                        format!(" LIMIT {limit}")
236                    } else {
237                        "".to_string()
238                    };
239
240                    if where_clause.is_empty() && limit_clause.is_empty() {
241                        query.clone()
242                    } else {
243                        format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
244                    }
245                }
246                RemoteDbType::Oracle => {
247                    let mut all_filters: Vec<String> = vec![];
248                    all_filters.extend_from_slice(unparsed_filters);
249                    if let Some(limit) = limit {
250                        all_filters.push(format!("ROWNUM <= {limit}"))
251                    }
252
253                    let where_clause = if all_filters.is_empty() {
254                        "".to_string()
255                    } else {
256                        format!(" WHERE {}", all_filters.join(" AND "))
257                    };
258                    if where_clause.is_empty() {
259                        query.clone()
260                    } else {
261                        format!("SELECT * FROM ({query}){where_clause}")
262                    }
263                }
264            },
265        }
266    }
267
268    pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
269        match self {
270            RemoteDbType::Postgres
271            | RemoteDbType::Oracle
272            | RemoteDbType::Sqlite
273            | RemoteDbType::Dm => {
274                format!("\"{identifier}\"")
275            }
276            RemoteDbType::Mysql => {
277                format!("`{identifier}`")
278            }
279        }
280    }
281
282    pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
283        indentifiers
284            .iter()
285            .map(|identifier| self.sql_identifier(identifier))
286            .collect::<Vec<String>>()
287            .join(".")
288    }
289
290    pub(crate) fn sql_string_literal(&self, value: &str) -> String {
291        let value = value.replace("'", "''");
292        format!("'{value}'")
293    }
294
295    pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
296        match self {
297            RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
298            RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
299            RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
300        }
301    }
302
303    pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
304        match self {
305            RemoteDbType::Postgres
306            | RemoteDbType::Mysql
307            | RemoteDbType::Oracle
308            | RemoteDbType::Sqlite
309            | RemoteDbType::Dm => {
310                format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
311            }
312        }
313    }
314
315    pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
316        if !self.support_rewrite_with_filters_limit(source) {
317            return source.query(*self);
318        }
319        self.rewrite_query(source, &[], Some(1))
320    }
321
322    pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
323        if !self.support_rewrite_with_filters_limit(source) {
324            return None;
325        }
326        match source {
327            RemoteSource::Table(table) => Some(format!(
328                "SELECT COUNT(1) FROM {}",
329                self.sql_table_name(table)
330            )),
331            RemoteSource::Query(query) => match self {
332                RemoteDbType::Postgres
333                | RemoteDbType::Mysql
334                | RemoteDbType::Sqlite
335                | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
336                RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
337            },
338        }
339    }
340
341    pub(crate) async fn fetch_count(
342        &self,
343        conn: Arc<dyn Connection>,
344        conn_options: &ConnectionOptions,
345        count1_query: &str,
346    ) -> DFResult<usize> {
347        let count1_schema = Arc::new(Schema::new(vec![Field::new(
348            "count(1)",
349            DataType::Int64,
350            false,
351        )]));
352        let stream = conn
353            .query(
354                conn_options,
355                &RemoteSource::Query(count1_query.to_string()),
356                count1_schema,
357                None,
358                &[],
359                None,
360            )
361            .await?;
362        let batches = collect(stream).await?;
363        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
364        if count_vec.len() != 1 {
365            return Err(DataFusionError::Execution(format!(
366                "Count query did not return exactly one row: {count_vec:?}",
367            )));
368        }
369        count_vec[0]
370            .map(|count| count as usize)
371            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
372    }
373}
374
375pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
376    match projection {
377        Some(p) => p.contains(&col_idx),
378        None => true,
379    }
380}
381
382#[allow(unused)]
383fn just_return<T>(v: T) -> DFResult<T> {
384    Ok(v)
385}
386
387#[allow(unused)]
388fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
389    Ok(*t)
390}