#![allow(unsafe_op_in_unsafe_fn)]
use std::ffi::{c_int, c_void};
use ndarray::{Dim, Dimension};
use numpy::npyffi::PyArrayObject;
use numpy::{Element, PY_ARRAY_API, PyArrayDescr, PyArrayDescrMethods, ToNpyDims, npyffi};
use polars_core::prelude::*;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
pub(super) unsafe fn create_borrowed_np_array<I>(
py: Python<'_>,
dtype: Bound<PyArrayDescr>,
mut shape: Dim<I>,
flags: c_int,
data: *mut c_void,
owner: Py<PyAny>,
) -> Py<PyAny>
where
Dim<I>: Dimension + ToNpyDims,
{
let array = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
dtype.into_dtype_ptr(),
shape.ndim_cint(),
shape.as_dims_ptr(),
std::ptr::null_mut(),
data,
flags,
std::ptr::null_mut(),
);
let owner_ptr = owner.as_ptr();
std::mem::forget(owner);
PY_ARRAY_API.PyArray_SetBaseObject(py, array as *mut PyArrayObject, owner_ptr);
Py::from_owned_ptr(py, array)
}
pub(super) fn dtype_supports_view(dtype: &DataType) -> bool {
match dtype {
dt if dt.is_primitive_numeric() => true,
DataType::Datetime(_, _) | DataType::Duration(_) => true,
DataType::Array(inner, _) => dtype_supports_view(inner.as_ref()),
_ => false,
}
}
pub(super) fn series_contains_null(s: &Series) -> bool {
if s.null_count() > 0 {
true
} else if let Ok(ca) = s.array() {
let s_inner = ca.get_inner();
series_contains_null(&s_inner)
} else {
false
}
}
pub(super) fn reshape_numpy_array(
py: Python<'_>,
arr: Py<PyAny>,
height: usize,
width: usize,
) -> PyResult<Py<PyAny>> {
let shape = arr
.getattr(py, intern!(py, "shape"))?
.extract::<Vec<usize>>(py)?;
if shape.len() == 1 {
let new_shape = (height, width);
arr.call_method1(py, intern!(py, "reshape"), new_shape)
} else {
let mut new_shape_vec = vec![height, width];
for v in &shape[1..] {
new_shape_vec.push(*v)
}
let new_shape = PyTuple::new(py, new_shape_vec)?;
arr.call_method1(py, intern!(py, "reshape"), new_shape)
}
}
pub(super) fn polars_dtype_to_np_temporal_dtype<'py>(
py: Python<'py>,
dtype: &DataType,
) -> Bound<'py, PyArrayDescr> {
use numpy::datetime::{Datetime, Timedelta, units};
match dtype {
DataType::Datetime(TimeUnit::Milliseconds, _) => {
Datetime::<units::Milliseconds>::get_dtype(py)
},
DataType::Datetime(TimeUnit::Microseconds, _) => {
Datetime::<units::Microseconds>::get_dtype(py)
},
DataType::Datetime(TimeUnit::Nanoseconds, _) => {
Datetime::<units::Nanoseconds>::get_dtype(py)
},
DataType::Duration(TimeUnit::Milliseconds) => {
Timedelta::<units::Milliseconds>::get_dtype(py)
},
DataType::Duration(TimeUnit::Microseconds) => {
Timedelta::<units::Microseconds>::get_dtype(py)
},
DataType::Duration(TimeUnit::Nanoseconds) => Timedelta::<units::Nanoseconds>::get_dtype(py),
_ => panic!("only Datetime/Duration inputs supported, got {dtype}"),
}
}