use ndarray::{ArrayView, ArrayViewMut, Dim, Dimension, IntoDimension, Ix, IxDyn};
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::{mem, slice};
use tch::{kind::Element, Device, Kind, Tensor};
use thiserror::Error;
#[derive(Debug)]
pub struct ExclusiveTensor<E, D>
where
D: Dimension,
{
tensor: Tensor,
shape: D,
num_elements: usize,
element_type: PhantomData<E>,
}
impl<E, D> ExclusiveTensor<E, D>
where
E: Element,
D: Dimension + IntoTorchShape,
{
pub fn zeros<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
unsafe {
Self::from_tensor_fn(shape, |shape, kind| {
Tensor::zeros(shape, (kind, Device::Cpu))
})
}
}
pub fn ones<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
unsafe {
Self::from_tensor_fn(shape, |shape, kind| {
Tensor::ones(shape, (kind, Device::Cpu))
})
}
}
unsafe fn from_tensor_fn<Sh, F>(shape: Sh, f: F) -> Self
where
Sh: IntoDimension<Dim = D>,
F: FnOnce(&[i64], Kind) -> Tensor,
{
let shape = shape.into_dimension();
let num_elements = match shape.size_checked() {
Some(size) if size < isize::MAX as usize => size,
_ => panic!("number of elements must not exceed isize::MAX"),
};
match num_elements.checked_mul(mem::size_of::<E>()) {
Some(size) if size < isize::MAX as usize => {}
_ => panic!("size of allocated memory must not exceed isize::MAX"),
}
let tensor = f(shape.clone().into_torch_shape().as_ref(), E::KIND);
Self {
tensor,
shape,
num_elements,
element_type: PhantomData,
}
}
}
impl<E, D: Dimension> ExclusiveTensor<E, D> {
#[allow(clippy::missing_const_for_fn)] pub fn into_tensor(self) -> Tensor {
self.tensor
}
}
impl<E> ExclusiveTensor<E, IxDyn>
where
E: Element,
{
pub fn try_copy_from(tensor: &Tensor) -> Result<Self, ExclusiveTensorError> {
let kind = tensor.kind();
if kind != E::KIND {
return Err(ExclusiveTensorError::MismatchedKind {
expected: E::KIND,
actual: kind,
});
}
let shape_vec: Vec<usize> = tensor
.size()
.into_iter()
.map(|d| d.try_into().unwrap()) .collect();
let shape = IxDyn(&shape_vec);
unsafe {
Ok(Self::from_tensor_fn(shape, |shape, kind| {
let mut new_tensor = Tensor::zeros(shape, (kind, Device::Cpu));
new_tensor.copy_(tensor);
new_tensor
}))
}
}
}
impl<E, D> ExclusiveTensor<E, D>
where
E: Element,
D: Dimension,
{
pub fn as_slice(&self) -> &[E] {
unsafe { slice::from_raw_parts(self.data_ptr().as_ptr(), self.num_elements) }
}
pub fn as_slice_mut(&mut self) -> &mut [E] {
unsafe { slice::from_raw_parts_mut(self.data_ptr().as_ptr(), self.num_elements) }
}
pub fn array_view(&self) -> ArrayView<E, D> {
unsafe { ArrayView::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }
}
pub fn array_view_mut(&mut self) -> ArrayViewMut<E, D> {
unsafe { ArrayViewMut::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }
}
fn data_ptr(&self) -> NonNull<E> {
if self.num_elements == 0 {
NonNull::dangling()
} else {
NonNull::new(self.tensor.data_ptr().cast()).expect("unexpected null data_ptr")
}
}
}
impl<E, D: Dimension> From<ExclusiveTensor<E, D>> for Tensor {
fn from(exclusive: ExclusiveTensor<E, D>) -> Self {
exclusive.into_tensor()
}
}
impl<'a, E, D> From<&'a ExclusiveTensor<E, D>> for ArrayView<'a, E, D>
where
E: Element,
D: Dimension,
{
fn from(exclusive: &'a ExclusiveTensor<E, D>) -> Self {
exclusive.array_view()
}
}
fn to_i64(x: Ix) -> i64 {
x.try_into().expect("dimension too large")
}
pub trait IntoTorchShape {
type TorchDim: AsRef<[i64]>;
fn into_torch_shape(self) -> Self::TorchDim;
}
impl IntoTorchShape for IxDyn {
type TorchDim = Vec<i64>;
fn into_torch_shape(self) -> Self::TorchDim {
self.as_array_view()
.into_iter()
.map(|&x| to_i64(x))
.collect()
}
}
impl IntoTorchShape for Dim<[Ix; 0]> {
type TorchDim = [i64; 0];
fn into_torch_shape(self) -> Self::TorchDim {
[]
}
}
impl IntoTorchShape for Dim<[Ix; 1]> {
type TorchDim = [i64; 1];
fn into_torch_shape(self) -> Self::TorchDim {
[self.into_pattern() as _]
}
}
impl IntoTorchShape for Dim<[Ix; 2]> {
type TorchDim = [i64; 2];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b) = self.into_pattern();
[to_i64(a), to_i64(b)]
}
}
impl IntoTorchShape for Dim<[Ix; 3]> {
type TorchDim = [i64; 3];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c)]
}
}
impl IntoTorchShape for Dim<[Ix; 4]> {
type TorchDim = [i64; 4];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c), to_i64(d)]
}
}
impl IntoTorchShape for Dim<[Ix; 5]> {
type TorchDim = [i64; 5];
#[allow(clippy::many_single_char_names)]
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d, e) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c), to_i64(d), to_i64(e)]
}
}
impl IntoTorchShape for Dim<[Ix; 6]> {
type TorchDim = [i64; 6];
#[allow(clippy::many_single_char_names)]
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d, e, f) = self.into_pattern();
[
to_i64(a),
to_i64(b),
to_i64(c),
to_i64(d),
to_i64(e),
to_i64(f),
]
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
pub enum ExclusiveTensorError {
#[error("expected kind {expected:?} but got {actual:?}")]
MismatchedKind { expected: Kind, actual: Kind },
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array};
#[test]
fn zeros() {
let u = ExclusiveTensor::<f32, _>::zeros([2, 4, 3]);
let tensor: Tensor = u.into();
assert_eq!(tensor.size(), [2, 4, 3]);
assert_eq!(tensor.kind(), Kind::Float);
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(
tensor,
Tensor::zeros(&[2, 4, 3], (Kind::Float, Device::Cpu))
);
}
#[test]
fn ones() {
let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
let tensor: Tensor = u.into();
assert_eq!(tensor.size(), [2, 4, 3]);
assert_eq!(tensor.kind(), Kind::Float);
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(tensor, Tensor::ones(&[2, 4, 3], (Kind::Float, Device::Cpu)));
}
#[test]
fn try_copy_from() {
let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
let copy = ExclusiveTensor::<f32, _>::try_copy_from(&src).unwrap();
let tensor: Tensor = copy.into();
assert_eq!(tensor.size(), [2, 3]);
assert_eq!(tensor.kind(), Kind::Float);
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(tensor, src);
}
#[test]
fn try_copy_from_cuda_if_available() {
let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])
.to_device(Device::cuda_if_available());
let copy = ExclusiveTensor::<f32, _>::try_copy_from(&src).unwrap();
let tensor: Tensor = copy.into();
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(tensor, src.to_device(Device::Cpu));
}
#[test]
fn try_copy_from_mismatched_type() {
let src = Tensor::of_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(
ExclusiveTensor::<f64, _>::try_copy_from(&src).unwrap_err(),
ExclusiveTensorError::MismatchedKind {
expected: Kind::Double,
actual: Kind::Float
}
);
}
#[test]
#[allow(clippy::float_cmp)]
fn slice_f64() {
let u = ExclusiveTensor::<f64, _>::ones([3, 1, 2]);
assert_eq!(u.as_slice().len(), 6);
assert_eq!(u.as_slice(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn slice_mut_i16() {
let mut u = ExclusiveTensor::<i16, _>::ones([3, 1, 2]);
assert_eq!(u.as_slice_mut().len(), 6);
for (i, x) in u.as_slice_mut().iter_mut().enumerate() {
*x = i.try_into().unwrap()
}
assert_eq!(u.as_slice(), &[0, 1, 2, 3, 4, 5]);
let tensor: Tensor = u.into();
assert_eq!(
tensor,
Tensor::of_slice(&[0, 1, 2, 3, 4, 5]).reshape(&[3, 1, 2])
);
}
#[test]
fn array_view_f32() {
let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
let view = u.array_view();
assert_eq!(view.dim(), (2, 4, 3));
assert_eq!(view, Array::<f32, _>::ones((2, 4, 3)));
}
#[test]
#[allow(clippy::unit_cmp)]
fn array_view_i64_scalar() {
let u = ExclusiveTensor::<i64, _>::ones([]);
let view = u.array_view();
assert_eq!(view.dim(), ());
assert_eq!(view.into_scalar(), &1);
}
#[test]
fn array_view_f32_empty() {
let u = ExclusiveTensor::<f32, _>::ones([0]);
let view = u.array_view();
assert_eq!(view.dim(), 0);
assert!(view.as_slice().unwrap().is_empty());
}
#[test]
fn array_view_mut() {
let mut u = ExclusiveTensor::<i32, _>::ones([3, 4]);
let mut view = u.array_view_mut();
for (i, mut row) in view.rows_mut().into_iter().enumerate() {
for (j, cell) in row.iter_mut().enumerate() {
*cell = (i * 10 + j).try_into().unwrap();
}
}
let expected = arr2(&[[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23]]);
assert_eq!(view, expected); let t: Tensor = u.into();
let expected: Tensor = expected.try_into().unwrap();
assert_eq!(t, expected); }
#[test]
fn array_view_mut_empty() {
let mut u = ExclusiveTensor::<f32, _>::ones([2, 0, 3]);
let mut view = u.array_view_mut();
assert!(view.as_slice_mut().unwrap().is_empty());
}
}