use std::marker::PhantomData;
use executorch_sys as sys;
use sys::cxx::vector::VectorElement;
use sys::cxx::{self, ExternType, SharedPtr, UniquePtr};
use super::{
Data, DataMut, DataTyped, DimOrderType, Scalar, SizesType, StridesType, TensorBase, View,
ViewMut,
};
use crate::util::{IntoCpp, IntoRust};
use crate::{Error, Result};
pub struct TensorPtr<'a, D>(SharedPtr<sys::Tensor>, PhantomData<(&'a (), D)>);
impl<S: Scalar> TensorPtr<'static, View<S>> {
#[cfg(feature = "ndarray")]
pub fn from_array<D: ndarray::Dimension>(array: ndarray::Array<S, D>) -> Result<Self> {
TensorPtrBuilder::<View<S>>::from_array(array).build()
}
pub fn from_vec(vec: Vec<S>) -> Self {
TensorPtrBuilder::<View<S>>::from_vec(vec).build().unwrap()
}
pub fn copy_of<D: Data>(tensor: &TensorBase<'_, D>) -> Self {
let tensor = unsafe { tensor.as_cpp().ptr.cast::<sys::Tensor>().as_ref().unwrap() };
TensorPtr(sys::TensorPtr_clone(tensor, S::TYPE.cpp()), PhantomData)
}
}
impl<'a, S: Scalar> TensorPtr<'a, View<S>> {
#[cfg(feature = "ndarray")]
pub fn from_array_view<D: ndarray::Dimension>(
array: ndarray::ArrayView<'a, S, D>,
) -> Result<Self> {
TensorPtrBuilder::<View<S>>::from_array_view(array).build()
}
pub fn from_slice(data: &'a [S]) -> Self {
TensorPtrBuilder::<View<S>>::from_slice(data)
.build()
.unwrap()
}
}
impl<D> TensorPtr<'_, D> {
pub fn as_tensor(&self) -> TensorBase<'_, D::Immutable>
where
D: Data,
{
let tensor = self.0.as_ref().unwrap();
let tensor = sys::TensorRef {
ptr: tensor as *const sys::Tensor as *const _,
};
unsafe { TensorBase::from_inner_ref(tensor) }
}
pub fn as_tensor_mut(&mut self) -> TensorBase<'_, D>
where
D: DataMut,
{
let tensor = self.0.as_ref().unwrap();
let tensor = sys::TensorRefMut {
ptr: tensor as *const sys::Tensor as *mut sys::Tensor as *mut _,
};
unsafe { TensorBase::from_inner_ref_mut(tensor) }
}
}
unsafe impl<D> Send for TensorPtr<'_, D> {}
pub struct TensorPtrBuilder<'a, D: DataTyped> {
sizes: UniquePtr<cxx::Vector<SizesType>>,
data: TensorPtrBuilderData<'a, D>,
strides: Option<UniquePtr<cxx::Vector<StridesType>>>,
dynamism: sys::TensorShapeDynamism,
}
enum TensorPtrBuilderData<'a, D: DataTyped> {
Vec { data: Vec<D::Scalar>, offset: usize },
Slice(&'a [D::Scalar]),
SliceMut(&'a mut [D::Scalar]),
Ptr(*const D::Scalar, PhantomData<&'a ()>),
PtrMut(*mut D::Scalar, PhantomData<&'a ()>),
}
impl<D: DataTyped> TensorPtrBuilder<'static, D> {
#[cfg(feature = "ndarray")]
pub fn from_array<Dim: ndarray::Dimension>(array: ndarray::Array<D::Scalar, Dim>) -> Self {
Self {
sizes: cxx_vec(array.shape().iter().map(|&s| s as SizesType)),
strides: Some(cxx_vec(
ndarray::ArrayBase::strides(&array)
.iter()
.map(|&s| s as StridesType),
)),
data: {
let (data, data_offset) = array.into_raw_vec_and_offset();
let data_offset = data_offset.unwrap_or(0);
assert!(data_offset < data.len());
TensorPtrBuilderData::Vec {
data,
offset: data_offset,
}
},
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
pub fn from_vec(data: Vec<D::Scalar>) -> Self {
Self {
sizes: cxx_vec([data.len() as SizesType]),
data: TensorPtrBuilderData::Vec { data, offset: 0 },
strides: None,
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
}
impl<'a, S: Scalar> TensorPtrBuilder<'a, View<S>> {
#[cfg(feature = "ndarray")]
pub fn from_array_view<Dim: ndarray::Dimension>(array: ndarray::ArrayView<'a, S, Dim>) -> Self {
Self {
sizes: cxx_vec(array.shape().iter().map(|&s| s as SizesType)),
data: TensorPtrBuilderData::Ptr(array.as_ptr(), PhantomData),
strides: Some(cxx_vec(
ndarray::ArrayBase::strides(&array)
.iter()
.map(|&s| s as StridesType),
)),
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
pub fn from_slice(data: &'a [S]) -> Self {
Self {
sizes: cxx_vec([data.len() as SizesType]),
data: TensorPtrBuilderData::Slice(data),
strides: None,
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
pub unsafe fn from_ptr(data: *const S, sizes: impl IntoIterator<Item = SizesType>) -> Self {
Self {
data: TensorPtrBuilderData::Ptr(data, PhantomData),
strides: None,
sizes: cxx_vec(sizes),
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
}
impl<'a, S: Scalar> TensorPtrBuilder<'a, ViewMut<S>> {
#[cfg(feature = "ndarray")]
pub fn from_array_view_mut<Dim: ndarray::Dimension>(
mut array: ndarray::ArrayViewMut<'a, S, Dim>,
) -> Self {
Self {
sizes: cxx_vec(array.shape().iter().map(|&s| s as SizesType)),
data: TensorPtrBuilderData::PtrMut(array.as_mut_ptr(), PhantomData),
strides: Some(cxx_vec(
ndarray::ArrayBase::strides(&array)
.iter()
.map(|&s| s as StridesType),
)),
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
pub fn from_slice_mut(data: &'a mut [S]) -> Self {
Self {
sizes: cxx_vec([data.len() as SizesType]),
data: TensorPtrBuilderData::SliceMut(data),
strides: None,
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
pub unsafe fn from_ptr_mut(data: *mut S, sizes: impl IntoIterator<Item = SizesType>) -> Self {
Self {
data: TensorPtrBuilderData::PtrMut(data, PhantomData),
strides: None,
sizes: cxx_vec(sizes),
dynamism: sys::TensorShapeDynamism::TensorShapeDynamism_STATIC,
}
}
}
impl<'a, D: DataTyped> TensorPtrBuilder<'a, D> {
pub unsafe fn sizes(mut self, sizes: impl IntoIterator<Item = SizesType>) -> Self {
self.sizes = cxx_vec(sizes);
self
}
pub unsafe fn strides(mut self, strides: impl IntoIterator<Item = StridesType>) -> Self {
self.strides = Some(cxx_vec(strides));
self
}
#[track_caller]
pub fn build(self) -> Result<TensorPtr<'a, View<D::Scalar>>> {
let ndim = self.sizes.len();
let strides = self
.strides
.unwrap_or_else(|| standard_layout_strides(&self.sizes));
assert_eq!(ndim, strides.len(), "Invalid strides length");
let mut dim_order = cxx_vec(std::iter::repeat(0 as DimOrderType).take(ndim));
unsafe {
sys::executorch_stride_to_dim_order(
strides.as_ref().unwrap().as_slice().as_ptr(),
ndim,
dim_order.as_mut().unwrap().as_mut_slice().as_mut_ptr(),
)
}
.rs()?;
debug_assert_eq!(ndim, dim_order.len());
let (data_ptr, allocation_vec, _data_bound) = match self.data {
TensorPtrBuilderData::Vec { data, offset } => {
let bound = data.len().checked_sub(offset).unwrap();
let ptr = unsafe { data.as_ptr().add(offset) };
(ptr, data, Some(bound))
}
TensorPtrBuilderData::Slice(data) => (data.as_ptr(), Vec::new(), Some(data.len())),
TensorPtrBuilderData::SliceMut(data) => (data.as_ptr(), Vec::new(), Some(data.len())),
TensorPtrBuilderData::Ptr(data, _) => (data, Vec::new(), None),
TensorPtrBuilderData::PtrMut(data, _) => (data as *const _, Vec::new(), None),
};
let valid_strides = unsafe {
sys::executorch_is_valid_dim_order_and_strides(
ndim,
self.sizes.as_ref().unwrap().as_slice().as_ptr(),
dim_order.as_ref().unwrap().as_slice().as_ptr(),
strides.as_ref().unwrap().as_slice().as_ptr(),
)
};
if !valid_strides {
crate::log::error!("Invalid strides");
return Err(Error::InvalidArgument);
}
let tensor = unsafe {
sys::TensorPtr_new(
self.sizes,
data_ptr as *const u8 as *mut u8,
dim_order,
strides,
D::Scalar::TYPE.cpp(),
self.dynamism,
Box::new(sys::util::RustAny::new(Box::new(allocation_vec))),
)
};
Ok(TensorPtr(tensor, PhantomData))
}
#[track_caller]
pub fn build_mut(self) -> Result<TensorPtr<'a, ViewMut<D::Scalar>>>
where
D: DataMut,
{
let ndim = self.sizes.len();
let dim_order = cxx_vec((0..ndim).map(|s| s as DimOrderType));
let strides = self
.strides
.unwrap_or_else(|| standard_layout_strides(&self.sizes));
assert_eq!(ndim, dim_order.len(), "Invalid dim order length");
assert_eq!(ndim, strides.len(), "Invalid strides length");
let (data_ptr, allocation_vec, _data_bound) = match self.data {
TensorPtrBuilderData::Vec { mut data, offset } => {
let bound = data.len().checked_sub(offset).unwrap();
let ptr = unsafe { data.as_mut_ptr().add(offset) };
(ptr, data, Some(bound))
}
TensorPtrBuilderData::Slice(_) => {
panic!("Cannot create a mutable tensor from an immutable slice")
}
TensorPtrBuilderData::SliceMut(data) => {
(data.as_mut_ptr(), Vec::new(), Some(data.len()))
}
TensorPtrBuilderData::Ptr(_, _) => {
panic!("Cannot create a mutable tensor from an immutable pointer")
}
TensorPtrBuilderData::PtrMut(data, _) => (data, Vec::new(), None),
};
let valid_strides = unsafe {
sys::executorch_is_valid_dim_order_and_strides(
ndim,
self.sizes.as_ref().unwrap().as_slice().as_ptr(),
dim_order.as_ref().unwrap().as_slice().as_ptr(),
strides.as_ref().unwrap().as_slice().as_ptr(),
)
};
if !valid_strides {
crate::log::error!("Invalid strides");
return Err(Error::InvalidArgument);
}
let tensor = unsafe {
sys::TensorPtr_new(
self.sizes,
data_ptr as *const u8 as *mut u8,
dim_order,
strides,
D::Scalar::TYPE.cpp(),
self.dynamism,
Box::new(sys::util::RustAny::new(Box::new(allocation_vec))),
)
};
Ok(TensorPtr(tensor, PhantomData))
}
}
unsafe impl<D: DataTyped> Send for TensorPtrBuilder<'_, D> {}
fn cxx_vec<T>(elms: impl IntoIterator<Item = T>) -> UniquePtr<cxx::Vector<T>>
where
T: ExternType<Kind = cxx::kind::Trivial> + VectorElement,
{
let mut vec = cxx::Vector::new();
elms.into_iter().for_each(|e| vec.pin_mut().push(e));
vec
}
fn standard_layout_strides(sizes: &cxx::Vector<SizesType>) -> UniquePtr<cxx::Vector<StridesType>> {
let mut strides = cxx_vec(std::iter::repeat(0 as SizesType).take(sizes.len()));
let mut stride = 1;
for i in (0..sizes.len()).rev() {
strides.as_mut().unwrap().index_mut(i).unwrap().set(stride);
stride *= sizes.get(i).unwrap();
}
strides
}
#[cfg(feature = "ndarray")]
#[macro_export]
macro_rules! tensor_ptr {
($($args:expr),*) => {
$crate::tensor::TensorPtr::<$crate::tensor::View<_>>::from_array(ndarray::array![$($args),*]).unwrap()
};
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "ndarray")]
#[test]
fn fron_array() {
let array = ndarray::array![[1, 2], [3, 4]];
let tensor_ptr = TensorPtr::from_array(array.clone()).unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(array, tensor.as_array::<ndarray::Ix2>());
}
#[test]
fn fron_vec() {
let vec = vec![1, 2, 3, 4];
let tensor_ptr = TensorPtr::from_vec(vec.clone());
let tensor = tensor_ptr.as_tensor();
assert_eq!(
vec,
(0..vec.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[cfg(feature = "ndarray")]
#[test]
fn fron_array_view() {
let array = ndarray::array![[1, 2], [3, 4]];
let tensor_ptr = TensorPtr::from_array_view(array.view()).unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(array, tensor.as_array::<ndarray::Ix2>());
}
#[test]
fn fron_slice() {
let data = [1, 2, 3, 4];
let tensor_ptr = TensorPtr::from_slice(&data);
let tensor = tensor_ptr.as_tensor();
assert_eq!(
data.to_vec(),
(0..data.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[test]
fn as_tensor_mut() {
let mut data = [1, 2, 3, 4];
let mut tensor_ptr = TensorPtrBuilder::<ViewMut<_>>::from_slice_mut(&mut data)
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
tensor[&[2]] = 50;
drop(tensor);
assert_eq!(data, [1, 2, 50, 4]);
}
#[cfg(feature = "ndarray")]
#[test]
fn builder_from_array() {
let array = ndarray::array![[1, 2], [3, 4]];
let tensor_ptr = TensorPtrBuilder::<View<_>>::from_array(array.clone())
.build()
.unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(array, tensor.as_array::<ndarray::Ix2>());
}
#[cfg(feature = "ndarray")]
#[test]
fn builder_from_array_build_mut() {
let array = ndarray::array![[1, 2], [3, 4]];
let mut tensor_ptr = TensorPtrBuilder::<ViewMut<_>>::from_array(array.clone())
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
assert_eq!(array, tensor.as_array::<ndarray::Ix2>());
tensor[&[1, 1]] = 50;
assert_ne!(array, tensor.as_array::<ndarray::Ix2>());
assert_eq!(
tensor.as_array::<ndarray::Ix2>(),
ndarray::array![[1, 2], [3, 50]]
);
}
#[test]
fn builder_from_vec() {
let vec = vec![1, 2, 3, 4];
let tensor_ptr = TensorPtrBuilder::<View<_>>::from_vec(vec.clone())
.build()
.unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(
vec,
(0..vec.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[test]
fn builder_from_vec_build_mut() {
let vec = vec![1, 2, 3, 4];
let mut tensor_ptr = TensorPtrBuilder::<ViewMut<_>>::from_vec(vec.clone())
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
assert_eq!(
vec,
(0..vec.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
tensor[&[2]] = 50;
assert_eq!(
vec![1, 2, 50, 4],
(0..vec.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[cfg(feature = "ndarray")]
#[test]
fn builder_from_array_view() {
let array = ndarray::array![[1, 2], [3, 4]];
let tensor_ptr = TensorPtrBuilder::from_array_view(array.view())
.build()
.unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(array, tensor.as_array::<ndarray::Ix2>());
}
#[cfg(feature = "ndarray")]
#[test]
fn builder_from_array_view_mut() {
let array_orig = ndarray::array![[1, 2], [3, 4]];
let mut array = array_orig.clone();
let mut tensor_ptr = TensorPtrBuilder::from_array_view_mut(array.view_mut())
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
assert_eq!(array_orig, tensor.as_array::<ndarray::Ix2>());
tensor[&[1, 1]] = 50;
assert_eq!(
tensor.as_array::<ndarray::Ix2>(),
ndarray::array![[1, 2], [3, 50]]
);
drop(tensor);
assert_eq!(array, ndarray::array![[1, 2], [3, 50]]);
}
#[test]
fn builder_from_slice() {
let data = [1, 2, 3, 4];
let tensor_ptr = TensorPtrBuilder::from_slice(&data).build().unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(
data.to_vec(),
(0..data.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[test]
fn builder_from_slice_mut() {
let data_orig = [1, 2, 3, 4];
let mut data = data_orig;
let mut tensor_ptr = TensorPtrBuilder::from_slice_mut(&mut data)
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
assert_eq!(
data_orig.to_vec(),
(0..data_orig.len())
.map(|i| tensor[&[i]])
.collect::<Vec<_>>()
);
tensor[&[2]] = 50;
assert_eq!(
vec![1, 2, 50, 4],
(0..data_orig.len())
.map(|i| tensor[&[i]])
.collect::<Vec<_>>()
);
drop(tensor);
assert_eq!([1, 2, 50, 4], data);
}
#[test]
fn builder_from_ptr() {
let data = [1, 2, 3, 4];
let tensor_ptr =
unsafe { TensorPtrBuilder::from_ptr(data.as_ptr(), [data.len() as SizesType]) }
.build()
.unwrap();
let tensor = tensor_ptr.as_tensor();
assert_eq!(
data.to_vec(),
(0..data.len()).map(|i| tensor[&[i]]).collect::<Vec<_>>()
);
}
#[test]
fn builder_from_ptr_mut() {
let data_orig = [1, 2, 3, 4];
let mut data = data_orig;
let mut tensor_ptr =
unsafe { TensorPtrBuilder::from_ptr_mut(data.as_mut_ptr(), [data.len() as SizesType]) }
.build_mut()
.unwrap();
let mut tensor = tensor_ptr.as_tensor_mut();
assert_eq!(
data_orig.to_vec(),
(0..data_orig.len())
.map(|i| tensor[&[i]])
.collect::<Vec<_>>()
);
tensor[&[2]] = 50;
assert_eq!(
vec![1, 2, 50, 4],
(0..data_orig.len())
.map(|i| tensor[&[i]])
.collect::<Vec<_>>()
);
drop(tensor);
assert_eq!([1, 2, 50, 4], data);
}
#[cfg(feature = "ndarray")]
#[test]
fn from_array_invalid_strides() {
use ndarray::{Array, ShapeBuilder};
assert!(TensorPtr::from_array(
Array::from_shape_vec((3,).strides((1,)), (0..3).collect()).unwrap()
)
.is_ok());
assert!(TensorPtr::from_array(
Array::from_shape_vec((3,).strides((10,)), (0..30).collect()).unwrap()
)
.is_err());
assert!(TensorPtr::from_array(
Array::from_shape_vec((2, 3).strides((3, 1)), (0..6).collect()).unwrap()
)
.is_ok());
assert!(TensorPtr::from_array(
Array::from_shape_vec((2, 3).strides((1, 2)), (0..6).collect()).unwrap()
)
.is_ok());
assert!(TensorPtr::from_array(
Array::from_shape_vec((2, 3).strides((2, 4)), (0..12).collect()).unwrap()
)
.is_err());
assert!(TensorPtrBuilder::<ViewMut<i32>>::from_array(
Array::from_shape_vec((3,).strides((1,)), (0..3).collect()).unwrap()
)
.build_mut()
.is_ok());
assert!(TensorPtrBuilder::<ViewMut<i32>>::from_array(
Array::from_shape_vec((3,).strides((10,)), (0..30).collect()).unwrap()
)
.build_mut()
.is_err());
assert!(TensorPtrBuilder::<ViewMut<i32>>::from_array(
Array::from_shape_vec((2, 3).strides((3, 1)), (0..6).collect()).unwrap()
)
.build_mut()
.is_ok());
assert!(TensorPtrBuilder::<ViewMut<i32>>::from_array(
Array::from_shape_vec((2, 3).strides((1, 2)), (0..6).collect()).unwrap()
)
.build_mut()
.is_err());
assert!(TensorPtrBuilder::<ViewMut<i32>>::from_array(
Array::from_shape_vec((2, 3).strides((2, 4)), (0..12).collect()).unwrap()
)
.build_mut()
.is_err());
}
#[cfg(feature = "ndarray")]
#[test]
fn copy_of() {
let array = ndarray::array![[1, 2], [3, 4]];
let tensor1 = TensorPtr::from_array_view(array.view()).unwrap();
let tensor2 = TensorPtr::<View<i32>>::copy_of(&tensor1.as_tensor());
assert_eq!(tensor1.as_tensor().as_array::<ndarray::Ix2>(), array);
assert_eq!(tensor2.as_tensor().as_array::<ndarray::Ix2>(), array);
let tensor2 = TensorPtr::<View<u8>>::copy_of(&tensor1.as_tensor());
assert_eq!(
tensor2.as_tensor().as_array::<ndarray::Ix2>(),
array.map(|&x| x as u8)
);
let tensor2 = TensorPtr::<View<f32>>::copy_of(&tensor1.as_tensor());
assert_eq!(
tensor2.as_tensor().as_array::<ndarray::Ix2>(),
array.map(|&x| x as f32)
);
}
#[cfg(feature = "ndarray")]
#[test]
fn tensor_ptr_macro() {
use ndarray::array;
assert_eq!(tensor_ptr!(1.0).as_tensor().as_array(), array![1.0]);
assert_eq!(tensor_ptr!(1u8).as_tensor().as_array_dyn().shape(), &[1]);
assert_eq!(tensor_ptr!(1u64, 2).as_tensor().as_array(), array![1, 2]);
let t: TensorPtr<'_, View<i8>> = tensor_ptr!([1i8, 2]);
assert_eq!(t.as_tensor().as_array_dyn().shape(), &[1, 2]);
}
}