polars_python/
on_startup.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use arrow::array::Array;
6use polars::chunked_array::object::ObjectArray;
7use polars::prelude::file_provider::FileProviderReturn;
8use polars::prelude::*;
9use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
10use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
11use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
12use polars_error::PolarsWarning;
13use polars_error::signals::register_polars_keyboard_interrupt_hook;
14use polars_ffi::version_0::SeriesExport;
15use polars_plan::plans::python_df_to_rust;
16use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
17use pyo3::prelude::*;
18use pyo3::{IntoPyObjectExt, intern};
19
20use crate::Wrap;
21use crate::dataframe::PyDataFrame;
22use crate::lazyframe::PyLazyFrame;
23use crate::map::lazy::call_lambda_with_series;
24use crate::prelude::ObjectValue;
25use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
26use crate::series::PySeries;
27
28fn python_function_caller_series(
29 s: &[Column],
30 output_dtype: Option<DataType>,
31 lambda: &Py<PyAny>,
32) -> PolarsResult<Column> {
33 Python::attach(|py| call_lambda_with_series(py, s, output_dtype, lambda))
34}
35
36fn python_function_caller_df(df: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame> {
37 Python::attach(|py| {
38 let pypolars = polars(py).bind(py);
39
40 let pydf = PyDataFrame::new(df);
42 let mut python_df_wrapper = pypolars
44 .getattr("wrap_df")
45 .unwrap()
46 .call1((pydf.clone(),))
47 .unwrap();
48
49 if !python_df_wrapper
50 .getattr("_df")
51 .unwrap()
52 .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
53 .unwrap()
54 {
55 let pldf = pl_df(py).bind(py);
56 let width = pydf.width();
57 let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
60 unsafe {
61 pydf._export_columns(columns.as_mut_ptr() as usize);
62 }
63 python_df_wrapper = pldf
65 .getattr("_import_columns")
66 .unwrap()
67 .call1((columns.as_mut_ptr() as usize, width))
68 .unwrap();
69 }
70 let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
72
73 let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
75 let pytype = result_df_wrapper.bind(py).get_type();
76 PolarsError::ComputeError(
77 format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
78 .into(),
79 )
80 })?;
81 match py_pydf.extract::<PyDataFrame>(py) {
83 Ok(pydf) => Ok(pydf.df.into_inner()),
84 Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
85 }
86 })
87}
88
89fn warning_function(msg: &str, warning: PolarsWarning) {
90 Python::attach(|py| {
91 let warn_fn = pl_utils(py)
92 .bind(py)
93 .getattr(intern!(py, "_polars_warn"))
94 .unwrap();
95
96 if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
97 eprintln!("{e}")
98 }
99 });
100}
101
102static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
103
104pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
107 POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
109 set_polars_allow_extension(true);
110
111 #[cfg(debug_assertions)]
113 {
114 recursive::set_minimum_stack_size(1024 * 1024);
115 recursive::set_stack_allocation_size(1024 * 1024 * 16);
116 }
117
118 let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
120 Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
121 as Box<dyn AnonymousObjectBuilder>
122 });
123
124 let object_converter = Arc::new(|av: AnyValue| {
125 let object = Python::attach(|py| ObjectValue {
126 inner: Wrap(av).into_py_any(py).unwrap(),
127 });
128 Box::new(object) as Box<dyn Any>
129 });
130 let pyobject_converter = Arc::new(|av: AnyValue| {
131 let object = Python::attach(|py| Wrap(av).into_py_any(py).unwrap());
132 Box::new(object) as Box<dyn Any>
133 });
134 fn object_array_getter(arr: &dyn Array, idx: usize) -> Option<AnyValue<'_>> {
135 let arr = arr.as_any().downcast_ref::<ObjectArray<ObjectValue>>().unwrap();
136 arr.get(idx).map(|v| AnyValue::Object(v))
137 }
138
139 polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
140 from_py: FromPythonConvertRegistry {
141 file_provider_result: Arc::new(|py_f| {
142 Python::attach(|py| {
143 Ok(Box::new(py_f.extract::<Wrap<FileProviderReturn>>(py)?.0) as _)
144 })
145 }),
146 series: Arc::new(|py_f| {
147 Python::attach(|py| {
148 Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
149 })
150 }),
151 df: Arc::new(|py_f| {
152 Python::attach(|py| {
153 Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
154 })
155 }),
156 dsl_plan: Arc::new(|py_f| {
157 Python::attach(|py| {
158 Ok(Box::new(
159 py_f.extract::<PyLazyFrame>(py)?
160 .ldf
161 .into_inner()
162 .logical_plan,
163 ) as _)
164 })
165 }),
166 schema: Arc::new(|py_f| {
167 Python::attach(|py| {
168 Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
169 })
170 }),
171 },
172 to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
173 df: Arc::new(|df| {
174 Python::attach(|py| {
175 PyDataFrame::new(df.downcast_ref::<DataFrame>().unwrap().clone())
176 .into_py_any(py)
177 })
178 }),
179 series: Arc::new(|series| {
180 Python::attach(|py| {
181 PySeries::new(series.downcast_ref::<Series>().unwrap().clone())
182 .into_py_any(py)
183 })
184 }),
185 dsl_plan: Arc::new(|dsl_plan| {
186 Python::attach(|py| {
187 PyLazyFrame::from(LazyFrame::from(
188 dsl_plan
189 .downcast_ref::<polars_plan::dsl::DslPlan>()
190 .unwrap()
191 .clone(),
192 ))
193 .into_py_any(py)
194 })
195 }),
196 schema: Arc::new(|schema| {
197 Python::attach(|py| {
198 Wrap(
199 schema
200 .downcast_ref::<polars_core::schema::Schema>()
201 .unwrap()
202 .clone(),
203 )
204 .into_py_any(py)
205 })
206 }),
207 },
208 });
209
210 let object_size = size_of::<ObjectValue>();
211 let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
212 registry::register_object_builder(
213 object_builder,
214 object_converter,
215 pyobject_converter,
216 physical_dtype,
217 Arc::new(object_array_getter)
218 );
219
220 use crate::dataset::dataset_provider_funcs;
221
222 polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
223 name: dataset_provider_funcs::name,
224 schema: dataset_provider_funcs::schema,
225 to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
226 });
227
228 python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
230 python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
232 polars_error::set_warning_function(warning_function);
234
235 if catch_keyboard_interrupt {
236 register_polars_keyboard_interrupt_hook();
237 }
238
239 use polars_core::datatypes::extension::UnknownExtensionTypeBehavior;
240 let behavior = match std::env::var("POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR").as_deref() {
241 Ok("load_as_storage") => UnknownExtensionTypeBehavior::LoadAsStorage,
242 Ok("load_as_extension") => UnknownExtensionTypeBehavior::LoadAsGeneric,
243 Ok("") | Err(_) => UnknownExtensionTypeBehavior::WarnAndLoadAsStorage,
244 _ => {
245 polars_warn!("Invalid value for 'POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR' environment variable. Expected one of 'load_as_storage' or 'load_as_extension'.");
246 UnknownExtensionTypeBehavior::WarnAndLoadAsStorage
247 },
248 };
249 polars_core::datatypes::extension::set_unknown_extension_type_behavior(behavior);
250 });
251}