datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "mysql")]
2mod mysql;
3#[cfg(feature = "oracle")]
4mod oracle;
5#[cfg(feature = "postgres")]
6mod postgres;
7#[cfg(feature = "sqlite")]
8mod sqlite;
9
10#[cfg(feature = "mysql")]
11pub use mysql::*;
12#[cfg(feature = "oracle")]
13pub use oracle::*;
14#[cfg(feature = "postgres")]
15pub use postgres::*;
16#[cfg(feature = "sqlite")]
17pub use sqlite::*;
18
19use crate::{DFResult, RemoteSchemaRef};
20use datafusion::arrow::datatypes::SchemaRef;
21use datafusion::common::DataFusionError;
22use datafusion::execution::SendableRecordBatchStream;
23use datafusion::prelude::Expr;
24use datafusion::sql::unparser::Unparser;
25use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
26use std::fmt::Debug;
27#[cfg(feature = "sqlite")]
28use std::path::PathBuf;
29use std::sync::Arc;
30
31#[async_trait::async_trait]
32pub trait Pool: Debug + Send + Sync {
33    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
34}
35
36#[async_trait::async_trait]
37pub trait Connection: Debug + Send + Sync {
38    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)>;
39
40    async fn query(
41        &self,
42        conn_options: &ConnectionOptions,
43        sql: &str,
44        table_schema: SchemaRef,
45        projection: Option<&Vec<usize>>,
46        filters: &[Expr],
47        limit: Option<usize>,
48    ) -> DFResult<SendableRecordBatchStream>;
49}
50
51pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
52    match options {
53        #[cfg(feature = "postgres")]
54        ConnectionOptions::Postgres(options) => {
55            let pool = connect_postgres(options).await?;
56            Ok(Arc::new(pool))
57        }
58        #[cfg(feature = "mysql")]
59        ConnectionOptions::Mysql(options) => {
60            let pool = connect_mysql(options)?;
61            Ok(Arc::new(pool))
62        }
63        #[cfg(feature = "oracle")]
64        ConnectionOptions::Oracle(options) => {
65            let pool = connect_oracle(options).await?;
66            Ok(Arc::new(pool))
67        }
68        #[cfg(feature = "sqlite")]
69        ConnectionOptions::Sqlite(path) => {
70            let pool = connect_sqlite(path).await?;
71            Ok(Arc::new(pool))
72        }
73    }
74}
75
76#[derive(Debug, Clone)]
77pub enum ConnectionOptions {
78    #[cfg(feature = "postgres")]
79    Postgres(PostgresConnectionOptions),
80    #[cfg(feature = "oracle")]
81    Oracle(OracleConnectionOptions),
82    #[cfg(feature = "mysql")]
83    Mysql(MysqlConnectionOptions),
84    #[cfg(feature = "sqlite")]
85    Sqlite(PathBuf),
86}
87
88impl ConnectionOptions {
89    pub(crate) fn stream_chunk_size(&self) -> usize {
90        match self {
91            #[cfg(feature = "postgres")]
92            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
93            #[cfg(feature = "oracle")]
94            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
95            #[cfg(feature = "mysql")]
96            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
97            #[cfg(feature = "sqlite")]
98            ConnectionOptions::Sqlite(_) => unreachable!(),
99        }
100    }
101
102    pub(crate) fn db_type(&self) -> RemoteDbType {
103        match self {
104            #[cfg(feature = "postgres")]
105            ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
106            #[cfg(feature = "oracle")]
107            ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
108            #[cfg(feature = "mysql")]
109            ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
110            #[cfg(feature = "sqlite")]
111            ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
112        }
113    }
114}
115
116pub(crate) enum RemoteDbType {
117    Postgres,
118    Mysql,
119    Oracle,
120    Sqlite,
121}
122
123impl RemoteDbType {
124    pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
125        sql.trim()[0..6].eq_ignore_ascii_case("select")
126    }
127
128    pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
129        match self {
130            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
131            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
132            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
133            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
134                "Oracle unparser not implemented".to_string(),
135            )),
136        }
137    }
138
139    pub(crate) fn try_rewrite_query(
140        &self,
141        sql: &str,
142        filters: &[Expr],
143        limit: Option<usize>,
144    ) -> Option<String> {
145        if !self.support_rewrite_with_filters_limit(sql) {
146            return None;
147        }
148        match self {
149            RemoteDbType::Postgres | RemoteDbType::Mysql | RemoteDbType::Sqlite => {
150                let where_clause = if filters.is_empty() {
151                    "".to_string()
152                } else {
153                    let unparser = self.create_unparser().ok()?;
154                    let filters_ast = filters
155                        .iter()
156                        .map(|f| unparser.expr_to_sql(f).expect("checked already"))
157                        .collect::<Vec<_>>();
158                    format!(
159                        " WHERE {}",
160                        filters_ast
161                            .iter()
162                            .map(|f| format!("{f}"))
163                            .collect::<Vec<_>>()
164                            .join(" AND ")
165                    )
166                };
167                let limit_clause = if let Some(limit) = limit {
168                    format!(" LIMIT {limit}")
169                } else {
170                    "".to_string()
171                };
172
173                if where_clause.is_empty() && limit_clause.is_empty() {
174                    None
175                } else {
176                    Some(format!(
177                        "SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}"
178                    ))
179                }
180            }
181            RemoteDbType::Oracle => {
182                let limit = limit?;
183                Some(format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {limit}"))
184            }
185        }
186    }
187}
188
189pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
190    match projection {
191        Some(p) => p.contains(&col_idx),
192        None => true,
193    }
194}
195
196#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
197fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
198    use bigdecimal::{FromPrimitive, ToPrimitive};
199    let scale = scale.unwrap_or_else(|| {
200        decimal
201            .fractional_digit_count()
202            .try_into()
203            .unwrap_or_default()
204    });
205    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
206    (decimal * scale_decimal).to_i128()
207}