1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars::prelude::python_dsl::PythonScanSource;
5use polars_plan::plans::{ExprToIRContext, IR, to_expr_ir};
6use polars_plan::prelude::expr_ir::ExprIR;
7use polars_plan::prelude::{AExpr, PythonOptions};
8use polars_utils::arena::{Arena, Node};
9use pyo3::prelude::*;
10use pyo3::types::{PyDict, PyList};
11
12use super::PyLazyFrame;
13use super::visitor::{expr_nodes, nodes};
14use crate::error::PyPolarsErr;
15use crate::{PyExpr, Wrap, raise_err};
16
17#[derive(Clone)]
18#[pyclass(frozen)]
19pub struct PyExprIR {
20 #[pyo3(get)]
21 node: usize,
22 #[pyo3(get)]
23 output_name: String,
24}
25
26impl From<ExprIR> for PyExprIR {
27 fn from(value: ExprIR) -> Self {
28 Self {
29 node: value.node().0,
30 output_name: value.output_name().to_string(),
31 }
32 }
33}
34
35impl From<&ExprIR> for PyExprIR {
36 fn from(value: &ExprIR) -> Self {
37 Self {
38 node: value.node().0,
39 output_name: value.output_name().to_string(),
40 }
41 }
42}
43
44type Version = (u16, u16);
45
46#[pyclass]
47pub struct NodeTraverser {
48 root: Node,
49 lp_arena: Arc<Mutex<Arena<IR>>>,
50 expr_arena: Arc<Mutex<Arena<AExpr>>>,
51 scratch: Vec<Node>,
52 expr_scratch: Vec<ExprIR>,
53 expr_mapping: Option<Vec<Node>>,
54}
55
56impl NodeTraverser {
57 const VERSION: Version = (10, 0);
62
63 pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
64 Self {
65 root,
66 lp_arena: Arc::new(Mutex::new(lp_arena)),
67 expr_arena: Arc::new(Mutex::new(expr_arena)),
68 scratch: vec![],
69 expr_scratch: vec![],
70 expr_mapping: None,
71 }
72 }
73
74 #[allow(clippy::type_complexity)]
75 pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
76 (self.lp_arena.clone(), self.expr_arena.clone())
77 }
78
79 fn fill_inputs(&mut self) {
80 let lp_arena = self.lp_arena.lock().unwrap();
81 let this_node = lp_arena.get(self.root);
82 self.scratch.clear();
83 this_node.copy_inputs(&mut self.scratch);
84 }
85
86 fn fill_expressions(&mut self) {
87 let lp_arena = self.lp_arena.lock().unwrap();
88 let this_node = lp_arena.get(self.root);
89 self.expr_scratch.clear();
90 this_node.copy_exprs(&mut self.expr_scratch);
91 }
92
93 fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
94 PyList::new(py, self.scratch.drain(..).map(|node| node.0))
95 }
96
97 fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
98 PyList::new(
99 py,
100 self.expr_scratch
101 .drain(..)
102 .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
103 )
104 }
105}
106
107#[pymethods]
108impl NodeTraverser {
109 fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
111 self.fill_expressions();
112 self.expr_to_list(py)
113 }
114
115 fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
117 self.fill_inputs();
118 self.scratch_to_list(py)
119 }
120
121 fn version(&self) -> Version {
123 NodeTraverser::VERSION
124 }
125
126 fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
128 let lp_arena = self.lp_arena.lock().unwrap();
129 let schema = lp_arena.get(self.root).schema(&lp_arena);
130 Wrap((**schema).clone()).into_pyobject(py)
131 }
132
133 fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
135 let expr_node = Node(expr_node);
136 let lp_arena = self.lp_arena.lock().unwrap();
137 let schema = lp_arena.get(self.root).schema(&lp_arena);
138 let expr_arena = self.expr_arena.lock().unwrap();
139 let field = expr_arena
140 .get(expr_node)
141 .to_field(&schema, &expr_arena)
142 .map_err(PyPolarsErr::from)?;
143 Wrap(field.dtype).into_pyobject(py)
144 }
145
146 fn set_node(&mut self, node: usize) {
148 self.root = Node(node);
149 }
150
151 fn get_node(&mut self) -> usize {
153 self.root.0
154 }
155
156 #[pyo3(signature = (function, is_pure = false))]
158 fn set_udf(&mut self, function: PyObject, is_pure: bool) {
159 let mut lp_arena = self.lp_arena.lock().unwrap();
160 let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
161 let ir = IR::PythonScan {
162 options: PythonOptions {
163 scan_fn: Some(function.into()),
164 schema,
165 output_schema: None,
166 with_columns: None,
167 python_source: PythonScanSource::Cuda,
168 predicate: Default::default(),
169 n_rows: None,
170 validate_schema: false,
171 is_pure,
172 },
173 };
174 lp_arena.replace(self.root, ir);
175 }
176
177 fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
178 let lp_arena = self.lp_arena.lock().unwrap();
179 let lp_node = lp_arena.get(self.root);
180 nodes::into_py(py, lp_node)
181 }
182
183 fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
184 let expr_arena = self.expr_arena.lock().unwrap();
185 let n = match &self.expr_mapping {
186 Some(mapping) => *mapping.get(node).unwrap(),
187 None => Node(node),
188 };
189 let expr = expr_arena.get(n);
190 expr_nodes::into_py(py, expr)
191 }
192
193 fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
196 let lp_arena = self.lp_arena.lock().unwrap();
197 let schema = lp_arena.get(self.root).schema(&lp_arena);
198 let mut expr_arena = self.expr_arena.lock().unwrap();
199 Ok((
200 expressions
201 .into_iter()
202 .map(|e| {
203 let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
204 ctx.allow_unknown = true;
205 to_expr_ir(e.inner, &mut ctx)
207 .map_err(PyPolarsErr::from)
208 .map(|v| v.node().0)
209 })
210 .collect::<Result<_, PyPolarsErr>>()?,
211 expr_arena.len(),
212 ))
213 }
214
215 fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
219 if mapping.len() != self.expr_arena.lock().unwrap().len() {
220 raise_err!("Invalid mapping length", ComputeError);
221 }
222 self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
223 Ok(())
224 }
225
226 fn unset_expr_mapping(&mut self) {
228 self.expr_mapping = None;
229 }
230}
231
232#[pymethods]
233#[allow(clippy::should_implement_trait)]
234impl PyLazyFrame {
235 fn visit(&self) -> PyResult<NodeTraverser> {
236 let mut lp_arena = Arena::with_capacity(16);
237 let mut expr_arena = Arena::with_capacity(16);
238 let root = self
239 .ldf
240 .read()
241 .clone()
242 .optimize(&mut lp_arena, &mut expr_arena)
243 .map_err(PyPolarsErr::from)?;
244 Ok(NodeTraverser {
245 root,
246 lp_arena: Arc::new(Mutex::new(lp_arena)),
247 expr_arena: Arc::new(Mutex::new(expr_arena)),
248 scratch: vec![],
249 expr_scratch: vec![],
250 expr_mapping: None,
251 })
252 }
253}