use std::sync::{Arc, RwLock, TryLockError};
use dlpk::sys::{DLDevice, DLPackVersion, DLDataType};
use dlpk::{DLDataTypeCode, DLPackPointerCast, DLPackTensor, GetDLPackDataType};
use crate::errors::Error;
use crate::c_api::mts_data_movement_t;
use super::{Array, MtsArray};
impl<T> From<ndarray::ArrayD<T>> for MtsArray where T: 'static + Clone + Send + Default + Sync + GetDLPackDataType + DLPackPointerCast {
fn from(value: ndarray::ArrayD<T>) -> Self {
let array = Arc::new(RwLock::new(value));
let boxed: Box<dyn Array> = Box::new(array);
return MtsArray::from(boxed);
}
}
impl<T> Array for Arc<RwLock<ndarray::ArrayD<T>>>
where
T: 'static + Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast,
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn create(&self, shape: &[usize], fill_value: MtsArray) -> Box<dyn Array> {
let cpu_device = DLDevice::cpu();
let max_version = DLPackVersion::current();
let fill_value_dlpack = fill_value.as_dlpack(cpu_device, None, max_version).expect("failed to extract fill_value as DLPack");
assert_eq!(fill_value_dlpack.shape(), &[], "fill_value must be a single scalar");
assert_eq!(fill_value_dlpack.device(), cpu_device, "fill_value must be on CPU");
let fill_value_ptr = fill_value_dlpack.data_ptr::<T>().expect("dtype mismatch between array and fill_value");
let fill_value_scalar = unsafe { std::ptr::read(fill_value_ptr) };
let array = ndarray::Array::from_elem(shape, fill_value_scalar);
return Box::new(Arc::new(RwLock::new(array)));
}
fn copy(&self, device: DLDevice) -> Box<dyn Array> {
assert_eq!(device, DLDevice::cpu(), "Rust ndarray data can only be copied to CPU device");
return Box::new(self.clone());
}
fn shape(&self) -> Vec<usize> {
match self.try_read() {
Ok(lock) => lock.shape().to_vec(),
Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
Err(TryLockError::WouldBlock) => panic!("array is already locked"),
}
}
fn reshape(&mut self, shape: &[usize]) {
let mut lock = match self.try_write() {
Ok(lock) => lock,
Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
Err(TryLockError::WouldBlock) => panic!("array is already locked"),
};
let array = std::mem::take(&mut *lock);
let array = array.into_shape_clone(shape).expect("invalid shape");
let _ = std::mem::replace(&mut *lock, array);
}
fn swap_axes(&mut self, axis_1: usize, axis_2: usize) {
let mut lock = match self.try_write() {
Ok(lock) => lock,
Err(TryLockError::Poisoned(_)) => panic!("array lock is poisoned"),
Err(TryLockError::WouldBlock) => panic!("array is already locked"),
};
lock.swap_axes(axis_1, axis_2);
}
fn move_data(
&mut self,
input: &dyn Array,
movements: &[mts_data_movement_t],
) {
use ndarray::{Axis, Slice};
let input = input.as_any().downcast_ref::<Self>().expect("input must be a ndarray of the same type");
let input = match input.try_read() {
Ok(lock) => lock,
Err(TryLockError::Poisoned(_)) => panic!("input array lock is poisoned"),
Err(TryLockError::WouldBlock) => panic!("input array is already locked"),
};
let mut output = match self.try_write() {
Ok(lock) => lock,
Err(TryLockError::Poisoned(_)) => panic!("output array lock is poisoned"),
Err(TryLockError::WouldBlock) => panic!("output array is already locked"),
};
if movements.is_empty() {
return;
}
let first_prop_start_in = movements[0].properties_start_in;
let first_prop_start_out = movements[0].properties_start_out;
let first_prop_len = movements[0].properties_length;
let mut constant_properties = true;
let mut contiguous_input_samples = true;
let mut contiguous_output_samples = true;
for w in movements.windows(2) {
if w[0].properties_start_in != first_prop_start_in ||
w[0].properties_start_out != first_prop_start_out ||
w[0].properties_length != first_prop_len {
constant_properties = false;
break;
}
if w[1].sample_in != w[0].sample_in + 1 {
contiguous_input_samples = false;
}
if w[1].sample_out != w[0].sample_out + 1 {
contiguous_output_samples = false;
}
}
if constant_properties {
let last = movements.last().unwrap();
if last.properties_start_in != first_prop_start_in ||
last.properties_start_out != first_prop_start_out ||
last.properties_length != first_prop_len {
constant_properties = false;
}
}
let property_axis = output.shape().len() - 1;
if constant_properties {
let input_slice_info = Slice::from(first_prop_start_in..(first_prop_start_in + first_prop_len));
let output_slice_info = Slice::from(first_prop_start_out..(first_prop_start_out + first_prop_len));
if contiguous_input_samples && contiguous_output_samples {
let sample_start_in = movements[0].sample_in;
let sample_start_out = movements[0].sample_out;
let sample_count = movements.len();
let input_samples = input.slice_axis(
Axis(0),
Slice::from(sample_start_in..(sample_start_in + sample_count))
);
let mut output_samples = output.slice_axis_mut(
Axis(0),
Slice::from(sample_start_out..(sample_start_out + sample_count))
);
let value = input_samples.slice_axis(Axis(property_axis), input_slice_info);
let mut output_location = output_samples.slice_axis_mut(Axis(property_axis), output_slice_info);
output_location.assign(&value);
} else {
for move_item in movements {
let input_sample = input.index_axis(Axis(0), move_item.sample_in);
let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
let value = input_sample.slice_axis(
Axis(property_axis - 1),
input_slice_info
);
let mut output_location = output_sample.slice_axis_mut(
Axis(property_axis - 1),
output_slice_info
);
output_location.assign(&value);
}
}
} else {
for move_item in movements {
let input_sample = input.index_axis(Axis(0), move_item.sample_in);
let mut output_sample = output.index_axis_mut(Axis(0), move_item.sample_out);
let value = input_sample.slice_axis(
Axis(property_axis - 1),
Slice::from(move_item.properties_start_in..(move_item.properties_start_in + move_item.properties_length))
);
let mut output_location = output_sample.slice_axis_mut(
Axis(property_axis - 1),
Slice::from(move_item.properties_start_out..(move_item.properties_start_out + move_item.properties_length))
);
output_location.assign(&value);
}
}
}
fn device(&self) -> DLDevice {
DLDevice::cpu()
}
fn dtype(&self) -> DLDataType {
T::get_dlpack_data_type()
}
fn as_dlpack(
&self,
device: DLDevice,
stream: Option<i64>,
max_version: DLPackVersion,
) -> Result<DLPackTensor, Error> {
if stream.is_some() {
return Err(Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: "CPU arrays can not be used with a stream".into(),
});
}
let vendored_version = DLPackVersion::current();
let major_mismatch = max_version.major != vendored_version.major;
let minor_too_high = max_version.minor < vendored_version.minor;
if major_mismatch || minor_too_high {
return Err(Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: format!(
"Metatensor supports DLPack version {}.{}. Caller requested incompatible version {}.{}",
vendored_version.major, vendored_version.minor, max_version.major, max_version.minor
),
});
}
let ndarray_device = DLDevice::cpu();
if device.device_type != ndarray_device.device_type || device.device_id != ndarray_device.device_id {
return Err(Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: format!(
"Requested DLPack device ({}) does not match array device ({})",
device, ndarray_device
),
});
}
let tensor: DLPackTensor = Arc::clone(self).try_into().map_err(|e| Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: format!("failed to convert ndarray to DLPack: {:?}", e),
})?;
Ok(tensor)
}
#[allow(clippy::enum_glob_use)]
fn from_dlpack(&self, dlpack_tensor: DLPackTensor) -> Result<Box<dyn Array>, Error> {
use DLDataTypeCode::*;
let dtype = dlpack_tensor.dtype();
if dtype.lanes != 1 {
return Err(Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: "Only DLPack tensors with lanes == 1 are supported".into(),
});
}
let map_error = |e| Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: format!("failed to convert DLPack to ndarray: {:?}", e),
};
if dtype.code == kDLFloat && dtype.bits == 64 {
let array: ndarray::ArrayD<f64> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLFloat && dtype.bits == 32 {
let array: ndarray::ArrayD<f32> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLInt && dtype.bits == 8 {
let array: ndarray::ArrayD<i8> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLInt && dtype.bits == 16 {
let array: ndarray::ArrayD<i16> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLInt && dtype.bits == 32 {
let array: ndarray::ArrayD<i32> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLInt && dtype.bits == 64 {
let array: ndarray::ArrayD<i64> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLUInt && dtype.bits == 8 {
let array: ndarray::ArrayD<u8> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLUInt && dtype.bits == 16 {
let array: ndarray::ArrayD<u16> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLUInt && dtype.bits == 32 {
let array: ndarray::ArrayD<u32> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLUInt && dtype.bits == 64 {
let array: ndarray::ArrayD<u64> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else if dtype.code == kDLBool && dtype.bits == 8 {
let array: ndarray::ArrayD<bool> = dlpack_tensor.try_into().map_err(map_error)?;
return Ok(Box::new(Arc::new(RwLock::new(array))));
} else {
return Err(Error {
code: Some(crate::c_api::MTS_INVALID_PARAMETER_ERROR),
message: format!("Unsupported DLPack dtype {}", dtype),
});
}
}
}
#[cfg(test)]
mod tests {
use dlpk::{DLPackPointerCast, GetDLPackDataType, sys::{DLDataTypeCode, DLDevice, DLPackVersion}};
use crate::MtsArray;
#[test]
fn ndarray_as_mts_array() {
let data = ndarray::Array::<f64, _>::zeros(vec![2, 3, 4]);
let mts_array = MtsArray::from(data);
assert_eq!(mts_array.shape().unwrap(), [2, 3, 4]);
let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], 42.0));
let created = mts_array.create(&[2, 3, 4], fill_value.as_ref()).unwrap();
assert_eq!(created.shape().unwrap(), [2, 3, 4]);
}
#[test]
fn ndarray_as_mts_array_dlpack() {
let data = ndarray::Array::<f64, _>::zeros(vec![4, 5, 6]);
let mts_array = MtsArray::from(data);
let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
assert_eq!(dl_managed.n_dims(), 3);
assert_eq!(dl_managed.shape(), [4, 5, 6]);
assert_eq!(dl_managed.dtype().code, DLDataTypeCode::kDLFloat);
assert_eq!(dl_managed.dtype().bits, 64);
assert_eq!(dl_managed.dtype().lanes, 1);
}
#[test]
fn ndarray_all_dtypes() {
fn test_for_dtype<T>(code: DLDataTypeCode, bits: u8) where T: Send + Sync + Clone + Default + GetDLPackDataType + DLPackPointerCast + 'static {
let data = ndarray::Array::<T, _>::from_elem(vec![2, 2], T::default());
let mts_array = MtsArray::from(data);
assert_eq!(mts_array.shape().unwrap(), [2, 2]);
let dl_managed = mts_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
assert_eq!(dl_managed.dtype().code, code);
assert_eq!(dl_managed.dtype().bits, bits);
assert_eq!(dl_managed.dtype().lanes, 1);
let fill_value = MtsArray::from(ndarray::Array::from_elem(vec![], T::default()));
let created = mts_array.create(&[1, 1], fill_value.as_ref()).unwrap();
let dl_managed = created.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
assert_eq!(dl_managed.dtype().code, code);
assert_eq!(dl_managed.dtype().bits, bits);
assert_eq!(dl_managed.dtype().lanes, 1);
}
test_for_dtype::<bool>(DLDataTypeCode::kDLBool, 8);
test_for_dtype::<f64>(DLDataTypeCode::kDLFloat, 64);
test_for_dtype::<f32>(DLDataTypeCode::kDLFloat, 32);
test_for_dtype::<i8>(DLDataTypeCode::kDLInt, 8);
test_for_dtype::<i16>(DLDataTypeCode::kDLInt, 16);
test_for_dtype::<i32>(DLDataTypeCode::kDLInt, 32);
test_for_dtype::<i64>(DLDataTypeCode::kDLInt, 64);
test_for_dtype::<u8>(DLDataTypeCode::kDLUInt, 8);
test_for_dtype::<u16>(DLDataTypeCode::kDLUInt, 16);
test_for_dtype::<u32>(DLDataTypeCode::kDLUInt, 32);
test_for_dtype::<u64>(DLDataTypeCode::kDLUInt, 64);
}
#[test]
fn ndarray_device() {
let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
let mts_array = MtsArray::from(data);
assert_eq!(mts_array.device().unwrap(), DLDevice::cpu());
}
#[test]
fn as_dlpack_rejects_stream() {
let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
let mts_array = MtsArray::from(data);
match mts_array.as_dlpack(DLDevice::cpu(), Some(42), DLPackVersion::current()) {
Err(e) => assert!(e.message.contains("stream"), "{}", e.message),
Ok(_) => panic!("expected error for non-null stream"),
}
}
#[test]
fn as_dlpack_rejects_wrong_device() {
let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
let mts_array = MtsArray::from(data);
let cuda = DLDevice {
device_type: dlpk::sys::DLDeviceType::kDLCUDA,
device_id: 0,
};
match mts_array.as_dlpack(cuda, None, DLPackVersion::current()) {
Err(e) => assert!(e.message.contains("does not match"), "{}", e.message),
Ok(_) => panic!("expected error for CUDA device on CPU array"),
}
}
#[test]
fn as_dlpack_rejects_incompatible_version() {
let data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
let mts_array = MtsArray::from(data);
let bad_version = DLPackVersion { major: 99, minor: 0 };
match mts_array.as_dlpack(DLDevice::cpu(), None, bad_version) {
Err(e) => assert!(e.message.contains("version"), "{}", e.message),
Ok(_) => panic!("expected error for incompatible DLPack version"),
}
}
#[test]
#[allow(clippy::float_cmp)]
fn from_dlpack() {
let mut f64_data = ndarray::Array::<f64, _>::zeros(vec![2, 3]);
f64_data[[0, 0]] = 1.573;
f64_data[[1, 2]] = -42.0;
let f64_array = MtsArray::from(f64_data);
let mut i16_data = ndarray::Array::<i16, _>::zeros(vec![2, 5, 10]);
i16_data[[0, 1, 3]] = 3;
i16_data[[1, 2, 4]] = -42;
let i16_array = MtsArray::from(i16_data);
let f64_dl_tensor = f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
let i16_dl_tensor = i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
let new_f64_array = f64_array.from_dlpack(f64_dl_tensor).unwrap();
let new_i16_array = i16_array.from_dlpack(i16_dl_tensor).unwrap();
assert_eq!(f64_array.origin().unwrap(), i16_array.origin().unwrap());
assert_eq!(new_f64_array.origin().unwrap(), f64_array.origin().unwrap());
assert_eq!(new_i16_array.origin().unwrap(), i16_array.origin().unwrap());
let new_f64_dl_tensor = new_f64_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
let new_i16_dl_tensor = new_i16_array.as_dlpack(DLDevice::cpu(), None, DLPackVersion::current()).unwrap();
let new_f64_data: ndarray::ArrayD<f64> = new_f64_dl_tensor.try_into().unwrap();
let new_i16_data: ndarray::ArrayD<i16> = new_i16_dl_tensor.try_into().unwrap();
assert_eq!(new_f64_data[[0, 0]], 1.573);
assert_eq!(new_f64_data[[1, 2]], -42.0);
assert_eq!(new_i16_data[[0, 1, 3]], 3);
assert_eq!(new_i16_data[[1, 2, 4]], -42);
}
}