datafusion_remote_table/connection/
mod.rs

1#[cfg(feature = "mysql")]
2mod mysql;
3#[cfg(feature = "oracle")]
4mod oracle;
5#[cfg(feature = "postgres")]
6mod postgres;
7#[cfg(feature = "sqlite")]
8mod sqlite;
9
10#[cfg(feature = "mysql")]
11pub use mysql::*;
12#[cfg(feature = "oracle")]
13pub use oracle::*;
14#[cfg(feature = "postgres")]
15pub use postgres::*;
16#[cfg(feature = "sqlite")]
17pub use sqlite::*;
18
19use crate::{DFResult, RemoteSchemaRef};
20use datafusion::arrow::datatypes::SchemaRef;
21use datafusion::execution::SendableRecordBatchStream;
22use std::fmt::Debug;
23#[cfg(feature = "sqlite")]
24use std::path::PathBuf;
25use std::sync::Arc;
26
27#[async_trait::async_trait]
28pub trait Pool: Debug + Send + Sync {
29    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
30}
31
32#[async_trait::async_trait]
33pub trait Connection: Debug + Send + Sync {
34    async fn infer_schema(&self, sql: &str) -> DFResult<(RemoteSchemaRef, SchemaRef)>;
35
36    async fn query(
37        &self,
38        conn_options: &ConnectionOptions,
39        sql: &str,
40        table_schema: SchemaRef,
41        projection: Option<&Vec<usize>>,
42    ) -> DFResult<SendableRecordBatchStream>;
43}
44
45pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
46    match options {
47        #[cfg(feature = "postgres")]
48        ConnectionOptions::Postgres(options) => {
49            let pool = connect_postgres(options).await?;
50            Ok(Arc::new(pool))
51        }
52        #[cfg(feature = "mysql")]
53        ConnectionOptions::Mysql(options) => {
54            let pool = connect_mysql(options)?;
55            Ok(Arc::new(pool))
56        }
57        #[cfg(feature = "oracle")]
58        ConnectionOptions::Oracle(options) => {
59            let pool = connect_oracle(options).await?;
60            Ok(Arc::new(pool))
61        }
62        #[cfg(feature = "sqlite")]
63        ConnectionOptions::Sqlite(path) => {
64            let pool = connect_sqlite(path).await?;
65            Ok(Arc::new(pool))
66        }
67    }
68}
69
70#[derive(Debug, Clone)]
71pub enum ConnectionOptions {
72    #[cfg(feature = "postgres")]
73    Postgres(PostgresConnectionOptions),
74    #[cfg(feature = "oracle")]
75    Oracle(OracleConnectionOptions),
76    #[cfg(feature = "mysql")]
77    Mysql(MysqlConnectionOptions),
78    #[cfg(feature = "sqlite")]
79    Sqlite(PathBuf),
80}
81
82impl ConnectionOptions {
83    pub fn stream_chunk_size(&self) -> usize {
84        match self {
85            #[cfg(feature = "postgres")]
86            ConnectionOptions::Postgres(options) => options.stream_chunk_size,
87            #[cfg(feature = "oracle")]
88            ConnectionOptions::Oracle(options) => options.stream_chunk_size,
89            #[cfg(feature = "mysql")]
90            ConnectionOptions::Mysql(options) => options.stream_chunk_size,
91            #[cfg(feature = "sqlite")]
92            ConnectionOptions::Sqlite(_) => unreachable!(),
93        }
94    }
95}
96
97pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
98    match projection {
99        Some(p) => p.contains(&col_idx),
100        None => true,
101    }
102}
103
104#[cfg(any(feature = "mysql", feature = "postgres", feature = "oracle"))]
105fn big_decimal_to_i128(decimal: &bigdecimal::BigDecimal, scale: Option<i32>) -> Option<i128> {
106    use bigdecimal::{FromPrimitive, ToPrimitive};
107    let scale = scale.unwrap_or_else(|| {
108        decimal
109            .fractional_digit_count()
110            .try_into()
111            .unwrap_or_default()
112    });
113    let scale_decimal = bigdecimal::BigDecimal::from_f32(10f32.powi(scale))?;
114    (decimal * scale_decimal).to_i128()
115}