#![allow(unsafe_op_in_unsafe_fn)]
use std::ptr;
use ndarray::IntoDimension;
use numpy::npyffi::types::npy_intp;
use numpy::npyffi::{self, flags};
use numpy::{Element, PY_ARRAY_API, PyArray1, PyArrayDescrMethods, ToNpyDims};
use polars_core::prelude::*;
use polars_core::utils::arrow::types::NativeType;
use pyo3::prelude::*;
use pyo3::types::{PyNone, PyTuple};
use super::PySeries;
unsafe fn aligned_array<T: Element + NativeType>(
py: Python<'_>,
size: usize,
) -> (Bound<'_, PyArray1<T>>, Vec<T>) {
let mut buf = vec![T::default(); size];
let len = buf.len();
let buffer_ptr = buf.as_mut_ptr();
let mut dims = [len].into_dimension();
let strides = [size_of::<T>() as npy_intp];
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
strides.as_ptr() as *mut _, buffer_ptr as _, flags::NPY_ARRAY_OUT_ARRAY, ptr::null_mut(), );
(
Bound::from_owned_ptr(py, ptr)
.cast_into_exact::<PyArray1<T>>()
.unwrap(),
buf,
)
}
fn get_refcnt<T>(pyarray: &Bound<'_, PyArray1<T>>) -> isize {
let refcnt = pyarray.get_refcnt();
#[cfg(target_pointer_width = "64")]
if refcnt >= (2 << 60) {
return refcnt - (2 << 60);
}
refcnt
}
macro_rules! impl_ufuncs {
($name:ident, $type:ident, $unsafe_from_ptr_method:ident) => {
#[pymethods]
impl PySeries {
fn $name(&self, lambda: &Bound<PyAny>, allocate_out: bool) -> PyResult<PySeries> {
Python::attach(|py| {
if !allocate_out {
let result = lambda.call1((PyNone::get(py),))?;
let series_factory = crate::py_modules::pl_series(py).bind(py);
return series_factory
.call((self.name(), result), None)?
.getattr("_s")?
.extract::<PySeries>()
.map_err(PyErr::from);
}
let size = self.len();
let (out_array, av) =
unsafe { aligned_array::<<$type as PolarsNumericType>::Native>(py, size) };
debug_assert_eq!(get_refcnt(&out_array), 1);
let args = PyTuple::new(py, std::slice::from_ref(&out_array))?;
debug_assert_eq!(get_refcnt(&out_array), 2);
let s = match lambda.call1(args) {
Ok(_) => {
assert!(get_refcnt(&out_array) <= 3);
let s = self.series.read();
let validity = s.chunks()[0].validity().cloned();
let ca = ChunkedArray::<$type>::from_vec_validity(
s.name().clone(),
av,
validity,
);
PySeries::new(ca.into_series())
},
Err(e) => {
return Err(e);
},
};
Ok(s)
})
}
}
};
}
impl_ufuncs!(apply_ufunc_f32, Float32Type, unsafe_from_ptr_f32);
impl_ufuncs!(apply_ufunc_f64, Float64Type, unsafe_from_ptr_f64);
impl_ufuncs!(apply_ufunc_u8, UInt8Type, unsafe_from_ptr_u8);
impl_ufuncs!(apply_ufunc_u16, UInt16Type, unsafe_from_ptr_u16);
impl_ufuncs!(apply_ufunc_u32, UInt32Type, unsafe_from_ptr_u32);
impl_ufuncs!(apply_ufunc_u64, UInt64Type, unsafe_from_ptr_u64);
impl_ufuncs!(apply_ufunc_i8, Int8Type, unsafe_from_ptr_i8);
impl_ufuncs!(apply_ufunc_i16, Int16Type, unsafe_from_ptr_i16);
impl_ufuncs!(apply_ufunc_i32, Int32Type, unsafe_from_ptr_i32);
impl_ufuncs!(apply_ufunc_i64, Int64Type, unsafe_from_ptr_i64);