datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22
23use crate::{DFResult, RemoteSchemaRef};
24use datafusion::arrow::datatypes::SchemaRef;
25use datafusion::common::DataFusionError;
26use datafusion::execution::SendableRecordBatchStream;
27use datafusion::prelude::Expr;
28use datafusion::sql::unparser::Unparser;
29use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
30use std::fmt::Debug;
31use std::sync::Arc;
32
33#[cfg(feature = "dm")]
34pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
35
36#[async_trait::async_trait]
37pub trait Pool: Debug + Send + Sync {
38 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
39}
40
41#[async_trait::async_trait]
42pub trait Connection: Debug + Send + Sync {
43 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
44
45 async fn query(
46 &self,
47 conn_options: &ConnectionOptions,
48 sql: &str,
49 table_schema: SchemaRef,
50 projection: Option<&Vec<usize>>,
51 filters: &[Expr],
52 limit: Option<usize>,
53 ) -> DFResult<SendableRecordBatchStream>;
54}
55
56pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
57 match options {
58 #[cfg(feature = "postgres")]
59 ConnectionOptions::Postgres(options) => {
60 let pool = connect_postgres(options).await?;
61 Ok(Arc::new(pool))
62 }
63 #[cfg(feature = "mysql")]
64 ConnectionOptions::Mysql(options) => {
65 let pool = connect_mysql(options)?;
66 Ok(Arc::new(pool))
67 }
68 #[cfg(feature = "oracle")]
69 ConnectionOptions::Oracle(options) => {
70 let pool = connect_oracle(options).await?;
71 Ok(Arc::new(pool))
72 }
73 #[cfg(feature = "sqlite")]
74 ConnectionOptions::Sqlite(options) => {
75 let pool = connect_sqlite(options).await?;
76 Ok(Arc::new(pool))
77 }
78 #[cfg(feature = "dm")]
79 ConnectionOptions::Dm(options) => {
80 let pool = connect_dm(options)?;
81 Ok(Arc::new(pool))
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub enum ConnectionOptions {
88 #[cfg(feature = "postgres")]
89 Postgres(PostgresConnectionOptions),
90 #[cfg(feature = "oracle")]
91 Oracle(OracleConnectionOptions),
92 #[cfg(feature = "mysql")]
93 Mysql(MysqlConnectionOptions),
94 #[cfg(feature = "sqlite")]
95 Sqlite(SqliteConnectionOptions),
96 #[cfg(feature = "dm")]
97 Dm(DmConnectionOptions),
98}
99
100impl ConnectionOptions {
101 pub(crate) fn stream_chunk_size(&self) -> usize {
102 match self {
103 #[cfg(feature = "postgres")]
104 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
105 #[cfg(feature = "oracle")]
106 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
107 #[cfg(feature = "mysql")]
108 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
109 #[cfg(feature = "sqlite")]
110 ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
111 #[cfg(feature = "dm")]
112 ConnectionOptions::Dm(options) => options.stream_chunk_size,
113 }
114 }
115
116 pub(crate) fn db_type(&self) -> RemoteDbType {
117 match self {
118 #[cfg(feature = "postgres")]
119 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
120 #[cfg(feature = "oracle")]
121 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
122 #[cfg(feature = "mysql")]
123 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
124 #[cfg(feature = "sqlite")]
125 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
126 #[cfg(feature = "dm")]
127 ConnectionOptions::Dm(_) => RemoteDbType::Dm,
128 }
129 }
130}
131
132pub enum RemoteDbType {
133 Postgres,
134 Mysql,
135 Oracle,
136 Sqlite,
137 Dm,
138}
139
140impl RemoteDbType {
141 pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
142 sql.trim()[0..6].eq_ignore_ascii_case("select")
143 }
144
145 pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
146 match self {
147 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
148 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
149 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
150 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
151 "Oracle unparser not implemented".to_string(),
152 )),
153 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
154 "Dm unparser not implemented".to_string(),
155 )),
156 }
157 }
158
159 pub(crate) fn try_rewrite_query(
160 &self,
161 sql: &str,
162 filters: &[Expr],
163 limit: Option<usize>,
164 ) -> DFResult<String> {
165 match self {
166 RemoteDbType::Postgres | RemoteDbType::Mysql | RemoteDbType::Sqlite => {
167 let where_clause = if filters.is_empty() {
168 "".to_string()
169 } else {
170 let unparser = self.create_unparser()?;
171 let filters_ast = filters
172 .iter()
173 .map(|f| unparser.expr_to_sql(f))
174 .collect::<DFResult<Vec<_>>>()?;
175 format!(
176 " WHERE {}",
177 filters_ast
178 .iter()
179 .map(|f| format!("{f}"))
180 .collect::<Vec<_>>()
181 .join(" AND ")
182 )
183 };
184 let limit_clause = if let Some(limit) = limit {
185 format!(" LIMIT {limit}")
186 } else {
187 "".to_string()
188 };
189
190 if where_clause.is_empty() && limit_clause.is_empty() {
191 Ok(sql.to_string())
192 } else {
193 Ok(format!(
194 "SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}"
195 ))
196 }
197 }
198 RemoteDbType::Oracle => Ok(limit
199 .map(|l| format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {l}"))
200 .unwrap_or_else(|| sql.to_string())),
201 RemoteDbType::Dm => Ok(limit
202 .map(|l| format!("SELECT * FROM ({sql}) LIMIT {l}"))
203 .unwrap_or_else(|| sql.to_string())),
204 }
205 }
206
207 pub(crate) fn query_limit_1(&self, sql: &str) -> DFResult<String> {
208 if !self.support_rewrite_with_filters_limit(sql) {
209 return Ok(sql.to_string());
210 }
211 self.try_rewrite_query(sql, &[], Some(1))
212 }
213}
214
215pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
216 match projection {
217 Some(p) => p.contains(&col_idx),
218 None => true,
219 }
220}
221
222#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
223fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
224 use bigdecimal::{FromPrimitive, ToPrimitive};
225 let scale = scale.unwrap_or_else(|| {
226 decimal
227 .fractional_digit_count()
228 .try_into()
229 .unwrap_or_default()
230 });
231 let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
232 (decimal * scale_decimal).to_i128()
233}