use {
std::{
fmt,
marker::{
PhantomData
},
ops::{
Deref,
DerefMut
},
slice
},
pyo3::{
prelude::*,
types::{
PyDict,
PyTuple
},
ToPyPointer
},
crate::{
backend::{
keras::{
ffi
}
},
core::{
array::{
ArrayRef,
ArrayMut,
ToArrayRef
},
data_source::{
DataSource
},
data_type::{
DataType,
Type
},
indices::{
ToIndices
},
shape::{
Shape
},
type_cast_error::{
TypeCastError
}
}
}
};
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArrayOrder {
RowMajor,
#[allow(dead_code)]
ColumnMajor
}
struct ArrayInit< 'a > {
order: ArrayOrder,
shape: Shape,
kind: &'a str
}
fn dtype( py: Python, obj: &PyObject ) -> String {
obj
.getattr( py, "dtype" ).unwrap() .getattr( py, "name" ).unwrap()
.extract( py ).unwrap()
}
pub struct PyArray {
obj: PyObject,
shape: Shape,
ty: Type
}
unsafe fn as_array_object( obj: &PyObject ) -> &ffi::PyArrayObject {
&*(obj.as_ptr() as *const ffi::PyArrayObject)
}
unsafe fn as_array_object_mut( obj: &mut PyObject ) -> &mut ffi::PyArrayObject {
&mut *(obj.as_ptr() as *mut ffi::PyArrayObject)
}
impl PyArray {
pub(crate) unsafe fn from_object_unchecked( py: Python, obj: PyObject ) -> Self {
let ty = match dtype( py, &obj ).as_str() {
"float32" => Type::F32,
"uint32" => Type::U32,
"uint16" => Type::U16,
"uint8" => Type::U8,
"int32" => Type::I32,
"int16" => Type::I16,
"int8" => Type::I8,
ty => unimplemented!( "Unhandled array type: {}", ty )
};
let internal = as_array_object( &obj );
let slice = slice::from_raw_parts( internal.dims, internal.nd as usize );
let shape = slice.into_iter().cloned().map( |size| size as usize ).collect();
PyArray { obj, shape, ty }
}
fn as_array_object( &self ) -> &ffi::PyArrayObject {
unsafe { as_array_object( &self.obj ) }
}
fn as_array_object_mut( &mut self ) -> &mut ffi::PyArrayObject {
unsafe { as_array_object_mut( &mut self.obj ) }
}
pub(crate) fn new( py: Python, shape: Shape, ty: Type ) -> PyArray {
PyArray::new_internal( py, ArrayInit {
order: ArrayOrder::RowMajor,
shape,
kind: py_type_name( ty )
})
}
fn new_internal( py: Python, init: ArrayInit ) -> PyArray {
let np = py.import( "numpy" ).unwrap();
let kwargs = PyDict::new( py );
let shape = PyTuple::new( py, &init.shape );
kwargs.set_item( "shape", shape ).unwrap();
let order = match init.order {
ArrayOrder::RowMajor => "C",
ArrayOrder::ColumnMajor => "F"
};
kwargs.set_item( "order", order ).unwrap();
kwargs.set_item( "dtype", init.kind ).unwrap();
let obj = np.get( "ndarray" ).unwrap().call( (), Some( &kwargs ) ).unwrap().to_object( py );
unsafe { PyArray::from_object_unchecked( py, obj ) }
}
pub fn dimension_count( &self ) -> usize {
self.as_array_object().nd as _
}
pub fn shape( &self ) -> Shape {
self.shape.clone()
}
pub fn reshape< S >( &self, py: Python, shape: S ) -> PyArray where S: Into< Shape > {
let shape = shape.into();
let current_shape = self.shape();
assert_eq!(
shape.product(),
current_shape.product(),
"Tried to reshape an PyArray from {} into {} where their products don't match ({} != {})",
current_shape,
shape,
current_shape.product(),
shape.product()
);
let shape = PyTuple::new( py, &shape );
let obj = self.obj.getattr( py, "reshape" ).unwrap().call( py, (shape,), None ).unwrap().to_object( py );
unsafe { PyArray::from_object_unchecked( py, obj ) }
}
pub fn into_typed< T: DataType >( self ) -> Result< TypedPyArray< T >, TypeCastError< Self > > {
if self.ty == T::TYPE {
Ok( TypedPyArray( self, PhantomData ) )
} else {
Err( TypeCastError {
source: "an array",
target: "a typed array",
source_ty: self.ty.into(),
target_ty: T::TYPE,
obj: self
})
}
}
pub fn data_is< T: DataType >( &self ) -> bool {
self.ty == T::TYPE
}
pub fn data_type( &self ) -> Type {
self.ty
}
pub fn as_bytes( &self ) -> &[u8] {
unsafe {
slice::from_raw_parts( self.as_array_object().data as *const u8, self.shape().product() * self.data_type().byte_size() )
}
}
pub fn as_bytes_mut( &mut self ) -> &mut [u8] {
unsafe {
slice::from_raw_parts_mut( self.as_array_object().data as *mut u8, self.shape().product() * self.data_type().byte_size() )
}
}
pub(crate) fn as_py_obj( &self ) -> &PyObject {
&self.obj
}
}
pub struct PyArraySource {
pointer: *mut u8,
length: usize,
shape: Shape,
data_type: Type
}
unsafe impl Send for PyArraySource {}
unsafe impl Sync for PyArraySource {}
impl Drop for PyArraySource {
fn drop( &mut self ) {
unsafe {
libc::free( self.pointer as *mut libc::c_void );
}
}
}
impl PyArraySource {
pub fn new( mut array: PyArray ) -> Self {
assert!( array.dimension_count() > 1 );
let original_shape = array.shape();
let length = original_shape.into_iter().next().unwrap();
let shape = original_shape.into_iter().skip( 1 ).collect();
let data_type = array.data_type();
let pointer;
{
let array_object = array.as_array_object_mut();
assert_ne!( array_object.flags & ffi::NPY_ARRAY_OWNDATA, 0 );
assert_eq!( unsafe { &*array_object.descr }.flags & ffi::NPY_ITEM_REFCOUNT, 0 );
{
let dims = unsafe { slice::from_raw_parts_mut( array_object.dims, array_object.nd as usize ) };
for dim in dims.iter_mut() {
*dim = 0;
}
}
pointer = array_object.data;
unsafe {
ffi::PyTraceMalloc_Untrack( ffi::NPY_TRACE_DOMAIN, pointer as libc::uintptr_t );
}
array_object.data = 0 as _;
}
PyArraySource {
pointer,
length,
shape,
data_type
}
}
fn as_bytes( &self ) -> &[u8] {
unsafe { slice::from_raw_parts( self.pointer, self.length * self.shape.product() * self.data_type.byte_size() ) }
}
}
impl DataSource for PyArraySource {
fn data_type( &self ) -> Type {
self.data_type
}
fn shape( &self ) -> Shape {
self.shape.clone()
}
fn len( &self ) -> usize {
self.length
}
fn gather_bytes_into< I >( &self, indices: I, output: &mut [u8] ) where I: ToIndices {
let input = self.as_bytes();
let input = ArrayRef::new( self.shape(), self.data_type(), input );
let mut output = ArrayMut::new( self.shape(), self.data_type(), output );
output.gather_from( indices, &input );
}
}
impl ToArrayRef for PyArraySource {
fn to_array_ref( &self ) -> ArrayRef {
ArrayRef::new( self.shape(), self.data_type(), self.as_bytes() )
}
}
impl ToPyObject for PyArray {
fn to_object( &self, py: Python ) -> PyObject {
self.obj.clone_ref( py )
}
}
fn py_type_name( ty: Type ) -> &'static str {
match ty {
Type::F32 => "float32",
Type::I32 => "int32",
Type::I16 => "int16",
Type::I8 => "int8",
Type::U32 => "uint32",
Type::U16 => "uint16",
Type::U8 => "uint8"
}
}
pub struct TypedPyArray< T >( PyArray, PhantomData< T > );
impl< T: DataType > TypedPyArray< T > {
pub fn new( py: Python, shape: Shape ) -> Self {
let array = PyArray::new( py, shape, T::TYPE );
TypedPyArray( array, PhantomData )
}
pub fn to_vec( &self ) -> Vec< T > {
self.as_slice().to_vec()
}
pub fn as_slice( &self ) -> &[T] {
unsafe {
slice::from_raw_parts( self.as_array_object().data as *const T, self.shape().product() )
}
}
pub fn as_slice_mut( &mut self ) -> &mut [T] {
unsafe {
slice::from_raw_parts_mut( self.as_array_object().data as *mut T, self.shape().product() )
}
}
}
impl< T > Deref for TypedPyArray< T > {
type Target = PyArray;
#[inline]
fn deref( &self ) -> &Self::Target {
&self.0
}
}
impl< T > DerefMut for TypedPyArray< T > {
#[inline]
fn deref_mut( &mut self ) -> &mut Self::Target {
&mut self.0
}
}
impl< T: DataType > Into< Vec< T > > for TypedPyArray< T > {
#[inline]
fn into( self ) -> Vec< T > {
self.to_vec()
}
}
impl< 'a, T: DataType > Into< Vec< T > > for &'a TypedPyArray< T > {
#[inline]
fn into( self ) -> Vec< T > {
self.to_vec()
}
}
impl< 'a, T: DataType > Into< Vec< T > > for &'a mut TypedPyArray< T > {
#[inline]
fn into( self ) -> Vec< T > {
self.to_vec()
}
}
impl< T: DataType > Into< PyArray > for TypedPyArray< T > {
#[inline]
fn into( self ) -> PyArray {
self.0
}
}
impl< T: DataType + fmt::Debug > fmt::Debug for TypedPyArray< T > {
fn fmt( &self, fmt: &mut fmt::Formatter ) -> fmt::Result {
fmt.debug_list().entries( self.as_slice().iter() ).finish()
}
}
impl< T > ToPyObject for TypedPyArray< T > {
fn to_object( &self, py: Python ) -> PyObject {
self.0.obj.clone_ref( py )
}
}