datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
29use datafusion::common::DataFusionError;
30use datafusion::execution::SendableRecordBatchStream;
31use datafusion::physical_plan::common::collect;
32use datafusion::sql::unparser::Unparser;
33use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
34use std::fmt::Debug;
35use std::sync::Arc;
36
37#[cfg(feature = "dm")]
38pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
39
40#[async_trait::async_trait]
41pub trait Pool: Debug + Send + Sync {
42 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
43}
44
45#[async_trait::async_trait]
46pub trait Connection: Debug + Send + Sync {
47 fn as_any(&self) -> &dyn Any;
48
49 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
50
51 async fn query(
52 &self,
53 conn_options: &ConnectionOptions,
54 source: &RemoteSource,
55 table_schema: SchemaRef,
56 projection: Option<&Vec<usize>>,
57 unparsed_filters: &[String],
58 limit: Option<usize>,
59 ) -> DFResult<SendableRecordBatchStream>;
60
61 async fn insert(
62 &self,
63 conn_options: &ConnectionOptions,
64 literalizer: Arc<dyn Literalize>,
65 table: &[String],
66 remote_schema: RemoteSchemaRef,
67 input: SendableRecordBatchStream,
68 ) -> DFResult<usize>;
69}
70
71#[allow(unused_variables)]
72pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
73 match options {
74 ConnectionOptions::Postgres(options) => {
75 #[cfg(feature = "postgres")]
76 {
77 let pool = connect_postgres(options).await?;
78 Ok(Arc::new(pool))
79 }
80 #[cfg(not(feature = "postgres"))]
81 {
82 Err(DataFusionError::Internal(
83 "Please enable the postgres feature".to_string(),
84 ))
85 }
86 }
87 ConnectionOptions::Mysql(options) => {
88 #[cfg(feature = "mysql")]
89 {
90 let pool = connect_mysql(options)?;
91 Ok(Arc::new(pool))
92 }
93 #[cfg(not(feature = "mysql"))]
94 {
95 Err(DataFusionError::Internal(
96 "Please enable the mysql feature".to_string(),
97 ))
98 }
99 }
100 ConnectionOptions::Oracle(options) => {
101 #[cfg(feature = "oracle")]
102 {
103 let pool = connect_oracle(options).await?;
104 Ok(Arc::new(pool))
105 }
106 #[cfg(not(feature = "oracle"))]
107 {
108 Err(DataFusionError::Internal(
109 "Please enable the oracle feature".to_string(),
110 ))
111 }
112 }
113 ConnectionOptions::Sqlite(options) => {
114 #[cfg(feature = "sqlite")]
115 {
116 let pool = connect_sqlite(options).await?;
117 Ok(Arc::new(pool))
118 }
119 #[cfg(not(feature = "sqlite"))]
120 {
121 Err(DataFusionError::Internal(
122 "Please enable the sqlite feature".to_string(),
123 ))
124 }
125 }
126 ConnectionOptions::Dm(options) => {
127 #[cfg(feature = "dm")]
128 {
129 let pool = connect_dm(options)?;
130 Ok(Arc::new(pool))
131 }
132 #[cfg(not(feature = "dm"))]
133 {
134 Err(DataFusionError::Internal(
135 "Please enable the dm feature".to_string(),
136 ))
137 }
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy)]
143pub enum RemoteDbType {
144 Postgres,
145 Mysql,
146 Oracle,
147 Sqlite,
148 Dm,
149}
150
151impl RemoteDbType {
152 pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
153 match source {
154 RemoteSource::Table(_) => true,
155 RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
156 }
157 }
158
159 pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
160 match self {
161 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
162 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
163 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
164 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
165 "Oracle unparser not implemented".to_string(),
166 )),
167 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
168 "Dm unparser not implemented".to_string(),
169 )),
170 }
171 }
172
173 pub(crate) fn rewrite_query(
174 &self,
175 source: &RemoteSource,
176 unparsed_filters: &[String],
177 limit: Option<usize>,
178 ) -> String {
179 match source {
180 RemoteSource::Table(table) => match self {
181 RemoteDbType::Postgres
182 | RemoteDbType::Mysql
183 | RemoteDbType::Sqlite
184 | RemoteDbType::Dm => {
185 let where_clause = if unparsed_filters.is_empty() {
186 "".to_string()
187 } else {
188 format!(" WHERE {}", unparsed_filters.join(" AND "))
189 };
190 let limit_clause = if let Some(limit) = limit {
191 format!(" LIMIT {limit}")
192 } else {
193 "".to_string()
194 };
195
196 format!(
197 "{}{where_clause}{limit_clause}",
198 self.select_all_query(table)
199 )
200 }
201 RemoteDbType::Oracle => {
202 let mut all_filters: Vec<String> = vec![];
203 all_filters.extend_from_slice(unparsed_filters);
204 if let Some(limit) = limit {
205 all_filters.push(format!("ROWNUM <= {limit}"))
206 }
207
208 let where_clause = if all_filters.is_empty() {
209 "".to_string()
210 } else {
211 format!(" WHERE {}", all_filters.join(" AND "))
212 };
213 format!("{}{where_clause}", self.select_all_query(table))
214 }
215 },
216 RemoteSource::Query(query) => match self {
217 RemoteDbType::Postgres
218 | RemoteDbType::Mysql
219 | RemoteDbType::Sqlite
220 | RemoteDbType::Dm => {
221 let where_clause = if unparsed_filters.is_empty() {
222 "".to_string()
223 } else {
224 format!(" WHERE {}", unparsed_filters.join(" AND "))
225 };
226 let limit_clause = if let Some(limit) = limit {
227 format!(" LIMIT {limit}")
228 } else {
229 "".to_string()
230 };
231
232 if where_clause.is_empty() && limit_clause.is_empty() {
233 query.clone()
234 } else {
235 format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
236 }
237 }
238 RemoteDbType::Oracle => {
239 let mut all_filters: Vec<String> = vec![];
240 all_filters.extend_from_slice(unparsed_filters);
241 if let Some(limit) = limit {
242 all_filters.push(format!("ROWNUM <= {limit}"))
243 }
244
245 let where_clause = if all_filters.is_empty() {
246 "".to_string()
247 } else {
248 format!(" WHERE {}", all_filters.join(" AND "))
249 };
250 if where_clause.is_empty() {
251 query.clone()
252 } else {
253 format!("SELECT * FROM ({query}){where_clause}")
254 }
255 }
256 },
257 }
258 }
259
260 pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
261 match self {
262 RemoteDbType::Postgres
263 | RemoteDbType::Oracle
264 | RemoteDbType::Sqlite
265 | RemoteDbType::Dm => {
266 format!("\"{identifier}\"")
267 }
268 RemoteDbType::Mysql => {
269 format!("`{identifier}`")
270 }
271 }
272 }
273
274 pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
275 indentifiers
276 .iter()
277 .map(|identifier| self.sql_identifier(identifier))
278 .collect::<Vec<String>>()
279 .join(".")
280 }
281
282 pub(crate) fn sql_string_literal(&self, value: &str) -> String {
283 let value = value.replace("'", "''");
284 format!("'{value}'")
285 }
286
287 pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
288 match self {
289 RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
290 RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
291 RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
292 }
293 }
294
295 pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
296 match self {
297 RemoteDbType::Postgres
298 | RemoteDbType::Mysql
299 | RemoteDbType::Oracle
300 | RemoteDbType::Sqlite
301 | RemoteDbType::Dm => {
302 format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
303 }
304 }
305 }
306
307 pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
308 if !self.support_rewrite_with_filters_limit(source) {
309 return source.query(*self);
310 }
311 self.rewrite_query(source, &[], Some(1))
312 }
313
314 pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
315 if !self.support_rewrite_with_filters_limit(source) {
316 return None;
317 }
318 match source {
319 RemoteSource::Table(table) => Some(self.select_all_query(table)),
320 RemoteSource::Query(query) => match self {
321 RemoteDbType::Postgres
322 | RemoteDbType::Mysql
323 | RemoteDbType::Sqlite
324 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
325 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
326 },
327 }
328 }
329
330 pub(crate) async fn fetch_count(
331 &self,
332 conn: Arc<dyn Connection>,
333 conn_options: &ConnectionOptions,
334 count1_query: &str,
335 ) -> DFResult<usize> {
336 let count1_schema = Arc::new(Schema::new(vec![Field::new(
337 "count(1)",
338 DataType::Int64,
339 false,
340 )]));
341 let stream = conn
342 .query(
343 conn_options,
344 &RemoteSource::Query(count1_query.to_string()),
345 count1_schema,
346 None,
347 &[],
348 None,
349 )
350 .await?;
351 let batches = collect(stream).await?;
352 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
353 if count_vec.len() != 1 {
354 return Err(DataFusionError::Execution(format!(
355 "Count query did not return exactly one row: {count_vec:?}",
356 )));
357 }
358 count_vec[0]
359 .map(|count| count as usize)
360 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
361 }
362}
363
364pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
365 match projection {
366 Some(p) => p.contains(&col_idx),
367 None => true,
368 }
369}
370
371#[allow(unused)]
372fn just_return<T>(v: T) -> DFResult<T> {
373 Ok(v)
374}
375
376#[allow(unused)]
377fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
378 Ok(*t)
379}
380
381#[test]
382fn tst_f32() {
383 println!("{}", 10f32.powi(40));
384}