datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22use std::any::Any;
23
24use crate::{DFResult, RemoteSchemaRef, extract_primitive_array};
25use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
26use datafusion::common::DataFusionError;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::common::collect;
29use datafusion::sql::unparser::Unparser;
30use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
31use std::fmt::Debug;
32use std::sync::Arc;
33
34#[cfg(feature = "dm")]
35pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
36
37#[async_trait::async_trait]
38pub trait Pool: Debug + Send + Sync {
39 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
40}
41
42#[async_trait::async_trait]
43pub trait Connection: Debug + Send + Sync {
44 fn as_any(&self) -> &dyn Any;
45
46 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
47
48 async fn query(
49 &self,
50 conn_options: &ConnectionOptions,
51 sql: &str,
52 table_schema: SchemaRef,
53 projection: Option<&Vec<usize>>,
54 unparsed_filters: &[String],
55 limit: Option<usize>,
56 ) -> DFResult<SendableRecordBatchStream>;
57}
58
59pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
60 match options {
61 #[cfg(feature = "postgres")]
62 ConnectionOptions::Postgres(options) => {
63 let pool = connect_postgres(options).await?;
64 Ok(Arc::new(pool))
65 }
66 #[cfg(feature = "mysql")]
67 ConnectionOptions::Mysql(options) => {
68 let pool = connect_mysql(options)?;
69 Ok(Arc::new(pool))
70 }
71 #[cfg(feature = "oracle")]
72 ConnectionOptions::Oracle(options) => {
73 let pool = connect_oracle(options).await?;
74 Ok(Arc::new(pool))
75 }
76 #[cfg(feature = "sqlite")]
77 ConnectionOptions::Sqlite(options) => {
78 let pool = connect_sqlite(options).await?;
79 Ok(Arc::new(pool))
80 }
81 #[cfg(feature = "dm")]
82 ConnectionOptions::Dm(options) => {
83 let pool = connect_dm(options)?;
84 Ok(Arc::new(pool))
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
90pub enum ConnectionOptions {
91 #[cfg(feature = "postgres")]
92 Postgres(PostgresConnectionOptions),
93 #[cfg(feature = "oracle")]
94 Oracle(OracleConnectionOptions),
95 #[cfg(feature = "mysql")]
96 Mysql(MysqlConnectionOptions),
97 #[cfg(feature = "sqlite")]
98 Sqlite(SqliteConnectionOptions),
99 #[cfg(feature = "dm")]
100 Dm(DmConnectionOptions),
101}
102
103impl ConnectionOptions {
104 pub(crate) fn stream_chunk_size(&self) -> usize {
105 match self {
106 #[cfg(feature = "postgres")]
107 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
108 #[cfg(feature = "oracle")]
109 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
110 #[cfg(feature = "mysql")]
111 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
112 #[cfg(feature = "sqlite")]
113 ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
114 #[cfg(feature = "dm")]
115 ConnectionOptions::Dm(options) => options.stream_chunk_size,
116 }
117 }
118
119 pub(crate) fn db_type(&self) -> RemoteDbType {
120 match self {
121 #[cfg(feature = "postgres")]
122 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
123 #[cfg(feature = "oracle")]
124 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
125 #[cfg(feature = "mysql")]
126 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
127 #[cfg(feature = "sqlite")]
128 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
129 #[cfg(feature = "dm")]
130 ConnectionOptions::Dm(_) => RemoteDbType::Dm,
131 }
132 }
133}
134
135pub enum RemoteDbType {
136 Postgres,
137 Mysql,
138 Oracle,
139 Sqlite,
140 Dm,
141}
142
143impl RemoteDbType {
144 pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
145 sql.trim()[0..6].eq_ignore_ascii_case("select")
146 }
147
148 pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
149 match self {
150 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
151 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
152 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
153 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
154 "Oracle unparser not implemented".to_string(),
155 )),
156 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
157 "Dm unparser not implemented".to_string(),
158 )),
159 }
160 }
161
162 pub(crate) fn rewrite_query(
163 &self,
164 sql: &str,
165 unparsed_filters: &[String],
166 limit: Option<usize>,
167 ) -> String {
168 match self {
169 RemoteDbType::Postgres
170 | RemoteDbType::Mysql
171 | RemoteDbType::Sqlite
172 | RemoteDbType::Dm => {
173 let where_clause = if unparsed_filters.is_empty() {
174 "".to_string()
175 } else {
176 format!(" WHERE {}", unparsed_filters.join(" AND "))
177 };
178 let limit_clause = if let Some(limit) = limit {
179 format!(" LIMIT {limit}")
180 } else {
181 "".to_string()
182 };
183
184 if where_clause.is_empty() && limit_clause.is_empty() {
185 sql.to_string()
186 } else {
187 format!("SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}")
188 }
189 }
190 RemoteDbType::Oracle => {
191 let mut all_filters: Vec<String> = vec![];
192 all_filters.extend_from_slice(unparsed_filters);
193 if let Some(limit) = limit {
194 all_filters.push(format!("ROWNUM <= {limit}"))
195 }
196
197 let where_clause = if all_filters.is_empty() {
198 "".to_string()
199 } else {
200 format!(" WHERE {}", all_filters.join(" AND "))
201 };
202 if where_clause.is_empty() {
203 sql.to_string()
204 } else {
205 format!("SELECT * FROM ({sql}){where_clause}")
206 }
207 }
208 }
209 }
210
211 pub(crate) fn query_limit_1(&self, sql: &str) -> String {
212 if !self.support_rewrite_with_filters_limit(sql) {
213 return sql.to_string();
214 }
215 self.rewrite_query(sql, &[], Some(1))
216 }
217
218 pub(crate) fn try_count1_query(&self, sql: &str) -> Option<String> {
219 if !self.support_rewrite_with_filters_limit(sql) {
220 return None;
221 }
222 match self {
223 RemoteDbType::Postgres
224 | RemoteDbType::Mysql
225 | RemoteDbType::Sqlite
226 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({sql}) AS __subquery")),
227 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({sql})")),
228 }
229 }
230
231 pub(crate) async fn fetch_count(
232 &self,
233 conn: Arc<dyn Connection>,
234 conn_options: &ConnectionOptions,
235 count1_query: &str,
236 ) -> DFResult<usize> {
237 let count1_schema = Arc::new(Schema::new(vec![Field::new(
238 "count(1)",
239 DataType::Int64,
240 false,
241 )]));
242 let stream = conn
243 .query(conn_options, count1_query, count1_schema, None, &[], None)
244 .await?;
245 let batches = collect(stream).await?;
246 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
247 if count_vec.len() != 1 {
248 return Err(DataFusionError::Execution(format!(
249 "Count query did not return exactly one row: {count_vec:?}",
250 )));
251 }
252 count_vec[0]
253 .map(|count| count as usize)
254 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
255 }
256}
257
258pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
259 match projection {
260 Some(p) => p.contains(&col_idx),
261 None => true,
262 }
263}
264
265#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
266fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
267 use bigdecimal::{FromPrimitive, ToPrimitive};
268 let scale = scale.unwrap_or_else(|| {
269 decimal
270 .fractional_digit_count()
271 .try_into()
272 .unwrap_or_default()
273 });
274 let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
275 (decimal * scale_decimal).to_i128()
276}
277
278#[allow(unused)]
279fn just_return<T>(v: T) -> DFResult<T> {
280 Ok(v)
281}
282
283#[allow(unused)]
284fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
285 Ok(*t)
286}