use std::any::Any;
use std::sync::Arc;
use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::{DataFusionError, Result as DFResult};
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyType;
use crate::dataset_exec::DatasetExec;
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
#[derive(Debug)]
pub(crate) struct Dataset {
dataset: Py<PyAny>,
}
impl Dataset {
pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
let ds = PyModule::import(py, "pyarrow.dataset")?;
let ds_attr = ds.getattr("Dataset")?;
let ds_type = ds_attr.cast::<PyType>()?;
if dataset.is_instance(ds_type)? {
Ok(Dataset {
dataset: dataset.clone().unbind(),
})
} else {
Err(PyValueError::new_err(
"dataset argument must be a pyarrow.dataset.Dataset object",
))
}
}
}
#[async_trait]
impl TableProvider for Dataset {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Python::attach(|py| {
let dataset = self.dataset.bind(py);
Arc::new(
dataset
.getattr("schema")
.unwrap()
.extract::<PyArrowType<_>>()
.unwrap()
.0,
)
})
}
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan(
&self,
_ctx: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Python::attach(|py| {
let plan: Arc<dyn ExecutionPlan> = Arc::new(
DatasetExec::new(py, self.dataset.bind(py), projection.cloned(), filters)
.map_err(|err| DataFusionError::External(Box::new(err)))?,
);
Ok(plan)
})
}
fn supports_filters_pushdown(
&self,
filter: &[&Expr],
) -> DFResult<Vec<TableProviderFilterPushDown>> {
filter
.iter()
.map(|&f| match PyArrowFilterExpression::try_from(f) {
Ok(_) => Ok(TableProviderFilterPushDown::Exact),
_ => Ok(TableProviderFilterPushDown::Unsupported),
})
.collect()
}
}