datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
29use datafusion::common::DataFusionError;
30use datafusion::execution::SendableRecordBatchStream;
31use datafusion::physical_plan::common::collect;
32use datafusion::sql::unparser::Unparser;
33use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
34use std::fmt::Debug;
35use std::sync::Arc;
36
37#[cfg(feature = "dm")]
38pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
39
40#[async_trait::async_trait]
41pub trait Pool: Debug + Send + Sync {
42 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
43 async fn state(&self) -> DFResult<PoolState>;
44}
45
46#[derive(Debug, Clone)]
47pub struct PoolState {
48 pub connections: usize,
49 pub idle_connections: usize,
50}
51
52#[async_trait::async_trait]
53pub trait Connection: Debug + Send + Sync {
54 fn as_any(&self) -> &dyn Any;
55
56 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
57
58 async fn query(
59 &self,
60 conn_options: &ConnectionOptions,
61 source: &RemoteSource,
62 table_schema: SchemaRef,
63 projection: Option<&Vec<usize>>,
64 unparsed_filters: &[String],
65 limit: Option<usize>,
66 ) -> DFResult<SendableRecordBatchStream>;
67
68 async fn insert(
69 &self,
70 conn_options: &ConnectionOptions,
71 literalizer: Arc<dyn Literalize>,
72 table: &[String],
73 remote_schema: RemoteSchemaRef,
74 input: SendableRecordBatchStream,
75 ) -> DFResult<usize>;
76}
77
78#[allow(unused_variables)]
79pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
80 match options {
81 ConnectionOptions::Postgres(options) => {
82 #[cfg(feature = "postgres")]
83 {
84 let pool = connect_postgres(options).await?;
85 Ok(Arc::new(pool))
86 }
87 #[cfg(not(feature = "postgres"))]
88 {
89 Err(DataFusionError::Internal(
90 "Please enable the postgres feature".to_string(),
91 ))
92 }
93 }
94 ConnectionOptions::Mysql(options) => {
95 #[cfg(feature = "mysql")]
96 {
97 let pool = connect_mysql(options)?;
98 Ok(Arc::new(pool))
99 }
100 #[cfg(not(feature = "mysql"))]
101 {
102 Err(DataFusionError::Internal(
103 "Please enable the mysql feature".to_string(),
104 ))
105 }
106 }
107 ConnectionOptions::Oracle(options) => {
108 #[cfg(feature = "oracle")]
109 {
110 let pool = connect_oracle(options).await?;
111 Ok(Arc::new(pool))
112 }
113 #[cfg(not(feature = "oracle"))]
114 {
115 Err(DataFusionError::Internal(
116 "Please enable the oracle feature".to_string(),
117 ))
118 }
119 }
120 ConnectionOptions::Sqlite(options) => {
121 #[cfg(feature = "sqlite")]
122 {
123 let pool = connect_sqlite(options).await?;
124 Ok(Arc::new(pool))
125 }
126 #[cfg(not(feature = "sqlite"))]
127 {
128 Err(DataFusionError::Internal(
129 "Please enable the sqlite feature".to_string(),
130 ))
131 }
132 }
133 ConnectionOptions::Dm(options) => {
134 #[cfg(feature = "dm")]
135 {
136 let pool = connect_dm(options)?;
137 Ok(Arc::new(pool))
138 }
139 #[cfg(not(feature = "dm"))]
140 {
141 Err(DataFusionError::Internal(
142 "Please enable the dm feature".to_string(),
143 ))
144 }
145 }
146 }
147}
148
149#[derive(Debug, Clone, Copy)]
150pub enum RemoteDbType {
151 Postgres,
152 Mysql,
153 Oracle,
154 Sqlite,
155 Dm,
156}
157
158impl RemoteDbType {
159 pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
160 match source {
161 RemoteSource::Table(_) => true,
162 RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
163 }
164 }
165
166 pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
167 match self {
168 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
169 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
170 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
171 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
172 "Oracle unparser not implemented".to_string(),
173 )),
174 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
175 "Dm unparser not implemented".to_string(),
176 )),
177 }
178 }
179
180 pub(crate) fn rewrite_query(
181 &self,
182 source: &RemoteSource,
183 unparsed_filters: &[String],
184 limit: Option<usize>,
185 ) -> String {
186 match source {
187 RemoteSource::Table(table) => match self {
188 RemoteDbType::Postgres
189 | RemoteDbType::Mysql
190 | RemoteDbType::Sqlite
191 | RemoteDbType::Dm => {
192 let where_clause = if unparsed_filters.is_empty() {
193 "".to_string()
194 } else {
195 format!(" WHERE {}", unparsed_filters.join(" AND "))
196 };
197 let limit_clause = if let Some(limit) = limit {
198 format!(" LIMIT {limit}")
199 } else {
200 "".to_string()
201 };
202
203 format!(
204 "{}{where_clause}{limit_clause}",
205 self.select_all_query(table)
206 )
207 }
208 RemoteDbType::Oracle => {
209 let mut all_filters: Vec<String> = vec![];
210 all_filters.extend_from_slice(unparsed_filters);
211 if let Some(limit) = limit {
212 all_filters.push(format!("ROWNUM <= {limit}"))
213 }
214
215 let where_clause = if all_filters.is_empty() {
216 "".to_string()
217 } else {
218 format!(" WHERE {}", all_filters.join(" AND "))
219 };
220 format!("{}{where_clause}", self.select_all_query(table))
221 }
222 },
223 RemoteSource::Query(query) => match self {
224 RemoteDbType::Postgres
225 | RemoteDbType::Mysql
226 | RemoteDbType::Sqlite
227 | RemoteDbType::Dm => {
228 let where_clause = if unparsed_filters.is_empty() {
229 "".to_string()
230 } else {
231 format!(" WHERE {}", unparsed_filters.join(" AND "))
232 };
233 let limit_clause = if let Some(limit) = limit {
234 format!(" LIMIT {limit}")
235 } else {
236 "".to_string()
237 };
238
239 if where_clause.is_empty() && limit_clause.is_empty() {
240 query.clone()
241 } else {
242 format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
243 }
244 }
245 RemoteDbType::Oracle => {
246 let mut all_filters: Vec<String> = vec![];
247 all_filters.extend_from_slice(unparsed_filters);
248 if let Some(limit) = limit {
249 all_filters.push(format!("ROWNUM <= {limit}"))
250 }
251
252 let where_clause = if all_filters.is_empty() {
253 "".to_string()
254 } else {
255 format!(" WHERE {}", all_filters.join(" AND "))
256 };
257 if where_clause.is_empty() {
258 query.clone()
259 } else {
260 format!("SELECT * FROM ({query}){where_clause}")
261 }
262 }
263 },
264 }
265 }
266
267 pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
268 match self {
269 RemoteDbType::Postgres
270 | RemoteDbType::Oracle
271 | RemoteDbType::Sqlite
272 | RemoteDbType::Dm => {
273 format!("\"{identifier}\"")
274 }
275 RemoteDbType::Mysql => {
276 format!("`{identifier}`")
277 }
278 }
279 }
280
281 pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
282 indentifiers
283 .iter()
284 .map(|identifier| self.sql_identifier(identifier))
285 .collect::<Vec<String>>()
286 .join(".")
287 }
288
289 pub(crate) fn sql_string_literal(&self, value: &str) -> String {
290 let value = value.replace("'", "''");
291 format!("'{value}'")
292 }
293
294 pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
295 match self {
296 RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
297 RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
298 RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
299 }
300 }
301
302 pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
303 match self {
304 RemoteDbType::Postgres
305 | RemoteDbType::Mysql
306 | RemoteDbType::Oracle
307 | RemoteDbType::Sqlite
308 | RemoteDbType::Dm => {
309 format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
310 }
311 }
312 }
313
314 pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
315 if !self.support_rewrite_with_filters_limit(source) {
316 return source.query(*self);
317 }
318 self.rewrite_query(source, &[], Some(1))
319 }
320
321 pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
322 if !self.support_rewrite_with_filters_limit(source) {
323 return None;
324 }
325 match source {
326 RemoteSource::Table(table) => Some(self.select_all_query(table)),
327 RemoteSource::Query(query) => match self {
328 RemoteDbType::Postgres
329 | RemoteDbType::Mysql
330 | RemoteDbType::Sqlite
331 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
332 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
333 },
334 }
335 }
336
337 pub(crate) async fn fetch_count(
338 &self,
339 conn: Arc<dyn Connection>,
340 conn_options: &ConnectionOptions,
341 count1_query: &str,
342 ) -> DFResult<usize> {
343 let count1_schema = Arc::new(Schema::new(vec![Field::new(
344 "count(1)",
345 DataType::Int64,
346 false,
347 )]));
348 let stream = conn
349 .query(
350 conn_options,
351 &RemoteSource::Query(count1_query.to_string()),
352 count1_schema,
353 None,
354 &[],
355 None,
356 )
357 .await?;
358 let batches = collect(stream).await?;
359 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
360 if count_vec.len() != 1 {
361 return Err(DataFusionError::Execution(format!(
362 "Count query did not return exactly one row: {count_vec:?}",
363 )));
364 }
365 count_vec[0]
366 .map(|count| count as usize)
367 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
368 }
369}
370
371pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
372 match projection {
373 Some(p) => p.contains(&col_idx),
374 None => true,
375 }
376}
377
378#[allow(unused)]
379fn just_return<T>(v: T) -> DFResult<T> {
380 Ok(v)
381}
382
383#[allow(unused)]
384fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
385 Ok(*t)
386}
387
388#[test]
389fn tst_f32() {
390 println!("{}", 10f32.powi(40));
391}