polars_python/interop/numpy/
utils.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::ffi::{c_int, c_void};
3
4use ndarray::{Dim, Dimension};
5use numpy::npyffi::PyArrayObject;
6use numpy::{Element, PY_ARRAY_API, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, npyffi};
7use polars_core::prelude::*;
8use pyo3::intern;
9use pyo3::prelude::*;
10use pyo3::types::PyTuple;
11
12/// Create a NumPy ndarray view of the data.
13pub(super) unsafe fn create_borrowed_np_array<I>(
14    py: Python<'_>,
15    dtype: Bound<PyArrayDescr>,
16    mut shape: Dim<I>,
17    flags: c_int,
18    data: *mut c_void,
19    owner: PyObject,
20) -> PyObject
21where
22    Dim<I>: Dimension + ToNpyDims,
23{
24    // See: https://numpy.org/doc/stable/reference/c-api/array.html
25    let array = PY_ARRAY_API.PyArray_NewFromDescr(
26        py,
27        PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
28        dtype.into_dtype_ptr(),
29        shape.ndim_cint(),
30        shape.as_dims_ptr(),
31        // We don't provide strides, but provide flags that tell c/f-order
32        std::ptr::null_mut(),
33        data,
34        flags,
35        std::ptr::null_mut(),
36    );
37
38    // This keeps the memory alive
39    let owner_ptr = owner.as_ptr();
40    // SetBaseObject steals a reference
41    // so we can forget.
42    std::mem::forget(owner);
43    PY_ARRAY_API.PyArray_SetBaseObject(py, array as *mut PyArrayObject, owner_ptr);
44
45    Py::from_owned_ptr(py, array)
46}
47
48/// Returns whether the data type supports creating a NumPy view.
49pub(super) fn dtype_supports_view(dtype: &DataType) -> bool {
50    match dtype {
51        dt if dt.is_primitive_numeric() => true,
52        DataType::Datetime(_, _) | DataType::Duration(_) => true,
53        DataType::Array(inner, _) => dtype_supports_view(inner.as_ref()),
54        _ => false,
55    }
56}
57
58/// Returns whether the Series contains nulls at any level of nesting.
59///
60/// Of the nested types, only Array types are handled since only those are relevant for NumPy views.
61pub(super) fn series_contains_null(s: &Series) -> bool {
62    if s.null_count() > 0 {
63        true
64    } else if let Ok(ca) = s.array() {
65        let s_inner = ca.get_inner();
66        series_contains_null(&s_inner)
67    } else {
68        false
69    }
70}
71
72/// Reshape the first dimension of a NumPy array to the given height and width.
73pub(super) fn reshape_numpy_array(
74    py: Python<'_>,
75    arr: PyObject,
76    height: usize,
77    width: usize,
78) -> PyResult<PyObject> {
79    let shape = arr
80        .getattr(py, intern!(py, "shape"))?
81        .extract::<Vec<usize>>(py)?;
82
83    if shape.len() == 1 {
84        // In this case, we can avoid allocating a Vec.
85        let new_shape = (height, width);
86        arr.call_method1(py, intern!(py, "reshape"), new_shape)
87    } else {
88        let mut new_shape_vec = vec![height, width];
89        for v in &shape[1..] {
90            new_shape_vec.push(*v)
91        }
92        let new_shape = PyTuple::new(py, new_shape_vec)?;
93        arr.call_method1(py, intern!(py, "reshape"), new_shape)
94    }
95}
96
97/// Get the NumPy temporal data type associated with the given Polars [`DataType`].
98pub(super) fn polars_dtype_to_np_temporal_dtype<'py>(
99    py: Python<'py>,
100    dtype: &DataType,
101) -> Bound<'py, PyArrayDescr> {
102    use numpy::datetime::{Datetime, Timedelta, units};
103    match dtype {
104        DataType::Datetime(TimeUnit::Milliseconds, _) => {
105            Datetime::<units::Milliseconds>::get_dtype(py)
106        },
107        DataType::Datetime(TimeUnit::Microseconds, _) => {
108            Datetime::<units::Microseconds>::get_dtype(py)
109        },
110        DataType::Datetime(TimeUnit::Nanoseconds, _) => {
111            Datetime::<units::Nanoseconds>::get_dtype(py)
112        },
113        DataType::Duration(TimeUnit::Milliseconds) => {
114            Timedelta::<units::Milliseconds>::get_dtype(py)
115        },
116        DataType::Duration(TimeUnit::Microseconds) => {
117            Timedelta::<units::Microseconds>::get_dtype(py)
118        },
119        DataType::Duration(TimeUnit::Nanoseconds) => Timedelta::<units::Nanoseconds>::get_dtype(py),
120        _ => panic!("only Datetime/Duration inputs supported, got {}", dtype),
121    }
122}