datafusion-remote-table 0.26.0

A DataFusion table provider for executing SQL on remote databases
Documentation
#[cfg(feature = "dm")]
mod dm;
#[cfg(feature = "mysql")]
mod mysql;
mod options;
#[cfg(feature = "oracle")]
mod oracle;
#[cfg(feature = "postgres")]
mod postgres;
#[cfg(feature = "sqlite")]
mod sqlite;

#[cfg(feature = "dm")]
pub use dm::*;
#[cfg(feature = "mysql")]
pub use mysql::*;
pub use options::*;
#[cfg(feature = "oracle")]
pub use oracle::*;
#[cfg(feature = "postgres")]
pub use postgres::*;
#[cfg(feature = "sqlite")]
pub use sqlite::*;

use std::any::Any;

use crate::{DFResult, Literalize, RemoteSchemaRef, RemoteSource, extract_primitive_array};
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, Int64Type, Schema, SchemaRef};
use datafusion_common::DataFusionError;
use datafusion_execution::SendableRecordBatchStream;
use datafusion_physical_plan::common::collect;
use datafusion_sql::unparser::Unparser;
use datafusion_sql::unparser::dialect::{MySqlDialect, PostgreSqlDialect, SqliteDialect};
use std::fmt::Debug;
use std::sync::Arc;

#[cfg(feature = "dm")]
pub static ODBC_ENV: std::sync::OnceLock<odbc_api::Environment> = std::sync::OnceLock::new();

#[async_trait::async_trait]
pub trait Pool: Debug + Send + Sync {
    async fn get(&self) -> DFResult<Arc<dyn Connection>>;
    async fn state(&self) -> DFResult<PoolState>;
}

#[derive(Debug, Clone)]
pub struct PoolState {
    pub connections: usize,
    pub idle_connections: usize,
}

#[async_trait::async_trait]
pub trait Connection: Debug + Send + Sync {
    fn as_any(&self) -> &dyn Any;

    async fn infer_schema(&self, source: &RemoteSource) -> DFResult<RemoteSchemaRef>;

    async fn query(
        &self,
        conn_options: &ConnectionOptions,
        source: &RemoteSource,
        table_schema: SchemaRef,
        projection: Option<&Vec<usize>>,
        unparsed_filters: &[String],
        limit: Option<usize>,
    ) -> DFResult<SendableRecordBatchStream>;

    async fn insert(
        &self,
        conn_options: &ConnectionOptions,
        literalizer: Arc<dyn Literalize>,
        table: &[String],
        remote_schema: RemoteSchemaRef,
        batch: RecordBatch,
    ) -> DFResult<usize>;
}

