datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "dm")]
2mod dm;
3#[cfg(feature = "mysql")]
4mod mysql;
5#[cfg(feature = "oracle")]
6mod oracle;
7#[cfg(feature = "postgres")]
8mod postgres;
9#[cfg(feature = "sqlite")]
10mod sqlite;
11
12#[cfg(feature = "dm")]
13pub use dm::*;
14#[cfg(feature = "mysql")]
15pub use mysql::*;
16#[cfg(feature = "oracle")]
17pub use oracle::*;
18#[cfg(feature = "postgres")]
19pub use postgres::*;
20#[cfg(feature = "sqlite")]
21pub use sqlite::*;
22
23use crate::{DFResult, RemoteSchemaRef};
24use datafusion::arrow::datatypes::SchemaRef;
25use datafusion::common::DataFusionError;
26use datafusion::execution::SendableRecordBatchStream;
27use datafusion::sql::unparser::Unparser;
28use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
29use std::fmt::Debug;
30use std::sync::Arc;
31
32#[cfg(feature = "dm")]
33pub(crate) static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();
34
35#[async_trait::async_trait]
36pub trait Pool: Debug + Send + Sync {
37 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
38}
39
40#[async_trait::async_trait]
41pub trait Connection: Debug + Send + Sync {
42 async fn infer_schema(&self, sql: &str) -> DFResult<RemoteSchemaRef>;
43
44 async fn query(
45 &self,
46 conn_options: &ConnectionOptions,
47 sql: &str,
48 table_schema: SchemaRef,
49 projection: Option<&Vec<usize>>,
50 unparsed_filters: &[String],
51 limit: Option<usize>,
52 ) -> DFResult<SendableRecordBatchStream>;
53}
54
55pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
56 match options {
57 #[cfg(feature = "postgres")]
58 ConnectionOptions::Postgres(options) => {
59 let pool = connect_postgres(options).await?;
60 Ok(Arc::new(pool))
61 }
62 #[cfg(feature = "mysql")]
63 ConnectionOptions::Mysql(options) => {
64 let pool = connect_mysql(options)?;
65 Ok(Arc::new(pool))
66 }
67 #[cfg(feature = "oracle")]
68 ConnectionOptions::Oracle(options) => {
69 let pool = connect_oracle(options).await?;
70 Ok(Arc::new(pool))
71 }
72 #[cfg(feature = "sqlite")]
73 ConnectionOptions::Sqlite(options) => {
74 let pool = connect_sqlite(options).await?;
75 Ok(Arc::new(pool))
76 }
77 #[cfg(feature = "dm")]
78 ConnectionOptions::Dm(options) => {
79 let pool = connect_dm(options)?;
80 Ok(Arc::new(pool))
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub enum ConnectionOptions {
87 #[cfg(feature = "postgres")]
88 Postgres(PostgresConnectionOptions),
89 #[cfg(feature = "oracle")]
90 Oracle(OracleConnectionOptions),
91 #[cfg(feature = "mysql")]
92 Mysql(MysqlConnectionOptions),
93 #[cfg(feature = "sqlite")]
94 Sqlite(SqliteConnectionOptions),
95 #[cfg(feature = "dm")]
96 Dm(DmConnectionOptions),
97}
98
99impl ConnectionOptions {
100 pub(crate) fn stream_chunk_size(&self) -> usize {
101 match self {
102 #[cfg(feature = "postgres")]
103 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
104 #[cfg(feature = "oracle")]
105 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
106 #[cfg(feature = "mysql")]
107 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
108 #[cfg(feature = "sqlite")]
109 ConnectionOptions::Sqlite(options) => options.stream_chunk_size,
110 #[cfg(feature = "dm")]
111 ConnectionOptions::Dm(options) => options.stream_chunk_size,
112 }
113 }
114
115 pub(crate) fn db_type(&self) -> RemoteDbType {
116 match self {
117 #[cfg(feature = "postgres")]
118 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
119 #[cfg(feature = "oracle")]
120 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
121 #[cfg(feature = "mysql")]
122 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
123 #[cfg(feature = "sqlite")]
124 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
125 #[cfg(feature = "dm")]
126 ConnectionOptions::Dm(_) => RemoteDbType::Dm,
127 }
128 }
129}
130
131pub enum RemoteDbType {
132 Postgres,
133 Mysql,
134 Oracle,
135 Sqlite,
136 Dm,
137}
138
139impl RemoteDbType {
140 pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
141 sql.trim()[0..6].eq_ignore_ascii_case("select")
142 }
143
144 pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
145 match self {
146 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
147 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
148 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
149 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
150 "Oracle unparser not implemented".to_string(),
151 )),
152 RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
153 "Dm unparser not implemented".to_string(),
154 )),
155 }
156 }
157
158 pub(crate) fn try_rewrite_query(
159 &self,
160 sql: &str,
161 unparsed_filters: &[String],
162 limit: Option<usize>,
163 ) -> DFResult<String> {
164 match self {
165 RemoteDbType::Postgres | RemoteDbType::Mysql | RemoteDbType::Sqlite => {
166 let where_clause = if unparsed_filters.is_empty() {
167 "".to_string()
168 } else {
169 format!(" WHERE {}", unparsed_filters.join(" AND "))
170 };
171 let limit_clause = if let Some(limit) = limit {
172 format!(" LIMIT {limit}")
173 } else {
174 "".to_string()
175 };
176
177 if where_clause.is_empty() && limit_clause.is_empty() {
178 Ok(sql.to_string())
179 } else {
180 Ok(format!(
181 "SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}"
182 ))
183 }
184 }
185 RemoteDbType::Oracle => Ok(limit
186 .map(|l| format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {l}"))
187 .unwrap_or_else(|| sql.to_string())),
188 RemoteDbType::Dm => Ok(limit
189 .map(|l| format!("SELECT * FROM ({sql}) LIMIT {l}"))
190 .unwrap_or_else(|| sql.to_string())),
191 }
192 }
193
194 pub(crate) fn query_limit_1(&self, sql: &str) -> DFResult<String> {
195 if !self.support_rewrite_with_filters_limit(sql) {
196 return Ok(sql.to_string());
197 }
198 self.try_rewrite_query(sql, &[], Some(1))
199 }
200}
201
202pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
203 match projection {
204 Some(p) => p.contains(&col_idx),
205 None => true,
206 }
207}
208
209#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
210fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
211 use bigdecimal::{FromPrimitive, ToPrimitive};
212 let scale = scale.unwrap_or_else(|| {
213 decimal
214 .fractional_digit_count()
215 .try_into()
216 .unwrap_or_default()
217 });
218 let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
219 (decimal * scale_decimal).to_i128()
220}
221
222#[allow(unused)]
223fn just_return<T>(v: T) -> DFResult<T> {
224 Ok(v)
225}
226
227#[allow(unused)]
228fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
229 Ok(*t)
230}