polars_python/interop/numpy/
utils.rs

1use std::ffi::{c_int, c_void};
2
3use ndarray::{Dim, Dimension};
4use numpy::npyffi::PyArrayObject;
5use numpy::{npyffi, Element, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, PY_ARRAY_API};
6use polars_core::prelude::*;
7use pyo3::intern;
8use pyo3::prelude::*;
9use pyo3::types::PyTuple;
10
11/// Create a NumPy ndarray view of the data.
12pub(super) unsafe fn create_borrowed_np_array<I>(
13    py: Python,
14    dtype: Bound<PyArrayDescr>,
15    mut shape: Dim<I>,
16    flags: c_int,
17    data: *mut c_void,
18    owner: PyObject,
19) -> PyObject
20where
21    Dim<I>: Dimension + ToNpyDims,
22{
23    // See: https://numpy.org/doc/stable/reference/c-api/array.html
24    let array = PY_ARRAY_API.PyArray_NewFromDescr(
25        py,
26        PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
27        dtype.into_dtype_ptr(),
28        shape.ndim_cint(),
29        shape.as_dims_ptr(),
30        // We don't provide strides, but provide flags that tell c/f-order
31        std::ptr::null_mut(),
32        data,
33        flags,
34        std::ptr::null_mut(),
35    );
36
37    // This keeps the memory alive
38    let owner_ptr = owner.as_ptr();
39    // SetBaseObject steals a reference
40    // so we can forget.
41    std::mem::forget(owner);
42    PY_ARRAY_API.PyArray_SetBaseObject(py, array as *mut PyArrayObject, owner_ptr);
43
44    Py::from_owned_ptr(py, array)
45}
46
47/// Returns whether the data type supports creating a NumPy view.
48pub(super) fn dtype_supports_view(dtype: &DataType) -> bool {
49    match dtype {
50        dt if dt.is_primitive_numeric() => true,
51        DataType::Datetime(_, _) | DataType::Duration(_) => true,
52        DataType::Array(inner, _) => dtype_supports_view(inner.as_ref()),
53        _ => false,
54    }
55}
56
57/// Returns whether the Series contains nulls at any level of nesting.
58///
59/// Of the nested types, only Array types are handled since only those are relevant for NumPy views.
60pub(super) fn series_contains_null(s: &Series) -> bool {
61    if s.null_count() > 0 {
62        true
63    } else if let Ok(ca) = s.array() {
64        let s_inner = ca.get_inner();
65        series_contains_null(&s_inner)
66    } else {
67        false
68    }
69}
70
71/// Reshape the first dimension of a NumPy array to the given height and width.
72pub(super) fn reshape_numpy_array(
73    py: Python,
74    arr: PyObject,
75    height: usize,
76    width: usize,
77) -> PyResult<PyObject> {
78    let shape = arr
79        .getattr(py, intern!(py, "shape"))?
80        .extract::<Vec<usize>>(py)?;
81
82    if shape.len() == 1 {
83        // In this case, we can avoid allocating a Vec.
84        let new_shape = (height, width);
85        arr.call_method1(py, intern!(py, "reshape"), new_shape)
86    } else {
87        let mut new_shape_vec = vec![height, width];
88        for v in &shape[1..] {
89            new_shape_vec.push(*v)
90        }
91        let new_shape = PyTuple::new(py, new_shape_vec)?;
92        arr.call_method1(py, intern!(py, "reshape"), new_shape)
93    }
94}
95
96/// Get the NumPy temporal data type associated with the given Polars [`DataType`].
97pub(super) fn polars_dtype_to_np_temporal_dtype<'a>(
98    py: Python<'a>,
99    dtype: &DataType,
100) -> Bound<'a, PyArrayDescr> {
101    use numpy::datetime::{units, Datetime, Timedelta};
102    match dtype {
103        DataType::Datetime(TimeUnit::Milliseconds, _) => {
104            Datetime::<units::Milliseconds>::get_dtype(py)
105        },
106        DataType::Datetime(TimeUnit::Microseconds, _) => {
107            Datetime::<units::Microseconds>::get_dtype(py)
108        },
109        DataType::Datetime(TimeUnit::Nanoseconds, _) => {
110            Datetime::<units::Nanoseconds>::get_dtype(py)
111        },
112        DataType::Duration(TimeUnit::Milliseconds) => {
113            Timedelta::<units::Milliseconds>::get_dtype(py)
114        },
115        DataType::Duration(TimeUnit::Microseconds) => {
116            Timedelta::<units::Microseconds>::get_dtype(py)
117        },
118        DataType::Duration(TimeUnit::Nanoseconds) => Timedelta::<units::Nanoseconds>::get_dtype(py),
119        _ => panic!("only Datetime/Duration inputs supported, got {}", dtype),
120    }
121}