use crate::{ToCv, TryAsRefCv, TryToCv};
use anyhow::{ensure, Error, Result};
use slice_of_array::prelude::*;
use std::{mem::ManuallyDrop, ops::Deref, slice};
macro_rules! impl_from_array {
($elem:ty, 1) => {
impl<'a, const N: usize> TryAsRefCv<'a, TensorAsArray<'a, [$elem; N]>> for tch::Tensor {
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [$elem; N]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(self.size() == &[N as i64]);
let slice: &[$elem] =
unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N) };
#[allow(unstable_name_collisions)]
let array = slice.as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<const N: usize> TryToCv<[$elem; N]> for tch::Tensor {
type Error = Error;
fn try_to_cv(&self) -> Result<[$elem; N], Self::Error> {
ensure!(self.size() == &[N as i64]);
let mut array = [Default::default(); N];
self.f_copy_data(array.as_mut(), N)?;
Ok(array)
}
}
impl<const N: usize> ToCv<tch::Tensor> for [$elem; N] {
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.as_ref())
}
}
};
($elem:ty, 2) => {
impl<'a, const N1: usize, const N2: usize> TryAsRefCv<'a, TensorAsArray<'a, [[$elem; N2]; N1]>>
for tch::Tensor
{
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[$elem; N2]; N1]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(self.size() == &[N1 as i64, N2 as i64]);
let slice: &[$elem] =
unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2) };
#[allow(unstable_name_collisions)]
let array = slice.nest().as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<const N1: usize, const N2: usize> TryToCv<[[$elem; N2]; N1]> for tch::Tensor {
type Error = Error;
fn try_to_cv(&self) -> Result<[[$elem; N2]; N1], Self::Error> {
ensure!(self.size() == &[N1 as i64, N2 as i64]);
let mut array = [[Default::default(); N2]; N1];
self.f_copy_data(array.flat_mut(), N1 * N2)?;
Ok(array)
}
}
impl<const N1: usize, const N2: usize> ToCv<tch::Tensor> for [[$elem; N2]; N1] {
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.flat()).view([N1 as i64, N2 as i64])
}
}
};
($elem:ty, 3) => {
impl<'a, const N1: usize, const N2: usize, const N3: usize>
TryAsRefCv<'a, TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>> for tch::Tensor
{
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[$elem; N3]; N2]; N1]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
let slice: &[$elem] =
unsafe { slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3) };
#[allow(unstable_name_collisions)]
let array = slice.nest().nest().as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<const N1: usize, const N2: usize, const N3: usize> TryToCv<[[[$elem; N3]; N2]; N1]>
for tch::Tensor
{
type Error = Error;
fn try_to_cv(&self) -> Result<[[[$elem; N3]; N2]; N1], Self::Error> {
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64]);
let mut array = [[[Default::default(); N3]; N2]; N1];
self.f_copy_data(array.flat_mut().flat_mut(), N1 * N2 * N3)?;
Ok(array)
}
}
impl<const N1: usize, const N2: usize, const N3: usize> ToCv<tch::Tensor>
for [[[$elem; N3]; N2]; N1]
{
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.flat().flat()).view([N1 as i64, N2 as i64, N3 as i64])
}
}
};
($elem:ty, 4) => {
impl<'a, const N1: usize, const N2: usize, const N3: usize, const N4: usize>
TryAsRefCv<'a, TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>> for tch::Tensor
{
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[$elem; N4]; N3]; N2]; N1]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
let slice: &[$elem] = unsafe {
slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4)
};
#[allow(unstable_name_collisions)]
let array = slice.nest().nest().nest().as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
TryToCv<[[[[$elem; N4]; N3]; N2]; N1]> for tch::Tensor
{
type Error = Error;
fn try_to_cv(&self) -> Result<[[[[$elem; N4]; N3]; N2]; N1], Self::Error> {
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64]);
let mut array = [[[[Default::default(); N4]; N3]; N2]; N1];
self.f_copy_data(array.flat_mut().flat_mut().flat_mut(), N1 * N2 * N3 * N4)?;
Ok(array)
}
}
impl<const N1: usize, const N2: usize, const N3: usize, const N4: usize>
ToCv<tch::Tensor> for [[[[$elem; N4]; N3]; N2]; N1]
{
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.flat().flat().flat())
.view([N1 as i64, N2 as i64, N3 as i64, N4 as i64])
}
}
};
($elem:ty, 5) => {
impl<
'a,
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
> TryAsRefCv<'a, TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
{
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[$elem; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
let slice: &[$elem] = unsafe {
slice::from_raw_parts(self.data_ptr() as *mut $elem, N1 * N2 * N3 * N4 * N5)
};
#[allow(unstable_name_collisions)]
let array = slice.nest().nest().nest().nest().as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
> TryToCv<[[[[[$elem; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
{
type Error = Error;
fn try_to_cv(&self) -> Result<[[[[[$elem; N5]; N4]; N3]; N2]; N1], Self::Error> {
ensure!(self.size() == &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64]);
let mut array = [[[[[Default::default(); N5]; N4]; N3]; N2]; N1];
self.f_copy_data(
array.flat_mut().flat_mut().flat_mut().flat_mut(),
N1 * N2 * N3 * N4 * N5,
)?;
Ok(array)
}
}
impl<
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
> ToCv<tch::Tensor> for [[[[[$elem; N5]; N4]; N3]; N2]; N1]
{
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.flat().flat().flat().flat())
.view([N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64])
}
}
};
($elem:ty, 6) => {
impl<
'a,
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
const N6: usize,
> TryAsRefCv<'a, TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>> for tch::Tensor
{
type Error = Error;
fn try_as_ref_cv(&'a self) -> Result<TensorAsArray<'a, [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]>, Self::Error> {
ensure!(self.device() == tch::Device::Cpu);
ensure!(self.kind() == <$elem as tch::kind::Element>::KIND);
ensure!(
self.size()
== &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
);
let slice: &[$elem] = unsafe {
slice::from_raw_parts(
self.data_ptr() as *mut $elem,
N1 * N2 * N3 * N4 * N5 * N6,
)
};
#[allow(unstable_name_collisions)]
let array = slice.nest().nest().nest().nest().nest().as_array();
Ok(TensorAsArray {
data: ManuallyDrop::new(*array),
_tensor: self,
})
}
}
impl<
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
const N6: usize,
> TryToCv<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]> for tch::Tensor
{
type Error = Error;
fn try_to_cv(&self) -> Result<[[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1], Self::Error> {
ensure!(
self.size()
== &[N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64]
);
let mut array = [[[[[[Default::default(); N6]; N5]; N4]; N3]; N2]; N1];
self.f_copy_data(
array.flat_mut().flat_mut().flat_mut().flat_mut().flat_mut(),
N1 * N2 * N3 * N4 * N5 * N6,
)?;
Ok(array)
}
}
impl<
const N1: usize,
const N2: usize,
const N3: usize,
const N4: usize,
const N5: usize,
const N6: usize,
> ToCv<tch::Tensor> for [[[[[[$elem; N6]; N5]; N4]; N3]; N2]; N1]
{
fn to_cv(&self) -> tch::Tensor {
tch::Tensor::from_slice(self.flat().flat().flat().flat().flat()).view([
N1 as i64, N2 as i64, N3 as i64, N4 as i64, N5 as i64, N6 as i64,
])
}
}
};
}
impl_from_array!(u8, 1);
impl_from_array!(u8, 2);
impl_from_array!(u8, 3);
impl_from_array!(u8, 4);
impl_from_array!(u8, 5);
impl_from_array!(u8, 6);
impl_from_array!(i8, 1);
impl_from_array!(i8, 2);
impl_from_array!(i8, 3);
impl_from_array!(i8, 4);
impl_from_array!(i8, 5);
impl_from_array!(i8, 6);
impl_from_array!(i16, 1);
impl_from_array!(i16, 2);
impl_from_array!(i16, 3);
impl_from_array!(i16, 4);
impl_from_array!(i16, 5);
impl_from_array!(i16, 6);
impl_from_array!(i32, 1);
impl_from_array!(i32, 2);
impl_from_array!(i32, 3);
impl_from_array!(i32, 4);
impl_from_array!(i32, 5);
impl_from_array!(i32, 6);
impl_from_array!(i64, 1);
impl_from_array!(i64, 2);
impl_from_array!(i64, 3);
impl_from_array!(i64, 4);
impl_from_array!(i64, 5);
impl_from_array!(i64, 6);
impl_from_array!(half::f16, 1);
impl_from_array!(half::f16, 2);
impl_from_array!(half::f16, 3);
impl_from_array!(half::f16, 4);
impl_from_array!(half::f16, 5);
impl_from_array!(half::f16, 6);
impl_from_array!(f32, 1);
impl_from_array!(f32, 2);
impl_from_array!(f32, 3);
impl_from_array!(f32, 4);
impl_from_array!(f32, 5);
impl_from_array!(f32, 6);
impl_from_array!(f64, 1);
impl_from_array!(f64, 2);
impl_from_array!(f64, 3);
impl_from_array!(f64, 4);
impl_from_array!(f64, 5);
impl_from_array!(f64, 6);
impl_from_array!(bool, 1);
impl_from_array!(bool, 2);
impl_from_array!(bool, 3);
impl_from_array!(bool, 4);
impl_from_array!(bool, 5);
impl_from_array!(bool, 6);
pub use tensors::*;
mod tensors {
use super::*;
#[derive(Debug)]
pub struct TensorAsArray<'a, T> {
pub(crate) data: ManuallyDrop<T>,
pub(crate) _tensor: &'a tch::Tensor,
}
impl<'a, T> Drop for TensorAsArray<'a, T> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.data);
}
}
}
impl<'a, T> AsRef<T> for TensorAsArray<'a, T> {
fn as_ref(&self) -> &T {
&self.data
}
}
impl<'a, T> Deref for TensorAsArray<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
#[derive(Debug)]
pub struct TchTensorAsImage {
pub(crate) tensor: tch::Tensor,
pub(crate) kind: TchTensorImageShape,
}
#[derive(Debug, Clone, Copy)]
pub enum TchTensorImageShape {
Whc,
Hwc,
Chw,
Cwh,
}
impl TchTensorAsImage {
pub fn new(tensor: tch::Tensor, kind: TchTensorImageShape) -> Result<Self> {
let ndim = tensor.dim();
ensure!(
ndim == 3,
"the tensor must have 3 dimensions, but get {}",
ndim
);
Ok(Self { tensor, kind })
}
pub fn into_inner(self) -> tch::Tensor {
self.tensor
}
pub fn kind(&self) -> TchTensorImageShape {
self.kind
}
pub fn try_to_cv<T>(&self) -> Result<T, <Self as TryToCv<T>>::Error>
where
Self: TryToCv<T>,
{
TryToCv::try_to_cv(self)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TryAsRefCv, TryToCv, ToCv};
use rand::prelude::*;
#[test]
fn tensor_to_array_ref() {
let mut rng = rand::thread_rng();
{
type T = [f32; 3];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
{
type T = [[f32; 3]; 2];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
{
type T = [[[f32; 4]; 3]; 2];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
{
type T = [[[[f32; 2]; 4]; 3]; 2];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
{
type T = [[[[[f32; 3]; 2]; 4]; 3]; 2];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
{
type T = [[[[[[f32; 2]; 3]; 2]; 4]; 3]; 2];
let input: T = rng.gen();
let tensor = input.to_cv();
let array: T = tensor.try_to_cv().unwrap();
assert!(array == input);
let array_wrapper: TensorAsArray<T> = (&tensor).try_as_ref_cv().unwrap();
assert!(*array_wrapper == input);
}
}
}