use std::any::Any;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion_expr::{Expr, TableProviderFilterPushDown, TableSource};
use pyo3::prelude::*;
use datafusion_optimizer::utils::split_conjunction;
use super::{data_type::DataTypeMap, function::SqlFunction};
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlSchema {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub tables: Vec<SqlTable>,
#[pyo3(get, set)]
pub views: Vec<SqlView>,
#[pyo3(get, set)]
pub functions: Vec<SqlFunction>,
}
#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlTable {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub columns: Vec<(String, DataTypeMap)>,
#[pyo3(get, set)]
pub primary_key: Option<String>,
#[pyo3(get, set)]
pub foreign_keys: Vec<String>,
#[pyo3(get, set)]
pub indexes: Vec<String>,
#[pyo3(get, set)]
pub constraints: Vec<String>,
#[pyo3(get, set)]
pub statistics: SqlStatistics,
#[pyo3(get, set)]
pub filepath: Option<String>,
}
#[pymethods]
impl SqlTable {
#[new]
pub fn new(
table_name: String,
columns: Vec<(String, DataTypeMap)>,
row_count: f64,
filepath: Option<String>,
) -> Self {
Self {
name: table_name,
columns,
primary_key: None,
foreign_keys: Vec::new(),
indexes: Vec::new(),
constraints: Vec::new(),
statistics: SqlStatistics::new(row_count),
filepath,
}
}
}
#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlView {
#[pyo3(get, set)]
pub name: String,
#[pyo3(get, set)]
pub definition: String, }
#[pymethods]
impl SqlSchema {
#[new]
pub fn new(schema_name: &str) -> Self {
Self {
name: schema_name.to_owned(),
tables: Vec::new(),
views: Vec::new(),
functions: Vec::new(),
}
}
pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
for tbl in &self.tables {
if tbl.name.eq(table_name) {
return Some(tbl.clone());
}
}
None
}
pub fn add_table(&mut self, table: SqlTable) {
self.tables.push(table);
}
pub fn drop_table(&mut self, table_name: String) {
self.tables.retain(|x| !x.name.eq(&table_name));
}
}
pub struct SqlTableSource {
schema: SchemaRef,
statistics: Option<SqlStatistics>,
filepath: Option<String>,
}
impl SqlTableSource {
pub fn new(
schema: SchemaRef,
statistics: Option<SqlStatistics>,
filepath: Option<String>,
) -> Self {
Self {
schema,
statistics,
filepath,
}
}
pub fn statistics(&self) -> Option<&SqlStatistics> {
self.statistics.as_ref()
}
#[allow(dead_code)]
pub fn filepath(&self) -> Option<&String> {
self.filepath.as_ref()
}
}
impl TableSource for SqlTableSource {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn supports_filter_pushdown(
&self,
filter: &Expr,
) -> datafusion_common::Result<TableProviderFilterPushDown> {
let filters = split_conjunction(filter);
if filters.iter().all(|f| is_supported_push_down_expr(f)) {
Ok(TableProviderFilterPushDown::Exact)
} else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
Ok(TableProviderFilterPushDown::Inexact)
} else {
Ok(TableProviderFilterPushDown::Unsupported)
}
}
fn table_type(&self) -> datafusion_expr::TableType {
datafusion_expr::TableType::Base
}
#[allow(deprecated)]
fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> datafusion_common::Result<Vec<TableProviderFilterPushDown>> {
filters
.iter()
.map(|f| self.supports_filter_pushdown(f))
.collect()
}
fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> {
None
}
}
fn is_supported_push_down_expr(_expr: &Expr) -> bool {
true
}
#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
#[derive(Debug, Clone)]
pub struct SqlStatistics {
row_count: f64,
}
#[pymethods]
impl SqlStatistics {
#[new]
pub fn new(row_count: f64) -> Self {
Self { row_count }
}
#[pyo3(name = "getRowCount")]
pub fn get_row_count(&self) -> f64 {
self.row_count
}
}