use std::ffi::c_void;
use std::fmt;
use std::marker::PhantomData;
use std::mem::{self, MaybeUninit};
use static_assertions::assert_not_impl_any;
use crate::{rtl, sys};
#[rustfmt::skip]
use crate::sys::MNumericArray_Data_Type::{
MNumericArray_Type_Bit8 as BIT8_TYPE,
MNumericArray_Type_Bit16 as BIT16_TYPE,
MNumericArray_Type_Bit32 as BIT32_TYPE,
MNumericArray_Type_Bit64 as BIT64_TYPE,
MNumericArray_Type_UBit8 as UBIT8_TYPE,
MNumericArray_Type_UBit16 as UBIT16_TYPE,
MNumericArray_Type_UBit32 as UBIT32_TYPE,
MNumericArray_Type_UBit64 as UBIT64_TYPE,
MNumericArray_Type_Real32 as REAL32_TYPE,
MNumericArray_Type_Real64 as REAL64_TYPE,
MNumericArray_Type_Complex_Real32 as COMPLEX_REAL32_TYPE,
MNumericArray_Type_Complex_Real64 as COMPLEX_REAL64_TYPE,
};
use crate::sys::MNumericArray_Convert_Method::*;
#[repr(transparent)]
#[derive(ref_cast::RefCast)]
pub struct NumericArray<T = ()>(sys::MNumericArray, PhantomData<T>);
pub struct UninitNumericArray<T: NumericArrayType>(sys::MNumericArray, PhantomData<T>);
assert_not_impl_any!(NumericArray: Copy);
assert_not_impl_any!(UninitNumericArray<i64>: Copy);
pub trait NumericArrayType: private::Sealed {
const TYPE: NumericArrayDataType;
}
mod private {
use crate::sys;
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for i8 {}
impl Sealed for i16 {}
impl Sealed for i32 {}
impl Sealed for i64 {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for sys::mcomplex {}
}
impl NumericArrayType for i8 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Bit8;
}
impl NumericArrayType for i16 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Bit16;
}
impl NumericArrayType for i32 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Bit32;
}
impl NumericArrayType for i64 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Bit64;
}
impl NumericArrayType for u8 {
const TYPE: NumericArrayDataType = NumericArrayDataType::UBit8;
}
impl NumericArrayType for u16 {
const TYPE: NumericArrayDataType = NumericArrayDataType::UBit16;
}
impl NumericArrayType for u32 {
const TYPE: NumericArrayDataType = NumericArrayDataType::UBit32;
}
impl NumericArrayType for u64 {
const TYPE: NumericArrayDataType = NumericArrayDataType::UBit64;
}
impl NumericArrayType for f32 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Real32;
}
impl NumericArrayType for f64 {
const TYPE: NumericArrayDataType = NumericArrayDataType::Real64;
}
impl NumericArrayType for sys::mcomplex {
const TYPE: NumericArrayDataType = NumericArrayDataType::ComplexReal64;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(u32)]
#[allow(missing_docs)]
pub enum NumericArrayDataType {
Bit8 = BIT8_TYPE as u32,
Bit16 = BIT16_TYPE as u32,
Bit32 = BIT32_TYPE as u32,
Bit64 = BIT64_TYPE as u32,
UBit8 = UBIT8_TYPE as u32,
UBit16 = UBIT16_TYPE as u32,
UBit32 = UBIT32_TYPE as u32,
UBit64 = UBIT64_TYPE as u32,
Real32 = REAL32_TYPE as u32,
Real64 = REAL64_TYPE as u32,
ComplexReal32 = COMPLEX_REAL32_TYPE as u32,
ComplexReal64 = COMPLEX_REAL64_TYPE as u32,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[repr(u32)]
#[allow(missing_docs)]
pub enum NumericArrayConvertMethod {
Cast = MNumericArray_Convert_Cast as u32,
Check = MNumericArray_Convert_Check as u32,
Coerce = MNumericArray_Convert_Coerce as u32,
Round = MNumericArray_Convert_Round as u32,
Scale = MNumericArray_Convert_Scale as u32,
ClipAndCast = MNumericArray_Convert_Clip_Cast as u32,
ClipAndCheck = MNumericArray_Convert_Clip_Check as u32,
ClipAndCoerce = MNumericArray_Convert_Clip_Coerce as u32,
ClipAndRound = MNumericArray_Convert_Clip_Round as u32,
ClipAndScale = MNumericArray_Convert_Clip_Scale as u32,
}
#[allow(missing_docs)]
pub enum NumericArrayKind<'e> {
Bit8(&'e NumericArray<i8>),
Bit16(&'e NumericArray<i16>),
Bit32(&'e NumericArray<i32>),
Bit64(&'e NumericArray<i64>),
UBit8(&'e NumericArray<u8>),
UBit16(&'e NumericArray<u16>),
UBit32(&'e NumericArray<u32>),
UBit64(&'e NumericArray<u64>),
Real32(&'e NumericArray<f32>),
Real64(&'e NumericArray<f64>),
ComplexReal64(&'e NumericArray<sys::mcomplex>),
}
const _: () = assert!(mem::size_of::<sys::mcomplex>() == mem::size_of::<[f64; 2]>());
const _: () = assert!(mem::align_of::<sys::mcomplex>() == mem::align_of::<f64>());
impl NumericArray {
pub fn kind(&self) -> NumericArrayKind<'_> {
unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
std::mem::transmute(array)
}
unsafe {
use NumericArrayDataType::*;
match self.data_type() {
Bit8 => NumericArrayKind::Bit8(trans(self)),
Bit16 => NumericArrayKind::Bit16(trans(self)),
Bit32 => NumericArrayKind::Bit32(trans(self)),
Bit64 => NumericArrayKind::Bit64(trans(self)),
UBit8 => NumericArrayKind::UBit8(trans(self)),
UBit16 => NumericArrayKind::UBit16(trans(self)),
UBit32 => NumericArrayKind::UBit32(trans(self)),
UBit64 => NumericArrayKind::UBit64(trans(self)),
Real32 => NumericArrayKind::Real32(trans(self)),
Real64 => NumericArrayKind::Real64(trans(self)),
ComplexReal32 => unimplemented!(
"NumericArray::kind(): NumericArray of ComplexReal32 is not currently supported."
),
ComplexReal64 => NumericArrayKind::ComplexReal64(trans(self)),
}
}
}
pub fn try_kind<T>(&self) -> Result<&NumericArray<T>, ()>
where
T: NumericArrayType,
{
unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
std::mem::transmute(array)
}
if self.data_type() == T::TYPE {
return Ok(unsafe { trans(self) });
}
Err(())
}
pub fn try_into_kind<T>(self) -> Result<NumericArray<T>, NumericArray>
where
T: NumericArrayType,
{
unsafe fn trans<T: NumericArrayType>(array: NumericArray) -> NumericArray<T> {
std::mem::transmute(array)
}
if self.data_type() == T::TYPE {
return Ok(unsafe { trans(self) });
}
Err(self)
}
}
impl<T: NumericArrayType> NumericArray<T> {
pub fn from_slice(data: &[T]) -> NumericArray<T> {
NumericArray::<T>::try_from_slice(data)
.expect("failed to create NumericArray from slice")
}
pub fn try_from_slice(data: &[T]) -> Result<NumericArray<T>, sys::errcode_t> {
let dim1 = data.len();
NumericArray::try_from_array(&[dim1], data)
}
pub fn from_array(dimensions: &[usize], data: &[T]) -> NumericArray<T> {
NumericArray::<T>::try_from_array(dimensions, data)
.expect("failed to create NumericArray from array")
}
pub fn try_from_array(
dimensions: &[usize],
data: &[T],
) -> Result<NumericArray<T>, sys::errcode_t> {
let uninit = UninitNumericArray::try_from_dimensions(dimensions)?;
Ok(uninit.init_from_slice(data))
}
pub fn as_slice(&self) -> &[T] {
let ptr: *mut c_void = self.data_ptr();
debug_assert!(!ptr.is_null());
debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
let ptr = ptr as *const T;
unsafe { std::slice::from_raw_parts(ptr, self.flattened_length()) }
}
pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
if self.share_count() == 0 {
unsafe { Some(self.as_slice_mut_unchecked()) }
} else {
None
}
}
pub unsafe fn as_slice_mut_unchecked(&mut self) -> &mut [T] {
let ptr: *mut c_void = self.data_ptr();
debug_assert!(!ptr.is_null());
debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
let ptr = ptr as *mut T;
std::slice::from_raw_parts_mut(ptr, self.flattened_length())
}
}
impl<T> NumericArray<T> {
pub fn into_generic(self) -> NumericArray {
let NumericArray(na, PhantomData) = self;
std::mem::forget(self);
NumericArray(na, PhantomData)
}
pub unsafe fn from_raw(array: sys::MNumericArray) -> NumericArray<T> {
NumericArray(array, PhantomData)
}
pub unsafe fn into_raw(self) -> sys::MNumericArray {
let NumericArray(raw, PhantomData) = self;
std::mem::forget(self);
raw
}
pub fn data_ptr(&self) -> *mut c_void {
let NumericArray(numeric_array, _) = *self;
unsafe { data_ptr(numeric_array) }
}
#[allow(missing_docs)]
pub fn data_type(&self) -> NumericArrayDataType {
let value: sys::numericarray_data_t = self.data_type_raw();
let value: u32 = value as u32;
NumericArrayDataType::try_from(value)
.expect("NumericArray tensor property type is value is not a known NumericArrayDataType variant")
}
pub fn data_type_raw(&self) -> sys::numericarray_data_t {
let NumericArray(numeric_array, _) = *self;
unsafe { rtl::MNumericArray_getType(numeric_array) }
}
pub fn flattened_length(&self) -> usize {
let NumericArray(numeric_array, _) = *self;
let len = unsafe { flattened_length(numeric_array) };
debug_assert!(len == self.dimensions().iter().copied().product::<usize>());
len
}
pub fn rank(&self) -> usize {
let NumericArray(numeric_array, _) = *self;
let rank: sys::mint = unsafe { rtl::MNumericArray_getRank(numeric_array) };
let rank = usize::try_from(rank).expect("NumericArray rank overflows usize");
rank
}
pub fn dimensions(&self) -> &[usize] {
let NumericArray(numeric_array, _) = *self;
let rank = self.rank();
debug_assert!(rank != 0);
let dims: *const crate::sys::mint =
unsafe { rtl::MNumericArray_getDimensions(numeric_array) };
const _: () = assert!(mem::size_of::<sys::mint>() == mem::size_of::<usize>());
let dims: *mut usize = dims as *mut usize;
debug_assert!(!dims.is_null());
unsafe { std::slice::from_raw_parts(dims, rank) }
}
pub fn share_count(&self) -> usize {
let NumericArray(raw, PhantomData) = *self;
let count: sys::mint = unsafe { rtl::MNumericArray_shareCount(raw) };
usize::try_from(count).expect("NumericArray share count mint overflows usize")
}
pub fn ptr_eq<T2>(&self, other: &NumericArray<T2>) -> bool {
let NumericArray(this, PhantomData) = *self;
let NumericArray(other, PhantomData) = *other;
this == other
}
pub fn convert_to<T2: NumericArrayType>(
&self,
method: NumericArrayConvertMethod,
tolerance: sys::mreal,
) -> Result<NumericArray<T2>, sys::errcode_t> {
let NumericArray(self_raw, PhantomData) = *self;
let mut new_raw: sys::MNumericArray = std::ptr::null_mut();
let err_code: sys::errcode_t = unsafe {
rtl::MNumericArray_convertType(
&mut new_raw,
self_raw,
T2::TYPE.as_raw(),
method.as_raw(),
tolerance,
)
};
if err_code != 0 || new_raw.is_null() {
return Err(err_code);
}
Ok(unsafe { NumericArray::<T2>::from_raw(new_raw) })
}
}
unsafe fn data_ptr(numeric_array: sys::MNumericArray) -> *mut c_void {
rtl::MNumericArray_getData(numeric_array)
}
unsafe fn flattened_length(numeric_array: sys::MNumericArray) -> usize {
let len: sys::mint = rtl::MNumericArray_getFlattenedLength(numeric_array);
let len = usize::try_from(len).expect("i64 overflows usize");
len
}
impl<T: NumericArrayType> UninitNumericArray<T> {
pub fn from_dimensions(dimensions: &[usize]) -> UninitNumericArray<T> {
UninitNumericArray::try_from_dimensions(dimensions)
.expect("failed to create UninitNumericArray from dimensions")
}
pub fn try_from_dimensions(
dimensions: &[usize],
) -> Result<UninitNumericArray<T>, sys::errcode_t> {
assert!(!dimensions.is_empty());
let rank = dimensions.len();
debug_assert!(rank > 0);
unsafe {
let mut numeric_array: sys::MNumericArray = std::ptr::null_mut();
let err_code: sys::errcode_t = rtl::MNumericArray_new(
<T as NumericArrayType>::TYPE.as_raw(),
i64::try_from(rank).expect("usize overflows i64"),
dimensions.as_ptr() as *mut sys::mint,
&mut numeric_array,
);
if err_code != 0 || numeric_array.is_null() {
return Err(err_code);
}
Ok(UninitNumericArray(numeric_array, PhantomData))
}
}
pub fn init_from_slice(mut self, source: &[T]) -> NumericArray<T> {
let data = self.as_slice_mut();
copy_from_slice_uninit(source, data);
unsafe { self.assume_init() }
}
pub fn as_slice_mut(&mut self) -> &mut [MaybeUninit<T>] {
let UninitNumericArray(numeric_array, PhantomData) = *self;
unsafe {
let len = flattened_length(numeric_array);
let ptr: *mut c_void = data_ptr(numeric_array);
let ptr = ptr as *mut MaybeUninit<T>;
std::slice::from_raw_parts_mut(ptr, len)
}
}
pub unsafe fn assume_init(self) -> NumericArray<T> {
let UninitNumericArray(expr, PhantomData) = self;
std::mem::forget(self);
NumericArray(expr, PhantomData)
}
}
fn copy_from_slice_uninit<T>(src: &[T], dest: &mut [MaybeUninit<T>]) {
assert_eq!(
src.len(),
dest.len(),
"destination and source slices have different lengths"
);
unsafe {
std::ptr::copy_nonoverlapping(
src.as_ptr(),
dest.as_mut_ptr() as *mut T,
dest.len(),
)
}
}
impl NumericArrayDataType {
#[allow(missing_docs)]
pub fn as_raw(self) -> sys::numericarray_data_t {
self as sys::numericarray_data_t
}
#[rustfmt::skip]
pub fn name(&self) -> &'static str {
match self {
NumericArrayDataType::Bit8 => "Integer8",
NumericArrayDataType::Bit16 => "Integer16",
NumericArrayDataType::Bit32 => "Integer32",
NumericArrayDataType::Bit64 => "Integer64",
NumericArrayDataType::UBit8 => "UnsignedInteger8",
NumericArrayDataType::UBit16 => "UnsignedInteger16",
NumericArrayDataType::UBit32 => "UnsignedInteger32",
NumericArrayDataType::UBit64 => "UnsignedInteger64",
NumericArrayDataType::Real32 => "Real32",
NumericArrayDataType::Real64 => "Real64",
NumericArrayDataType::ComplexReal32 => "ComplexReal32",
NumericArrayDataType::ComplexReal64 => "ComplexReal64",
}
}
}
impl NumericArrayConvertMethod {
#[allow(missing_docs)]
pub fn as_raw(self) -> sys::numericarray_convert_method_t {
self as sys::numericarray_convert_method_t
}
}
impl<T> Clone for NumericArray<T> {
fn clone(&self) -> NumericArray<T> {
let NumericArray(raw, PhantomData) = *self;
unsafe {
let mut new: sys::MNumericArray = std::ptr::null_mut();
let err_code: sys::errcode_t = rtl::MNumericArray_clone(raw, &mut new);
if err_code != 0 || new.is_null() {
panic!("NumericArray clone failed with error code: {}", err_code);
}
NumericArray::<T>::from_raw(new)
}
}
}
impl<T> Drop for NumericArray<T> {
fn drop(&mut self) {
if self.share_count() > 0 {
let NumericArray(raw, PhantomData) = *self;
unsafe { rtl::MNumericArray_disown(raw) }
} else {
let NumericArray(raw, PhantomData) = *self;
unsafe { rtl::MNumericArray_free(raw) }
}
}
}
impl<T> fmt::Debug for NumericArray<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("NumericArray")
.field("raw", &self.0)
.field("data_type", &self.data_type())
.finish()
}
}
impl TryFrom<u32> for NumericArrayDataType {
type Error = ();
fn try_from(value: u32) -> Result<Self, Self::Error> {
#[rustfmt::skip]
let ok = match value {
_ if value == BIT8_TYPE as u32 => NumericArrayDataType::Bit8,
_ if value == BIT16_TYPE as u32 => NumericArrayDataType::Bit16,
_ if value == BIT32_TYPE as u32 => NumericArrayDataType::Bit32,
_ if value == BIT64_TYPE as u32 => NumericArrayDataType::Bit64,
_ if value == UBIT8_TYPE as u32 => NumericArrayDataType::UBit8,
_ if value == UBIT16_TYPE as u32 => NumericArrayDataType::UBit16,
_ if value == UBIT32_TYPE as u32 => NumericArrayDataType::UBit32,
_ if value == UBIT64_TYPE as u32 => NumericArrayDataType::UBit64,
_ if value == REAL32_TYPE as u32 => NumericArrayDataType::Real32,
_ if value == REAL64_TYPE as u32 => NumericArrayDataType::Real64,
_ if value == COMPLEX_REAL32_TYPE as u32 => NumericArrayDataType::ComplexReal32,
_ if value == COMPLEX_REAL64_TYPE as u32 => NumericArrayDataType::ComplexReal64,
_ => return Err(()),
};
Ok(ok)
}
}