datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5mod options;
6#[cfg(feature = "oracle")]
7mod oracle;
8#[cfg(feature = "postgres")]
9mod postgres;
10#[cfg(feature = "sqlite")]
11mod sqlite;
12
13#[cfg(feature = "dm")]
14pub use dm::*;
15#[cfg(feature = "mysql")]
16pub use mysql::*;
17pub use options::*;
18#[cfg(feature = "oracle")]
19pub use oracle::*;
20#[cfg(feature = "postgres")]
21pub use postgres::*;
22#[cfg(feature = "sqlite")]
23pub use sqlite::*;
24
25use std::any::Any;
26
27use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
28use arrow::array::RecordBatch;
29use arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
30use datafusion_common::DataFusionError;
31use datafusion_execution::SendableRecordBatchStream;
32use datafusion_physical_plan::common::collect;
33use datafusion_sql::unparser::Unparser;
34use datafusion_sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
35use std::fmt::Debug;
36use std::sync::Arc;
37
38#[cfg(feature = "dm")]
39pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
40
41#[async_trait::async_trait]
42pub trait Pool: Debug + Send + Sync {
43 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
44 async fn state(&self) -> DFResult<PoolState>;
45}
46
47#[derive(Debug, Clone)]
48pub struct PoolState {
49 pub connections: usize,
50 pub idle_connections: usize,
51}
52
53#[async_trait::async_trait]
54pub trait Connection: Debug + Send + Sync {
55 fn as_any(&self) -> &dyn Any;
56
57 async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;
58
59 async fn query(
60 &self,
61 conn_options: &ConnectionOptions,
62 source: &RemoteSource,
63 table_schema: SchemaRef,
64 projection: Option<&Vec<usize>>,
65 unparsed_filters: &[String],
66 limit: Option<usize>,
67 ) -> DFResult<SendableRecordBatchStream>;
68
69 async fn insert(
70 &self,
71 conn_options: &ConnectionOptions,
72 literalizer: Arc<dyn Literalize>,
73 table: &[String],
74 remote_schema: RemoteSchemaRef,
75 batch: RecordBatch,
76 ) -> DFResult<usize>;
77}
78
79#[allow(unused_variables)]
80pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
81 match options {
82 ConnectionOptions::Postgres(options) => {
83 #[cfg(feature = "postgres")]
84 {
85 let pool = connect_postgres(options).await?;
86 Ok(Arc::new(pool))
87 }
88 #[cfg(not(feature = "postgres"))]
89 {
90 Err(DataFusionError::Internal(
91 "Please enable the postgres feature".to_string(),
92 ))
93 }
94 }
95 ConnectionOptions::Mysql(options) => {
96 #[cfg(feature = "mysql")]
97 {
98 let pool = connect_mysql(options)?;
99 Ok(Arc::new(pool))
100 }
101 #[cfg(not(feature = "mysql"))]
102 {
103 Err(DataFusionError::Internal(
104 "Please enable the mysql feature".to_string(),
105 ))
106 }
107 }
108 ConnectionOptions::Oracle(options) => {
109 #[cfg(feature = "oracle")]
110 {
111 let pool = connect_oracle(options).await?;
112 Ok(Arc::new(pool))
113 }
114 #[cfg(not(feature = "oracle"))]
115 {
116 Err(DataFusionError::Internal(
117 "Please enable the oracle feature".to_string(),
118 ))
119 }
120 }
121 ConnectionOptions::Sqlite(options) => {
122 #[cfg(feature = "sqlite")]
123 {
124 let pool = connect_sqlite(options).await?;
125 Ok(Arc::new(pool))
126 }
127 #[cfg(not(feature = "sqlite"))]
128 {
129 Err(DataFusionError::Internal(
130 "Please enable the sqlite feature".to_string(),
131 ))
132 }
133 }
134 ConnectionOptions::Dm(options) => {
135 #[cfg(feature = "dm")]
136 {
137 let pool = connect_dm(options)?;
138 Ok(Arc::new(pool))
139 }
140 #[cfg(not(feature = "dm"))]
141 {
142 Err(DataFusionError::Internal(
143 "Please enable the dm feature".to_string(),
144 ))
145 }
146 }
147 }
148}
149
150#[derive(Debug, Clone, Copy)]
151pub enum RemoteDbType {
152 Postgres,
153 Mysql,
154 Oracle,
155 Sqlite,
156 Dm,
157}
158
159impl RemoteDbType {
160 pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
161 match source {
162 RemoteSource::Table(_) => true,
163 RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
164 }
165 }
166
167 pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
168 match self {
169 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
170 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
171 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
172 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
173 "Oracle unparser not implemented".to_string(),
174 )),
175 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
176 "Dm unparser not implemented".to_string(),
177 )),
178 }
179 }
180
181 pub(crate) fn rewrite_query(
182 &self,
183 source: &RemoteSource,
184 unparsed_filters: &[String],
185 limit: Option<usize>,
186 ) -> String {
187 match source {
188 RemoteSource::Table(table) => match self {
189 RemoteDbType::Postgres
190 | RemoteDbType::Mysql
191 | RemoteDbType::Sqlite
192 | RemoteDbType::Dm => {
193 let where_clause = if unparsed_filters.is_empty() {
194 "".to_string()
195 } else {
196 format!(" WHERE {}", unparsed_filters.join(" AND "))
197 };
198 let limit_clause = if let Some(limit) = limit {
199 format!(" LIMIT {limit}")
200 } else {
201 "".to_string()
202 };
203
204 format!(
205 "{}{where_clause}{limit_clause}",
206 self.select_all_query(table)
207 )
208 }
209 RemoteDbType::Oracle => {
210 let mut all_filters: Vec<String> = vec![];
211 all_filters.extend_from_slice(unparsed_filters);
212 if let Some(limit) = limit {
213 all_filters.push(format!("ROWNUM <= {limit}"))
214 }
215
216 let where_clause = if all_filters.is_empty() {
217 "".to_string()
218 } else {
219 format!(" WHERE {}", all_filters.join(" AND "))
220 };
221 format!("{}{where_clause}", self.select_all_query(table))
222 }
223 },
224 RemoteSource::Query(query) => match self {
225 RemoteDbType::Postgres
226 | RemoteDbType::Mysql
227 | RemoteDbType::Sqlite
228 | RemoteDbType::Dm => {
229 let where_clause = if unparsed_filters.is_empty() {
230 "".to_string()
231 } else {
232 format!(" WHERE {}", unparsed_filters.join(" AND "))
233 };
234 let limit_clause = if let Some(limit) = limit {
235 format!(" LIMIT {limit}")
236 } else {
237 "".to_string()
238 };
239
240 if where_clause.is_empty() && limit_clause.is_empty() {
241 query.clone()
242 } else {
243 format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
244 }
245 }
246 RemoteDbType::Oracle => {
247 let mut all_filters: Vec<String> = vec![];
248 all_filters.extend_from_slice(unparsed_filters);
249 if let Some(limit) = limit {
250 all_filters.push(format!("ROWNUM <= {limit}"))
251 }
252
253 let where_clause = if all_filters.is_empty() {
254 "".to_string()
255 } else {
256 format!(" WHERE {}", all_filters.join(" AND "))
257 };
258 if where_clause.is_empty() {
259 query.clone()
260 } else {
261 format!("SELECT * FROM ({query}){where_clause}")
262 }
263 }
264 },
265 }
266 }
267
268 pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
269 match self {
270 RemoteDbType::Postgres
271 | RemoteDbType::Oracle
272 | RemoteDbType::Sqlite
273 | RemoteDbType::Dm => {
274 format!("\"{identifier}\"")
275 }
276 RemoteDbType::Mysql => {
277 format!("`{identifier}`")
278 }
279 }
280 }
281
282 pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
283 indentifiers
284 .iter()
285 .map(|identifier| self.sql_identifier(identifier))
286 .collect::<Vec<String>>()
287 .join(".")
288 }
289
290 pub(crate) fn sql_string_literal(&self, value: &str) -> String {
291 let value = value.replace("'", "''");
292 format!("'{value}'")
293 }
294
295 pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
296 match self {
297 RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
298 RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
299 RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
300 }
301 }
302
303 pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
304 match self {
305 RemoteDbType::Postgres
306 | RemoteDbType::Mysql
307 | RemoteDbType::Oracle
308 | RemoteDbType::Sqlite
309 | RemoteDbType::Dm => {
310 format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
311 }
312 }
313 }
314
315 pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
316 if !self.support_rewrite_with_filters_limit(source) {
317 return source.query(*self);
318 }
319 self.rewrite_query(source, &[], Some(1))
320 }
321
322 pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
323 if !self.support_rewrite_with_filters_limit(source) {
324 return None;
325 }
326 match source {
327 RemoteSource::Table(table) => Some(format!(
328 "SELECT COUNT(1) FROM {}",
329 self.sql_table_name(table)
330 )),
331 RemoteSource::Query(query) => match self {
332 RemoteDbType::Postgres
333 | RemoteDbType::Mysql
334 | RemoteDbType::Sqlite
335 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
336 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
337 },
338 }
339 }
340
341 pub(crate) async fn fetch_count(
342 &self,
343 conn: Arc<dyn Connection>,
344 conn_options: &ConnectionOptions,
345 count1_query: &str,
346 ) -> DFResult<usize> {
347 let count1_schema = Arc::new(Schema::new(vec![Field::new(
348 "count(1)",
349 DataType::Int64,
350 false,
351 )]));
352 let stream = conn
353 .query(
354 conn_options,
355 &RemoteSource::Query(count1_query.to_string()),
356 count1_schema,
357 None,
358 &[],
359 None,
360 )
361 .await?;
362 let batches = collect(stream).await?;
363 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
364 if count_vec.len() != 1 {
365 return Err(DataFusionError::Execution(format!(
366 "Count query did not return exactly one row: {count_vec:?}",
367 )));
368 }
369 count_vec[0]
370 .map(|count| count as usize)
371 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
372 }
373}
374
375pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
376 match projection {
377 Some(p) => p.contains(&col_idx),
378 None => true,
379 }
380}
381
382#[allow(unused)]
383fn just_return<T>(v: T) -> DFResult<T> {
384 Ok(v)
385}
386
387#[allow(unused)]
388fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
389 Ok(*t)
390}