use std::mem::size_of;
use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
use std::ptr;
#[cfg(feature = "half")]
use half::{bf16, f16};
use num_traits::{Bounded, Zero};
#[cfg(feature = "half")]
use pyo3::sync::PyOnceLock;
use pyo3::{
conversion::IntoPyObject,
exceptions::{PyIndexError, PyValueError},
ffi::{self, PyTuple_Size},
pyobject_native_type_named,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
Borrowed, Bound, Py, PyAny, PyResult, PyTypeInfo, Python,
};
use crate::npyffi::{
NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
};
pub use num_complex::{Complex32, Complex64};
#[repr(transparent)]
pub struct PyArrayDescr(PyAny);
pyobject_native_type_named!(PyArrayDescr);
unsafe impl PyTypeInfo for PyArrayDescr {
const NAME: &'static str = "PyArrayDescr";
const MODULE: Option<&'static str> = Some("numpy");
#[inline]
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
}
}
#[inline]
pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
T::get_dtype(py)
}
impl PyArrayDescr {
#[inline]
pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
where
T: IntoPyObject<'py>,
{
fn inner<'py>(
py: Python<'py>,
obj: Borrowed<'_, 'py, PyAny>,
) -> PyResult<Bound<'py, PyArrayDescr>> {
let mut descr: *mut PyArray_Descr = ptr::null_mut();
unsafe {
PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
Bound::from_owned_ptr_or_err(py, descr.cast()).map(|any| any.cast_into_unchecked())
}
}
inner(
py,
ob.into_pyobject(py)
.map_err(Into::into)?
.into_any()
.as_borrowed(),
)
}
#[inline]
pub fn object(py: Python<'_>) -> Bound<'_, Self> {
Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
}
#[inline]
pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
T::get_dtype(py)
}
fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
unsafe {
let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
}
}
pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
unsafe {
let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
}
}
}
#[doc(alias = "PyArrayDescr")]
pub trait PyArrayDescrMethods<'py>: Sealed {
fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
fn into_dtype_ptr(self) -> *mut PyArray_Descr;
fn is_equiv_to(&self, other: &Self) -> bool;
fn typeobj(&self) -> Bound<'py, PyType>;
fn num(&self) -> c_int {
unsafe { &*self.as_dtype_ptr() }.type_num
}
fn itemsize(&self) -> usize;
fn alignment(&self) -> usize;
fn byteorder(&self) -> u8 {
unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
}
fn char(&self) -> u8 {
unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
}
fn kind(&self) -> u8 {
unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
}
fn flags(&self) -> u64;
fn ndim(&self) -> usize;
fn base(&self) -> Bound<'py, PyArrayDescr>;
fn shape(&self) -> Vec<usize>;
fn has_object(&self) -> bool {
self.flags() & NPY_ITEM_HASOBJECT != 0
}
fn is_aligned_struct(&self) -> bool {
self.flags() & NPY_ALIGNED_STRUCT != 0
}
fn has_subarray(&self) -> bool;
fn has_fields(&self) -> bool;
fn is_native_byteorder(&self) -> Option<bool> {
match self.byteorder() {
b'=' => Some(true),
b'|' => None,
byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
}
}
fn names(&self) -> Option<Vec<String>>;
fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
}
mod sealed {
pub trait Sealed {}
}
use sealed::Sealed;
impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
self.as_ptr() as _
}
fn into_dtype_ptr(self) -> *mut PyArray_Descr {
self.into_ptr() as _
}
fn is_equiv_to(&self, other: &Self) -> bool {
let self_ptr = self.as_dtype_ptr();
let other_ptr = other.as_dtype_ptr();
unsafe {
self_ptr == other_ptr
|| PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
}
}
fn typeobj(&self) -> Bound<'py, PyType> {
let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
}
fn itemsize(&self) -> usize {
unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
}
fn alignment(&self) -> usize {
unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
}
fn flags(&self) -> u64 {
unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
}
fn ndim(&self) -> usize {
let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
match subarray {
None => 0,
Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
}
}
fn base(&self) -> Bound<'py, PyArrayDescr> {
let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
match subarray {
None => self.clone(),
Some(subarray) => unsafe {
Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).cast_into_unchecked()
},
}
}
fn shape(&self) -> Vec<usize> {
let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
match subarray {
None => Vec::new(),
Some(subarray) => {
let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
shape.extract().expect("Operation failed")
}
}
}
fn has_subarray(&self) -> bool {
unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
}
fn has_fields(&self) -> bool {
unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
}
fn names(&self) -> Option<Vec<String>> {
if !self.has_fields() {
return None;
}
let names = unsafe {
Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
};
names.extract().ok()
}
fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
if !self.has_fields() {
return Err(PyValueError::new_err(
"cannot get field information: type descriptor has no fields",
));
}
let dict = unsafe {
Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
};
let dict = unsafe { dict.cast_unchecked::<PyDict>() };
let tuple = dict
.get_item(name)?
.ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
.cast_into::<PyTuple>()
.expect("Operation failed");
let dtype = tuple
.get_item(0)
.expect("Operation failed")
.cast_into::<PyArrayDescr>()
.expect("Operation failed");
let offset = tuple
.get_item(1)
.expect("Operation failed")
.extract()
.expect("Operation failed");
Ok((dtype, offset))
}
}
impl Sealed for Bound<'_, PyArrayDescr> {}
pub unsafe trait Element: Sized + Send + Sync {
const IS_COPY: bool;
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
fn clone_ref(&self, py: Python<'_>) -> Self;
#[inline]
fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
slc.iter().map(|elem| elem.clone_ref(py)).collect()
}
#[inline]
fn array_from_view<D>(
py: Python<'_>,
view: ::ndarray::ArrayView<'_, Self, D>,
) -> ::ndarray::Array<Self, D>
where
D: ::ndarray::Dimension,
{
view.map(|elem| elem.clone_ref(py))
}
}
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
match size_of::<T>() {
x if x == size_of::<T0>() => npy_types[0],
x if x == size_of::<T1>() => npy_types[1],
x if x == size_of::<T2>() => npy_types[2],
_ => panic!("Unable to match integer type descriptor: {npy_types:?}"),
}
}
fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
let is_unsigned = T::min_value() == T::zero();
let bit_width = 8 * size_of::<T>();
match (is_unsigned, bit_width) {
(false, 8) => NPY_TYPES::NPY_BYTE,
(false, 16) => NPY_TYPES::NPY_SHORT,
(false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
NPY_TYPES::NPY_LONG,
NPY_TYPES::NPY_INT,
NPY_TYPES::NPY_SHORT,
]),
(false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
NPY_TYPES::NPY_LONG,
NPY_TYPES::NPY_LONGLONG,
NPY_TYPES::NPY_INT,
]),
(true, 8) => NPY_TYPES::NPY_UBYTE,
(true, 16) => NPY_TYPES::NPY_USHORT,
(true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
NPY_TYPES::NPY_ULONG,
NPY_TYPES::NPY_UINT,
NPY_TYPES::NPY_USHORT,
]),
(true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
NPY_TYPES::NPY_ULONG,
NPY_TYPES::NPY_ULONGLONG,
NPY_TYPES::NPY_UINT,
]),
_ => unreachable!(),
}
}
macro_rules! clone_methods_impl {
($Self:ty) => {
#[inline]
fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
::std::clone::Clone::clone(self)
}
#[inline]
fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
::std::borrow::ToOwned::to_owned(slc)
}
#[inline]
fn array_from_view<D>(
_py: ::pyo3::Python<'_>,
view: ::ndarray::ArrayView<'_, $Self, D>,
) -> ::ndarray::Array<$Self, D>
where
D: ::ndarray::Dimension,
{
::ndarray::ArrayView::to_owned(&view)
}
};
}
pub(crate) use clone_methods_impl;
use pyo3::BoundObject;
macro_rules! impl_element_scalar {
(@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
$(#[$meta])*
unsafe impl Element for $ty {
const IS_COPY: bool = true;
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
PyArrayDescr::from_npy_type(py, $npy_type)
}
clone_methods_impl!($ty);
}
};
($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
};
($($tys:ty),+) => {
$(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
};
}
impl_element_scalar!(bool => NPY_BOOL);
impl_element_scalar!(i8, i16, i32, i64);
impl_element_scalar!(u8, u16, u32, u64);
impl_element_scalar!(f32 => NPY_FLOAT);
impl_element_scalar!(f64 => NPY_DOUBLE);
#[cfg(feature = "half")]
impl_element_scalar!(f16 => NPY_HALF);
#[cfg(feature = "half")]
unsafe impl Element for bf16 {
const IS_COPY: bool = true;
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
static DTYPE: PyOnceLock<Py<PyArrayDescr>> = PyOnceLock::new();
DTYPE
.get_or_init(py, || {
PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
})
.clone_ref(py)
.into_bound(py)
}
clone_methods_impl!(Self);
}
impl_element_scalar!(Complex32 => NPY_CFLOAT,
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
impl_element_scalar!(Complex64 => NPY_CDOUBLE,
#[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
impl_element_scalar!(usize, isize);
unsafe impl Element for Py<PyAny> {
const IS_COPY: bool = false;
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
PyArrayDescr::object(py)
}
#[inline]
fn clone_ref(&self, py: Python<'_>) -> Self {
Py::clone_ref(self, py)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pyo3::types::PyString;
use pyo3::{py_run, types::PyTypeMethods};
use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
#[test]
fn test_dtype_new() {
Python::attach(|py| {
assert!(PyArrayDescr::new(py, "float64")
.expect("Operation failed")
.is(dtype::<f64>(py)));
let dt =
PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).expect("Operation failed");
assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
assert!(dt.has_object());
assert!(dt
.get_field("a")
.expect("Operation failed")
.0
.is(dtype::<Py<PyAny>>(py)));
assert!(dt
.get_field("b")
.expect("Operation failed")
.0
.is(dtype::<bool>(py)));
assert!(PyArrayDescr::new(py, 123_usize).is_err());
});
}
#[test]
fn test_dtype_names() {
fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
dtype::<T>(py)
.typeobj()
.qualname()
.expect("Operation failed")
}
Python::attach(|py| {
if is_numpy_2(py) {
assert_eq!(type_name::<bool>(py), "bool");
} else {
assert_eq!(type_name::<bool>(py), "bool_");
}
assert_eq!(type_name::<i8>(py), "int8");
assert_eq!(type_name::<i16>(py), "int16");
assert_eq!(type_name::<i32>(py), "int32");
assert_eq!(type_name::<i64>(py), "int64");
assert_eq!(type_name::<u8>(py), "uint8");
assert_eq!(type_name::<u16>(py), "uint16");
assert_eq!(type_name::<u32>(py), "uint32");
assert_eq!(type_name::<u64>(py), "uint64");
assert_eq!(type_name::<f32>(py), "float32");
assert_eq!(type_name::<f64>(py), "float64");
assert_eq!(type_name::<Complex32>(py), "complex64");
assert_eq!(type_name::<Complex64>(py), "complex128");
#[cfg(target_pointer_width = "32")]
{
assert_eq!(type_name::<usize>(py), "uint32");
assert_eq!(type_name::<isize>(py), "int32");
}
#[cfg(target_pointer_width = "64")]
{
assert_eq!(type_name::<usize>(py), "uint64");
assert_eq!(type_name::<isize>(py), "int64");
}
});
}
#[test]
fn test_dtype_methods_scalar() {
Python::attach(|py| {
let dt = dtype::<f64>(py);
assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(
dt.typeobj().qualname().expect("Operation failed"),
"float64"
);
assert_eq!(dt.char(), b'd');
assert_eq!(dt.kind(), b'f');
assert_eq!(dt.byteorder(), b'=');
assert_eq!(dt.is_native_byteorder(), Some(true));
assert_eq!(dt.itemsize(), 8);
assert_eq!(dt.alignment(), 8);
assert!(!dt.has_object());
assert!(dt.names().is_none());
assert!(!dt.has_fields());
assert!(!dt.is_aligned_struct());
assert!(!dt.has_subarray());
assert!(dt.base().is_equiv_to(&dt));
assert_eq!(dt.ndim(), 0);
assert_eq!(dt.shape(), Vec::<usize>::new());
});
}
#[test]
fn test_dtype_methods_subarray() {
Python::attach(|py| {
let locals = PyDict::new(py);
py_run!(
py,
*locals,
"dtype = __import__('numpy').dtype(('f8', (2, 3)))"
);
let dt = locals
.get_item("dtype")
.expect("Operation failed")
.expect("Operation failed")
.cast_into::<PyArrayDescr>()
.expect("Operation failed");
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(dt.typeobj().qualname().expect("Operation failed"), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
assert_eq!(dt.is_native_byteorder(), None);
assert_eq!(dt.itemsize(), 48);
assert_eq!(dt.alignment(), 8);
assert!(!dt.has_object());
assert!(dt.names().is_none());
assert!(!dt.has_fields());
assert!(!dt.is_aligned_struct());
assert!(dt.has_subarray());
assert_eq!(dt.ndim(), 2);
assert_eq!(dt.shape(), vec![2, 3]);
assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
});
}
#[test]
fn test_dtype_methods_record() {
Python::attach(|py| {
let locals = PyDict::new(py);
py_run!(
py,
*locals,
"dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
);
let dt = locals
.get_item("dtype")
.expect("Operation failed")
.expect("Operation failed")
.cast_into::<PyArrayDescr>()
.expect("Operation failed");
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
assert_eq!(dt.typeobj().qualname().expect("Operation failed"), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
assert_eq!(dt.is_native_byteorder(), None);
assert_eq!(dt.itemsize(), 24);
assert_eq!(dt.alignment(), 8);
assert!(dt.has_object());
assert_eq!(
dt.names(),
Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
);
assert!(dt.has_fields());
assert!(dt.is_aligned_struct());
assert!(!dt.has_subarray());
assert_eq!(dt.ndim(), 0);
assert_eq!(dt.shape(), Vec::<usize>::new());
assert!(dt.base().is_equiv_to(&dt));
let x = dt.get_field("x").expect("Operation failed");
assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
assert_eq!(x.1, 0);
let y = dt.get_field("y").expect("Operation failed");
assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
assert_eq!(y.1, 8);
let z = dt.get_field("z").expect("Operation failed");
assert!(z.0.is_equiv_to(&dtype::<Py<PyAny>>(py)));
assert_eq!(z.1, 16);
});
}
}