Skip to main content

polars_python/
utils.rs

1use std::panic::AssertUnwindSafe;
2
3use polars::frame::DataFrame;
4use polars::series::IntoSeries;
5use polars_error::PolarsResult;
6use polars_error::signals::{KeyboardInterrupt, catch_keyboard_interrupt};
7use pyo3::exceptions::PyKeyboardInterrupt;
8use pyo3::marker::Ungil;
9use pyo3::types::PyAnyMethods;
10use pyo3::{PyErr, PyResult, Python};
11
12use crate::dataframe::PyDataFrame;
13use crate::error::PyPolarsErr;
14use crate::series::PySeries;
15use crate::timeout::{cancel_polars_timeout, is_timeout_enabled, schedule_polars_timeout};
16
17/// Calls method on downcasted ChunkedArray for all possible publicly exposed Polars dtypes.
18#[macro_export]
19macro_rules! apply_all_polars_dtypes {
20    ($self:expr, $method:ident, $($args:expr),*) => {
21        match $self.dtype() {
22            DataType::Boolean => $self.bool().unwrap().$method($($args),*),
23            DataType::UInt8 => $self.u8().unwrap().$method($($args),*),
24            DataType::UInt16 => $self.u16().unwrap().$method($($args),*),
25            DataType::UInt32 => $self.u32().unwrap().$method($($args),*),
26            DataType::UInt64 => $self.u64().unwrap().$method($($args),*),
27            DataType::UInt128 => $self.u128().unwrap().$method($($args),*),
28            DataType::Int8 => $self.i8().unwrap().$method($($args),*),
29            DataType::Int16 => $self.i16().unwrap().$method($($args),*),
30            DataType::Int32 => $self.i32().unwrap().$method($($args),*),
31            DataType::Int64 => $self.i64().unwrap().$method($($args),*),
32            DataType::Int128 => $self.i128().unwrap().$method($($args),*),
33            DataType::Float16 => $self.f16().unwrap().$method($($args),*),
34            DataType::Float32 => $self.f32().unwrap().$method($($args),*),
35            DataType::Float64 => $self.f64().unwrap().$method($($args),*),
36            DataType::String => $self.str().unwrap().$method($($args),*),
37            DataType::Binary => $self.binary().unwrap().$method($($args),*),
38            DataType::Decimal(_, _) => $self.decimal().unwrap().$method($($args),*),
39
40            DataType::Date => $self.date().unwrap().$method($($args),*),
41            DataType::Datetime(_, _) => $self.datetime().unwrap().$method($($args),*),
42            DataType::Duration(_) => $self.duration().unwrap().$method($($args),*),
43            DataType::Time => $self.time().unwrap().$method($($args),*),
44
45            DataType::List(_) => $self.list().unwrap().$method($($args),*),
46            DataType::Struct(_) => $self.struct_().unwrap().$method($($args),*),
47            DataType::Array(_, _) => $self.array().unwrap().$method($($args),*),
48
49            dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => match dt.cat_physical().unwrap() {
50                CategoricalPhysical::U8 => $self.cat8().unwrap().$method($($args),*),
51                CategoricalPhysical::U16 => $self.cat16().unwrap().$method($($args),*),
52                CategoricalPhysical::U32 => $self.cat32().unwrap().$method($($args),*),
53            },
54
55            #[cfg(feature = "object")]
56            DataType::Object(_) => {
57                $self
58                .as_any()
59                .downcast_ref::<ObjectChunked<ObjectValue>>()
60                .unwrap()
61                .$method($($args),*)
62            },
63            DataType::Extension(_, _) => $self.ext().unwrap().$method($($args),*),
64
65            DataType::Null => $self.null().unwrap().$method($($args),*),
66
67            dt @ (DataType::BinaryOffset | DataType::Unknown(_)) => panic!("dtype {:?} not supported", dt)
68        }
69    }
70}
71
72/// Boilerplate for `|e| PyPolarsErr::from(e).into()`
73#[allow(unused)]
74pub(crate) fn to_py_err<E: Into<PyPolarsErr>>(e: E) -> PyErr {
75    e.into().into()
76}
77
78pub trait EnterPolarsExt {
79    /// Whenever you have a block of code in the public Python API that
80    /// (potentially) takes a long time, wrap it in enter_polars. This will
81    /// ensure we release the GIL and catch KeyboardInterrupts.
82    ///
83    /// This not only can increase performance and usability, it can avoid
84    /// deadlocks on the GIL for Python UDFs.
85    fn enter_polars<T, E, F>(self, f: F) -> PyResult<T>
86    where
87        F: Ungil + Send + FnOnce() -> Result<T, E>,
88        T: Ungil + Send,
89        E: Ungil + Send + Into<PyPolarsErr>;
90
91    /// Same as enter_polars, but wraps the result in PyResult::Ok, useful
92    /// shorthand for infallible functions.
93    #[inline(always)]
94    fn enter_polars_ok<T, F>(self, f: F) -> PyResult<T>
95    where
96        Self: Sized,
97        F: Ungil + Send + FnOnce() -> T,
98        T: Ungil + Send,
99    {
100        self.enter_polars(move || PyResult::Ok(f()))
101    }
102
103    /// Same as enter_polars, but expects a PolarsResult<DataFrame> as return
104    /// which is converted to a PyDataFrame.
105    #[inline(always)]
106    fn enter_polars_df<F>(self, f: F) -> PyResult<PyDataFrame>
107    where
108        Self: Sized,
109        F: Ungil + Send + FnOnce() -> PolarsResult<DataFrame>,
110    {
111        self.enter_polars(f).map(PyDataFrame::new)
112    }
113
114    /// Same as enter_polars, but expects a PolarsResult<S> as return which
115    /// is converted to a PySeries through S: IntoSeries.
116    #[inline(always)]
117    fn enter_polars_series<T, F>(self, f: F) -> PyResult<PySeries>
118    where
119        Self: Sized,
120        T: Ungil + Send + IntoSeries,
121        F: Ungil + Send + FnOnce() -> PolarsResult<T>,
122    {
123        self.enter_polars(f).map(|s| PySeries::new(s.into_series()))
124    }
125}
126
127fn get_traceback(py: Python<'_>) -> PyResult<String> {
128    let tb = py.import(pyo3::intern!(py, "traceback"))?;
129    let format_stack = tb.getattr("format_stack")?;
130    let lines: Vec<String> = format_stack.call0()?.extract()?;
131    Ok(lines.join("\n"))
132}
133
134impl EnterPolarsExt for Python<'_> {
135    fn enter_polars<T, E, F>(self, f: F) -> PyResult<T>
136    where
137        F: Ungil + Send + FnOnce() -> Result<T, E>,
138        T: Ungil + Send,
139        E: Ungil + Send + Into<PyPolarsErr>,
140    {
141        let timeout = if is_timeout_enabled() {
142            std::hint::cold_path();
143            schedule_polars_timeout(get_traceback(self).ok())
144        } else {
145            None
146        };
147        let ret = self.detach(|| catch_keyboard_interrupt(AssertUnwindSafe(f)));
148        cancel_polars_timeout(timeout);
149        match ret {
150            Ok(Ok(ret)) => Ok(ret),
151            Ok(Err(err)) => Err(PyErr::from(err.into())),
152            Err(KeyboardInterrupt) => Err(PyKeyboardInterrupt::new_err("")),
153        }
154    }
155}