polars_python/
on_startup.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use polars::prelude::*;
6use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
7use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
8use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
9use polars_error::PolarsWarning;
10use polars_error::signals::register_polars_keyboard_interrupt_hook;
11use polars_ffi::version_0::SeriesExport;
12use polars_plan::plans::python_df_to_rust;
13use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
14use pyo3::prelude::*;
15use pyo3::{IntoPyObjectExt, intern};
16
17use crate::Wrap;
18use crate::dataframe::PyDataFrame;
19use crate::map::lazy::{ToSeries, call_lambda_with_series};
20use crate::prelude::ObjectValue;
21use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
22
23fn python_function_caller_series(s: Column, lambda: &PyObject) -> PolarsResult<Column> {
24 Python::with_gil(|py| {
25 let object = call_lambda_with_series(py, s.as_materialized_series(), lambda)?;
26 object.to_series(py, polars(py), s.name()).map(Column::from)
27 })
28}
29
30fn python_function_caller_df(df: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame> {
31 Python::with_gil(|py| {
32 let pypolars = polars(py).bind(py);
33
34 let mut pydf = PyDataFrame::new(df);
36 let mut python_df_wrapper = pypolars
38 .getattr("wrap_df")
39 .unwrap()
40 .call1((pydf.clone(),))
41 .unwrap();
42
43 if !python_df_wrapper
44 .getattr("_df")
45 .unwrap()
46 .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
47 .unwrap()
48 {
49 let pldf = pl_df(py).bind(py);
50 let width = pydf.width();
51 let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
54 unsafe {
55 pydf._export_columns(columns.as_mut_ptr() as usize);
56 }
57 python_df_wrapper = pldf
59 .getattr("_import_columns")
60 .unwrap()
61 .call1((columns.as_mut_ptr() as usize, width))
62 .unwrap();
63 }
64 let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
66
67 let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
69 let pytype = result_df_wrapper.bind(py).get_type();
70 PolarsError::ComputeError(
71 format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
72 .into(),
73 )
74 })?;
75 match py_pydf.extract::<PyDataFrame>(py) {
77 Ok(pydf) => Ok(pydf.df),
78 Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
79 }
80 })
81}
82
83fn warning_function(msg: &str, warning: PolarsWarning) {
84 Python::with_gil(|py| {
85 let warn_fn = pl_utils(py)
86 .bind(py)
87 .getattr(intern!(py, "_polars_warn"))
88 .unwrap();
89
90 if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
91 eprintln!("{e}")
92 }
93 });
94}
95
96static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
97
98pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
101 POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
103 set_polars_allow_extension(true);
104
105 #[cfg(debug_assertions)]
107 {
108 recursive::set_minimum_stack_size(1024 * 1024);
109 recursive::set_stack_allocation_size(1024 * 1024 * 16);
110 }
111
112 let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
114 Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
115 as Box<dyn AnonymousObjectBuilder>
116 });
117
118 let object_converter = Arc::new(|av: AnyValue| {
119 let object = Python::with_gil(|py| ObjectValue {
120 inner: Wrap(av).into_py_any(py).unwrap(),
121 });
122 Box::new(object) as Box<dyn Any>
123 });
124 let pyobject_converter = Arc::new(|av: AnyValue| {
125 let object = Python::with_gil(|py| Wrap(av).into_py_any(py).unwrap());
126 Box::new(object) as Box<dyn Any>
127 });
128
129 polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
130 from_py: FromPythonConvertRegistry {
131 sink_target: Arc::new(|py_f| {
132 Python::with_gil(|py| {
133 Ok(
134 Box::new(py_f.extract::<Wrap<polars_plan::dsl::SinkTarget>>(py)?.0)
135 as _,
136 )
137 })
138 }),
139 },
140 });
141
142 let object_size = size_of::<ObjectValue>();
143 let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
144 registry::register_object_builder(
145 object_builder,
146 object_converter,
147 pyobject_converter,
148 physical_dtype,
149 );
150 python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
152 python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
154 polars_error::set_warning_function(warning_function);
156
157 if catch_keyboard_interrupt {
158 register_polars_keyboard_interrupt_hook();
159 }
160 });
161}