Skip to main content

polars_python/
on_startup.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::any::Any;
3use std::sync::OnceLock;
4
5use arrow::array::Array;
6use polars::chunked_array::object::ObjectArray;
7use polars::prelude::file_provider::FileProviderReturn;
8use polars::prelude::*;
9use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
10use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
11use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
12use polars_error::PolarsWarning;
13use polars_error::signals::register_polars_keyboard_interrupt_hook;
14use polars_ffi::version_0::SeriesExport;
15use polars_plan::plans::python_df_to_rust;
16use polars_utils::python_convert_registry::{FromPythonConvertRegistry, PythonConvertRegistry};
17use pyo3::IntoPyObjectExt;
18use pyo3::prelude::*;
19
20use crate::Wrap;
21use crate::dataframe::PyDataFrame;
22use crate::lazyframe::PyLazyFrame;
23use crate::map::lazy::call_lambda_with_series;
24use crate::prelude::ObjectValue;
25use crate::py_modules::{pl_df, polars, polars_rs};
26use crate::series::PySeries;
27
28fn python_function_caller_series(
29    s: &[Column],
30    output_dtype: Option<DataType>,
31    lambda: &Py<PyAny>,
32) -> PolarsResult<Column> {
33    Python::attach(|py| call_lambda_with_series(py, s, output_dtype, lambda))
34}
35
36fn python_function_caller_df(df: DataFrame, lambda: &Py<PyAny>) -> PolarsResult<DataFrame> {
37    Python::attach(|py| {
38        let pypolars = polars(py).bind(py);
39
40        // create a PySeries struct/object for Python
41        let pydf = PyDataFrame::new(df);
42        // Wrap this PySeries object in the python side Series wrapper
43        let mut python_df_wrapper = pypolars
44            .getattr("wrap_df")
45            .unwrap()
46            .call1((pydf.clone(),))
47            .unwrap();
48
49        if !python_df_wrapper
50            .getattr("_df")
51            .unwrap()
52            .is_instance(polars_rs(py).getattr(py, "PyDataFrame").unwrap().bind(py))
53            .unwrap()
54        {
55            let pldf = pl_df(py).bind(py);
56            let width = pydf.width();
57            // Don't resize the Vec to avoid calling SeriesExport's Drop impl
58            // The import takes ownership and is responsible for dropping
59            let mut columns: Vec<SeriesExport> = Vec::with_capacity(width);
60            unsafe {
61                pydf._export_columns(columns.as_mut_ptr() as usize);
62            }
63            // Wrap this PyDataFrame object in the python side DataFrame wrapper
64            python_df_wrapper = pldf
65                .getattr("_import_columns")
66                .unwrap()
67                .call1((columns.as_mut_ptr() as usize, width))
68                .unwrap();
69        }
70        // call the lambda and get a python side df wrapper
71        let result_df_wrapper = lambda.call1(py, (python_df_wrapper,))?;
72
73        // unpack the wrapper in a PyDataFrame
74        let py_pydf = result_df_wrapper.getattr(py, "_df").map_err(|_| {
75            let pytype = result_df_wrapper.bind(py).get_type();
76            PolarsError::ComputeError(
77                format!("Expected 'LazyFrame.map' to return a 'DataFrame', got a '{pytype}'",)
78                    .into(),
79            )
80        })?;
81        // Downcast to Rust
82        match py_pydf.extract::<PyDataFrame>(py) {
83            Ok(pydf) => Ok(pydf.df.into_inner()),
84            Err(_) => python_df_to_rust(py, result_df_wrapper.into_bound(py)),
85        }
86    })
87}
88
89fn warning_function(msg: &str, warning: PolarsWarning) {
90    Python::attach(|py| {
91        let Some(warn_fn) = WARN_FUNCTION.get() else {
92            eprintln!("{msg}");
93            return;
94        };
95
96        if let Err(e) = warn_fn
97            .bind(py)
98            .call1((msg, Wrap(warning).into_pyobject(py).unwrap()))
99        {
100            eprintln!("{e}")
101        }
102    });
103}
104
105static POLARS_REGISTRY_INIT_LOCK: OnceLock<()> = OnceLock::new();
106static WARN_FUNCTION: OnceLock<Py<PyAny>> = OnceLock::new();
107
108/// # Safety
109/// Caller must ensure that no other threads read the objects set by this registration.
110pub unsafe fn register_startup_deps(catch_keyboard_interrupt: bool, warn_function: Py<PyAny>) {
111    // TODO: should we throw an error if we try to initialize while already initialized?
112    POLARS_REGISTRY_INIT_LOCK.get_or_init(|| {
113        WARN_FUNCTION.set(warn_function).unwrap();
114        set_polars_allow_extension(true);
115
116        // Stack frames can get really large in debug mode.
117        #[cfg(debug_assertions)]
118        {
119            recursive::set_minimum_stack_size(1024 * 1024);
120            recursive::set_stack_allocation_size(1024 * 1024 * 16);
121        }
122
123        #[cfg(feature = "backtrace_filter")]
124        {
125            use std::path::Path;
126            use color_backtrace::{BacktracePrinter, default_output_stream, default_is_dependency_frame, Frame, ColorScheme};
127            use color_backtrace::termcolor::{ColorSpec, Color};
128
129            let polars_base_path = || {
130                let on_startup = Path::new(file!()).canonicalize().ok()?;
131                let src = on_startup.parent()?;
132                let polars_python = src.parent()?;
133                let crates = polars_python.parent()?;
134                let root = crates.parent()?;
135                Some(root.to_path_buf())
136            };
137
138            let mut btp = BacktracePrinter::default();
139            if let Some(bp) = polars_base_path() {
140                btp = btp.dependency_predicate(Box::new(move |frame: &Frame| -> bool {
141                    if let Some(file) = frame.filename.as_ref().and_then(|f| f.canonicalize().ok()) {
142                        !file.starts_with(&bp)
143                    } else {
144                        default_is_dependency_frame(frame)
145                    }
146                }));
147            }
148
149            let mut color_scheme = ColorScheme::classic();
150            color_scheme.dependency_code = ColorSpec::new();
151            color_scheme.dependency_code.set_dimmed(true);
152            color_scheme.dependency_code = color_scheme.dependency_code_hash.clone();
153            color_scheme.crate_code = ColorSpec::new();
154            color_scheme.crate_code.set_fg(Some(Color::Blue));
155            color_scheme.crate_code_hash = color_scheme.crate_code.clone();
156
157            btp
158                .color_scheme(color_scheme)
159                .install(default_output_stream());
160        }
161
162        // Register object type builder.
163        let object_builder = Box::new(|name: PlSmallStr, capacity: usize| {
164            Box::new(ObjectChunkedBuilder::<ObjectValue>::new(name, capacity))
165                as Box<dyn AnonymousObjectBuilder>
166        });
167
168        let object_converter = Arc::new(|av: AnyValue| {
169            let object = Python::attach(|py| ObjectValue {
170                inner: Wrap(av).into_py_any(py).unwrap(),
171            });
172            Box::new(object) as Box<dyn Any>
173        });
174        let pyobject_converter = Arc::new(|av: AnyValue| {
175            let object = Python::attach(|py| Wrap(av).into_py_any(py).unwrap());
176            Box::new(object) as Box<dyn Any>
177        });
178        fn object_array_getter(arr: &dyn Array, idx: usize) -> Option<AnyValue<'_>> {
179            let arr = arr.as_any().downcast_ref::<ObjectArray<ObjectValue>>().unwrap();
180            arr.get(idx).map(|v| AnyValue::Object(v))
181        }
182        fn with_gil(f: &mut dyn FnMut()) {
183            Python::attach(|_| f())
184        }
185
186        polars_utils::python_convert_registry::register_converters(PythonConvertRegistry {
187            from_py: FromPythonConvertRegistry {
188                file_provider_result: Arc::new(|py_f| {
189                    Python::attach(|py| {
190                        Ok(Box::new(py_f.extract::<Wrap<FileProviderReturn>>(py)?.0) as _)
191                    })
192                }),
193                series: Arc::new(|py_f| {
194                    Python::attach(|py| {
195                        Ok(Box::new(py_f.extract::<PySeries>(py)?.series.into_inner()) as _)
196                    })
197                }),
198                df: Arc::new(|py_f| {
199                    Python::attach(|py| {
200                        Ok(Box::new(py_f.extract::<PyDataFrame>(py)?.df.into_inner()) as _)
201                    })
202                }),
203                dsl_plan: Arc::new(|py_f| {
204                    Python::attach(|py| {
205                        Ok(Box::new(
206                            py_f.extract::<PyLazyFrame>(py)?
207                                .ldf
208                                .into_inner()
209                                .logical_plan,
210                        ) as _)
211                    })
212                }),
213                schema: Arc::new(|py_f| {
214                    Python::attach(|py| {
215                        Ok(Box::new(py_f.extract::<Wrap<polars_core::schema::Schema>>(py)?.0) as _)
216                    })
217                }),
218            },
219            to_py: polars_utils::python_convert_registry::ToPythonConvertRegistry {
220                df: Arc::new(|df| {
221                    Python::attach(|py| {
222                        PyDataFrame::new(df.downcast_ref::<DataFrame>().unwrap().clone())
223                            .into_py_any(py)
224                    })
225                }),
226                series: Arc::new(|series| {
227                    Python::attach(|py| {
228                        PySeries::new(series.downcast_ref::<Series>().unwrap().clone())
229                            .into_py_any(py)
230                    })
231                }),
232                dsl_plan: Arc::new(|dsl_plan| {
233                    Python::attach(|py| {
234                        PyLazyFrame::from(LazyFrame::from(
235                            dsl_plan
236                                .downcast_ref::<polars_plan::dsl::DslPlan>()
237                                .unwrap()
238                                .clone(),
239                        ))
240                        .into_py_any(py)
241                    })
242                }),
243                schema: Arc::new(|schema| {
244                    Python::attach(|py| {
245                        Wrap(
246                            schema
247                                .downcast_ref::<polars_core::schema::Schema>()
248                                .unwrap()
249                                .clone(),
250                        )
251                        .into_py_any(py)
252                    })
253                }),
254            },
255        });
256
257        let object_size = size_of::<ObjectValue>();
258        let physical_dtype = ArrowDataType::FixedSizeBinary(object_size);
259        registry::register_object_builder(
260            object_builder,
261            object_converter,
262            pyobject_converter,
263            physical_dtype,
264            Arc::new(object_array_getter),
265            Arc::new(with_gil)
266        );
267
268        use crate::dataset::dataset_provider_funcs;
269
270        polars_plan::dsl::DATASET_PROVIDER_VTABLE.get_or_init(|| PythonDatasetProviderVTable {
271            name: dataset_provider_funcs::name,
272            schema: dataset_provider_funcs::schema,
273            to_dataset_scan: dataset_provider_funcs::to_dataset_scan,
274        });
275
276        use crate::delta::dv_provider_funcs;
277
278        polars_plan::dsl::deletion::DELTA_DV_PROVIDER_VTABLE.get_or_init(|| {
279            polars_plan::dsl::deletion::DeltaDeletionVectorProviderVTable {
280                call: dv_provider_funcs::call,
281            }
282        });
283
284        // Register SERIES UDF.
285        python_dsl::CALL_PYTHON_COLUMNS_UDF.set(python_function_caller_series).unwrap();
286        // Register DATAFRAME UDF.
287        python_dsl::CALL_PYTHON_DF_UDF.set(python_function_caller_df).unwrap();
288        // Register warning function for `polars_warn!`.
289        polars_error::set_warning_function(warning_function);
290
291        if catch_keyboard_interrupt {
292            register_polars_keyboard_interrupt_hook();
293        }
294
295        use polars_core::datatypes::extension::UnknownExtensionTypeBehavior;
296        let behavior = match std::env::var("POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR").as_deref() {
297            Ok("load_as_storage") => UnknownExtensionTypeBehavior::LoadAsStorage,
298            Ok("load_as_extension") => UnknownExtensionTypeBehavior::LoadAsGeneric,
299            Ok("") | Err(_) => UnknownExtensionTypeBehavior::WarnAndLoadAsStorage,
300            _ => {
301                polars_warn!("Invalid value for 'POLARS_UNKNOWN_EXTENSION_TYPE_BEHAVIOR' environment variable. Expected one of 'load_as_storage' or 'load_as_extension'.");
302                UnknownExtensionTypeBehavior::WarnAndLoadAsStorage
303            },
304        };
305        polars_core::datatypes::extension::set_unknown_extension_type_behavior(behavior);
306    });
307}