use datafusion_python_util::wait_for_future;
use datafusion_substrait::logical_plan::{consumer, producer};
use datafusion_substrait::serializer;
use datafusion_substrait::substrait::proto::Plan;
use prost::Message;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use crate::context::PySessionContext;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::sql::logical::PyLogicalPlan;
#[pyclass(
from_py_object,
frozen,
name = "Plan",
module = "datafusion.substrait",
subclass
)]
#[derive(Debug, Clone)]
pub struct PyPlan {
pub plan: Plan,
}
#[pymethods]
impl PyPlan {
fn encode(&self, py: Python) -> PyResult<Py<PyAny>> {
let mut proto_bytes = Vec::<u8>::new();
self.plan
.encode(&mut proto_bytes)
.map_err(PyDataFusionError::EncodeError)?;
Ok(PyBytes::new(py, &proto_bytes).into())
}
fn to_json(&self) -> PyDataFusionResult<String> {
let json = serde_json::to_string_pretty(&self.plan).map_err(to_datafusion_err)?;
Ok(json)
}
#[staticmethod]
fn from_json(json: &str) -> PyDataFusionResult<PyPlan> {
let plan: Plan = serde_json::from_str(json).map_err(to_datafusion_err)?;
Ok(PyPlan { plan })
}
}
impl From<PyPlan> for Plan {
fn from(plan: PyPlan) -> Plan {
plan.plan
}
}
impl From<Plan> for PyPlan {
fn from(plan: Plan) -> PyPlan {
PyPlan { plan }
}
}
#[pyclass(
from_py_object,
frozen,
name = "Serde",
module = "datafusion.substrait",
subclass
)]
#[derive(Debug, Clone)]
pub struct PySubstraitSerializer;
#[pymethods]
impl PySubstraitSerializer {
#[staticmethod]
pub fn serialize(
sql: &str,
ctx: PySessionContext,
path: &str,
py: Python,
) -> PyDataFusionResult<()> {
wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))??;
Ok(())
}
#[staticmethod]
pub fn serialize_to_plan(
sql: &str,
ctx: PySessionContext,
py: Python,
) -> PyDataFusionResult<PyPlan> {
PySubstraitSerializer::serialize_bytes(sql, ctx, py).and_then(|proto_bytes| {
let proto_bytes = proto_bytes.bind(py).cast::<PyBytes>().unwrap();
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
})
}
#[staticmethod]
pub fn serialize_bytes(
sql: &str,
ctx: PySessionContext,
py: Python,
) -> PyDataFusionResult<Py<PyAny>> {
let proto_bytes: Vec<u8> =
wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))??;
Ok(PyBytes::new(py, &proto_bytes).into())
}
#[staticmethod]
pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult<PyPlan> {
let plan = wait_for_future(py, serializer::deserialize(path))??;
Ok(PyPlan { plan: *plan })
}
#[staticmethod]
pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) -> PyDataFusionResult<PyPlan> {
let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))??;
Ok(PyPlan { plan: *plan })
}
}
#[pyclass(
from_py_object,
frozen,
name = "Producer",
module = "datafusion.substrait",
subclass
)]
#[derive(Debug, Clone)]
pub struct PySubstraitProducer;
#[pymethods]
impl PySubstraitProducer {
#[staticmethod]
pub fn to_substrait_plan(plan: PyLogicalPlan, ctx: &PySessionContext) -> PyResult<PyPlan> {
let session_state = ctx.ctx.state();
match producer::to_substrait_plan(&plan.plan, &session_state) {
Ok(plan) => Ok(PyPlan { plan: *plan }),
Err(e) => Err(py_datafusion_err(e)),
}
}
}
#[pyclass(
from_py_object,
frozen,
name = "Consumer",
module = "datafusion.substrait",
subclass
)]
#[derive(Debug, Clone)]
pub struct PySubstraitConsumer;
#[pymethods]
impl PySubstraitConsumer {
#[staticmethod]
pub fn from_substrait_plan(
ctx: &PySessionContext,
plan: PyPlan,
py: Python,
) -> PyDataFusionResult<PyLogicalPlan> {
let session_state = ctx.ctx.state();
let result = consumer::from_substrait_plan(&session_state, &plan.plan);
let logical_plan = wait_for_future(py, result)??;
Ok(PyLogicalPlan::new(logical_plan))
}
}
pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyPlan>()?;
m.add_class::<PySubstraitConsumer>()?;
m.add_class::<PySubstraitProducer>()?;
m.add_class::<PySubstraitSerializer>()?;
Ok(())
}