datafusion_remote_table/connection/
mod.rs1#[cfg(feature = "mysql")]
2mod mysql;
3#[cfg(feature = "oracle")]
4mod oracle;
5#[cfg(feature = "postgres")]
6mod postgres;
7#[cfg(feature = "sqlite")]
8mod sqlite;
9
10#[cfg(feature = "mysql")]
11pub use mysql::*;
12#[cfg(feature = "oracle")]
13pub use oracle::*;
14#[cfg(feature = "postgres")]
15pub use postgres::*;
16#[cfg(feature = "sqlite")]
17pub use sqlite::*;
18
19use crate::{DFResult, RemoteSchemaRef};
20use datafusion::arrow::datatypes::SchemaRef;
21use datafusion::common::DataFusionError;
22use datafusion::execution::SendableRecordBatchStream;
23use datafusion::prelude::Expr;
24use datafusion::sql::unparser::Unparser;
25use datafusion::sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
26use std::fmt::Debug;
27#[cfg(feature = "sqlite")]
28use std::path::PathBuf;
29use std::sync::Arc;
30
31#[async_trait::async_trait]
32pub trait Pool: Debug + Send + Sync {
33 async fn get(&self) -> DFResult<Arc<dyn Connection>>;
34}
35
36#[async_trait::async_trait]
37pub trait Connection: Debug + Send + Sync {
38 async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)>;
39
40 async fn query(
41 &self,
42 conn_options: &ConnectionOptions,
43 sql: &str,
44 table_schema: SchemaRef,
45 projection: Option<&Vec<usize>>,
46 filters: &[Expr],
47 limit: Option<usize>,
48 ) -> DFResult<SendableRecordBatchStream>;
49}
50
51pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
52 match options {
53 #[cfg(feature = "postgres")]
54 ConnectionOptions::Postgres(options) => {
55 let pool = connect_postgres(options).await?;
56 Ok(Arc::new(pool))
57 }
58 #[cfg(feature = "mysql")]
59 ConnectionOptions::Mysql(options) => {
60 let pool = connect_mysql(options)?;
61 Ok(Arc::new(pool))
62 }
63 #[cfg(feature = "oracle")]
64 ConnectionOptions::Oracle(options) => {
65 let pool = connect_oracle(options).await?;
66 Ok(Arc::new(pool))
67 }
68 #[cfg(feature = "sqlite")]
69 ConnectionOptions::Sqlite(path) => {
70 let pool = connect_sqlite(path).await?;
71 Ok(Arc::new(pool))
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub enum ConnectionOptions {
78 #[cfg(feature = "postgres")]
79 Postgres(PostgresConnectionOptions),
80 #[cfg(feature = "oracle")]
81 Oracle(OracleConnectionOptions),
82 #[cfg(feature = "mysql")]
83 Mysql(MysqlConnectionOptions),
84 #[cfg(feature = "sqlite")]
85 Sqlite(PathBuf),
86}
87
88impl ConnectionOptions {
89 pub(crate) fn stream_chunk_size(&self) -> usize {
90 match self {
91 #[cfg(feature = "postgres")]
92 ConnectionOptions::Postgres(options) => options.stream_chunk_size,
93 #[cfg(feature = "oracle")]
94 ConnectionOptions::Oracle(options) => options.stream_chunk_size,
95 #[cfg(feature = "mysql")]
96 ConnectionOptions::Mysql(options) => options.stream_chunk_size,
97 #[cfg(feature = "sqlite")]
98 ConnectionOptions::Sqlite(_) => unreachable!(),
99 }
100 }
101
102 pub(crate) fn db_type(&self) -> RemoteDbType {
103 match self {
104 #[cfg(feature = "postgres")]
105 ConnectionOptions::Postgres(_) => RemoteDbType::Postgres,
106 #[cfg(feature = "oracle")]
107 ConnectionOptions::Oracle(_) => RemoteDbType::Oracle,
108 #[cfg(feature = "mysql")]
109 ConnectionOptions::Mysql(_) => RemoteDbType::Mysql,
110 #[cfg(feature = "sqlite")]
111 ConnectionOptions::Sqlite(_) => RemoteDbType::Sqlite,
112 }
113 }
114}
115
116pub(crate) enum RemoteDbType {
117 Postgres,
118 Mysql,
119 Oracle,
120 Sqlite,
121}
122
123impl RemoteDbType {
124 pub(crate) fn support_rewrite_with_filters_limit(&self, sql: &str) -> bool {
125 sql.trim()[0..6].eq_ignore_ascii_case("select")
126 }
127
128 pub(crate) fn create_unparser(&self) -> DFResult<Unparser> {
129 match self {
130 RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
131 RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
132 RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
133 RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
134 "Oracle unparser not implemented".to_string(),
135 )),
136 }
137 }
138
139 pub(crate) fn try_rewrite_query(
140 &self,
141 sql: &str,
142 filters: &[Expr],
143 limit: Option<usize>,
144 ) -> Option<String> {
145 if !self.support_rewrite_with_filters_limit(sql) {
146 return None;
147 }
148 match self {
149 RemoteDbType::Postgres | RemoteDbType::Mysql | RemoteDbType::Sqlite => {
150 let where_clause = if filters.is_empty() {
151 "".to_string()
152 } else {
153 let unparser = self.create_unparser().ok()?;
154 let filters_ast = filters
155 .iter()
156 .map(|f| unparser.expr_to_sql(f).expect("checked already"))
157 .collect::<Vec<_>>();
158 format!(
159 " WHERE {}",
160 filters_ast
161 .iter()
162 .map(|f| format!("{f}"))
163 .collect::<Vec<_>>()
164 .join(" AND ")
165 )
166 };
167 let limit_clause = if let Some(limit) = limit {
168 format!(" LIMIT {limit}")
169 } else {
170 "".to_string()
171 };
172
173 if where_clause.is_empty() && limit_clause.is_empty() {
174 None
175 } else {
176 Some(format!(
177 "SELECT * FROM ({sql}) as __subquery{where_clause}{limit_clause}"
178 ))
179 }
180 }
181 RemoteDbType::Oracle => {
182 let limit = limit?;
183 Some(format!("SELECT * FROM ({sql}) WHERE ROWNUM <= {limit}"))
184 }
185 }
186 }
187}
188
189pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
190 match projection {
191 Some(p) => p.contains(&col_idx),
192 None => true,
193 }
194}
195
196#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
197fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
198 use bigdecimal::{FromPrimitive, ToPrimitive};
199 let scale = scale.unwrap_or_else(|| {
200 decimal
201 .fractional_digit_count()
202 .try_into()
203 .unwrap_or_default()
204 });
205 let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
206 (decimal * scale_decimal).to_i128()
207}