1use polars_core::error::{PolarsResult, polars_err};
2use polars_expr::state::ExecutionState;
3use polars_mem_engine::create_physical_plan;
4use polars_plan::plans::{AExpr, IR, IRPlan};
5use polars_plan::prelude::{Arena, Node};
6use polars_utils::pl_serialize;
7use pyo3::intern;
8use pyo3::prelude::{PyAnyMethods, PyModule, Python, *};
9use pyo3::types::{IntoPyDict, PyBytes};
10
11use crate::error::PyPolarsErr;
12use crate::lazyframe::visit::NodeTraverser;
13use crate::utils::EnterPolarsExt;
14use crate::{PyDataFrame, PyLazyFrame};
15
16#[pyfunction]
17pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python<'_>) -> PyResult<Bound<'_, PyBytes>> {
18 let plan = lf.ldf.logical_plan;
19 let bytes = polars::prelude::prepare_cloud_plan(plan).map_err(PyPolarsErr::from)?;
20
21 Ok(PyBytes::new(py, &bytes))
22}
23
24#[pyfunction]
29pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec<u8>, py: Python) -> PyResult<PyDataFrame> {
30 let mut ir_plan: IRPlan =
32 pl_serialize::deserialize_from_reader::<_, _, false>(ir_plan_ser.as_slice())
33 .map_err(PyPolarsErr::from)?;
34
35 gpu_post_opt(
37 py,
38 ir_plan.lp_top,
39 &mut ir_plan.lp_arena,
40 &mut ir_plan.expr_arena,
41 )
42 .map_err(PyPolarsErr::from)?;
43
44 let mut physical_plan = create_physical_plan(
46 ir_plan.lp_top,
47 &mut ir_plan.lp_arena,
48 &mut ir_plan.expr_arena,
49 None,
50 )
51 .map_err(PyPolarsErr::from)?;
52
53 let mut state = ExecutionState::new();
55 py.enter_polars_df(|| physical_plan.execute(&mut state))
56}
57
58fn gpu_post_opt(
60 py: Python<'_>,
61 root: Node,
62 lp_arena: &mut Arena<IR>,
63 expr_arena: &mut Arena<AExpr>,
64) -> PolarsResult<()> {
65 let cudf = PyModule::import(py, intern!(py, "cudf_polars")).unwrap();
67 let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap();
68
69 let polars = PyModule::import(py, intern!(py, "polars")).unwrap();
71 let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap();
72 let kwargs = [("raise_on_fail", true)].into_py_dict(py).unwrap();
73 let engine = engine.call((), Some(&kwargs)).unwrap();
74
75 let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena));
77
78 let arenas = nt.get_arenas();
80
81 let kwargs = [("config", engine)].into_py_dict(py).unwrap();
84 lambda
85 .call((nt,), Some(&kwargs))
86 .map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?;
87
88 std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap());
91 std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap());
92
93 Ok(())
94}