polars_python/
on_startup.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use polars::prelude::*;
6use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
7use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
8use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
9use polars_error::PolarsWarning;
10use polars_error::signals::register_polars_keyboard_interrupt_hook;
11use polars_ffi::version_0::SeriesExport;
12use polars_plan::plans::python_df_to_rust;
13use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
14use pyo3::prelude::*;
15use pyo3::{IntoPyObjectExt, intern};
16
17use crate::Wrap;
18use crate::dataframe::PyDataFrame;
19use crate::lazyframe::PyLazyFrame;
20use crate::map::lazy::call_lambda_with_series;
21use crate::prelude::ObjectValue;
22use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
23use crate::series::PySeries;
24
25fn python_function_caller_series(
26 s: &[Column],
27 output_dtype: Option<DataType>,
28 lambda: &PyObject,
29) -> PolarsResult<Column> {
30 Python::with_gil(|py| call_lambda_with_series(py, s, output_dtype, lambda))
31}
32
33fn python_function_caller_df(df: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame> {
34 Python::with_gil(|py| {
35 let pypolars = polars(py).bind(py);
36
37 let pydf = PyDataFrame::new(df);
39 let mut python_df_wrapper = pypolars
41 .getattr("wrap_df")
42 .unwrap()
43 .call1((pydf.clone(),))
44 .unwrap();
45
46 if !python_df_wrapper
47 .getattr("_df")
48 .unwrap()
49 .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
50 .unwrap()
51 {
52 let pldf = pl_df(py).bind(py);
53 let width = pydf.width();
54 let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
57 unsafe {
58 pydf._export_columns(columns.as_mut_ptr() as usize);
59 }
60 python_df_wrapper = pldf
62 .getattr("_import_columns")
63 .unwrap()
64 .call1((columns.as_mut_ptr() as usize, width))
65 .unwrap();
66 }
67 let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
69
70 let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
72 let pytype = result_df_wrapper.bind(py).get_type();
73 PolarsError::ComputeError(
74 format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
75 .into(),
76 )
77 })?;
78 match py_pydf.extract::<PyDataFrame>(py) {
80 Ok(pydf) => Ok(pydf.df.into_inner()),
81 Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
82 }
83 })
84}
85
86fn warning_function(msg: &str, warning: PolarsWarning) {
87 Python::with_gil(|py| {
88 let warn_fn = pl_utils(py)
89 .bind(py)
90 .getattr(intern!(py, "_polars_warn"))
91 .unwrap();
92
93 if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
94 eprintln!("{e}")
95 }
96 });
97}
98
99static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
100
101pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
104 POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
106 set_polars_allow_extension(true);
107
108 #[cfg(debug_assertions)]
110 {
111 recursive::set_minimum_stack_size(1024 * 1024);
112 recursive::set_stack_allocation_size(1024 * 1024 * 16);
113 }
114
115 let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
117 Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
118 as Box<dyn AnonymousObjectBuilder>
119 });
120
121 let object_converter = Arc::new(|av: AnyValue| {
122 let object = Python::with_gil(|py| ObjectValue {
123 inner: Wrap(av).into_py_any(py).unwrap(),
124 });
125 Box::new(object) as Box<dyn Any>
126 });
127 let pyobject_converter = Arc::new(|av: AnyValue| {
128 let object = Python::with_gil(|py| Wrap(av).into_py_any(py).unwrap());
129 Box::new(object) as Box<dyn Any>
130 });
131
132 polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
133 from_py: FromPythonConvertRegistry {
134 partition_target_cb_result: Arc::new(|py_f| {
135 Python::with_gil(|py| {
136 Ok(Box::new(
137 py_f.extract::<Wrap<polars_plan::dsl::PartitionTargetCallbackResult>>(
138 py,
139 )?
140 .0,
141 ) as _)
142 })
143 }),
144 series: Arc::new(|py_f| {
145 Python::with_gil(|py| {
146 Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
147 })
148 }),
149 df: Arc::new(|py_f| {
150 Python::with_gil(|py| {
151 Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
152 })
153 }),
154 dsl_plan: Arc::new(|py_f| {
155 Python::with_gil(|py| {
156 Ok(Box::new(
157 py_f.extract::<PyLazyFrame>(py)?
158 .ldf
159 .into_inner()
160 .logical_plan,
161 ) as _)
162 })
163 }),
164 schema: Arc::new(|py_f| {
165 Python::with_gil(|py| {
166 Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
167 })
168 }),
169 },
170 to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
171 df: Arc::new(|df| {
172 Python::with_gil(|py| PyDataFrame::new(*df.downcast().unwrap()).into_py_any(py))
173 }),
174 series: Arc::new(|series| {
175 Python::with_gil(|py| {
176 PySeries::new(*series.downcast().unwrap()).into_py_any(py)
177 })
178 }),
179 dsl_plan: Arc::new(|dsl_plan| {
180 Python::with_gil(|py| {
181 PyLazyFrame::from(LazyFrame::from(
182 *dsl_plan.downcast::<polars_plan::dsl::DslPlan>().unwrap(),
183 ))
184 .into_py_any(py)
185 })
186 }),
187 schema: Arc::new(|schema| {
188 Python::with_gil(|py| {
189 Wrap(*schema.downcast::<polars_core::schema::Schema>().unwrap())
190 .into_py_any(py)
191 })
192 }),
193 },
194 });
195
196 let object_size = size_of::<ObjectValue>();
197 let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
198 registry::register_object_builder(
199 object_builder,
200 object_converter,
201 pyobject_converter,
202 physical_dtype,
203 );
204
205 use crate::dataset::dataset_provider_funcs;
206
207 polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
208 name: dataset_provider_funcs::name,
209 schema: dataset_provider_funcs::schema,
210 to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
211 });
212
213 python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
215 python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
217 polars_error::set_warning_function(warning_function);
219
220 if catch_keyboard_interrupt {
221 register_polars_keyboard_interrupt_hook();
222 }
223 });
224}