polars_python/
on_startup.rs

1use std::any::Any;
2
3use polars::prelude::*;
4use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
5use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
6use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
7use polars_core::error::PolarsError::ComputeError;
8use polars_error::{set_signals_function, PolarsWarning};
9use pyo3::prelude::*;
10use pyo3::{intern, IntoPyObjectExt};
11
12use crate::dataframe::PyDataFrame;
13use crate::map::lazy::{call_lambda_with_series, ToSeries};
14use crate::prelude::ObjectValue;
15use crate::py_modules::{pl_utils, polars};
16use crate::Wrap;
17
18fn python_function_caller_series(s: Column, lambda: &PyObject) -> PolarsResult<Column> {
19    Python::with_gil(|py| {
20        let object = call_lambda_with_series(py, s.clone().take_materialized_series(), lambda)
21            .map_err(|s| ComputeError(format!("{}", s).into()))?;
22        object.to_series(py, polars(py), s.name()).map(Column::from)
23    })
24}
25
26fn python_function_caller_df(df: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame> {
27    Python::with_gil(|py| {
28        // create a PyDataFrame struct/object for Python
29        let pydf = PyDataFrame::new(df);
30        // Wrap this PyDataFrame object in the python side DataFrame wrapper
31        let python_df_wrapper = polars(py)
32            .getattr(py, "wrap_df")
33            .unwrap()
34            .call1(py, (pydf,))
35            .unwrap();
36        // call the lambda and get a python side df wrapper
37        let result_df_wrapper = lambda.call1(py, (python_df_wrapper,)).map_err(|e| {
38            PolarsError::ComputeError(format!("User provided python function failed: {e}").into())
39        })?;
40        // unpack the wrapper in a PyDataFrame
41        let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
42            let pytype = result_df_wrapper.bind(py).get_type();
43            PolarsError::ComputeError(
44                format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
45                    .into(),
46            )
47        })?;
48
49        // Downcast to Rust
50        let pydf = py_pydf.extract::<PyDataFrame>(py).unwrap();
51        // Finally get the actual DataFrame
52        let df = pydf.df;
53
54        Ok(df)
55    })
56}
57
58fn warning_function(msg: &str, warning: PolarsWarning) {
59    Python::with_gil(|py| {
60        let warn_fn = pl_utils(py)
61            .bind(py)
62            .getattr(intern!(py, "_polars_warn"))
63            .unwrap();
64
65        if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
66            eprintln!("{e}")
67        }
68    });
69}
70
71/// # Safety
72/// Caller must ensure that no other threads read the objects set by this registration.
73pub unsafe fn register_startup_deps(check_python_signals: bool) {
74    set_polars_allow_extension(true);
75    if !registry::is_object_builder_registered() {
76        // Stack frames can get really large in debug mode.
77        #[cfg(debug_assertions)]
78        {
79            recursive::set_minimum_stack_size(1024 * 1024);
80            recursive::set_stack_allocation_size(1024 * 1024 * 16);
81        }
82
83        // register object type builder
84        let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
85            Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
86                as Box<dyn AnonymousObjectBuilder>
87        });
88
89        let object_converter = Arc::new(|av: AnyValue| {
90            let object = Python::with_gil(|py| ObjectValue {
91                inner: Wrap(av).into_py_any(py).unwrap(),
92            });
93            Box::new(object) as Box<dyn Any>
94        });
95
96        let object_size = size_of::<ObjectValue>();
97        let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
98        registry::register_object_builder(object_builder, object_converter, physical_dtype);
99        // register SERIES UDF
100        python_udf::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
101        // register DATAFRAME UDF
102        python_udf::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
103        // register warning function for `polars_warn!`
104        polars_error::set_warning_function(warning_function);
105
106        if check_python_signals {
107            fn signals_function() -> PolarsResult<()> {
108                Python::with_gil(|py| {
109                    py.check_signals()
110                        .map_err(|err| polars_err!(ComputeError: "{err}"))
111                })
112            }
113
114            set_signals_function(signals_function);
115        }
116
117        Python::with_gil(|py| {
118            // init AnyValue LUT
119            crate::conversion::any_value::LUT
120                .set(py, Default::default())
121                .unwrap();
122        });
123    }
124}