polars_python/
on_startup.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use polars::prelude::*;
6use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
7use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
8use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
9use polars_error::PolarsWarning;
10use polars_error::signals::register_polars_keyboard_interrupt_hook;
11use polars_ffi::version_0::SeriesExport;
12use polars_plan::plans::python_df_to_rust;
13use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
14use pyo3::prelude::*;
15use pyo3::{IntoPyObjectExt, intern};
16
17use crate::Wrap;
18use crate::dataframe::PyDataFrame;
19use crate::map::lazy::{ToSeries, call_lambda_with_series};
20use crate::prelude::ObjectValue;
21use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
22
23fn python_function_caller_series(
24    s: Column,
25    output_dtype: Option<DataType>,
26    lambda: &PyObject,
27) -> PolarsResult<Column> {
28    Python::with_gil(|py| {
29        let object =
30            call_lambda_with_series(py, s.as_materialized_series(), Some(output_dtype), lambda)?;
31        object.to_series(py, polars(py), s.name()).map(Column::from)
32    })
33}
34
35fn python_function_caller_df(df: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame> {
36    Python::with_gil(|py| {
37        let pypolars = polars(py).bind(py);
38
39        // create a PySeries struct/object for Python
40        let mut pydf = PyDataFrame::new(df);
41        // Wrap this PySeries object in the python side Series wrapper
42        let mut python_df_wrapper = pypolars
43            .getattr("wrap_df")
44            .unwrap()
45            .call1((pydf.clone(),))
46            .unwrap();
47
48        if !python_df_wrapper
49            .getattr("_df")
50            .unwrap()
51            .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
52            .unwrap()
53        {
54            let pldf = pl_df(py).bind(py);
55            let width = pydf.width();
56            // Don't resize the Vec to avoid calling SeriesExport's Drop impl
57            // The import takes ownership and is responsible for dropping
58            let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
59            unsafe {
60                pydf._export_columns(columns.as_mut_ptr() as usize);
61            }
62            // Wrap this PyDataFrame object in the python side DataFrame wrapper
63            python_df_wrapper = pldf
64                .getattr("_import_columns")
65                .unwrap()
66                .call1((columns.as_mut_ptr() as usize, width))
67                .unwrap();
68        }
69        // call the lambda and get a python side df wrapper
70        let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
71
72        // unpack the wrapper in a PyDataFrame
73        let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
74            let pytype = result_df_wrapper.bind(py).get_type();
75            PolarsError::ComputeError(
76                format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
77                    .into(),
78            )
79        })?;
80        // Downcast to Rust
81        match py_pydf.extract::<PyDataFrame>(py) {
82            Ok(pydf) => Ok(pydf.df),
83            Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
84        }
85    })
86}
87
88fn warning_function(msg: &str, warning: PolarsWarning) {
89    Python::with_gil(|py| {
90        let warn_fn = pl_utils(py)
91            .bind(py)
92            .getattr(intern!(py, "_polars_warn"))
93            .unwrap();
94
95        if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
96            eprintln!("{e}")
97        }
98    });
99}
100
101static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
102
103/// # Safety
104/// Caller must ensure that no other threads read the objects set by this registration.
105pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
106    // TODO: should we throw an error if we try to initialize while already initialized?
107    POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
108        set_polars_allow_extension(true);
109
110        // Stack frames can get really large in debug mode.
111        #[cfg(debug_assertions)]
112        {
113            recursive::set_minimum_stack_size(1024 * 1024);
114            recursive::set_stack_allocation_size(1024 * 1024 * 16);
115        }
116
117        // Register object type builder.
118        let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
119            Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
120                as Box<dyn AnonymousObjectBuilder>
121        });
122
123        let object_converter = Arc::new(|av: AnyValue| {
124            let object = Python::with_gil(|py| ObjectValue {
125                inner: Wrap(av).into_py_any(py).unwrap(),
126            });
127            Box::new(object) as Box<dyn Any>
128        });
129        let pyobject_converter = Arc::new(|av: AnyValue| {
130            let object = Python::with_gil(|py| Wrap(av).into_py_any(py).unwrap());
131            Box::new(object) as Box<dyn Any>
132        });
133
134        polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
135            from_py: FromPythonConvertRegistry {
136                sink_target: Arc::new(|py_f| {
137                    Python::with_gil(|py| {
138                        Ok(
139                            Box::new(py_f.extract::<Wrap<polars_plan::dsl::SinkTarget>>(py)?.0)
140                                as _,
141                        )
142                    })
143                }),
144            },
145            to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
146                df: Arc::new(|df| {
147                    Python::with_gil(|py| PyDataFrame::new(*df.downcast().unwrap()).into_py_any(py))
148                }),
149            },
150        });
151
152        let object_size = size_of::<ObjectValue>();
153        let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
154        registry::register_object_builder(
155            object_builder,
156            object_converter,
157            pyobject_converter,
158            physical_dtype,
159        );
160        // Register SERIES UDF.
161        python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
162        // Register DATAFRAME UDF.
163        python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
164        // Register warning function for `polars_warn!`.
165        polars_error::set_warning_function(warning_function);
166
167        if catch_keyboard_interrupt {
168            register_polars_keyboard_interrupt_hook();
169        }
170    });
171}