polars_python/lazyframe/
visit.rs

1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars_plan::plans::{to_aexpr, Context, IR};
5use polars_plan::prelude::expr_ir::ExprIR;
6use polars_plan::prelude::{AExpr, PythonOptions, PythonScanSource};
7use polars_utils::arena::{Arena, Node};
8use pyo3::prelude::*;
9use pyo3::types::{PyDict, PyList};
10
11use super::visitor::{expr_nodes, nodes};
12use super::PyLazyFrame;
13use crate::error::PyPolarsErr;
14use crate::{raise_err, PyExpr, Wrap};
15
16#[derive(Clone)]
17#[pyclass]
18pub struct PyExprIR {
19    #[pyo3(get)]
20    node: usize,
21    #[pyo3(get)]
22    output_name: String,
23}
24
25impl From<ExprIR> for PyExprIR {
26    fn from(value: ExprIR) -> Self {
27        Self {
28            node: value.node().0,
29            output_name: value.output_name().to_string(),
30        }
31    }
32}
33
34impl From<&ExprIR> for PyExprIR {
35    fn from(value: &ExprIR) -> Self {
36        Self {
37            node: value.node().0,
38            output_name: value.output_name().to_string(),
39        }
40    }
41}
42
43type Version = (u16, u16);
44
45#[pyclass]
46pub struct NodeTraverser {
47    root: Node,
48    lp_arena: Arc<Mutex<Arena<IR>>>,
49    expr_arena: Arc<Mutex<Arena<AExpr>>>,
50    scratch: Vec<Node>,
51    expr_scratch: Vec<ExprIR>,
52    expr_mapping: Option<Vec<Node>>,
53}
54
55impl NodeTraverser {
56    // Versioning for IR, (major, minor)
57    // Increment major on breaking changes to the IR (e.g. renaming
58    // fields, reordering tuples), minor on backwards compatible
59    // changes (e.g. exposing a new expression node).
60    const VERSION: Version = (5, 0);
61
62    pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
63        Self {
64            root,
65            lp_arena: Arc::new(Mutex::new(lp_arena)),
66            expr_arena: Arc::new(Mutex::new(expr_arena)),
67            scratch: vec![],
68            expr_scratch: vec![],
69            expr_mapping: None,
70        }
71    }
72
73    #[allow(clippy::type_complexity)]
74    pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
75        (self.lp_arena.clone(), self.expr_arena.clone())
76    }
77
78    fn fill_inputs(&mut self) {
79        let lp_arena = self.lp_arena.lock().unwrap();
80        let this_node = lp_arena.get(self.root);
81        self.scratch.clear();
82        this_node.copy_inputs(&mut self.scratch);
83    }
84
85    fn fill_expressions(&mut self) {
86        let lp_arena = self.lp_arena.lock().unwrap();
87        let this_node = lp_arena.get(self.root);
88        self.expr_scratch.clear();
89        this_node.copy_exprs(&mut self.expr_scratch);
90    }
91
92    fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
93        PyList::new(py, self.scratch.drain(..).map(|node| node.0))
94    }
95
96    fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
97        PyList::new(
98            py,
99            self.expr_scratch
100                .drain(..)
101                .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
102        )
103    }
104}
105
106#[pymethods]
107impl NodeTraverser {
108    /// Get expression nodes
109    fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
110        self.fill_expressions();
111        self.expr_to_list(py)
112    }
113
114    /// Get input nodes
115    fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
116        self.fill_inputs();
117        self.scratch_to_list(py)
118    }
119
120    /// The current version of the IR
121    fn version(&self) -> Version {
122        NodeTraverser::VERSION
123    }
124
125    /// Get Schema of current node as python dict<str, pl.DataType>
126    fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
127        let lp_arena = self.lp_arena.lock().unwrap();
128        let schema = lp_arena.get(self.root).schema(&lp_arena);
129        Wrap(&**schema).into_pyobject(py)
130    }
131
132    /// Get expression dtype of expr_node, the schema used is that of the current root node
133    fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
134        let expr_node = Node(expr_node);
135        let lp_arena = self.lp_arena.lock().unwrap();
136        let schema = lp_arena.get(self.root).schema(&lp_arena);
137        let expr_arena = self.expr_arena.lock().unwrap();
138        let field = expr_arena
139            .get(expr_node)
140            .to_field(&schema, Context::Default, &expr_arena)
141            .map_err(PyPolarsErr::from)?;
142        Wrap(field.dtype).into_pyobject(py)
143    }
144
145    /// Set the current node in the plan.
146    fn set_node(&mut self, node: usize) {
147        self.root = Node(node);
148    }
149
150    /// Get the current node in the plan.
151    fn get_node(&mut self) -> usize {
152        self.root.0
153    }
154
155    /// Set a python UDF that will replace the subtree location with this function src.
156    fn set_udf(&mut self, function: PyObject) {
157        let mut lp_arena = self.lp_arena.lock().unwrap();
158        let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
159        let ir = IR::PythonScan {
160            options: PythonOptions {
161                scan_fn: Some(function.into()),
162                schema,
163                output_schema: None,
164                with_columns: None,
165                python_source: PythonScanSource::Cuda,
166                predicate: Default::default(),
167                n_rows: None,
168            },
169        };
170        lp_arena.replace(self.root, ir);
171    }
172
173    fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
174        let lp_arena = self.lp_arena.lock().unwrap();
175        let lp_node = lp_arena.get(self.root);
176        nodes::into_py(py, lp_node)
177    }
178
179    fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
180        let expr_arena = self.expr_arena.lock().unwrap();
181        let n = match &self.expr_mapping {
182            Some(mapping) => *mapping.get(node).unwrap(),
183            None => Node(node),
184        };
185        let expr = expr_arena.get(n);
186        expr_nodes::into_py(py, expr)
187    }
188
189    /// Add some expressions to the arena and return their new node ids as well
190    /// as the total number of nodes in the arena.
191    fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
192        let mut expr_arena = self.expr_arena.lock().unwrap();
193        Ok((
194            expressions
195                .into_iter()
196                .map(|e| {
197                    to_aexpr(e.inner, &mut expr_arena)
198                        .map_err(PyPolarsErr::from)
199                        .map(|v| v.0)
200                })
201                .collect::<Result<_, PyPolarsErr>>()?,
202            expr_arena.len(),
203        ))
204    }
205
206    /// Set up a mapping of expression nodes used in `view_expression_node``.
207    /// With a mapping set, `view_expression_node(i)` produces the node for
208    /// `mapping[i]`.
209    fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
210        if mapping.len() != self.expr_arena.lock().unwrap().len() {
211            raise_err!("Invalid mapping length", ComputeError);
212        }
213        self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
214        Ok(())
215    }
216
217    /// Unset the expression mapping (reinstates the identity map)
218    fn unset_expr_mapping(&mut self) {
219        self.expr_mapping = None;
220    }
221}
222
223#[pymethods]
224#[allow(clippy::should_implement_trait)]
225impl PyLazyFrame {
226    fn visit(&self) -> PyResult<NodeTraverser> {
227        let mut lp_arena = Arena::with_capacity(16);
228        let mut expr_arena = Arena::with_capacity(16);
229        let root = self
230            .ldf
231            .clone()
232            .optimize(&mut lp_arena, &mut expr_arena)
233            .map_err(PyPolarsErr::from)?;
234        Ok(NodeTraverser {
235            root,
236            lp_arena: Arc::new(Mutex::new(lp_arena)),
237            expr_arena: Arc::new(Mutex::new(expr_arena)),
238            scratch: vec![],
239            expr_scratch: vec![],
240            expr_mapping: None,
241        })
242    }
243}