polars_python/lazyframe/
visit.rs1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars_plan::plans::{to_aexpr, Context, IR};
5use polars_plan::prelude::expr_ir::ExprIR;
6use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource};
7use polars_utils::arena::{Arena, Node};
8use pyo3::prelude::*;
9use pyo3::types::{PyDict, PyList};
10
11use super::visitor::{expr_nodes, nodes};
12use super::PyLazyFrame;
13use crate::error::PyPolarsErr;
14use crate::{raise_err, PyExpr, Wrap};
15
16#[derive(Clone)]
17#[pyclass]
18pub struct PyExprIR {
19 #[pyo3(get)]
20 node: usize,
21 #[pyo3(get)]
22 output_name: String,
23}
24
25impl From<ExprIR> for PyExprIR {
26 fn from(value: ExprIR) -> Self {
27 Self {
28 node: value.node().0,
29 output_name: value.output_name().to_string(),
30 }
31 }
32}
33
34impl From<&ExprIR> for PyExprIR {
35 fn from(value: &ExprIR) -> Self {
36 Self {
37 node: value.node().0,
38 output_name: value.output_name().to_string(),
39 }
40 }
41}
42
43type Version = (u16, u16);
44
45#[pyclass]
46pub struct NodeTraverser {
47 root: Node,
48 lp_arena: Arc<Mutex<Arena<IR>>>,
49 expr_arena: Arc<Mutex<Arena<AExpr>>>,
50 scratch: Vec<Node>,
51 expr_scratch: Vec<ExprIR>,
52 expr_mapping: Option<Vec<Node>>,
53}
54
55impl NodeTraverser {
56 const VERSION: Version = (5, 0);
61
62 pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
63 Self {
64 root,
65 lp_arena: Arc::new(Mutex::new(lp_arena)),
66 expr_arena: Arc::new(Mutex::new(expr_arena)),
67 scratch: vec![],
68 expr_scratch: vec![],
69 expr_mapping: None,
70 }
71 }
72
73 #[allow(clippy::type_complexity)]
74 pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
75 (self.lp_arena.clone(), self.expr_arena.clone())
76 }
77
78 fn fill_inputs(&mut self) {
79 let lp_arena = self.lp_arena.lock().unwrap();
80 let this_node = lp_arena.get(self.root);
81 self.scratch.clear();
82 this_node.copy_inputs(&mut self.scratch);
83 }
84
85 fn fill_expressions(&mut self) {
86 let lp_arena = self.lp_arena.lock().unwrap();
87 let this_node = lp_arena.get(self.root);
88 self.expr_scratch.clear();
89 this_node.copy_exprs(&mut self.expr_scratch);
90 }
91
92 fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
93 PyList::new(py, self.scratch.drain(..).map(|node| node.0))
94 }
95
96 fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
97 PyList::new(
98 py,
99 self.expr_scratch
100 .drain(..)
101 .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
102 )
103 }
104}
105
106#[pymethods]
107impl NodeTraverser {
108 fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
110 self.fill_expressions();
111 self.expr_to_list(py)
112 }
113
114 fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
116 self.fill_inputs();
117 self.scratch_to_list(py)
118 }
119
120 fn version(&self) -> Version {
122 NodeTraverser::VERSION
123 }
124
125 fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
127 let lp_arena = self.lp_arena.lock().unwrap();
128 let schema = lp_arena.get(self.root).schema(&lp_arena);
129 Wrap(&**schema).into_pyobject(py)
130 }
131
132 fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
134 let expr_node = Node(expr_node);
135 let lp_arena = self.lp_arena.lock().unwrap();
136 let schema = lp_arena.get(self.root).schema(&lp_arena);
137 let expr_arena = self.expr_arena.lock().unwrap();
138 let field = expr_arena
139 .get(expr_node)
140 .to_field(&schema, Context::Default, &expr_arena)
141 .map_err(PyPolarsErr::from)?;
142 Wrap(field.dtype).into_pyobject(py)
143 }
144
145 fn set_node(&mut self, node: usize) {
147 self.root = Node(node);
148 }
149
150 fn get_node(&mut self) -> usize {
152 self.root.0
153 }
154
155 fn set_udf(&mut self, function: PyObject) {
157 let mut lp_arena = self.lp_arena.lock().unwrap();
158 let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
159 let ir = IR::PythonScan {
160 options: PythonOptions {
161 scan_fn: Some(function.into()),
162 schema,
163 output_schema: None,
164 with_columns: None,
165 python_source: PythonScanSource::Cuda,
166 predicate: Default::default(),
167 n_rows: None,
168 },
169 };
170 lp_arena.replace(self.root, ir);
171 }
172
173 fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
174 let lp_arena = self.lp_arena.lock().unwrap();
175 let lp_node = lp_arena.get(self.root);
176 nodes::into_py(py, lp_node)
177 }
178
179 fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
180 let expr_arena = self.expr_arena.lock().unwrap();
181 let n = match &self.expr_mapping {
182 Some(mapping) => *mapping.get(node).unwrap(),
183 None => Node(node),
184 };
185 let expr = expr_arena.get(n);
186 expr_nodes::into_py(py, expr)
187 }
188
189 fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
192 let mut expr_arena = self.expr_arena.lock().unwrap();
193 Ok((
194 expressions
195 .into_iter()
196 .map(|e| {
197 to_aexpr(e.inner, &mut expr_arena)
198 .map_err(PyPolarsErr::from)
199 .map(|v| v.0)
200 })
201 .collect::<Result<_, PyPolarsErr>>()?,
202 expr_arena.len(),
203 ))
204 }
205
206 fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
210 if mapping.len() != self.expr_arena.lock().unwrap().len() {
211 raise_err!("Invalid mapping length", ComputeError);
212 }
213 self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
214 Ok(())
215 }
216
217 fn unset_expr_mapping(&mut self) {
219 self.expr_mapping = None;
220 }
221}
222
223#[pymethods]
224#[allow(clippy::should_implement_trait)]
225impl PyLazyFrame {
226 fn visit(&self) -> PyResult<NodeTraverser> {
227 let mut lp_arena = Arena::with_capacity(16);
228 let mut expr_arena = Arena::with_capacity(16);
229 let root = self
230 .ldf
231 .clone()
232 .optimize(&mut lp_arena, &mut expr_arena)
233 .map_err(PyPolarsErr::from)?;
234 Ok(NodeTraverser {
235 root,
236 lp_arena: Arc::new(Mutex::new(lp_arena)),
237 expr_arena: Arc::new(Mutex::new(expr_arena)),
238 scratch: vec![],
239 expr_scratch: vec![],
240 expr_mapping: None,
241 })
242 }
243}