use std::mem::size_of;
use std::os::raw::{
c_char, c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort,
};
use std::ptr;
#[cfg(feature = "half")]
use half::f16;
use num_traits::{Bounded, Zero};
use pyo3::{
exceptions::{PyIndexError, PyValueError},
ffi::{self, PyTuple_Size},
pyobject_native_type_extract, pyobject_native_type_named,
types::{PyDict, PyTuple, PyType},
AsPyPointer, FromPyObject, FromPyPointer, IntoPyPointer, PyAny, PyNativeType, PyObject,
PyResult, PyTypeInfo, Python, ToPyObject,
};
use crate::npyffi::{
NpyTypes, PyArray_Descr, NPY_ALIGNED_STRUCT, NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES,
PY_ARRAY_API,
};
pub use num_complex::{Complex32, Complex64};
#[repr(transparent)]
pub struct PyArrayDescr(PyAny);
pyobject_native_type_named!(PyArrayDescr);
unsafe impl PyTypeInfo for PyArrayDescr {
type AsRefTarget = Self;
const NAME: &'static str = "PyArrayDescr";
const MODULE: Option<&'static str> = Some("numpy");
#[inline]
fn type_object_raw(py: Python) -> *mut ffi::PyTypeObject {
unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
}
fn is_type_of(ob: &PyAny) -> bool {
unsafe { ffi::PyObject_TypeCheck(ob.as_ptr(), Self::type_object_raw(ob.py())) > 0 }
}
}
pyobject_native_type_extract!(PyArrayDescr);
pub fn dtype<T: Element>(py: Python) -> &PyArrayDescr {
T::get_dtype(py)
}
impl PyArrayDescr {
#[inline]
pub fn new<'py, T: ToPyObject + ?Sized>(py: Python<'py>, ob: &T) -> PyResult<&'py Self> {
fn inner<'py>(py: Python<'py>, obj: PyObject) -> PyResult<&'py PyArrayDescr> {
let mut descr: *mut PyArray_Descr = ptr::null_mut();
unsafe {
PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr as *mut _);
py.from_owned_ptr_or_err(descr as _)
}
}
inner(py, ob.to_object(py))
}
pub fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
self.as_ptr() as _
}
pub fn into_dtype_ptr(&self) -> *mut PyArray_Descr {
self.into_ptr() as _
}
pub fn object(py: Python) -> &Self {
Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
}
pub fn of<T: Element>(py: Python) -> &Self {
T::get_dtype(py)
}
pub fn is_equiv_to(&self, other: &Self) -> bool {
let self_ptr = self.as_dtype_ptr();
let other_ptr = other.as_dtype_ptr();
unsafe {
self_ptr == other_ptr
|| PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
}
}
fn from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
unsafe {
let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
py.from_owned_ptr(descr as _)
}
}
pub(crate) fn new_from_npy_type(py: Python, npy_type: NPY_TYPES) -> &Self {
unsafe {
let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
py.from_owned_ptr(descr as _)
}
}
pub fn typeobj(&self) -> &PyType {
let dtype_type_ptr = unsafe { *self.as_dtype_ptr() }.typeobj;
unsafe { PyType::from_type_ptr(self.py(), dtype_type_ptr) }
}
pub fn num(&self) -> c_int {
unsafe { *self.as_dtype_ptr() }.type_num
}
pub fn itemsize(&self) -> usize {
unsafe { *self.as_dtype_ptr() }.elsize.max(0) as _
}
pub fn alignment(&self) -> usize {
unsafe { *self.as_dtype_ptr() }.alignment.max(0) as _
}
pub fn byteorder(&self) -> u8 {
unsafe { *self.as_dtype_ptr() }.byteorder.max(0) as _
}
pub fn char(&self) -> u8 {
unsafe { *self.as_dtype_ptr() }.type_.max(0) as _
}
pub fn kind(&self) -> u8 {
unsafe { *self.as_dtype_ptr() }.kind.max(0) as _
}
pub fn flags(&self) -> c_char {
unsafe { *self.as_dtype_ptr() }.flags
}
pub fn ndim(&self) -> usize {
if !self.has_subarray() {
return 0;
}
unsafe { PyTuple_Size((*((*self.as_dtype_ptr()).subarray)).shape).max(0) as _ }
}
pub fn base(&self) -> &PyArrayDescr {
if !self.has_subarray() {
self
} else {
unsafe {
Self::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).base as _)
}
}
}
pub fn shape(&self) -> Vec<usize> {
if !self.has_subarray() {
Vec::new()
} else {
unsafe {
PyTuple::from_borrowed_ptr(self.py(), (*(*self.as_dtype_ptr()).subarray).shape)
}
.extract()
.unwrap()
}
}
pub fn has_object(&self) -> bool {
self.flags() & NPY_ITEM_HASOBJECT != 0
}
pub fn is_aligned_struct(&self) -> bool {
self.flags() & NPY_ALIGNED_STRUCT != 0
}
pub fn has_subarray(&self) -> bool {
unsafe { !(*self.as_dtype_ptr()).subarray.is_null() }
}
pub fn has_fields(&self) -> bool {
unsafe { !(*self.as_dtype_ptr()).names.is_null() }
}
pub fn is_native_byteorder(&self) -> Option<bool> {
match self.byteorder() {
b'=' => Some(true),
b'|' => None,
byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
}
}
pub fn names(&self) -> Option<Vec<&str>> {
if !self.has_fields() {
return None;
}
let names = unsafe { PyTuple::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).names) };
FromPyObject::extract(names).ok()
}
pub fn get_field(&self, name: &str) -> PyResult<(&PyArrayDescr, usize)> {
if !self.has_fields() {
return Err(PyValueError::new_err(
"cannot get field information: type descriptor has no fields",
));
}
let dict = unsafe { PyDict::from_borrowed_ptr(self.py(), (*self.as_dtype_ptr()).fields) };
let tuple = dict
.get_item(name)
.ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
.downcast::<PyTuple>()
.unwrap();
let dtype = FromPyObject::extract(tuple.as_ref().get_item(0).unwrap()).unwrap();
let offset = FromPyObject::extract(tuple.as_ref().get_item(1).unwrap()).unwrap();
Ok((dtype, offset))
}
}
pub unsafe trait Element: Clone + Send {
const IS_COPY: bool;
fn get_dtype(py: Python) -> &PyArrayDescr;
}
fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
match size_of::<T>() {
x if x == size_of::<T0>() => npy_types[0],
x if x == size_of::<T1>() => npy_types[1],
x if x == size_of::<T2>() => npy_types[2],
_ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
}
}
fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
let is_unsigned = T::min_value() == T::zero();
let bit_width = 8 * size_of::<T>();
match (is_unsigned, bit_width) {
(false, 8) => NPY_TYPES::NPY_BYTE,
(false, 16) => NPY_TYPES::NPY_SHORT,
(false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
NPY_TYPES::NPY_LONG,
NPY_TYPES::NPY_INT,
NPY_TYPES::NPY_SHORT,
]),
(false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
NPY_TYPES::NPY_LONG,
NPY_TYPES::NPY_LONGLONG,
NPY_TYPES::NPY_INT,
]),
(true, 8) => NPY_TYPES::NPY_UBYTE,
(true, 16) => NPY_TYPES::NPY_USHORT,
(true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
NPY_TYPES::NPY_ULONG,
NPY_TYPES::NPY_UINT,
NPY_TYPES::NPY_USHORT,
]),
(true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
NPY_TYPES::NPY_ULONG,
NPY_TYPES::NPY_ULONGLONG,
NPY_TYPES::NPY_UINT,
]),
_ => unreachable!(),
}
}
macro_rules! impl_element_scalar {
(@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
$(#[$meta])*
unsafe impl Element for $ty {
const IS_COPY: bool = true;
fn get_dtype(py: Python) -> &PyArrayDescr {
PyArrayDescr::from_npy_type(py, $npy_type)
}
}
};
($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
};
($($tys:ty),+) => {
$(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
};
}
impl_element_scalar!(bool => NPY_BOOL);
impl_element_scalar!(i8, i16, i32, i64);
impl_element_scalar!(u8, u16, u32, u64);
impl_element_scalar!(f32 => NPY_FLOAT);
impl_element_scalar!(f64 => NPY_DOUBLE);
#[cfg(feature = "half")]
impl_element_scalar!(f16 => NPY_HALF);
impl_element_scalar!(Complex32 => NPY_CFLOAT,
#[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
impl_element_scalar!(Complex64 => NPY_CDOUBLE,
#[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
impl_element_scalar!(usize, isize);
unsafe impl Element for PyObject {
const IS_COPY: bool = false;
fn get_dtype(py: Python) -> &PyArrayDescr {
PyArrayDescr::object(py)
}
}
#[cfg(test)]
mod tests {
use super::*;
use pyo3::py_run;
use crate::npyffi::NPY_NEEDS_PYAPI;
#[test]
fn test_dtype_new() {
Python::with_gil(|py| {
assert!(PyArrayDescr::new(py, "float64")
.unwrap()
.is(dtype::<f64>(py)));
let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
assert_eq!(dt.names(), Some(vec!["a", "b"]));
assert!(dt.has_object());
assert!(dt.get_field("a").unwrap().0.is(dtype::<PyObject>(py)));
assert!(dt.get_field("b").unwrap().0.is(dtype::<bool>(py)));
assert!(PyArrayDescr::new(py, &123_usize).is_err());
});
}
#[test]
fn test_dtype_names() {
fn type_name<T: Element>(py: Python) -> &str {
dtype::<T>(py).typeobj().name().unwrap()
}
Python::with_gil(|py| {
assert_eq!(type_name::<bool>(py), "bool_");
assert_eq!(type_name::<i8>(py), "int8");
assert_eq!(type_name::<i16>(py), "int16");
assert_eq!(type_name::<i32>(py), "int32");
assert_eq!(type_name::<i64>(py), "int64");
assert_eq!(type_name::<u8>(py), "uint8");
assert_eq!(type_name::<u16>(py), "uint16");
assert_eq!(type_name::<u32>(py), "uint32");
assert_eq!(type_name::<u64>(py), "uint64");
assert_eq!(type_name::<f32>(py), "float32");
assert_eq!(type_name::<f64>(py), "float64");
assert_eq!(type_name::<Complex32>(py), "complex64");
assert_eq!(type_name::<Complex64>(py), "complex128");
#[cfg(target_pointer_width = "32")]
{
assert_eq!(type_name::<usize>(py), "uint32");
assert_eq!(type_name::<isize>(py), "int32");
}
#[cfg(target_pointer_width = "64")]
{
assert_eq!(type_name::<usize>(py), "uint64");
assert_eq!(type_name::<isize>(py), "int64");
}
});
}
#[test]
fn test_dtype_methods_scalar() {
Python::with_gil(|py| {
let dt = dtype::<f64>(py);
assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(dt.typeobj().name().unwrap(), "float64");
assert_eq!(dt.char(), b'd');
assert_eq!(dt.kind(), b'f');
assert_eq!(dt.byteorder(), b'=');
assert_eq!(dt.is_native_byteorder(), Some(true));
assert_eq!(dt.itemsize(), 8);
assert_eq!(dt.alignment(), 8);
assert!(!dt.has_object());
assert_eq!(dt.names(), None);
assert!(!dt.has_fields());
assert!(!dt.is_aligned_struct());
assert!(!dt.has_subarray());
assert!(dt.base().is_equiv_to(dt));
assert_eq!(dt.ndim(), 0);
assert_eq!(dt.shape(), vec![]);
});
}
#[test]
fn test_dtype_methods_subarray() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py_run!(
py,
*locals,
"dtype = __import__('numpy').dtype(('f8', (2, 3)))"
);
let dt = locals
.get_item("dtype")
.unwrap()
.downcast::<PyArrayDescr>()
.unwrap();
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
assert_eq!(dt.flags(), 0);
assert_eq!(dt.typeobj().name().unwrap(), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
assert_eq!(dt.is_native_byteorder(), None);
assert_eq!(dt.itemsize(), 48);
assert_eq!(dt.alignment(), 8);
assert!(!dt.has_object());
assert_eq!(dt.names(), None);
assert!(!dt.has_fields());
assert!(!dt.is_aligned_struct());
assert!(dt.has_subarray());
assert_eq!(dt.ndim(), 2);
assert_eq!(dt.shape(), vec![2, 3]);
assert!(dt.base().is_equiv_to(dtype::<f64>(py)));
});
}
#[test]
fn test_dtype_methods_record() {
Python::with_gil(|py| {
let locals = PyDict::new(py);
py_run!(
py,
*locals,
"dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
);
let dt = locals
.get_item("dtype")
.unwrap()
.downcast::<PyArrayDescr>()
.unwrap();
assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
assert_eq!(dt.typeobj().name().unwrap(), "void");
assert_eq!(dt.char(), b'V');
assert_eq!(dt.kind(), b'V');
assert_eq!(dt.byteorder(), b'|');
assert_eq!(dt.is_native_byteorder(), None);
assert_eq!(dt.itemsize(), 24);
assert_eq!(dt.alignment(), 8);
assert!(dt.has_object());
assert_eq!(dt.names(), Some(vec!["x", "y", "z"]));
assert!(dt.has_fields());
assert!(dt.is_aligned_struct());
assert!(!dt.has_subarray());
assert_eq!(dt.ndim(), 0);
assert_eq!(dt.shape(), vec![]);
assert!(dt.base().is_equiv_to(dt));
let x = dt.get_field("x").unwrap();
assert!(x.0.is_equiv_to(dtype::<u8>(py)));
assert_eq!(x.1, 0);
let y = dt.get_field("y").unwrap();
assert!(y.0.is_equiv_to(dtype::<f64>(py)));
assert_eq!(y.1, 8);
let z = dt.get_field("z").unwrap();
assert!(z.0.is_equiv_to(dtype::<PyObject>(py)));
assert_eq!(z.1, 16);
});
}
}