use pyo3::{prelude::*, types::PyBytes};
use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, DataFusionError};
use crate::sql::logical::PyLogicalPlan;
use crate::utils::wait_for_future;
use datafusion_substrait::logical_plan::{consumer, producer};
use datafusion_substrait::serializer;
use datafusion_substrait::substrait::proto::Plan;
use prost::Message;
#[pyclass(name = "plan", module = "datafusion.substrait", subclass)]
#[derive(Debug, Clone)]
pub(crate) struct PyPlan {
pub(crate) plan: Plan,
}
#[pymethods]
impl PyPlan {
fn encode(&self, py: Python) -> PyResult<PyObject> {
let mut proto_bytes = Vec::<u8>::new();
self.plan
.encode(&mut proto_bytes)
.map_err(DataFusionError::EncodeError)?;
Ok(PyBytes::new(py, &proto_bytes).into())
}
}
impl From<PyPlan> for Plan {
fn from(plan: PyPlan) -> Plan {
plan.plan
}
}
impl From<Plan> for PyPlan {
fn from(plan: Plan) -> PyPlan {
PyPlan { plan }
}
}
#[pyclass(name = "serde", module = "datafusion.substrait", subclass)]
#[derive(Debug, Clone)]
pub(crate) struct PySubstraitSerializer;
#[pymethods]
impl PySubstraitSerializer {
#[staticmethod]
pub fn serialize(sql: &str, ctx: PySessionContext, path: &str, py: Python) -> PyResult<()> {
wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))
.map_err(DataFusionError::from)?;
Ok(())
}
#[staticmethod]
pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) -> PyResult<PyPlan> {
match PySubstraitSerializer::serialize_bytes(sql, ctx, py) {
Ok(proto_bytes) => {
let proto_bytes: &PyBytes = proto_bytes.as_ref(py).downcast().unwrap();
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
}
Err(e) => Err(py_datafusion_err(e)),
}
}
#[staticmethod]
pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult<PyObject> {
let proto_bytes: Vec<u8> = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))
.map_err(DataFusionError::from)?;
Ok(PyBytes::new(py, &proto_bytes).into())
}
#[staticmethod]
pub fn deserialize(path: &str, py: Python) -> PyResult<PyPlan> {
let plan =
wait_for_future(py, serializer::deserialize(path)).map_err(DataFusionError::from)?;
Ok(PyPlan { plan: *plan })
}
#[staticmethod]
pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) -> PyResult<PyPlan> {
let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))
.map_err(DataFusionError::from)?;
Ok(PyPlan { plan: *plan })
}
}
#[pyclass(name = "producer", module = "datafusion.substrait", subclass)]
#[derive(Debug, Clone)]
pub(crate) struct PySubstraitProducer;
#[pymethods]
impl PySubstraitProducer {
#[staticmethod]
pub fn to_substrait_plan(plan: PyLogicalPlan, ctx: &PySessionContext) -> PyResult<PyPlan> {
match producer::to_substrait_plan(&plan.plan, &ctx.ctx) {
Ok(plan) => Ok(PyPlan { plan: *plan }),
Err(e) => Err(py_datafusion_err(e)),
}
}
}
#[pyclass(name = "consumer", module = "datafusion.substrait", subclass)]
#[derive(Debug, Clone)]
pub(crate) struct PySubstraitConsumer;
#[pymethods]
impl PySubstraitConsumer {
#[staticmethod]
pub fn from_substrait_plan(
ctx: &mut PySessionContext,
plan: PyPlan,
py: Python,
) -> PyResult<PyLogicalPlan> {
let result = consumer::from_substrait_plan(&mut ctx.ctx, &plan.plan);
let logical_plan = wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(PyLogicalPlan::new(logical_plan))
}
}
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_class::<PyPlan>()?;
m.add_class::<PySubstraitConsumer>()?;
m.add_class::<PySubstraitProducer>()?;
m.add_class::<PySubstraitSerializer>()?;
Ok(())
}