polars_python/
on_startup.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use polars::prelude::*;
6use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
7use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
8use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
9use polars_error::PolarsWarning;
10use polars_error::signals::register_polars_keyboard_interrupt_hook;
11use polars_ffi::version_0::SeriesExport;
12use polars_plan::plans::python_df_to_rust;
13use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
14use pyo3::prelude::*;
15use pyo3::{IntoPyObjectExt, intern};
16
17use crate::Wrap;
18use crate::dataframe::PyDataFrame;
19use crate::lazyframe::PyLazyFrame;
20use crate::map::lazy::call_lambda_with_series;
21use crate::prelude::ObjectValue;
22use crate::py_modules::{pl_df, pl_utils, polars, polars_rs};
23use crate::series::PySeries;
24
25fn python_function_caller_series(
26    s: &[Column],
27    output_dtype: Option<DataType>,
28    lambda: &PyObject,
29) -> PolarsResult<Column> {
30    Python::with_gil(|py| call_lambda_with_series(py, s, output_dtype, lambda))
31}
32
33fn python_function_caller_df(df: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame> {
34    Python::with_gil(|py| {
35        let pypolars = polars(py).bind(py);
36
37        // create a PySeries struct/object for Python
38        let pydf = PyDataFrame::new(df);
39        // Wrap this PySeries object in the python side Series wrapper
40        let mut python_df_wrapper = pypolars
41            .getattr("wrap_df")
42            .unwrap()
43            .call1((pydf.clone(),))
44            .unwrap();
45
46        if !python_df_wrapper
47            .getattr("_df")
48            .unwrap()
49            .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
50            .unwrap()
51        {
52            let pldf = pl_df(py).bind(py);
53            let width = pydf.width();
54            // Don't resize the Vec to avoid calling SeriesExport's Drop impl
55            // The import takes ownership and is responsible for dropping
56            let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
57            unsafe {
58                pydf._export_columns(columns.as_mut_ptr() as usize);
59            }
60            // Wrap this PyDataFrame object in the python side DataFrame wrapper
61            python_df_wrapper = pldf
62                .getattr("_import_columns")
63                .unwrap()
64                .call1((columns.as_mut_ptr() as usize, width))
65                .unwrap();
66        }
67        // call the lambda and get a python side df wrapper
68        let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
69
70        // unpack the wrapper in a PyDataFrame
71        let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
72            let pytype = result_df_wrapper.bind(py).get_type();
73            PolarsError::ComputeError(
74                format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
75                    .into(),
76            )
77        })?;
78        // Downcast to Rust
79        match py_pydf.extract::<PyDataFrame>(py) {
80            Ok(pydf) => Ok(pydf.df.into_inner()),
81            Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
82        }
83    })
84}
85
86fn warning_function(msg: &str, warning: PolarsWarning) {
87    Python::with_gil(|py| {
88        let warn_fn = pl_utils(py)
89            .bind(py)
90            .getattr(intern!(py, "_polars_warn"))
91            .unwrap();
92
93        if let Err(e) = warn_fn.call1((msg, Wrap(warning).into_pyobject(py).unwrap())) {
94            eprintln!("{e}")
95        }
96    });
97}
98
99static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
100
101/// # Safety
102/// Caller must ensure that no other threads read the objects set by this registration.
103pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool) {
104    // TODO: should we throw an error if we try to initialize while already initialized?
105    POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
106        set_polars_allow_extension(true);
107
108        // Stack frames can get really large in debug mode.
109        #[cfg(debug_assertions)]
110        {
111            recursive::set_minimum_stack_size(1024 * 1024);
112            recursive::set_stack_allocation_size(1024 * 1024 * 16);
113        }
114
115        // Register object type builder.
116        let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
117            Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
118                as Box<dyn AnonymousObjectBuilder>
119        });
120
121        let object_converter = Arc::new(|av: AnyValue| {
122            let object = Python::with_gil(|py| ObjectValue {
123                inner: Wrap(av).into_py_any(py).unwrap(),
124            });
125            Box::new(object) as Box<dyn Any>
126        });
127        let pyobject_converter = Arc::new(|av: AnyValue| {
128            let object = Python::with_gil(|py| Wrap(av).into_py_any(py).unwrap());
129            Box::new(object) as Box<dyn Any>
130        });
131
132        polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
133            from_py: FromPythonConvertRegistry {
134                partition_target_cb_result: Arc::new(|py_f| {
135                    Python::with_gil(|py| {
136                        Ok(Box::new(
137                            py_f.extract::<Wrap<polars_plan::dsl::PartitionTargetCallbackResult>>(
138                                py,
139                            )?
140                            .0,
141                        ) as _)
142                    })
143                }),
144                series: Arc::new(|py_f| {
145                    Python::with_gil(|py| {
146                        Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
147                    })
148                }),
149                df: Arc::new(|py_f| {
150                    Python::with_gil(|py| {
151                        Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
152                    })
153                }),
154                dsl_plan: Arc::new(|py_f| {
155                    Python::with_gil(|py| {
156                        Ok(Box::new(
157                            py_f.extract::<PyLazyFrame>(py)?
158                                .ldf
159                                .into_inner()
160                                .logical_plan,
161                        ) as _)
162                    })
163                }),
164                schema: Arc::new(|py_f| {
165                    Python::with_gil(|py| {
166                        Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
167                    })
168                }),
169            },
170            to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
171                df: Arc::new(|df| {
172                    Python::with_gil(|py| PyDataFrame::new(*df.downcast().unwrap()).into_py_any(py))
173                }),
174                series: Arc::new(|series| {
175                    Python::with_gil(|py| {
176                        PySeries::new(*series.downcast().unwrap()).into_py_any(py)
177                    })
178                }),
179                dsl_plan: Arc::new(|dsl_plan| {
180                    Python::with_gil(|py| {
181                        PyLazyFrame::from(LazyFrame::from(
182                            *dsl_plan.downcast::<polars_plan::dsl::DslPlan>().unwrap(),
183                        ))
184                        .into_py_any(py)
185                    })
186                }),
187                schema: Arc::new(|schema| {
188                    Python::with_gil(|py| {
189                        Wrap(*schema.downcast::<polars_core::schema::Schema>().unwrap())
190                            .into_py_any(py)
191                    })
192                }),
193            },
194        });
195
196        let object_size = size_of::<ObjectValue>();
197        let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
198        registry::register_object_builder(
199            object_builder,
200            object_converter,
201            pyobject_converter,
202            physical_dtype,
203        );
204
205        use crate::dataset::dataset_provider_funcs;
206
207        polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
208            name: dataset_provider_funcs::name,
209            schema: dataset_provider_funcs::schema,
210            to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
211        });
212
213        // Register SERIES UDF.
214        python_dsl::CALL_COLUMNS_UDF_PYTHON = Some(python_function_caller_series);
215        // Register DATAFRAME UDF.
216        python_dsl::CALL_DF_UDF_PYTHON = Some(python_function_caller_df);
217        // Register warning function for `polars_warn!`.
218        polars_error::set_warning_function(warning_function);
219
220        if catch_keyboard_interrupt {
221            register_polars_keyboard_interrupt_hook();
222        }
223    });
224}