polars_python/lazyframe/
visit.rs

1use std::sync::{Arc, Mutex};
2
3use polars::prelude::PolarsError;
4use polars::prelude::python_dsl::PythonScanSource;
5use polars_plan::plans::{ExprToIRContext, IR, to_expr_ir};
6use polars_plan::prelude::expr_ir::ExprIR;
7use polars_plan::prelude::{AExpr, PythonOptions};
8use polars_utils::arena::{Arena, Node};
9use pyo3::prelude::*;
10use pyo3::types::{PyDict, PyList};
11
12use super::PyLazyFrame;
13use super::visitor::{expr_nodes, nodes};
14use crate::error::PyPolarsErr;
15use crate::{PyExpr, Wrap, raise_err};
16
17#[derive(Clone)]
18#[pyclass(frozen)]
19pub struct PyExprIR {
20    #[pyo3(get)]
21    node: usize,
22    #[pyo3(get)]
23    output_name: String,
24}
25
26impl From<ExprIR> for PyExprIR {
27    fn from(value: ExprIR) -> Self {
28        Self {
29            node: value.node().0,
30            output_name: value.output_name().to_string(),
31        }
32    }
33}
34
35impl From<&ExprIR> for PyExprIR {
36    fn from(value: &ExprIR) -> Self {
37        Self {
38            node: value.node().0,
39            output_name: value.output_name().to_string(),
40        }
41    }
42}
43
44type Version = (u16, u16);
45
46#[pyclass]
47pub struct NodeTraverser {
48    root: Node,
49    lp_arena: Arc<Mutex<Arena<IR>>>,
50    expr_arena: Arc<Mutex<Arena<AExpr>>>,
51    scratch: Vec<Node>,
52    expr_scratch: Vec<ExprIR>,
53    expr_mapping: Option<Vec<Node>>,
54}
55
56impl NodeTraverser {
57    // Versioning for IR, (major, minor)
58    // Increment major on breaking changes to the IR (e.g. renaming
59    // fields, reordering tuples), minor on backwards compatible
60    // changes (e.g. exposing a new expression node).
61    const VERSION: Version = (10, 0);
62
63    pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
64        Self {
65            root,
66            lp_arena: Arc::new(Mutex::new(lp_arena)),
67            expr_arena: Arc::new(Mutex::new(expr_arena)),
68            scratch: vec![],
69            expr_scratch: vec![],
70            expr_mapping: None,
71        }
72    }
73
74    #[allow(clippy::type_complexity)]
75    pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
76        (self.lp_arena.clone(), self.expr_arena.clone())
77    }
78
79    fn fill_inputs(&mut self) {
80        let lp_arena = self.lp_arena.lock().unwrap();
81        let this_node = lp_arena.get(self.root);
82        self.scratch.clear();
83        this_node.copy_inputs(&mut self.scratch);
84    }
85
86    fn fill_expressions(&mut self) {
87        let lp_arena = self.lp_arena.lock().unwrap();
88        let this_node = lp_arena.get(self.root);
89        self.expr_scratch.clear();
90        this_node.copy_exprs(&mut self.expr_scratch);
91    }
92
93    fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
94        PyList::new(py, self.scratch.drain(..).map(|node| node.0))
95    }
96
97    fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
98        PyList::new(
99            py,
100            self.expr_scratch
101                .drain(..)
102                .map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
103        )
104    }
105}
106
107#[pymethods]
108impl NodeTraverser {
109    /// Get expression nodes
110    fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
111        self.fill_expressions();
112        self.expr_to_list(py)
113    }
114
115    /// Get input nodes
116    fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
117        self.fill_inputs();
118        self.scratch_to_list(py)
119    }
120
121    /// The current version of the IR
122    fn version(&self) -> Version {
123        NodeTraverser::VERSION
124    }
125
126    /// Get Schema of current node as python dict<str, pl.DataType>
127    fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
128        let lp_arena = self.lp_arena.lock().unwrap();
129        let schema = lp_arena.get(self.root).schema(&lp_arena);
130        Wrap((**schema).clone()).into_pyobject(py)
131    }
132
133    /// Get expression dtype of expr_node, the schema used is that of the current root node
134    fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
135        let expr_node = Node(expr_node);
136        let lp_arena = self.lp_arena.lock().unwrap();
137        let schema = lp_arena.get(self.root).schema(&lp_arena);
138        let expr_arena = self.expr_arena.lock().unwrap();
139        let field = expr_arena
140            .get(expr_node)
141            .to_field(&schema, &expr_arena)
142            .map_err(PyPolarsErr::from)?;
143        Wrap(field.dtype).into_pyobject(py)
144    }
145
146    /// Set the current node in the plan.
147    fn set_node(&mut self, node: usize) {
148        self.root = Node(node);
149    }
150
151    /// Get the current node in the plan.
152    fn get_node(&mut self) -> usize {
153        self.root.0
154    }
155
156    /// Set a python UDF that will replace the subtree location with this function src.
157    #[pyo3(signature = (function, is_pure = false))]
158    fn set_udf(&mut self, function: PyObject, is_pure: bool) {
159        let mut lp_arena = self.lp_arena.lock().unwrap();
160        let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
161        let ir = IR::PythonScan {
162            options: PythonOptions {
163                scan_fn: Some(function.into()),
164                schema,
165                output_schema: None,
166                with_columns: None,
167                python_source: PythonScanSource::Cuda,
168                predicate: Default::default(),
169                n_rows: None,
170                validate_schema: false,
171                is_pure,
172            },
173        };
174        lp_arena.replace(self.root, ir);
175    }
176
177    fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
178        let lp_arena = self.lp_arena.lock().unwrap();
179        let lp_node = lp_arena.get(self.root);
180        nodes::into_py(py, lp_node)
181    }
182
183    fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
184        let expr_arena = self.expr_arena.lock().unwrap();
185        let n = match &self.expr_mapping {
186            Some(mapping) => *mapping.get(node).unwrap(),
187            None => Node(node),
188        };
189        let expr = expr_arena.get(n);
190        expr_nodes::into_py(py, expr)
191    }
192
193    /// Add some expressions to the arena and return their new node ids as well
194    /// as the total number of nodes in the arena.
195    fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
196        let lp_arena = self.lp_arena.lock().unwrap();
197        let schema = lp_arena.get(self.root).schema(&lp_arena);
198        let mut expr_arena = self.expr_arena.lock().unwrap();
199        Ok((
200            expressions
201                .into_iter()
202                .map(|e| {
203                    let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
204                    ctx.allow_unknown = true;
205                    // NOTE: Probably throwing away the output names here is not okay?
206                    to_expr_ir(e.inner, &mut ctx)
207                        .map_err(PyPolarsErr::from)
208                        .map(|v| v.node().0)
209                })
210                .collect::<Result<_, PyPolarsErr>>()?,
211            expr_arena.len(),
212        ))
213    }
214
215    /// Set up a mapping of expression nodes used in `view_expression_node``.
216    /// With a mapping set, `view_expression_node(i)` produces the node for
217    /// `mapping[i]`.
218    fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
219        if mapping.len() != self.expr_arena.lock().unwrap().len() {
220            raise_err!("Invalid mapping length", ComputeError);
221        }
222        self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
223        Ok(())
224    }
225
226    /// Unset the expression mapping (reinstates the identity map)
227    fn unset_expr_mapping(&mut self) {
228        self.expr_mapping = None;
229    }
230}
231
232#[pymethods]
233#[allow(clippy::should_implement_trait)]
234impl PyLazyFrame {
235    fn visit(&self) -> PyResult<NodeTraverser> {
236        let mut lp_arena = Arena::with_capacity(16);
237        let mut expr_arena = Arena::with_capacity(16);
238        let root = self
239            .ldf
240            .read()
241            .clone()
242            .optimize(&mut lp_arena, &mut expr_arena)
243            .map_err(PyPolarsErr::from)?;
244        Ok(NodeTraverser {
245            root,
246            lp_arena: Arc::new(Mutex::new(lp_arena)),
247            expr_arena: Arc::new(Mutex::new(expr_arena)),
248            scratch: vec![],
249            expr_scratch: vec![],
250            expr_mapping: None,
251        })
252    }
253}