polars_python/
lazygroupby.rs

1use std::sync::Arc;
2
3use polars::lazy::frame::{LazyFrame, LazyGroupBy};
4use polars::prelude::{DataFrame, Schema};
5use pyo3::prelude::*;
6
7use crate::conversion::Wrap;
8use crate::error::PyPolarsErr;
9use crate::expr::ToExprs;
10use crate::py_modules::polars;
11use crate::{PyDataFrame, PyExpr, PyLazyFrame};
12
13#[pyclass]
14#[repr(transparent)]
15pub struct PyLazyGroupBy {
16    // option because we cannot get a self by value in pyo3
17    pub lgb: Option<LazyGroupBy>,
18}
19
20#[pymethods]
21impl PyLazyGroupBy {
22    fn agg(&mut self, aggs: Vec<PyExpr>) -> PyLazyFrame {
23        let lgb = self.lgb.clone().unwrap();
24        let aggs = aggs.to_exprs();
25        lgb.agg(aggs).into()
26    }
27
28    fn head(&mut self, n: usize) -> PyLazyFrame {
29        let lgb = self.lgb.clone().unwrap();
30        lgb.head(Some(n)).into()
31    }
32
33    fn tail(&mut self, n: usize) -> PyLazyFrame {
34        let lgb = self.lgb.clone().unwrap();
35        lgb.tail(Some(n)).into()
36    }
37
38    #[pyo3(signature = (lambda, schema=None))]
39    fn map_groups(
40        &mut self,
41        lambda: PyObject,
42        schema: Option<Wrap<Schema>>,
43    ) -> PyResult<PyLazyFrame> {
44        let lgb = self.lgb.clone().unwrap();
45        let schema = match schema {
46            Some(schema) => Arc::new(schema.0),
47            None => LazyFrame::from(lgb.logical_plan.clone())
48                .collect_schema()
49                .map_err(PyPolarsErr::from)?,
50        };
51
52        let function = move |df: DataFrame| {
53            Python::with_gil(|py| {
54                // get the pypolars module
55                let pypolars = polars(py).bind(py);
56
57                // create a PyDataFrame struct/object for Python
58                let pydf = PyDataFrame::new(df);
59
60                // Wrap this PySeries object in the python side DataFrame wrapper
61                let python_df_wrapper =
62                    pypolars.getattr("wrap_df").unwrap().call1((pydf,)).unwrap();
63
64                // call the lambda and get a python side DataFrame wrapper
65                let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
66                // unpack the wrapper in a PyDataFrame
67                let py_pydf = result_df_wrapper.getattr(py, "_df").expect(
68                "Could not get DataFrame attribute '_df'. Make sure that you return a DataFrame object.",
69            );
70                // Downcast to Rust
71                let pydf = py_pydf.extract::<PyDataFrame>(py).unwrap();
72                // Finally get the actual DataFrame
73                Ok(pydf.df)
74            })
75        };
76        Ok(lgb.apply(function, schema).into())
77    }
78}