polars_python/lazyframe/
visit.rs1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars::prelude::python_dsl::PythonScanSource;
5use polars_plan::plans::{Context, IR, to_aexpr};
6use polars_plan::prelude::expr_ir::ExprIR;
7use polars_plan::prelude::{AExpr, PythonOptions};
8use polars_utils::arena::{Arena, Node};
9use pyo3::prelude::*;
10use pyo3::types::{PyDict, PyList};
11
12use super::PyLazyFrame;
13use super::visitor::{expr_nodes, nodes};
14use crate::error::PyPolarsErr;
15use crate::{PyExpr, Wrap, raise_err};
16
17#[derive(Clone)]
18#[pyclass]
19pub struct PyExprIR {
20 #[pyo3(get)]
21 node: usize,
22 #[pyo3(get)]
23 output_name: String,
24}
25
26impl From<ExprIR> for PyExprIR {
27 fn from(value: ExprIR) -> Self {
28 Self {
29 node: value.node().0,
30 output_name: value.output_name().to_string(),
31 }
32 }
33}
34
35impl From<&ExprIR> for PyExprIR {
36 fn from(value: &ExprIR) -> Self {
37 Self {
38 node: value.node().0,
39 output_name: value.output_name().to_string(),
40 }
41 }
42}
43
44type Version = (u16, u16);
45
46#[pyclass]
47pub struct NodeTraverser {
48 root: Node,
49 lp_arena: Arc<Mutex<Arena<IR>>>,
50 expr_arena: Arc<Mutex<Arena<AExpr>>>,
51 scratch: Vec<Node>,
52 expr_scratch: Vec<ExprIR>,
53 expr_mapping: Option<Vec<Node>>,
54}
55
56impl NodeTraverser {
57 const VERSION: Version = (7, 0);
62
63 pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
64 Self {
65 root,
66 lp_arena: Arc::new(Mutex::new(lp_arena)),
67 expr_arena: Arc::new(Mutex::new(expr_arena)),
68 scratch: vec![],
69 expr_scratch: vec![],
70 expr_mapping: None,
71 }
72 }
73
74 #[allow(clippy::type_complexity)]
75 pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
76 (self.lp_arena.clone(), self.expr_arena.clone())
77 }
78
79 fn fill_inputs(&mut self) {
80 let lp_arena = self.lp_arena.lock().unwrap();
81 let this_node = lp_arena.get(self.root);
82 self.scratch.clear();
83 this_node.copy_inputs(&mut self.scratch);
84 }
85
86 fn fill_expressions(&mut self) {
87 let lp_arena = self.lp_arena.lock().unwrap();
88 let this_node = lp_arena.get(self.root);
89 self.expr_scratch.clear();
90 this_node.copy_exprs(&mut self.expr_scratch);
91 }
92
93 fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
94 PyList::new(py, self.scratch.drain(..).map(|node| node.0))
95 }
96
97 fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
98 PyList::new(
99 py,
100 self.expr_scratch
101 .drain(..)
102 .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
103 )
104 }
105}
106
107#[pymethods]
108impl NodeTraverser {
109 fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
111 self.fill_expressions();
112 self.expr_to_list(py)
113 }
114
115 fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
117 self.fill_inputs();
118 self.scratch_to_list(py)
119 }
120
121 fn version(&self) -> Version {
123 NodeTraverser::VERSION
124 }
125
126 fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
128 let lp_arena = self.lp_arena.lock().unwrap();
129 let schema = lp_arena.get(self.root).schema(&lp_arena);
130 Wrap(&**schema).into_pyobject(py)
131 }
132
133 fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
135 let expr_node = Node(expr_node);
136 let lp_arena = self.lp_arena.lock().unwrap();
137 let schema = lp_arena.get(self.root).schema(&lp_arena);
138 let expr_arena = self.expr_arena.lock().unwrap();
139 let field = expr_arena
140 .get(expr_node)
141 .to_field(&schema, Context::Default, &expr_arena)
142 .map_err(PyPolarsErr::from)?;
143 Wrap(field.dtype).into_pyobject(py)
144 }
145
146 fn set_node(&mut self, node: usize) {
148 self.root = Node(node);
149 }
150
151 fn get_node(&mut self) -> usize {
153 self.root.0
154 }
155
156 fn set_udf(&mut self, function: PyObject) {
158 let mut lp_arena = self.lp_arena.lock().unwrap();
159 let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
160 let ir = IR::PythonScan {
161 options: PythonOptions {
162 scan_fn: Some(function.into()),
163 schema,
164 output_schema: None,
165 with_columns: None,
166 python_source: PythonScanSource::Cuda,
167 predicate: Default::default(),
168 n_rows: None,
169 validate_schema: false,
170 },
171 };
172 lp_arena.replace(self.root, ir);
173 }
174
175 fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
176 let lp_arena = self.lp_arena.lock().unwrap();
177 let lp_node = lp_arena.get(self.root);
178 nodes::into_py(py, lp_node)
179 }
180
181 fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
182 let expr_arena = self.expr_arena.lock().unwrap();
183 let n = match &self.expr_mapping {
184 Some(mapping) => *mapping.get(node).unwrap(),
185 None => Node(node),
186 };
187 let expr = expr_arena.get(n);
188 expr_nodes::into_py(py, expr)
189 }
190
191 fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
194 let mut expr_arena = self.expr_arena.lock().unwrap();
195 Ok((
196 expressions
197 .into_iter()
198 .map(|e| {
199 to_aexpr(e.inner, &mut expr_arena)
200 .map_err(PyPolarsErr::from)
201 .map(|v| v.0)
202 })
203 .collect::<Result<_, PyPolarsErr>>()?,
204 expr_arena.len(),
205 ))
206 }
207
208 fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
212 if mapping.len() != self.expr_arena.lock().unwrap().len() {
213 raise_err!("Invalid mapping length", ComputeError);
214 }
215 self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
216 Ok(())
217 }
218
219 fn unset_expr_mapping(&mut self) {
221 self.expr_mapping = None;
222 }
223}
224
225#[pymethods]
226#[allow(clippy::should_implement_trait)]
227impl PyLazyFrame {
228 fn visit(&self) -> PyResult<NodeTraverser> {
229 let mut lp_arena = Arena::with_capacity(16);
230 let mut expr_arena = Arena::with_capacity(16);
231 let root = self
232 .ldf
233 .clone()
234 .optimize(&mut lp_arena, &mut expr_arena)
235 .map_err(PyPolarsErr::from)?;
236 Ok(NodeTraverser {
237 root,
238 lp_arena: Arc::new(Mutex::new(lp_arena)),
239 expr_arena: Arc::new(Mutex::new(expr_arena)),
240 scratch: vec![],
241 expr_scratch: vec![],
242 expr_mapping: None,
243 })
244 }
245}