use std::any::Any;
use std::fmt;
use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::catalog::Session;
use datafusion::common::stats::{ColumnStatistics, Precision};
use datafusion::common::Statistics;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
use datafusion::logical_expr::{BinaryExpr, Expr, Operator, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::scalar::ScalarValue;
use oxisql_core::{Row, Value};
use crate::error::OxiSqlFusionError;
use crate::types::rows_to_record_batch;
#[derive(Clone)]
pub struct OxiSqlTableProvider {
schema: SchemaRef,
rows: Arc<Vec<Row>>,
partitions: Vec<Arc<Vec<Row>>>,
}
impl fmt::Debug for OxiSqlTableProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OxiSqlTableProvider")
.field("schema", &self.schema)
.field("row_count", &self.rows.len())
.finish()
}
}
impl OxiSqlTableProvider {
pub fn from_rows(rows: Vec<Row>, schema: SchemaRef) -> Self {
Self {
schema,
rows: Arc::new(rows),
partitions: Vec::new(),
}
}
pub async fn from_connection(
conn: &dyn oxisql_core::Connection,
table_name: &str,
schema: SchemaRef,
) -> Result<Self, OxiSqlFusionError> {
let sql = format!("SELECT * FROM {table_name}");
let rows = conn
.query(&sql, &[])
.await
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))?;
Ok(Self::from_rows(rows, schema))
}
pub async fn refresh(
&mut self,
conn: &dyn oxisql_core::Connection,
table_name: &str,
) -> Result<(), OxiSqlFusionError> {
let sql = format!("SELECT * FROM {table_name}");
let rows = conn
.query(&sql, &[])
.await
.map_err(|e| OxiSqlFusionError::OxiSql(e.to_string()))?;
self.rows = Arc::new(rows);
Ok(())
}
pub fn len(&self) -> usize {
self.rows.len()
}
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
#[must_use]
pub fn with_range_partition(mut self, key_column: &str, n_partitions: usize) -> Self {
if n_partitions == 0 {
return self;
}
let col_idx = match self.schema.index_of(key_column) {
Ok(idx) => idx,
Err(_) => return self,
};
let mut sorted: Vec<Row> = (*self.rows).clone();
sorted.sort_by(|a, b| {
let va = a.get_by_index(col_idx);
let vb = b.get_by_index(col_idx);
match (va, vb) {
(Some(l), Some(r)) => l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal),
(None, Some(_)) => std::cmp::Ordering::Less,
(Some(_), None) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
}
});
let total = sorted.len();
let chunk_size = total.div_ceil(n_partitions.max(1));
let parts: Vec<Arc<Vec<Row>>> = sorted
.chunks(if chunk_size == 0 { 1 } else { chunk_size })
.map(|chunk| Arc::new(chunk.to_vec()))
.collect();
self.rows = Arc::new(sorted);
self.partitions = parts;
self
}
}
fn is_simple_filter(expr: &Expr) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
matches!(
op,
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq
) && is_col_or_literal(left)
&& is_col_or_literal(right)
}
Expr::IsNull(inner) | Expr::IsNotNull(inner) => is_col_or_literal(inner),
_ => false,
}
}
fn is_col_or_literal(expr: &Expr) -> bool {
matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
}
fn eval_filter_on_row(expr: &Expr, row: &Row, schema: &arrow::datatypes::Schema) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let (col_idx, scalar, flip) = if let (Expr::Column(col), Expr::Literal(sv, _)) =
(left.as_ref(), right.as_ref())
{
match schema.index_of(col.name.as_str()) {
Ok(idx) => (idx, sv, false),
Err(_) => return true,
}
} else if let (Expr::Literal(sv, _), Expr::Column(col)) =
(left.as_ref(), right.as_ref())
{
match schema.index_of(col.name.as_str()) {
Ok(idx) => (idx, sv, true),
Err(_) => return true,
}
} else {
return true; };
let row_val = match row.get_by_index(col_idx) {
Some(v) => v,
None => return true,
};
let ord = compare_value_scalar(row_val, scalar);
match ord {
None => true, Some(o) => {
let effective = if flip { o.reverse() } else { o };
match op {
Operator::Eq => effective == std::cmp::Ordering::Equal,
Operator::NotEq => effective != std::cmp::Ordering::Equal,
Operator::Lt => effective == std::cmp::Ordering::Less,
Operator::LtEq => effective != std::cmp::Ordering::Greater,
Operator::Gt => effective == std::cmp::Ordering::Greater,
Operator::GtEq => effective != std::cmp::Ordering::Less,
_ => true,
}
}
}
}
Expr::IsNull(inner) => {
if let Expr::Column(col) = inner.as_ref() {
match schema.index_of(col.name.as_str()) {
Ok(idx) => matches!(row.get_by_index(idx), Some(Value::Null) | None),
Err(_) => true,
}
} else {
true
}
}
Expr::IsNotNull(inner) => {
if let Expr::Column(col) = inner.as_ref() {
match schema.index_of(col.name.as_str()) {
Ok(idx) => !matches!(row.get_by_index(idx), Some(Value::Null) | None),
Err(_) => true,
}
} else {
true
}
}
_ => true, }
}
fn compare_value_scalar(val: &Value, scalar: &ScalarValue) -> Option<std::cmp::Ordering> {
match (val, scalar) {
(Value::I64(v), ScalarValue::Int64(Some(s))) => v.partial_cmp(s),
(Value::I64(v), ScalarValue::Int32(Some(s))) => v.partial_cmp(&i64::from(*s)),
(Value::I64(v), ScalarValue::Int16(Some(s))) => v.partial_cmp(&i64::from(*s)),
(Value::I64(v), ScalarValue::Int8(Some(s))) => v.partial_cmp(&i64::from(*s)),
(Value::F64(v), ScalarValue::Float64(Some(s))) => v.partial_cmp(s),
(Value::F64(v), ScalarValue::Float32(Some(s))) => v.partial_cmp(&f64::from(*s)),
(Value::I64(v), ScalarValue::Float64(Some(s))) => (*v as f64).partial_cmp(s),
(Value::I64(v), ScalarValue::Float32(Some(s))) => (*v as f64).partial_cmp(&f64::from(*s)),
(Value::Text(v), ScalarValue::Utf8(Some(s)))
| (Value::Text(v), ScalarValue::LargeUtf8(Some(s))) => v.as_str().partial_cmp(s.as_str()),
(Value::Bool(v), ScalarValue::Boolean(Some(s))) => v.partial_cmp(s),
(Value::Null, ScalarValue::Null)
| (Value::Null, ScalarValue::Int64(None))
| (Value::Null, ScalarValue::Int32(None))
| (Value::Null, ScalarValue::Int16(None))
| (Value::Null, ScalarValue::Int8(None))
| (Value::Null, ScalarValue::Float64(None))
| (Value::Null, ScalarValue::Float32(None))
| (Value::Null, ScalarValue::Boolean(None))
| (Value::Null, ScalarValue::Utf8(None))
| (Value::Null, ScalarValue::LargeUtf8(None)) => Some(std::cmp::Ordering::Equal),
_ => None,
}
}
#[async_trait]
impl TableProvider for OxiSqlTableProvider {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
let source_partitions: Vec<&[Row]> = if self.partitions.is_empty() {
vec![self.rows.as_slice()]
} else {
self.partitions.iter().map(|p| p.as_slice()).collect()
};
let schema_ref = Arc::clone(&self.schema);
let df_err =
|e: OxiSqlFusionError| datafusion::error::DataFusionError::External(Box::new(e));
let mut partitions: Vec<Vec<arrow::record_batch::RecordBatch>> =
Vec::with_capacity(source_partitions.len());
for slice in source_partitions {
let kept: Vec<Row> = if filters.is_empty() {
slice.to_vec()
} else {
slice
.iter()
.filter(|row| {
filters
.iter()
.all(|f| eval_filter_on_row(f, row, &schema_ref))
})
.cloned()
.collect()
};
let batch = rows_to_record_batch(kept, Arc::clone(&schema_ref)).map_err(df_err)?;
partitions.push(vec![batch]);
}
let exec = MemorySourceConfig::try_new_exec(
&partitions,
Arc::clone(&self.schema),
projection.cloned(),
)?;
Ok(exec as Arc<dyn ExecutionPlan>)
}
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> DFResult<Vec<TableProviderFilterPushDown>> {
Ok(filters
.iter()
.map(|f| {
if is_simple_filter(f) {
TableProviderFilterPushDown::Inexact
} else {
TableProviderFilterPushDown::Unsupported
}
})
.collect())
}
fn statistics(&self) -> Option<Statistics> {
let n_cols = self.schema.fields().len();
let col_stats: Vec<ColumnStatistics> = (0..n_cols)
.map(|_| ColumnStatistics::new_unknown())
.collect();
let mut stats = Statistics::default()
.with_num_rows(Precision::Exact(self.rows.len()))
.with_total_byte_size(Precision::Absent);
stats.column_statistics = col_stats;
Some(stats)
}
}
impl fmt::Display for OxiSqlTableProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"OxiSqlTableProvider(rows={}, cols={})",
self.rows.len(),
self.schema.fields().len()
)
}
}