#[allow(unused_variables)]
pub async fn connect(options: &ConnectionOptions) -> DFResult<Arc<dyn Pool>> {
    match options {
        ConnectionOptions::Postgres(options) => {
            #[cfg(feature = "postgres")]
            {
                let pool = connect_postgres(options).await?;
                Ok(Arc::new(pool))
            }
            #[cfg(not(feature = "postgres"))]
            {
                Err(DataFusionError::Internal(
                    "Please enable the postgres feature".to_string(),
                ))
            }
        }
        ConnectionOptions::Mysql(options) => {
            #[cfg(feature = "mysql")]
            {
                let pool = connect_mysql(options)?;
                Ok(Arc::new(pool))
            }
            #[cfg(not(feature = "mysql"))]
            {
                Err(DataFusionError::Internal(
                    "Please enable the mysql feature".to_string(),
                ))
            }
        }
        ConnectionOptions::Oracle(options) => {
            #[cfg(feature = "oracle")]
            {
                let pool = connect_oracle(options).await?;
                Ok(Arc::new(pool))
            }
            #[cfg(not(feature = "oracle"))]
            {
                Err(DataFusionError::Internal(
                    "Please enable the oracle feature".to_string(),
                ))
            }
        }
        ConnectionOptions::Sqlite(options) => {
            #[cfg(feature = "sqlite")]
            {
                let pool = connect_sqlite(options).await?;
                Ok(Arc::new(pool))
            }
            #[cfg(not(feature = "sqlite"))]
            {
                Err(DataFusionError::Internal(
                    "Please enable the sqlite feature".to_string(),
                ))
            }
        }
        ConnectionOptions::Dm(options) => {
            #[cfg(feature = "dm")]
            {
                let pool = connect_dm(options)?;
                Ok(Arc::new(pool))
            }
            #[cfg(not(feature = "dm"))]
            {
                Err(DataFusionError::Internal(
                    "Please enable the dm feature".to_string(),
                ))
            }
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub enum RemoteDbType {
    Postgres,
    Mysql,
    Oracle,
    Sqlite,
    Dm,
}

impl RemoteDbType {
    pub(crate) fn support_rewrite_with_filters_limit(&self, source: &RemoteSource) -> bool {
        match source {
            RemoteSource::Table(_) => true,
            RemoteSource::Query(query) => query.trim()[0..6].eq_ignore_ascii_case("select"),
        }
    }

    pub(crate) fn create_unparser(&self) -> DFResult<Unparser<'_>> {
        match self {
            RemoteDbType::Postgres => Ok(Unparser::new(&PostgreSqlDialect {})),
            RemoteDbType::Mysql => Ok(Unparser::new(&MySqlDialect {})),
            RemoteDbType::Sqlite => Ok(Unparser::new(&SqliteDialect {})),
            RemoteDbType::Oracle => Err(DataFusionError::NotImplemented(
                "Oracle unparser not implemented".to_string(),
            )),
            RemoteDbType::Dm => Err(DataFusionError::NotImplemented(
                "Dm unparser not implemented".to_string(),
            )),
        }
    }

    pub(crate) fn rewrite_query(
        &self,
        source: &RemoteSource,
        unparsed_filters: &[String],
        limit: Option<usize>,
    ) -> String {
        match source {
            RemoteSource::Table(table) => match self {
                RemoteDbType::Postgres
                | RemoteDbType::Mysql
                | RemoteDbType::Sqlite
                | RemoteDbType::Dm => {
                    let where_clause = if unparsed_filters.is_empty() {
                        "".to_string()
                    } else {
                        format!(" WHERE {}", unparsed_filters.join(" AND "))
                    };
                    let limit_clause = if let Some(limit) = limit {
                        format!(" LIMIT {limit}")
                    } else {
                        "".to_string()
                    };

                    format!(
                        "{}{where_clause}{limit_clause}",
                        self.select_all_query(table)
                    )
                }
                RemoteDbType::Oracle => {
                    let mut all_filters: Vec<String> = vec![];
                    all_filters.extend_from_slice(unparsed_filters);
                    if let Some(limit) = limit {
                        all_filters.push(format!("ROWNUM <= {limit}"))
                    }

                    let where_clause = if all_filters.is_empty() {
                        "".to_string()
                    } else {
                        format!(" WHERE {}", all_filters.join(" AND "))
                    };
                    format!("{}{where_clause}", self.select_all_query(table))
                }
            },
            RemoteSource::Query(query) => match self {
                RemoteDbType::Postgres
                | RemoteDbType::Mysql
                | RemoteDbType::Sqlite
                | RemoteDbType::Dm => {
                    let where_clause = if unparsed_filters.is_empty() {
                        "".to_string()
                    } else {
                        format!(" WHERE {}", unparsed_filters.join(" AND "))
                    };
                    let limit_clause = if let Some(limit) = limit {
                        format!(" LIMIT {limit}")
                    } else {
                        "".to_string()
                    };

                    if where_clause.is_empty() && limit_clause.is_empty() {
                        query.clone()
                    } else {
                        format!("SELECT * FROM ({query}) as __subquery{where_clause}{limit_clause}")
                    }
                }
                RemoteDbType::Oracle => {
                    let mut all_filters: Vec<String> = vec![];
                    all_filters.extend_from_slice(unparsed_filters);
                    if let Some(limit) = limit {
                        all_filters.push(format!("ROWNUM <= {limit}"))
                    }

                    let where_clause = if all_filters.is_empty() {
                        "".to_string()
                    } else {
                        format!(" WHERE {}", all_filters.join(" AND "))
                    };
                    if where_clause.is_empty() {
                        query.clone()
                    } else {
                        format!("SELECT * FROM ({query}){where_clause}")
                    }
                }
            },
        }
    }

    pub(crate) fn sql_identifier(&self, identifier: &str) -> String {
        match self {
            RemoteDbType::Postgres
            | RemoteDbType::Oracle
            | RemoteDbType::Sqlite
            | RemoteDbType::Dm => {
                format!("\"{identifier}\"")
            }
            RemoteDbType::Mysql => {
                format!("`{identifier}`")
            }
        }
    }

    pub(crate) fn sql_table_name(&self, indentifiers: &[String]) -> String {
        indentifiers
            .iter()
            .map(|identifier| self.sql_identifier(identifier))
            .collect::<Vec<String>>()
            .join(".")
    }

    pub(crate) fn sql_string_literal(&self, value: &str) -> String {
        let value = value.replace("'", "''");
        format!("'{value}'")
    }

    pub(crate) fn sql_binary_literal(&self, value: &[u8]) -> String {
        match self {
            RemoteDbType::Postgres => format!("E'\\\\x{}'", hex::encode(value)),
            RemoteDbType::Mysql | RemoteDbType::Sqlite => format!("X'{}'", hex::encode(value)),
            RemoteDbType::Oracle | RemoteDbType::Dm => todo!(),
        }
    }

    pub(crate) fn select_all_query(&self, table_identifiers: &[String]) -> String {
        match self {
            RemoteDbType::Postgres
            | RemoteDbType::Mysql
            | RemoteDbType::Oracle
            | RemoteDbType::Sqlite
            | RemoteDbType::Dm => {
                format!("SELECT * FROM {}", self.sql_table_name(table_identifiers))
            }
        }
    }

    pub(crate) fn limit_1_query_if_possible(&self, source: &RemoteSource) -> String {
        if !self.support_rewrite_with_filters_limit(source) {
            return source.query(*self);
        }
        self.rewrite_query(source, &[], Some(1))
    }

    pub(crate) fn try_count1_query(&self, source: &RemoteSource) -> Option<String> {
        if !self.support_rewrite_with_filters_limit(source) {
            return None;
        }
        match source {
            RemoteSource::Table(table) => Some(format!(
                "SELECT COUNT(1) FROM {}",
                self.sql_table_name(table)
            )),
            RemoteSource::Query(query) => match self {
                RemoteDbType::Postgres
                | RemoteDbType::Mysql
                | RemoteDbType::Sqlite
                | RemoteDbType::Dm => Some(format!("SELECT COUNT(1) FROM ({query}) AS __subquery")),
                RemoteDbType::Oracle => Some(format!("SELECT COUNT(1) FROM ({query})")),
            },
        }
    }

    pub(crate) async fn fetch_count(
        &self,
        conn: Arc<dyn Connection>,
        conn_options: &ConnectionOptions,
        count1_query: &str,
    ) -> DFResult<usize> {
        let count1_schema = Arc::new(Schema::new(vec![Field::new(
            "count(1)",
            DataType::Int64,
            false,
        )]));
        let stream = conn
            .query(
                conn_options,
                &RemoteSource::Query(count1_query.to_string()),
                count1_schema,
                None,
                &[],
                None,
            )
            .await?;
        let batches = collect(stream).await?;
        let count_vec = extract_primitive_array::<Int64Type>(&batches, 0)?;
        if count_vec.len() != 1 {
            return Err(DataFusionError::Execution(format!(
                "Count query did not return exactly one row: {count_vec:?}",
            )));
        }
        count_vec[0]
            .map(|count| count as usize)
            .ok_or_else(|| DataFusionError::Execution("Count query returned null".to_string()))
    }
}

pub(crate) fn projections_contains(projection: Option<&Vec<usize>>, col_idx: usize) -> bool {
    match projection {
        Some(p) => p.contains(&col_idx),
        None => true,
    }
}

#[allow(unused)]
fn just_return<T>(v: T) -> DFResult<T> {
    Ok(v)
}

#[allow(unused)]
fn just_deref<T: Copy>(t: &T) -> DFResult<T> {
    Ok(*t)
}