use std::sync::{Arc, RwLock, RwLockReadGuard};
use ndarray::ArrayD;
use dlpk::sys::DLDevice;
use crate::c_api::{mts_array_t, mts_data_origin_t, mts_data_movement_t};
use crate::Error;
use crate::errors::check_status;
use super::{ArrayRef, ArrayRefMut};
use super::origin::get_data_origin;
pub struct MtsArray {
array: mts_array_t
}
impl Drop for MtsArray {
fn drop(&mut self) {
if let Some(destroy) = self.array.destroy {
unsafe { destroy(self.array.ptr) }
}
}
}
impl MtsArray {
pub fn from_raw(array: mts_array_t) -> MtsArray {
MtsArray { array }
}
pub fn into_raw(self) -> mts_array_t {
let array = self.array;
std::mem::forget(self);
array
}
#[inline]
pub fn as_any(&self) -> &dyn std::any::Any {
let origin = self.origin().unwrap_or(0);
assert_eq!(
origin, *super::array::RUST_DATA_ORIGIN,
"this array was not created as a rust Array (origin is '{}')",
get_data_origin(origin).unwrap_or_else(|_| "unknown".into())
);
let array = self.array.ptr.cast::<super::array::RustArray>();
unsafe {
return (*array).as_any();
}
}
#[inline]
fn as_lock<T>(&self) -> &Arc<RwLock<ArrayD<T>>> where T: 'static {
self.as_any().downcast_ref().expect("this is not an Arc<RwLock<ArrayD>>")
}
#[inline]
pub fn as_ndarray<T>(&self) -> RwLockReadGuard<'_, ArrayD<T>> where T: 'static {
return self.as_lock().read().expect("lock was poisoned");
}
pub fn as_raw(&self) -> &mts_array_t {
&self.array
}
pub fn as_raw_mut(&mut self) -> &mut mts_array_t {
&mut self.array
}
pub fn as_ref(&'_ self) -> ArrayRef<'_> {
unsafe { ArrayRef::from_raw(self.array) }
}
pub fn as_mut(&'_ mut self) -> ArrayRefMut<'_> {
unsafe { ArrayRefMut::from_raw(self.array) }
}
pub fn origin(&self) -> Result<mts_data_origin_t, Error> {
let function = self.array.origin.expect("mts_array_t.origin function is NULL");
let mut origin = 0;
unsafe {
check_status(function(self.array.ptr, &mut origin))?;
}
return Ok(origin);
}
pub fn device(&self) -> Result<DLDevice, Error> {
let function = self.array.device.expect("mts_array_t.device function is NULL");
let mut device = DLDevice::cpu();
unsafe {
check_status(function(self.array.ptr, &mut device))?;
}
return Ok(device);
}
pub fn dtype(&self) -> Result<dlpk::sys::DLDataType, Error> {
let function = self.array.dtype.expect("mts_array_t.dtype function is NULL");
let mut dtype = dlpk::sys::DLDataType { code: dlpk::sys::DLDataTypeCode::kDLFloat, bits: 0, lanes: 0 };
unsafe {
check_status(function(self.array.ptr, &mut dtype))?;
}
return Ok(dtype);
}
pub fn as_dlpack(
&self,
device: DLDevice,
stream: Option<i64>,
max_version: dlpk::sys::DLPackVersion,
) -> Result<dlpk::DLPackTensor, Error> {
let function = self.array.as_dlpack.expect("mts_array_t.as_dlpack function is NULL");
let mut tensor = std::ptr::null_mut();
let stream_c = stream.as_ref().map_or(std::ptr::null(), |s| s as *const i64);
unsafe {
check_status(function(self.array.ptr, &mut tensor, device, stream_c, max_version))?;
}
let tensor = unsafe {
dlpk::DLPackTensor::from_ptr(tensor)
};
return Ok(tensor);
}
pub fn from_dlpack(&self, dlpack_tensor: dlpk::DLPackTensor) -> Result<MtsArray, Error> {
let function = self.array.from_dlpack.expect("mts_array_t.from_dlpack function is NULL");
let mut new_array = mts_array_t::null();
unsafe {
check_status(function(self.array.ptr, dlpack_tensor.into_raw().as_ptr(), &mut new_array))?;
}
return Ok(MtsArray::from_raw(new_array));
}
pub fn shape(&self) -> Result<&[usize], Error> {
let function = self.array.shape.expect("mts_array_t.shape function is NULL");
let mut shape = std::ptr::null();
let mut shape_count: usize = 0;
unsafe {
check_status(function(self.array.ptr, &mut shape, &mut shape_count))?;
}
if shape_count == 0 {
return Ok(&[]);
} else {
assert!(!shape.is_null());
let shape = unsafe {
std::slice::from_raw_parts(shape, shape_count)
};
return Ok(shape);
}
}
pub fn reshape(&mut self, shape: &[usize]) -> Result<(), Error> {
let function = self.array.reshape.expect("mts_array_t.reshape function is NULL");
unsafe {
check_status(function(self.array.ptr, shape.as_ptr(), shape.len()))?;
}
return Ok(());
}
pub fn swap_axes(&mut self, axis_1: usize, axis_2: usize) -> Result<(), Error> {
let function = self.array.swap_axes.expect("mts_array_t.swap_axes function is NULL");
unsafe {
check_status(function(self.array.ptr, axis_1, axis_2))?;
}
return Ok(());
}
pub fn create(&self, shape: &[usize], fill_value: ArrayRef<'_>) -> Result<MtsArray, Error> {
let function = self.array.create.expect("mts_array_t.create function is NULL");
let mut new_array = mts_array_t::null();
unsafe {
check_status(function(
self.array.ptr,
shape.as_ptr(),
shape.len(),
*fill_value.as_raw(),
&mut new_array
))?;
}
return Ok(MtsArray::from_raw(new_array));
}
pub fn copy(&self, device: DLDevice) -> Result<MtsArray, Error> {
let function = self.array.copy.expect("mts_array_t.copy function is NULL");
let mut new_array = mts_array_t::null();
unsafe {
check_status(function(self.array.ptr, device, &mut new_array))?;
}
return Ok(MtsArray::from_raw(new_array));
}
pub fn move_data<'input>(
&mut self,
input: impl Into<ArrayRef<'input>>,
moves: &[mts_data_movement_t],
) -> Result<(), Error> {
let function = self.array.move_data.expect("mts_array_t.move_data function is NULL");
let input = input.into();
unsafe {
check_status(function(
self.array.ptr,
input.as_raw().ptr,
moves.as_ptr(),
moves.len(),
))?;
}
return Ok(());
}
}
impl<'a> From<&'a MtsArray> for ArrayRef<'a> {
fn from(array: &'a MtsArray) -> ArrayRef<'a> {
array.as_ref()
}
}
impl<'a> From<&'a mut MtsArray> for ArrayRefMut<'a> {
fn from(array: &'a mut MtsArray) -> ArrayRefMut<'a> {
array.as_mut()
}
}