datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22use std::any::Any;
23
24use crate::{DFResult, RemoteSchemaRef, RemoteSource, Unparse, extract_primitive_array};
25use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
26use datafusion::common::DataFusionError;
27use datafusion::execution::SendableRecordBatchStream;
28use datafusion::physical_plan::common::collect;
29use datafusion::sql::unparser::Unparser;
30use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
31use std::fmt::Debug;
32use std::sync::Arc;
33
34#[cfg(feature = "dm")]
35pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
36
37#[async_trait::async_trait]
38pub trait Pool: Debug + Send + Sync {
39 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
40}
41
42#[async_trait::async_trait]
43pub trait Connection: Debug + Send + Sync {
44 fn as_any(&self) -> &dyn Any;
45
46 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
47
48 async fn query(
49 &self,
50 conn_options: &ConnectionOptions,
51 source: &RemoteSource,
52 table_schema: SchemaRef,
53 projection: Option<&Vec<usize>>,
54 unparsed_filters: &[String],
55 limit: Option<usize>,
56 ) -> DFResult<SendableRecordBatchStream>;
57
58 async fn insert(
59 &self,
60 conn_options: &ConnectionOptions,
61 unparser: Arc<dyn Unparse>,
62 table: &[String],
63 remote_schema: RemoteSchemaRef,
64 input: SendableRecordBatchStream,
65 ) -> DFResult<usize>;
66}
67
68pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
69 match options {
70 #[cfg(feature = "postgres")]
71 ConnectionOptions::Postgres(options) => {
72 let pool = connect_postgres(options).await?;
73 Ok(Arc::new(pool))
74 }
75 #[cfg(feature = "mysql")]
76 ConnectionOptions::Mysql(options) => {
77 let pool = connect_mysql(options)?;
78 Ok(Arc::new(pool))
79 }
80 #[cfg(feature = "oracle")]
81 ConnectionOptions::Oracle(options) => {
82 let pool = connect_oracle(options).await?;
83 Ok(Arc::new(pool))
84 }
85 #[cfg(feature = "sqlite")]
86 ConnectionOptions::Sqlite(options) => {
87 let pool = connect_sqlite(options).await?;
88 Ok(Arc::new(pool))
89 }
90 #[cfg(feature = "dm")]
91 ConnectionOptions::Dm(options) => {
92 let pool = connect_dm(options)?;
93 Ok(Arc::new(pool))
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
99pub enum ConnectionOptions {
100 #[cfg(feature = "postgres")]
101 Postgres(PostgresConnectionOptions),
102 #[cfg(feature = "oracle")]
103 Oracle(OracleConnectionOptions),
104 #[cfg(feature = "mysql")]
105 Mysql(MysqlConnectionOptions),
106 #[cfg(feature = "sqlite")]
107 Sqlite(SqliteConnectionOptions),
108 #[cfg(feature = "dm")]
109 Dm(DmConnectionOptions),
110}
111
112impl ConnectionOptions {
113 pub(crate) fn stream_chunk_size(&self) -> usize {
114 match self {
115 #[cfg(feature = "postgres")]
116 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
117 #[cfg(feature = "oracle")]
118 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
119 #[cfg(feature = "mysql")]
120 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
121 #[cfg(feature = "sqlite")]
122 ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
123 #[cfg(feature = "dm")]
124 ConnectionOptions::Dm(options) => options.stream_chunk_size,
125 }
126 }
127
128 pub(crate) fn db_type(&self) -> RemoteDbType {
129 match self {
130 #[cfg(feature = "postgres")]
131 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
132 #[cfg(feature = "oracle")]
133 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
134 #[cfg(feature = "mysql")]
135 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
136 #[cfg(feature = "sqlite")]
137 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
138 #[cfg(feature = "dm")]
139 ConnectionOptions::Dm(_) => RemoteDbType::Dm,
140 }
141 }
142
143 pub fn with_pool_max_size(self, pool_max_size: usize) -> Self {
144 match self {
145 #[cfg(feature = "postgres")]
146 ConnectionOptions::Postgres(options) => {
147 ConnectionOptions::Postgres(options.with_pool_max_size(pool_max_size))
148 }
149 #[cfg(feature = "oracle")]
150 ConnectionOptions::Oracle(options) => {
151 ConnectionOptions::Oracle(options.with_pool_max_size(pool_max_size))
152 }
153 #[cfg(feature = "mysql")]
154 ConnectionOptions::Mysql(options) => {
155 ConnectionOptions::Mysql(options.with_pool_max_size(pool_max_size))
156 }
157 #[cfg(feature = "sqlite")]
158 ConnectionOptions::Sqlite(options) => ConnectionOptions::Sqlite(options),
159 #[cfg(feature = "dm")]
160 ConnectionOptions::Dm(options) => ConnectionOptions::Dm(options),
161 }
162 }
163}
164
165#[derive(Debug, Clone, Copy)]
166pub enum RemoteDbType {
167 Postgres,
168 Mysql,
169 Oracle,
170 Sqlite,
171 Dm,
172}
173
174impl RemoteDbType {
175 pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
176 match source {
177 RemoteSource::Table(_) => true,
178 RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
179 }
180 }
181
182 pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
183 match self {
184 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
185 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
186 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
187 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
188 "Oracle unparser not implemented".to_string(),
189 )),
190 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
191 "Dm unparser not implemented".to_string(),
192 )),
193 }
194 }
195
196 pub(crate) fn rewrite_query(
197 &self,
198 source: &RemoteSource,
199 unparsed_filters: &[String],
200 limit: Option<usize>,
201 ) -> String {
202 match source {
203 RemoteSource::Table(table) => match self {
204 RemoteDbType::Postgres
205 | RemoteDbType::Mysql
206 | RemoteDbType::Sqlite
207 | RemoteDbType::Dm => {
208 let where_clause = if unparsed_filters.is_empty() {
209 "".to_string()
210 } else {
211 format!(" WHERE {}", unparsed_filters.join(" AND "))
212 };
213 let limit_clause = if let Some(limit) = limit {
214 format!(" LIMIT {limit}")
215 } else {
216 "".to_string()
217 };
218
219 format!(
220 "{}{where_clause}{limit_clause}",
221 self.select_all_query(table)
222 )
223 }
224 RemoteDbType::Oracle => {
225 let mut all_filters: Vec<String> = vec![];
226 all_filters.extend_from_slice(unparsed_filters);
227 if let Some(limit) = limit {
228 all_filters.push(format!("ROWNUM <= {limit}"))
229 }
230
231 let where_clause = if all_filters.is_empty() {
232 "".to_string()
233 } else {
234 format!(" WHERE {}", all_filters.join(" AND "))
235 };
236 format!("{}{where_clause}", self.select_all_query(table))
237 }
238 },
239 RemoteSource::Query(query) => match self {
240 RemoteDbType::Postgres
241 | RemoteDbType::Mysql
242 | RemoteDbType::Sqlite
243 | RemoteDbType::Dm => {
244 let where_clause = if unparsed_filters.is_empty() {
245 "".to_string()
246 } else {
247 format!(" WHERE {}", unparsed_filters.join(" AND "))
248 };
249 let limit_clause = if let Some(limit) = limit {
250 format!(" LIMIT {limit}")
251 } else {
252 "".to_string()
253 };
254
255 if where_clause.is_empty() && limit_clause.is_empty() {
256 query.clone()
257 } else {
258 format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
259 }
260 }
261 RemoteDbType::Oracle => {
262 let mut all_filters: Vec<String> = vec![];
263 all_filters.extend_from_slice(unparsed_filters);
264 if let Some(limit) = limit {
265 all_filters.push(format!("ROWNUM <= {limit}"))
266 }
267
268 let where_clause = if all_filters.is_empty() {
269 "".to_string()
270 } else {
271 format!(" WHERE {}", all_filters.join(" AND "))
272 };
273 if where_clause.is_empty() {
274 query.clone()
275 } else {
276 format!("SELECT * FROM ({query}){where_clause}")
277 }
278 }
279 },
280 }
281 }
282
283 pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
284 match self {
285 RemoteDbType::Postgres
286 | RemoteDbType::Oracle
287 | RemoteDbType::Sqlite
288 | RemoteDbType::Dm => {
289 format!("\"{identifier}\"")
290 }
291 RemoteDbType::Mysql => {
292 format!("`{identifier}`")
293 }
294 }
295 }
296
297 pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
298 indentifiers
299 .iter()
300 .map(|identifier| self.sql_identifier(identifier))
301 .collect::<Vec<String>>()
302 .join(".")
303 }
304
305 pub(crate) fn sql_string_literal(&self, value: &str) -> String {
306 let value = value.replace("'", "''");
307 format!("'{value}'")
308 }
309
310 pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
311 match self {
312 RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
313 RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
314 RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
315 }
316 }
317
318 pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
319 match self {
320 RemoteDbType::Postgres
321 | RemoteDbType::Mysql
322 | RemoteDbType::Oracle
323 | RemoteDbType::Sqlite
324 | RemoteDbType::Dm => {
325 format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
326 }
327 }
328 }
329
330 pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
331 if !self.support_rewrite_with_filters_limit(source) {
332 return source.query(*self);
333 }
334 self.rewrite_query(source, &[], Some(1))
335 }
336
337 pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
338 if !self.support_rewrite_with_filters_limit(source) {
339 return None;
340 }
341 match source {
342 RemoteSource::Table(table) => Some(self.select_all_query(table)),
343 RemoteSource::Query(query) => match self {
344 RemoteDbType::Postgres
345 | RemoteDbType::Mysql
346 | RemoteDbType::Sqlite
347 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
348 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
349 },
350 }
351 }
352
353 pub(crate) async fn fetch_count(
354 &self,
355 conn: Arc<dyn Connection>,
356 conn_options: &ConnectionOptions,
357 count1_query: &str,
358 ) -> DFResult<usize> {
359 let count1_schema = Arc::new(Schema::new(vec![Field::new(
360 "count(1)",
361 DataType::Int64,
362 false,
363 )]));
364 let stream = conn
365 .query(
366 conn_options,
367 &RemoteSource::Query(count1_query.to_string()),
368 count1_schema,
369 None,
370 &[],
371 None,
372 )
373 .await?;
374 let batches = collect(stream).await?;
375 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
376 if count_vec.len() != 1 {
377 return Err(DataFusionError::Execution(format!(
378 "Count query did not return exactly one row: {count_vec:?}",
379 )));
380 }
381 count_vec[0]
382 .map(|count| count as usize)
383 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
384 }
385}
386
387pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
388 match projection {
389 Some(p) => p.contains(&col_idx),
390 None => true,
391 }
392}
393
394#[allow(unused)]
395fn just_return<T>(v: T) -> DFResult<T> {
396 Ok(v)
397}
398
399#[allow(unused)]
400fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
401 Ok(*t)
402}
403
404#[test]
405fn tst_f32() {
406 println!("{}", 10f32.powi(40));
407}