1use polars_core::error::{polars_err, PolarsResult};
2use polars_expr::state::ExecutionState;
3use polars_mem_engine::create_physical_plan;
4use polars_plan::plans::{AExpr, IRPlan, IR};
5use polars_plan::prelude::{Arena, Node};
6use polars_utils::pl_serialize;
7use pyo3::intern;
8use pyo3::prelude::{PyAnyMethods, PyModule, Python, *};
9use pyo3::types::{IntoPyDict, PyBytes};
10
11use crate::error::PyPolarsErr;
12use crate::lazyframe::visit::NodeTraverser;
13use crate::{PyDataFrame, PyLazyFrame};
14
15#[pyfunction]
16pub fn prepare_cloud_plan(lf: PyLazyFrame, py: Python<'_>) -> PyResult<Bound<'_, PyBytes>> {
17 let plan = lf.ldf.logical_plan;
18 let bytes = polars::prelude::prepare_cloud_plan(plan).map_err(PyPolarsErr::from)?;
19
20 Ok(PyBytes::new(py, &bytes))
21}
22
23#[pyfunction]
28pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec<u8>, py: Python) -> PyResult<PyDataFrame> {
29 let mut ir_plan: IRPlan =
31 pl_serialize::deserialize_from_reader(ir_plan_ser.as_slice()).map_err(PyPolarsErr::from)?;
32
33 gpu_post_opt(
35 py,
36 ir_plan.lp_top,
37 &mut ir_plan.lp_arena,
38 &mut ir_plan.expr_arena,
39 )
40 .map_err(PyPolarsErr::from)?;
41
42 let mut physical_plan =
44 create_physical_plan(ir_plan.lp_top, &mut ir_plan.lp_arena, &ir_plan.expr_arena)
45 .map_err(PyPolarsErr::from)?;
46
47 let mut state = ExecutionState::new();
49 let df = py.allow_threads(|| physical_plan.execute(&mut state).map_err(PyPolarsErr::from))?;
50
51 Ok(df.into())
52}
53
54fn gpu_post_opt(
56 py: Python,
57 root: Node,
58 lp_arena: &mut Arena<IR>,
59 expr_arena: &mut Arena<AExpr>,
60) -> PolarsResult<()> {
61 let cudf = PyModule::import(py, intern!(py, "cudf_polars")).unwrap();
63 let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap();
64
65 let polars = PyModule::import(py, intern!(py, "polars")).unwrap();
67 let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap();
68 let kwargs = [("raise_on_fail", true)].into_py_dict(py).unwrap();
69 let engine = engine.call((), Some(&kwargs)).unwrap();
70
71 let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena));
73
74 let arenas = nt.get_arenas();
76
77 let kwargs = [("config", engine)].into_py_dict(py).unwrap();
80 lambda
81 .call((nt,), Some(&kwargs))
82 .map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?;
83
84 std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap());
87 std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap());
88
89 Ok(())
90}