datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22
23use crate::{DFResult, RemoteSchemaRef, extract_primitive_array};
24use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
25use datafusion::common::DataFusionError;
26use datafusion::execution::SendableRecordBatchStream;
27use datafusion::physical_plan::common::collect;
28use datafusion::sql::unparser::Unparser;
29use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
30use std::fmt::Debug;
31use std::sync::Arc;
32
33#[cfg(feature = "dm")]
34pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
35
36#[async_trait::async_trait]
37pub trait Pool: Debug + Send + Sync {
38 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
39}
40
41#[async_trait::async_trait]
42pub trait Connection: Debug + Send + Sync {
43 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
44
45 async fn query(
46 &self,
47 conn_options: &ConnectionOptions,
48 sql: &str,
49 table_schema: SchemaRef,
50 projection: Option<&Vec<usize>>,
51 unparsed_filters: &[String],
52 limit: Option<usize>,
53 ) -> DFResult<SendableRecordBatchStream>;
54}
55
56pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
57 match options {
58 #[cfg(feature = "postgres")]
59 ConnectionOptions::Postgres(options) => {
60 let pool = connect_postgres(options).await?;
61 Ok(Arc::new(pool))
62 }
63 #[cfg(feature = "mysql")]
64 ConnectionOptions::Mysql(options) => {
65 let pool = connect_mysql(options)?;
66 Ok(Arc::new(pool))
67 }
68 #[cfg(feature = "oracle")]
69 ConnectionOptions::Oracle(options) => {
70 let pool = connect_oracle(options).await?;
71 Ok(Arc::new(pool))
72 }
73 #[cfg(feature = "sqlite")]
74 ConnectionOptions::Sqlite(options) => {
75 let pool = connect_sqlite(options).await?;
76 Ok(Arc::new(pool))
77 }
78 #[cfg(feature = "dm")]
79 ConnectionOptions::Dm(options) => {
80 let pool = connect_dm(options)?;
81 Ok(Arc::new(pool))
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
87pub enum ConnectionOptions {
88 #[cfg(feature = "postgres")]
89 Postgres(PostgresConnectionOptions),
90 #[cfg(feature = "oracle")]
91 Oracle(OracleConnectionOptions),
92 #[cfg(feature = "mysql")]
93 Mysql(MysqlConnectionOptions),
94 #[cfg(feature = "sqlite")]
95 Sqlite(SqliteConnectionOptions),
96 #[cfg(feature = "dm")]
97 Dm(DmConnectionOptions),
98}
99
100impl ConnectionOptions {
101 pub(crate) fn stream_chunk_size(&self) -> usize {
102 match self {
103 #[cfg(feature = "postgres")]
104 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
105 #[cfg(feature = "oracle")]
106 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
107 #[cfg(feature = "mysql")]
108 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
109 #[cfg(feature = "sqlite")]
110 ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
111 #[cfg(feature = "dm")]
112 ConnectionOptions::Dm(options) => options.stream_chunk_size,
113 }
114 }
115
116 pub(crate) fn db_type(&self) -> RemoteDbType {
117 match self {
118 #[cfg(feature = "postgres")]
119 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
120 #[cfg(feature = "oracle")]
121 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
122 #[cfg(feature = "mysql")]
123 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
124 #[cfg(feature = "sqlite")]
125 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
126 #[cfg(feature = "dm")]
127 ConnectionOptions::Dm(_) => RemoteDbType::Dm,
128 }
129 }
130}
131
132pub enum RemoteDbType {
133 Postgres,
134 Mysql,
135 Oracle,
136 Sqlite,
137 Dm,
138}
139
140impl RemoteDbType {
141 pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
142 sql.trim()[0..6].eq_ignore_ascii_case("select")
143 }
144
145 pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
146 match self {
147 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
148 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
149 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
150 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
151 "Oracle unparser not implemented".to_string(),
152 )),
153 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
154 "Dm unparser not implemented".to_string(),
155 )),
156 }
157 }
158
159 pub(crate) fn rewrite_query(
160 &self,
161 sql: &str,
162 unparsed_filters: &[String],
163 limit: Option<usize>,
164 ) -> String {
165 match self {
166 RemoteDbType::Postgres
167 | RemoteDbType::Mysql
168 | RemoteDbType::Sqlite
169 | RemoteDbType::Dm => {
170 let where_clause = if unparsed_filters.is_empty() {
171 "".to_string()
172 } else {
173 format!(" WHERE {}", unparsed_filters.join(" AND "))
174 };
175 let limit_clause = if let Some(limit) = limit {
176 format!(" LIMIT {limit}")
177 } else {
178 "".to_string()
179 };
180
181 if where_clause.is_empty() && limit_clause.is_empty() {
182 sql.to_string()
183 } else {
184 format!("SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}")
185 }
186 }
187 RemoteDbType::Oracle => {
188 let mut all_filters: Vec<String> = vec![];
189 all_filters.extend_from_slice(unparsed_filters);
190 if let Some(limit) = limit {
191 all_filters.push(format!("ROWNUM <= {limit}"))
192 }
193
194 let where_clause = if all_filters.is_empty() {
195 "".to_string()
196 } else {
197 format!(" WHERE {}", all_filters.join(" AND "))
198 };
199 if where_clause.is_empty() {
200 sql.to_string()
201 } else {
202 format!("SELECT * FROM ({sql}){where_clause}")
203 }
204 }
205 }
206 }
207
208 pub(crate) fn query_limit_1(&self, sql: &str) -> String {
209 if !self.support_rewrite_with_filters_limit(sql) {
210 return sql.to_string();
211 }
212 self.rewrite_query(sql, &[], Some(1))
213 }
214
215 pub(crate) fn try_count1_query(&self, sql: &str) -> Option<String> {
216 if !self.support_rewrite_with_filters_limit(sql) {
217 return None;
218 }
219 match self {
220 RemoteDbType::Postgres
221 | RemoteDbType::Mysql
222 | RemoteDbType::Sqlite
223 | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({sql}) AS __subquery")),
224 RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({sql})")),
225 }
226 }
227
228 pub(crate) async fn fetch_count(
229 &self,
230 conn: Arc<dyn Connection>,
231 conn_options: &ConnectionOptions,
232 count1_query: &str,
233 ) -> DFResult<usize> {
234 let count1_schema = Arc::new(Schema::new(vec![Field::new(
235 "count(1)",
236 DataType::Int64,
237 false,
238 )]));
239 let stream = conn
240 .query(conn_options, count1_query, count1_schema, None, &[], None)
241 .await?;
242 let batches = collect(stream).await?;
243 let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
244 if count_vec.len() != 1 {
245 return Err(DataFusionError::Execution(format!(
246 "Count query did not return exactly one row: {count_vec:?}",
247 )));
248 }
249 count_vec[0]
250 .map(|count| count as usize)
251 .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
252 }
253}
254
255pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
256 match projection {
257 Some(p) => p.contains(&col_idx),
258 None => true,
259 }
260}
261
262#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
263fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
264 use bigdecimal::{FromPrimitive, ToPrimitive};
265 let scale = scale.unwrap_or_else(|| {
266 decimal
267 .fractional_digit_count()
268 .try_into()
269 .unwrap_or_default()
270 });
271 let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
272 (decimal * scale_decimal).to_i128()
273}
274
275#[allow(unused)]
276fn just_return<T>(v: T) -> DFResult<T> {
277 Ok(v)
278}
279
280#[allow(unused)]
281fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
282 Ok(*t)
283}