use polars_core::error::{PolarsResult, polars_err};
use polars_expr::state::ExecutionState;
use polars_mem_engine::create_physical_plan;
use polars_plan::plans::{AExpr, IR, IRPlan};
use polars_plan::prelude::{Arena, Node};
use polars_utils::pl_serialize;
use pyo3::intern;
use pyo3::prelude::{PyAnyMethods, PyModule, Python, *};
use pyo3::types::IntoPyDict;
use crate::PyDataFrame;
use crate::error::PyPolarsErr;
use crate::lazyframe::visit::NodeTraverser;
use crate::utils::EnterPolarsExt;
#[pyfunction]
pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec<u8>, py: Python) -> PyResult<PyDataFrame> {
let mut ir_plan: IRPlan =
pl_serialize::deserialize_from_reader::<_, _, false>(ir_plan_ser.as_slice())
.map_err(PyPolarsErr::from)?;
gpu_post_opt(
py,
ir_plan.lp_top,
&mut ir_plan.lp_arena,
&mut ir_plan.expr_arena,
)
.map_err(PyPolarsErr::from)?;
let mut physical_plan = create_physical_plan(
ir_plan.lp_top,
&mut ir_plan.lp_arena,
&mut ir_plan.expr_arena,
None,
)
.map_err(PyPolarsErr::from)?;
let mut state = ExecutionState::new();
py.enter_polars_df(|| physical_plan.execute(&mut state))
}
fn gpu_post_opt(
py: Python<'_>,
root: Node,
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<()> {
let cudf = PyModule::import(py, intern!(py, "cudf_polars")).unwrap();
let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap();
let polars = PyModule::import(py, intern!(py, "polars")).unwrap();
let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap();
let kwargs = [("raise_on_fail", true)].into_py_dict(py).unwrap();
let engine = engine.call((), Some(&kwargs)).unwrap();
let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena));
let arenas = nt.get_arenas();
let kwargs = [("config", engine)].into_py_dict(py).unwrap();
lambda
.call((nt,), Some(&kwargs))
.map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?;
std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap());
std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap());
Ok(())
}