use std::fmt;
use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
use datafusion::execution::TaskContext;
use datafusion::logical_expr::{
Between, BinaryExpr, Expr, Like, Operator, TableProviderFilterPushDown,
};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
SendableRecordBatchStream,
};
use datafusion::scalar::ScalarValue;
use oxisql_core::Connection;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortOrder {
Asc,
Desc,
}
impl fmt::Display for SortOrder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SortOrder::Asc => f.write_str("ASC"),
SortOrder::Desc => f.write_str("DESC"),
}
}
}
pub struct OxiSqlStreamProvider {
schema: SchemaRef,
table_name: String,
conn: Arc<dyn Connection>,
sort_order: Option<Vec<(String, SortOrder)>>,
auto_partition_config: Option<(usize, usize)>,
}
impl OxiSqlStreamProvider {
pub fn new(
conn: Arc<dyn Connection>,
table_name: impl Into<String>,
schema: SchemaRef,
) -> Self {
Self {
schema,
table_name: table_name.into(),
conn,
sort_order: None,
auto_partition_config: None,
}
}
#[must_use]
pub fn with_sort(mut self, order: Vec<(String, SortOrder)>) -> Self {
self.sort_order = Some(order);
self
}
pub fn sort_order(&self) -> Option<&[(String, SortOrder)]> {
self.sort_order.as_deref()
}
#[must_use]
pub fn with_auto_partition(mut self, n_parallel: usize, target_batch_size: usize) -> Self {
self.auto_partition_config = Some((n_parallel, target_batch_size));
self
}
pub fn auto_partition_config(&self) -> Option<(usize, usize)> {
self.auto_partition_config
}
#[cfg(feature = "mysql")]
pub fn from_mysql(
conn: oxisql_mysql::MyConnection,
table_name: impl Into<String>,
schema: SchemaRef,
) -> Self {
Self::new(Arc::new(conn) as Arc<dyn Connection>, table_name, schema)
}
#[cfg(feature = "postgres")]
pub fn from_postgres(
conn: oxisql_postgres::PgConnection,
table_name: impl Into<String>,
schema: SchemaRef,
) -> Self {
Self::new(Arc::new(conn) as Arc<dyn Connection>, table_name, schema)
}
#[cfg(feature = "sqlite")]
pub fn from_sqlite(
conn: oxisql_sqlite_compat::SqliteConnection,
table_name: impl Into<String>,
schema: SchemaRef,
) -> Self {
Self::new(Arc::new(conn) as Arc<dyn Connection>, table_name, schema)
}
fn project_clause(&self, projection: Option<&Vec<usize>>) -> String {
match projection {
None => "*".to_string(),
Some(indices) => indices
.iter()
.map(|&i| self.schema.field(i).name().as_str())
.collect::<Vec<_>>()
.join(", "),
}
}
fn projected_schema(&self, projection: Option<&Vec<usize>>) -> SchemaRef {
match projection {
None => Arc::clone(&self.schema),
Some(indices) => {
let fields: Vec<_> = indices
.iter()
.map(|&i| self.schema.field(i).clone())
.collect();
Arc::new(arrow::datatypes::Schema::new(fields))
}
}
}
}
impl fmt::Debug for OxiSqlStreamProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OxiSqlStreamProvider")
.field("table_name", &self.table_name)
.field("schema", &self.schema)
.finish()
}
}
#[async_trait]
impl TableProvider for OxiSqlStreamProvider {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let col_clause = self.project_clause(projection);
let output_schema = self.projected_schema(projection);
let mut base_sql = format!("SELECT {} FROM {}", col_clause, self.table_name);
let where_parts: Vec<String> = filters.iter().filter_map(expr_to_sql).collect();
if !where_parts.is_empty() {
base_sql.push_str(" WHERE ");
base_sql.push_str(&where_parts.join(" AND "));
}
if let Some(ref order) = self.sort_order {
if !order.is_empty() {
base_sql.push_str(" ORDER BY ");
let order_clause = order
.iter()
.map(|(col, dir)| format!("{col} {dir}"))
.collect::<Vec<_>>()
.join(", ");
base_sql.push_str(&order_clause);
}
}
if let Some(n) = limit {
let mut sql = base_sql;
sql.push_str(&format!(" LIMIT {n}"));
let exec = OxiSqlExecPlan::new(Arc::clone(&self.conn), sql, Arc::clone(&output_schema));
return Ok(Arc::new(exec) as Arc<dyn ExecutionPlan>);
}
if let Some((n_parallel, target_batch_size)) = self.auto_partition_config {
if n_parallel > 1 && target_batch_size > 0 {
let sqls: Vec<String> = (0..n_parallel)
.map(|i| {
format!(
"{} LIMIT {} OFFSET {}",
base_sql,
target_batch_size,
i * target_batch_size
)
})
.collect();
let exec = OxiSqlMultiPartExecPlan::new(
Arc::clone(&self.conn),
sqls,
Arc::clone(&output_schema),
);
return Ok(Arc::new(exec) as Arc<dyn ExecutionPlan>);
}
}
let exec =
OxiSqlExecPlan::new(Arc::clone(&self.conn), base_sql, Arc::clone(&output_schema));
Ok(Arc::new(exec) as Arc<dyn ExecutionPlan>)
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DFResult<Vec<TableProviderFilterPushDown>> {
Ok(filters
.iter()
.map(|f| {
if can_push_filter(f) {
TableProviderFilterPushDown::Exact
} else {
TableProviderFilterPushDown::Unsupported
}
})
.collect())
}
}
pub fn can_push_filter(expr: &Expr) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
| Operator::And
| Operator::Or
) && can_push_atom(left)
&& can_push_atom(right)
}
Expr::IsNull(inner) | Expr::IsNotNull(inner) => can_push_atom(inner),
Expr::Not(inner) => can_push_filter(inner),
Expr::InList(inlist) => {
can_push_atom(&inlist.expr) && inlist.list.iter().all(can_push_atom)
}
Expr::Between(Between {
expr, low, high, ..
}) => can_push_atom(expr) && can_push_atom(low) && can_push_atom(high),
Expr::Like(Like { expr, pattern, .. }) => can_push_atom(expr) && can_push_atom(pattern),
_ => false,
}
}
fn can_push_atom(expr: &Expr) -> bool {
match expr {
Expr::Column(_) => true,
Expr::Literal(_, _) => true,
other => can_push_filter(other),
}
}
pub fn expr_to_sql(expr: &Expr) -> Option<String> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let l = atom_to_sql(left)?;
let r = atom_to_sql(right)?;
let op_str = match op {
Operator::Eq => "=",
Operator::NotEq => "<>",
Operator::Lt => "<",
Operator::LtEq => "<=",
Operator::Gt => ">",
Operator::GtEq => ">=",
Operator::And => "AND",
Operator::Or => "OR",
_ => return None,
};
Some(format!("({l} {op_str} {r})"))
}
Expr::IsNull(inner) => Some(format!("({} IS NULL)", atom_to_sql(inner)?)),
Expr::IsNotNull(inner) => Some(format!("({} IS NOT NULL)", atom_to_sql(inner)?)),
Expr::Not(inner) => Some(format!("(NOT {})", expr_to_sql(inner)?)),
Expr::InList(inlist) => {
let e = atom_to_sql(&inlist.expr)?;
let items: Option<Vec<String>> = inlist.list.iter().map(atom_to_sql).collect();
let items = items?;
let not_kw = if inlist.negated { "NOT " } else { "" };
Some(format!("({e} {not_kw}IN ({}))", items.join(", ")))
}
Expr::Between(Between {
expr,
low,
high,
negated,
}) => {
let e = atom_to_sql(expr)?;
let lo = atom_to_sql(low)?;
let hi = atom_to_sql(high)?;
let not_kw = if *negated { "NOT " } else { "" };
Some(format!("({e} {not_kw}BETWEEN {lo} AND {hi})"))
}
Expr::Like(Like {
expr,
pattern,
negated,
case_insensitive,
escape_char,
}) => {
let e = atom_to_sql(expr)?;
let p = atom_to_sql(pattern)?;
let not_kw = if *negated { "NOT " } else { "" };
let like_kw = if *case_insensitive { "ILIKE" } else { "LIKE" };
let escape_clause = match escape_char {
Some(c) => format!(" ESCAPE '{c}'"),
None => String::new(),
};
Some(format!("({e} {not_kw}{like_kw} {p}{escape_clause})"))
}
_ => None,
}
}
fn atom_to_sql(expr: &Expr) -> Option<String> {
match expr {
Expr::Column(col) => Some(col.name.clone()),
Expr::Literal(scalar, _metadata) => scalar_to_sql(scalar),
other => expr_to_sql(other),
}
}
fn scalar_to_sql(scalar: &ScalarValue) -> Option<String> {
match scalar {
ScalarValue::Int8(Some(v)) => Some(v.to_string()),
ScalarValue::Int16(Some(v)) => Some(v.to_string()),
ScalarValue::Int32(Some(v)) => Some(v.to_string()),
ScalarValue::Int64(Some(v)) => Some(v.to_string()),
ScalarValue::Float32(Some(v)) => Some(v.to_string()),
ScalarValue::Float64(Some(v)) => Some(v.to_string()),
ScalarValue::Boolean(Some(v)) => Some(if *v { "TRUE" } else { "FALSE" }.to_string()),
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
Some(format!("'{}'", s.replace('\'', "''")))
}
ScalarValue::Null => Some("NULL".to_string()),
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
| ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Boolean(None)
| ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None) => Some("NULL".to_string()),
_ => None,
}
}
struct OxiSqlExecPlan {
schema: SchemaRef,
sql: String,
conn: Arc<dyn Connection>,
cache: Arc<PlanProperties>,
}
impl fmt::Debug for OxiSqlExecPlan {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OxiSqlExecPlan")
.field("sql", &self.sql)
.field("schema", &self.schema)
.finish()
}
}
impl OxiSqlExecPlan {
fn new(conn: Arc<dyn Connection>, sql: String, schema: SchemaRef) -> Self {
let eq = EquivalenceProperties::new(Arc::clone(&schema));
let properties = PlanProperties::new(
eq,
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
);
Self {
schema,
sql,
conn,
cache: Arc::new(properties),
}
}
}
impl DisplayAs for OxiSqlExecPlan {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "OxiSqlExecPlan sql={:?}", self.sql)
}
}
impl ExecutionPlan for OxiSqlExecPlan {
fn name(&self) -> &'static str {
"OxiSqlExecPlan"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
if children.is_empty() {
Ok(self)
} else {
Err(datafusion::error::DataFusionError::Internal(
"OxiSqlExecPlan has no children".into(),
))
}
}
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
) -> datafusion::error::Result<SendableRecordBatchStream> {
let sql = self.sql.clone();
let conn = Arc::clone(&self.conn);
let schema = Arc::clone(&self.schema);
let stream = futures::stream::once(async move {
let rows = conn
.query(&sql, &[])
.await
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
crate::types::rows_to_record_batch(rows, schema)
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
stream,
)))
}
}
struct OxiSqlMultiPartExecPlan {
schema: SchemaRef,
sqls: Vec<String>,
conn: Arc<dyn Connection>,
cache: Arc<PlanProperties>,
}
impl fmt::Debug for OxiSqlMultiPartExecPlan {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OxiSqlMultiPartExecPlan")
.field("partitions", &self.sqls.len())
.field("schema", &self.schema)
.finish()
}
}
impl OxiSqlMultiPartExecPlan {
fn new(conn: Arc<dyn Connection>, sqls: Vec<String>, schema: SchemaRef) -> Self {
let n = sqls.len().max(1);
let eq = EquivalenceProperties::new(Arc::clone(&schema));
let properties = PlanProperties::new(
eq,
Partitioning::UnknownPartitioning(n),
EmissionType::Incremental,
Boundedness::Bounded,
);
Self {
schema,
sqls,
conn,
cache: Arc::new(properties),
}
}
}
impl DisplayAs for OxiSqlMultiPartExecPlan {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "OxiSqlMultiPartExecPlan partitions={}", self.sqls.len())
}
}
impl ExecutionPlan for OxiSqlMultiPartExecPlan {
fn name(&self) -> &'static str {
"OxiSqlMultiPartExecPlan"
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.cache
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
if children.is_empty() {
Ok(self)
} else {
Err(datafusion::error::DataFusionError::Internal(
"OxiSqlMultiPartExecPlan has no children".into(),
))
}
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> datafusion::error::Result<SendableRecordBatchStream> {
let sql = self.sqls.get(partition).cloned().ok_or_else(|| {
datafusion::error::DataFusionError::Internal(format!(
"OxiSqlMultiPartExecPlan: partition index {partition} out of range ({})",
self.sqls.len()
))
})?;
let conn = Arc::clone(&self.conn);
let schema = Arc::clone(&self.schema);
let stream = futures::stream::once(async move {
let rows = conn
.query(&sql, &[])
.await
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
crate::types::rows_to_record_batch(rows, schema)
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
Arc::clone(&self.schema),
stream,
)))
}
}