polars_python/lazyframe/
visit.rs

1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars::prelude::python_dsl::PythonScanSource;
5use polars_plan::plans::{Context, IR, to_aexpr};
6use polars_plan::prelude::expr_ir::ExprIR;
7use polars_plan::prelude::{AExpr, PythonOptions};
8use polars_utils::arena::{Arena, Node};
9use pyo3::prelude::*;
10use pyo3::types::{PyDict, PyList};
11
12use super::PyLazyFrame;
13use super::visitor::{expr_nodes, nodes};
14use crate::error::PyPolarsErr;
15use crate::{PyExpr, Wrap, raise_err};
16
17#[derive(Clone)]
18#[pyclass]
19pub struct PyExprIR {
20    #[pyo3(get)]
21    node: usize,
22    #[pyo3(get)]
23    output_name: String,
24}
25
26impl From<ExprIR> for PyExprIR {
27    fn from(value: ExprIR) -> Self {
28        Self {
29            node: value.node().0,
30            output_name: value.output_name().to_string(),
31        }
32    }
33}
34
35impl From<&ExprIR> for PyExprIR {
36    fn from(value: &ExprIR) -> Self {
37        Self {
38            node: value.node().0,
39            output_name: value.output_name().to_string(),
40        }
41    }
42}
43
44type Version = (u16, u16);
45
46#[pyclass]
47pub struct NodeTraverser {
48    root: Node,
49    lp_arena: Arc<Mutex<Arena<IR>>>,
50    expr_arena: Arc<Mutex<Arena<AExpr>>>,
51    scratch: Vec<Node>,
52    expr_scratch: Vec<ExprIR>,
53    expr_mapping: Option<Vec<Node>>,
54}
55
56impl NodeTraverser {
57    // Versioning for IR, (major, minor)
58    // Increment major on breaking changes to the IR (e.g. renaming
59    // fields, reordering tuples), minor on backwards compatible
60    // changes (e.g. exposing a new expression node).
61    const VERSION: Version = (7, 0);
62
63    pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
64        Self {
65            root,
66            lp_arena: Arc::new(Mutex::new(lp_arena)),
67            expr_arena: Arc::new(Mutex::new(expr_arena)),
68            scratch: vec![],
69            expr_scratch: vec![],
70            expr_mapping: None,
71        }
72    }
73
74    #[allow(clippy::type_complexity)]
75    pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
76        (self.lp_arena.clone(), self.expr_arena.clone())
77    }
78
79    fn fill_inputs(&mut self) {
80        let lp_arena = self.lp_arena.lock().unwrap();
81        let this_node = lp_arena.get(self.root);
82        self.scratch.clear();
83        this_node.copy_inputs(&mut self.scratch);
84    }
85
86    fn fill_expressions(&mut self) {
87        let lp_arena = self.lp_arena.lock().unwrap();
88        let this_node = lp_arena.get(self.root);
89        self.expr_scratch.clear();
90        this_node.copy_exprs(&mut self.expr_scratch);
91    }
92
93    fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
94        PyList::new(py, self.scratch.drain(..).map(|node| node.0))
95    }
96
97    fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
98        PyList::new(
99            py,
100            self.expr_scratch
101                .drain(..)
102                .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
103        )
104    }
105}
106
107#[pymethods]
108impl NodeTraverser {
109    /// Get expression nodes
110    fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
111        self.fill_expressions();
112        self.expr_to_list(py)
113    }
114
115    /// Get input nodes
116    fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
117        self.fill_inputs();
118        self.scratch_to_list(py)
119    }
120
121    /// The current version of the IR
122    fn version(&self) -> Version {
123        NodeTraverser::VERSION
124    }
125
126    /// Get Schema of current node as python dict<str, pl.DataType>
127    fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
128        let lp_arena = self.lp_arena.lock().unwrap();
129        let schema = lp_arena.get(self.root).schema(&lp_arena);
130        Wrap(&**schema).into_pyobject(py)
131    }
132
133    /// Get expression dtype of expr_node, the schema used is that of the current root node
134    fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
135        let expr_node = Node(expr_node);
136        let lp_arena = self.lp_arena.lock().unwrap();
137        let schema = lp_arena.get(self.root).schema(&lp_arena);
138        let expr_arena = self.expr_arena.lock().unwrap();
139        let field = expr_arena
140            .get(expr_node)
141            .to_field(&schema, Context::Default, &expr_arena)
142            .map_err(PyPolarsErr::from)?;
143        Wrap(field.dtype).into_pyobject(py)
144    }
145
146    /// Set the current node in the plan.
147    fn set_node(&mut self, node: usize) {
148        self.root = Node(node);
149    }
150
151    /// Get the current node in the plan.
152    fn get_node(&mut self) -> usize {
153        self.root.0
154    }
155
156    /// Set a python UDF that will replace the subtree location with this function src.
157    fn set_udf(&mut self, function: PyObject) {
158        let mut lp_arena = self.lp_arena.lock().unwrap();
159        let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
160        let ir = IR::PythonScan {
161            options: PythonOptions {
162                scan_fn: Some(function.into()),
163                schema,
164                output_schema: None,
165                with_columns: None,
166                python_source: PythonScanSource::Cuda,
167                predicate: Default::default(),
168                n_rows: None,
169                validate_schema: false,
170            },
171        };
172        lp_arena.replace(self.root, ir);
173    }
174
175    fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
176        let lp_arena = self.lp_arena.lock().unwrap();
177        let lp_node = lp_arena.get(self.root);
178        nodes::into_py(py, lp_node)
179    }
180
181    fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
182        let expr_arena = self.expr_arena.lock().unwrap();
183        let n = match &self.expr_mapping {
184            Some(mapping) => *mapping.get(node).unwrap(),
185            None => Node(node),
186        };
187        let expr = expr_arena.get(n);
188        expr_nodes::into_py(py, expr)
189    }
190
191    /// Add some expressions to the arena and return their new node ids as well
192    /// as the total number of nodes in the arena.
193    fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
194        let mut expr_arena = self.expr_arena.lock().unwrap();
195        Ok((
196            expressions
197                .into_iter()
198                .map(|e| {
199                    to_aexpr(e.inner, &mut expr_arena)
200                        .map_err(PyPolarsErr::from)
201                        .map(|v| v.0)
202                })
203                .collect::<Result<_, PyPolarsErr>>()?,
204            expr_arena.len(),
205        ))
206    }
207
208    /// Set up a mapping of expression nodes used in `view_expression_node``.
209    /// With a mapping set, `view_expression_node(i)` produces the node for
210    /// `mapping[i]`.
211    fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
212        if mapping.len() != self.expr_arena.lock().unwrap().len() {
213            raise_err!("Invalid mapping length", ComputeError);
214        }
215        self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
216        Ok(())
217    }
218
219    /// Unset the expression mapping (reinstates the identity map)
220    fn unset_expr_mapping(&mut self) {
221        self.expr_mapping = None;
222    }
223}
224
225#[pymethods]
226#[allow(clippy::should_implement_trait)]
227impl PyLazyFrame {
228    fn visit(&self) -> PyResult<NodeTraverser> {
229        let mut lp_arena = Arena::with_capacity(16);
230        let mut expr_arena = Arena::with_capacity(16);
231        let root = self
232            .ldf
233            .clone()
234            .optimize(&mut lp_arena, &mut expr_arena)
235            .map_err(PyPolarsErr::from)?;
236        Ok(NodeTraverser {
237            root,
238            lp_arena: Arc::new(Mutex::new(lp_arena)),
239            expr_arena: Arc::new(Mutex::new(expr_arena)),
240            scratch: vec![],
241            expr_scratch: vec![],
242            expr_mapping: None,
243        })
244    }
245}