use std::error::Error;
use std::fmt;
use std::fmt::Display;
use rten_tensor::errors::DimensionError;
use rten_tensor::storage::{Alloc, GlobalAlloc, ViewData};
use rten_tensor::{
AsView, DynIndices, DynLayout, Layout, NdTensor, NdTensorView, Storage, Tensor, TensorBase,
TensorView,
};
use smallvec::SmallVec;
use crate::buffer_pool::{Buffer, BufferPool, ExtractBuffer};
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum DataType {
Int32,
Float,
Int8,
UInt8,
}
impl DataType {
pub fn size(self) -> u8 {
match self {
DataType::Int32 | DataType::Float => 4,
DataType::Int8 | DataType::UInt8 => 1,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum ValueType {
Tensor(DataType),
Sequence(DataType),
}
impl ValueType {
pub(crate) fn to_tensor_type(self) -> Self {
match self {
ttype @ Self::Tensor(_) => ttype,
Self::Sequence(dtype) => Self::Tensor(dtype),
}
}
pub(crate) fn to_sequence_type(self) -> Self {
match self {
Self::Tensor(dtype) => Self::Sequence(dtype),
stype @ Self::Sequence(_) => stype,
}
}
}
impl std::fmt::Display for ValueType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tensor(dtype) => write!(f, "tensor({})", dtype),
Self::Sequence(dtype) => write!(f, "sequence({})", dtype),
}
}
}
pub trait DataTypeOf {
fn dtype_of() -> DataType;
}
macro_rules! impl_data_type_of {
($type:ty, $dtype:ident) => {
impl DataTypeOf for $type {
fn dtype_of() -> DataType {
DataType::$dtype
}
}
};
}
impl_data_type_of!(f32, Float);
impl_data_type_of!(i32, Int32);
impl_data_type_of!(i8, Int8);
impl_data_type_of!(u8, UInt8);
impl std::fmt::Display for DataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
DataType::Float => "f32",
DataType::Int32 => "i32",
DataType::Int8 => "i8",
DataType::UInt8 => "u8",
}
)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ValueMeta {
pub(crate) dtype: ValueType,
pub(crate) shape: Vec<usize>,
}
impl Display for ValueMeta {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {:?}", self.dtype, self.shape)
}
}
#[derive(Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum TryFromValueError {
WrongRank { actual: usize, expected: usize },
WrongType {
actual: ValueType,
expected: ValueType,
},
ExpectedSequence,
}
impl Display for TryFromValueError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::WrongRank { actual, expected } => {
write!(
f,
"expected tensor with {} dims but has {} dims",
expected, actual
)
}
Self::WrongType { actual, expected } => {
write!(
f,
"expected tensor with type {} but has type {}",
expected, actual
)
}
Self::ExpectedSequence => {
write!(f, "value is not a sequence")
}
}
}
}
impl Error for TryFromValueError {}
impl From<DimensionError> for TryFromValueError {
fn from(val: DimensionError) -> TryFromValueError {
let DimensionError { actual, expected } = val;
TryFromValueError::WrongRank { actual, expected }
}
}
enum ValueLayout<'a> {
Tensor(&'a DynLayout),
Vector(usize),
}
impl<'a> From<&'a DynLayout> for ValueLayout<'a> {
fn from(layout: &'a DynLayout) -> Self {
Self::Tensor(layout)
}
}
macro_rules! impl_proxy_layout {
() => {
type Index<'b> = SmallVec<[usize; 4]>;
type Indices = DynIndices;
fn ndim(&self) -> usize {
match self.layout() {
ValueLayout::Tensor(layout) => layout.ndim(),
ValueLayout::Vector(_) => 1,
}
}
fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
match self.layout() {
ValueLayout::Tensor(layout) => layout.offset(&index),
ValueLayout::Vector(len) => index
.get(0)
.and_then(|&idx| if idx < len { Some(idx) } else { None }),
}
}
fn len(&self) -> usize {
match self.layout() {
ValueLayout::Tensor(layout) => layout.len(),
ValueLayout::Vector(len) => len,
}
}
fn is_empty(&self) -> bool {
match self.layout() {
ValueLayout::Tensor(layout) => layout.is_empty(),
ValueLayout::Vector(len) => len == 0,
}
}
fn shape(&self) -> Self::Index<'_> {
match self.layout() {
ValueLayout::Tensor(layout) => SmallVec::from_slice(layout.shape()),
ValueLayout::Vector(len) => SmallVec::from_slice(&[len]),
}
}
fn size(&self, dim: usize) -> usize {
match self.layout() {
ValueLayout::Tensor(layout) => layout.size(dim),
ValueLayout::Vector(len) => [len][dim],
}
}
fn strides(&self) -> Self::Index<'_> {
match self.layout() {
ValueLayout::Tensor(layout) => SmallVec::from_slice(layout.strides()),
ValueLayout::Vector(_) => SmallVec::from_slice(&[1]),
}
}
fn stride(&self, dim: usize) -> usize {
match self.layout() {
ValueLayout::Tensor(layout) => layout.stride(dim),
ValueLayout::Vector(_) => [1][dim],
}
}
fn indices(&self) -> Self::Indices {
match self.layout() {
ValueLayout::Tensor(layout) => layout.indices(),
ValueLayout::Vector(len) => DynIndices::from_shape(&[len]),
}
}
};
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum ValueView<'a> {
FloatTensor(TensorView<'a, f32>),
Int32Tensor(TensorView<'a, i32>),
Int8Tensor(TensorView<'a, i8>),
UInt8Tensor(TensorView<'a, u8>),
Sequence(&'a Sequence),
}
impl<'a> ValueView<'a> {
pub fn from_shape<T>(
shape: impl AsRef<[usize]>,
data: &'a [T],
) -> Result<ValueView<'a>, impl Error + Send + Sync + 'static>
where
ValueView<'a>: From<TensorView<'a, T>>,
{
TensorView::try_from_data(shape.as_ref(), data).map(|tensor| tensor.into())
}
pub fn dtype(&self) -> ValueType {
let t = ValueType::Tensor;
match self {
Self::FloatTensor(_) => t(DataType::Float),
Self::Int32Tensor(_) => t(DataType::Int32),
Self::Int8Tensor(_) => t(DataType::Int8),
Self::UInt8Tensor(_) => t(DataType::UInt8),
Self::Sequence(seq) => ValueType::Sequence(seq.dtype()),
}
}
pub fn to_owned(&self) -> Value {
self.to_owned_in(GlobalAlloc::new())
}
pub fn to_owned_in<A: Alloc>(&self, alloc: A) -> Value {
match self {
ValueView::FloatTensor(t) => t.to_tensor_in(alloc).into(),
ValueView::Int32Tensor(t) => t.to_tensor_in(alloc).into(),
ValueView::Int8Tensor(t) => t.to_tensor_in(alloc).into(),
ValueView::UInt8Tensor(t) => t.to_tensor_in(alloc).into(),
ValueView::Sequence(seq) => (*seq).clone().into(),
}
}
pub(crate) fn to_meta(&self) -> ValueMeta {
ValueMeta {
shape: self.shape().to_vec(),
dtype: self.dtype(),
}
}
fn layout(&self) -> ValueLayout<'_> {
match self {
ValueView::FloatTensor(t) => t.layout().into(),
ValueView::Int32Tensor(t) => t.layout().into(),
ValueView::Int8Tensor(t) => t.layout().into(),
ValueView::UInt8Tensor(t) => t.layout().into(),
ValueView::Sequence(seq) => ValueLayout::Vector(seq.len()),
}
}
}
impl Layout for ValueView<'_> {
impl_proxy_layout!();
}
macro_rules! impl_value_view_conversions {
($variant:ident, $element_type:ty) => {
impl<'a> TryFrom<ValueView<'a>> for TensorView<'a, $element_type> {
type Error = TryFromValueError;
fn try_from(
input: ValueView<'a>,
) -> Result<TensorView<'a, $element_type>, Self::Error> {
match input {
ValueView::$variant(t) => Ok(t),
_ => Err(TryFromValueError::WrongType {
actual: input.dtype(),
expected: ValueType::Tensor(<$element_type as DataTypeOf>::dtype_of()),
}),
}
}
}
impl<'a, const N: usize> TryFrom<ValueView<'a>> for NdTensorView<'a, $element_type, N> {
type Error = TryFromValueError;
fn try_from(
input: ValueView<'a>,
) -> Result<NdTensorView<'a, $element_type, N>, Self::Error> {
let ndim = input.ndim();
match input {
ValueView::$variant(t) => {
t.try_into().map_err(|_| TryFromValueError::WrongRank {
actual: ndim,
expected: N,
})
}
_ => Err(TryFromValueError::WrongType {
actual: input.dtype(),
expected: ValueType::Tensor(<$element_type as DataTypeOf>::dtype_of()),
}),
}
}
}
impl<'a> TryFrom<ValueView<'a>> for $element_type {
type Error = TryFromValueError;
fn try_from(input: ValueView<'a>) -> Result<$element_type, Self::Error> {
let tensor: TensorView<'a, _> = input.try_into()?;
tensor.item().copied().ok_or(TryFromValueError::WrongRank {
actual: tensor.ndim(),
expected: 0,
})
}
}
impl<'a> From<&'a Tensor<$element_type>> for ValueView<'a> {
fn from(t: &'a Tensor<$element_type>) -> ValueView<'a> {
ValueView::$variant(t.view())
}
}
impl<'a> From<TensorView<'a, $element_type>> for ValueView<'a> {
fn from(t: TensorView<'a, $element_type>) -> ValueView<'a> {
ValueView::$variant(t)
}
}
impl<'a, const N: usize> From<NdTensorView<'a, $element_type, N>> for ValueView<'a> {
fn from(t: NdTensorView<'a, $element_type, N>) -> ValueView<'a> {
ValueView::$variant(t.as_dyn())
}
}
};
}
impl_value_view_conversions!(FloatTensor, f32);
impl_value_view_conversions!(Int32Tensor, i32);
impl_value_view_conversions!(Int8Tensor, i8);
impl_value_view_conversions!(UInt8Tensor, u8);
impl<'a> From<&'a Value> for ValueView<'a> {
fn from(output: &'a Value) -> ValueView<'a> {
match output {
Value::FloatTensor(t) => ValueView::FloatTensor(t.view()),
Value::Int32Tensor(t) => ValueView::Int32Tensor(t.view()),
Value::Int8Tensor(t) => ValueView::Int8Tensor(t.view()),
Value::UInt8Tensor(t) => ValueView::UInt8Tensor(t.view()),
Value::Sequence(seq) => ValueView::Sequence(seq),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Value {
FloatTensor(Tensor<f32>),
Int32Tensor(Tensor<i32>),
Int8Tensor(Tensor<i8>),
UInt8Tensor(Tensor<u8>),
Sequence(Sequence),
}
impl Value {
pub fn from_shape<T>(
shape: impl AsRef<[usize]>,
data: Vec<T>,
) -> Result<Self, impl Error + Send + Sync + 'static>
where
Self: From<Tensor<T>>,
{
Tensor::try_from_data(shape.as_ref(), data).map(|tensor| tensor.into())
}
pub fn into_shape_vec<T: Clone, const N: usize>(
self,
) -> Result<([usize; N], Vec<T>), TryFromValueError>
where
NdTensor<T, N>: TryFrom<Value, Error = TryFromValueError>,
{
let tensor: NdTensor<T, N> = self.try_into()?;
Ok((tensor.shape(), tensor.into_data()))
}
pub fn dtype(&self) -> ValueType {
let t = ValueType::Tensor;
match self {
Self::FloatTensor(_) => t(DataType::Float),
Self::Int32Tensor(_) => t(DataType::Int32),
Self::Int8Tensor(_) => t(DataType::Int8),
Self::UInt8Tensor(_) => t(DataType::UInt8),
Self::Sequence(seq) => ValueType::Sequence(seq.dtype()),
}
}
pub fn as_view(&self) -> ValueView<'_> {
match self {
Self::FloatTensor(ft) => ValueView::FloatTensor(ft.view()),
Self::Int32Tensor(it) => ValueView::Int32Tensor(it.view()),
Self::Int8Tensor(it) => ValueView::Int8Tensor(it.view()),
Self::UInt8Tensor(it) => ValueView::UInt8Tensor(it.view()),
Self::Sequence(seq) => ValueView::Sequence(seq),
}
}
pub(crate) fn to_meta(&self) -> ValueMeta {
ValueMeta {
shape: self.shape().to_vec(),
dtype: self.dtype(),
}
}
pub(crate) fn add_to_pool(self, pool: &BufferPool) {
match self {
Self::FloatTensor(t) => t.extract_buffer().map(|buf| pool.add(buf)),
Self::Int32Tensor(t) => t.extract_buffer().map(|buf| pool.add(buf)),
Self::Int8Tensor(t) => t.extract_buffer().map(|buf| pool.add(buf)),
Self::UInt8Tensor(t) => t.extract_buffer().map(|buf| pool.add(buf)),
Self::Sequence(seq) => {
seq.add_to_pool(pool);
Some(())
}
};
}
pub fn into_tensor<T>(self) -> Option<Tensor<T>>
where
Tensor<T>: TryFrom<Self>,
{
self.try_into().ok()
}
pub fn as_tensor_view<'a, T>(&'a self) -> Option<TensorView<'a, T>>
where
TensorView<'a, T>: TryFrom<&'a Self>,
{
self.try_into().ok()
}
pub(crate) fn bytes(&self) -> usize {
match self {
Value::Int32Tensor(t) => tensor_bytes(t),
Value::Int8Tensor(t) => tensor_bytes(t),
Value::UInt8Tensor(t) => tensor_bytes(t),
Value::FloatTensor(t) => tensor_bytes(t),
Value::Sequence(seq) => seq.bytes(),
}
}
fn layout(&self) -> ValueLayout<'_> {
match self {
Value::Int32Tensor(t) => t.layout().into(),
Value::Int8Tensor(t) => t.layout().into(),
Value::UInt8Tensor(t) => t.layout().into(),
Value::FloatTensor(t) => t.layout().into(),
Value::Sequence(seq) => ValueLayout::Vector(seq.len()),
}
}
}
fn tensor_bytes<S: Storage, L: Layout>(tensor: &TensorBase<S, L>) -> usize {
tensor.len() * size_of::<S::Elem>()
}
impl Layout for Value {
impl_proxy_layout!();
}
impl ExtractBuffer for Value {
fn extract_buffer(self) -> Option<Buffer> {
match self {
Value::Int32Tensor(t) => t.extract_buffer(),
Value::Int8Tensor(t) => t.extract_buffer(),
Value::UInt8Tensor(t) => t.extract_buffer(),
Value::FloatTensor(t) => t.extract_buffer(),
Value::Sequence(_) => {
None
}
}
}
}
macro_rules! impl_value_conversions {
($variant:ident, $element_type:ty) => {
impl From<$element_type> for Value {
fn from(scalar: $element_type) -> Value {
Value::$variant(Tensor::from(scalar))
}
}
impl From<Tensor<$element_type>> for Value {
fn from(t: Tensor<$element_type>) -> Value {
Value::$variant(t)
}
}
impl<const N: usize> From<NdTensor<$element_type, N>> for Value {
fn from(t: NdTensor<$element_type, N>) -> Value {
Value::$variant(t.into_dyn())
}
}
impl TryFrom<Value> for Tensor<$element_type> {
type Error = TryFromValueError;
fn try_from(o: Value) -> Result<Tensor<$element_type>, Self::Error> {
let dtype = o.dtype();
match o {
Value::$variant(t) => Ok(t),
_ => Err(TryFromValueError::WrongType {
actual: dtype,
expected: ValueType::Tensor(<$element_type as DataTypeOf>::dtype_of()),
}),
}
}
}
impl<const N: usize> TryFrom<Value> for NdTensor<$element_type, N> {
type Error = TryFromValueError;
fn try_from(o: Value) -> Result<NdTensor<$element_type, N>, TryFromValueError> {
let tensor: Tensor<_> = o.try_into()?;
let ndim = tensor.ndim();
tensor.try_into().map_err(|_| TryFromValueError::WrongRank {
actual: ndim,
expected: N,
})
}
}
impl<'a> TryFrom<&'a Value> for TensorView<'a, $element_type> {
type Error = TryFromValueError;
fn try_from(o: &'a Value) -> Result<TensorView<'a, $element_type>, TryFromValueError> {
match o {
Value::$variant(t) => Ok(t.view()),
_ => Err(TryFromValueError::WrongType {
actual: o.dtype(),
expected: ValueType::Tensor(<$element_type as DataTypeOf>::dtype_of()),
}),
}
}
}
impl<'a, const N: usize> TryFrom<&'a Value> for NdTensorView<'a, $element_type, N> {
type Error = TryFromValueError;
fn try_from(
o: &'a Value,
) -> Result<NdTensorView<'a, $element_type, N>, TryFromValueError> {
let view: TensorView<'a, _> = o.try_into()?;
let ndim = view.ndim();
view.try_into().map_err(|_| TryFromValueError::WrongRank {
actual: ndim,
expected: N,
})
}
}
};
}
impl_value_conversions!(FloatTensor, f32);
impl_value_conversions!(Int32Tensor, i32);
impl_value_conversions!(Int8Tensor, i8);
impl_value_conversions!(UInt8Tensor, u8);
impl From<Sequence> for Value {
fn from(seq: Sequence) -> Value {
Value::Sequence(seq)
}
}
#[derive(Clone, Debug)]
pub enum ValueOrView<'a> {
View(ValueView<'a>),
Value(Value),
}
impl<'a> ValueOrView<'a> {
pub fn as_view(&self) -> ValueView<'_> {
match self {
ValueOrView::View(inp) => inp.clone(),
ValueOrView::Value(outp) => outp.as_view(),
}
}
pub fn to_owned(&self) -> Value {
match self {
ValueOrView::View(inp) => inp.to_owned(),
ValueOrView::Value(outp) => outp.clone(),
}
}
pub fn into_owned(self) -> Value {
match self {
ValueOrView::View(view) => view.to_owned(),
ValueOrView::Value(value) => value,
}
}
fn layout(&self) -> ValueLayout<'_> {
match self {
Self::View(inp) => inp.layout(),
Self::Value(outp) => outp.layout(),
}
}
}
impl<'a> From<ValueView<'a>> for ValueOrView<'a> {
fn from(val: ValueView<'a>) -> Self {
ValueOrView::View(val)
}
}
impl<'a, T: 'static, S: Storage<Elem = T>, L: Layout + Clone> From<&'a TensorBase<S, L>>
for ValueOrView<'a>
where
ValueView<'a>: From<TensorView<'a, T>>,
{
fn from(val: &'a TensorBase<S, L>) -> Self {
ValueOrView::View(val.as_dyn().into())
}
}
impl<'a, T, L: Layout + Clone> From<TensorBase<ViewData<'a, T>, L>> for ValueOrView<'a>
where
ValueView<'a>: From<TensorView<'a, T>>,
{
fn from(val: TensorBase<ViewData<'a, T>, L>) -> Self {
ValueOrView::View(val.as_dyn().into())
}
}
impl<T, L: Layout + Clone> From<TensorBase<Vec<T>, L>> for ValueOrView<'static>
where
Value: From<Tensor<T>>,
DynLayout: From<L>,
{
fn from(val: TensorBase<Vec<T>, L>) -> Self {
ValueOrView::Value(val.into_dyn().into())
}
}
impl From<Value> for ValueOrView<'static> {
fn from(val: Value) -> Self {
ValueOrView::Value(val)
}
}
impl<'a> From<&'a Value> for ValueOrView<'a> {
fn from(val: &'a Value) -> Self {
let inp: ValueView<'a> = ValueView::from(val);
inp.into()
}
}
impl Layout for ValueOrView<'_> {
impl_proxy_layout!();
}
impl ExtractBuffer for ValueOrView<'_> {
fn extract_buffer(self) -> Option<Buffer> {
match self {
Self::View(_) => None,
Self::Value(val) => val.extract_buffer(),
}
}
}
#[derive(Debug, PartialEq)]
pub enum Scalar {
Int(i32),
Float(f32),
}
impl Scalar {
pub fn dtype(&self) -> DataType {
match self {
Self::Int(_) => DataType::Int32,
Self::Float(_) => DataType::Float,
}
}
}
#[derive(Debug)]
pub enum SequenceError {
InvalidPosition,
InvalidType,
}
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum Sequence {
Float(Vec<Tensor<f32>>),
Int32(Vec<Tensor<i32>>),
Int8(Vec<Tensor<i8>>),
UInt8(Vec<Tensor<u8>>),
}
impl Sequence {
pub fn new(dtype: DataType) -> Sequence {
match dtype {
DataType::Int32 => Vec::<Tensor<i32>>::new().into(),
DataType::Int8 => Vec::<Tensor<i8>>::new().into(),
DataType::UInt8 => Vec::<Tensor<u8>>::new().into(),
DataType::Float => Vec::<Tensor<f32>>::new().into(),
}
}
pub fn dtype(&self) -> DataType {
match self {
Self::Float(_) => DataType::Float,
Self::Int32(_) => DataType::Int32,
Self::Int8(_) => DataType::Int8,
Self::UInt8(_) => DataType::UInt8,
}
}
pub fn len(&self) -> usize {
match self {
Self::Float(floats) => floats.len(),
Self::Int32(ints) => ints.len(),
Self::Int8(ints) => ints.len(),
Self::UInt8(ints) => ints.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn at(&self, index: usize) -> Option<ValueView<'_>> {
match self {
Self::Float(floats) => Self::at_impl(floats, index),
Self::Int32(ints) => Self::at_impl(ints, index),
Self::Int8(ints) => Self::at_impl(ints, index),
Self::UInt8(ints) => Self::at_impl(ints, index),
}
}
fn at_impl<T>(items: &[T], index: usize) -> Option<ValueView<'_>>
where
for<'a> ValueView<'a>: From<&'a T>,
{
items.get(index).map(|it| it.into())
}
pub fn insert(&mut self, index: usize, val: Value) -> Result<(), SequenceError> {
if index > self.len() {
return Err(SequenceError::InvalidPosition);
}
match (self, val) {
(Self::Float(floats), Value::FloatTensor(val)) => floats.insert(index, val),
(Self::Int32(ints), Value::Int32Tensor(val)) => ints.insert(index, val),
(Self::Int8(ints), Value::Int8Tensor(val)) => ints.insert(index, val),
(Self::UInt8(ints), Value::UInt8Tensor(val)) => ints.insert(index, val),
_ => {
return Err(SequenceError::InvalidType);
}
}
Ok(())
}
pub fn remove(&mut self, index: usize) -> Result<Value, SequenceError> {
if index >= self.len() {
return Err(SequenceError::InvalidPosition);
}
let value: Value = match self {
Self::Float(floats) => floats.remove(index).into(),
Self::Int32(ints) => ints.remove(index).into(),
Self::Int8(ints) => ints.remove(index).into(),
Self::UInt8(ints) => ints.remove(index).into(),
};
Ok(value)
}
pub fn iter(&self) -> impl Iterator<Item = ValueView<'_>> {
(0..self.len()).map(|i| self.at(i).unwrap())
}
fn bytes(&self) -> usize {
match self {
Self::Float(tensors) => tensors.iter().map(tensor_bytes).sum(),
Self::Int32(tensors) => tensors.iter().map(tensor_bytes).sum(),
Self::Int8(tensors) => tensors.iter().map(tensor_bytes).sum(),
Self::UInt8(tensors) => tensors.iter().map(tensor_bytes).sum(),
}
}
fn add_to_pool(self, pool: &BufferPool) {
match self {
Self::Float(floats) => Self::add_items_to_pool(floats, pool),
Self::Int32(ints) => Self::add_items_to_pool(ints, pool),
Self::Int8(ints) => Self::add_items_to_pool(ints, pool),
Self::UInt8(ints) => Self::add_items_to_pool(ints, pool),
}
}
fn add_items_to_pool<T: ExtractBuffer>(items: Vec<T>, pool: &BufferPool) {
for item in items {
if let Some(buf) = item.extract_buffer() {
pool.add(buf);
}
}
}
}
macro_rules! impl_sequence_conversions {
($variant:ident, $seq_type:ty) => {
impl From<Vec<$seq_type>> for Sequence {
fn from(val: Vec<$seq_type>) -> Sequence {
Sequence::$variant(val)
}
}
impl<const N: usize> From<[$seq_type; N]> for Sequence {
fn from(val: [$seq_type; N]) -> Sequence {
Sequence::$variant(val.into())
}
}
};
}
impl_sequence_conversions!(Float, Tensor<f32>);
impl_sequence_conversions!(Int32, Tensor<i32>);
impl_sequence_conversions!(Int8, Tensor<i8>);
impl_sequence_conversions!(UInt8, Tensor<u8>);
impl<'a> TryFrom<ValueView<'a>> for &'a Sequence {
type Error = TryFromValueError;
fn try_from(val: ValueView<'a>) -> Result<Self, Self::Error> {
match val {
ValueView::Sequence(seq) => Ok(seq),
_ => Err(TryFromValueError::ExpectedSequence),
}
}
}
impl TryFrom<Value> for Sequence {
type Error = TryFromValueError;
fn try_from(val: Value) -> Result<Self, Self::Error> {
match val {
Value::Sequence(seq) => Ok(seq),
_ => Err(TryFromValueError::ExpectedSequence),
}
}
}
#[cfg(test)]
mod tests {
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor, TensorView};
use super::{DataType, TryFromValueError, Value, ValueType, ValueView};
#[test]
fn test_value_view_from_shape() {
let value = ValueView::from_shape([2, 2], &[1, 2, 3, 4]).unwrap();
assert!(matches!(value, ValueView::Int32Tensor(_)));
assert_eq!(value.shape().as_slice(), &[2, 2]);
let value = ValueView::from_shape([2, 3], &[1, 2, 3, 4]);
assert!(value.is_err());
}
#[test]
fn test_value_view_from_tensor() {
let tensor = NdTensor::<i32, 3>::zeros([1, 2, 3]);
let input: ValueView = tensor.view().into();
assert!(matches!(input, ValueView::Int32Tensor(_)));
assert_eq!(input.shape().as_slice(), &[1, 2, 3]);
let tensor = NdTensor::<f32, 2>::zeros([5, 5]);
let input: ValueView = tensor.view().into();
assert!(matches!(input, ValueView::FloatTensor(_)));
assert_eq!(input.shape().as_slice(), &[5, 5]);
}
#[test]
fn test_value_from_shape() {
let value = Value::from_shape([2, 2], vec![1, 2, 3, 4]).unwrap();
assert!(matches!(value, Value::Int32Tensor(_)));
assert_eq!(value.shape().as_slice(), &[2, 2]);
let value = Value::from_shape([2, 3], vec![1, 2, 3, 4]);
assert!(value.is_err());
}
#[test]
fn test_value_into_shape_vec() {
let value = Value::from_shape([2, 2], vec![1, 2, 3, 4]).unwrap();
let (shape, data) = value.into_shape_vec::<i32, 2>().unwrap();
assert_eq!(shape, [2, 2]);
assert_eq!(data, [1, 2, 3, 4]);
let value = Value::from_shape([2, 2], vec![1, 2, 3, 4]).unwrap();
let err = value.into_shape_vec::<f32, 2>().err().unwrap();
assert!(matches!(err, TryFromValueError::WrongType { .. }));
let value = Value::from_shape([2, 2], vec![1, 2, 3, 4]).unwrap();
let err = value.into_shape_vec::<i32, 3>().err().unwrap();
assert!(matches!(err, TryFromValueError::WrongRank { .. }));
}
#[test]
fn test_tensor_from_value() {
let original = NdTensor::from([[1., 2.], [3., 4.]]);
let output: Value = original.clone().into();
let mat_dyn: Tensor<f32> = output.clone().try_into().unwrap();
assert_eq!(mat_dyn, original);
let mat: NdTensor<f32, 2> = output.clone().try_into().unwrap();
assert_eq!(mat, original);
let err: Result<NdTensor<i32, 2>, _> = output.clone().try_into();
assert_eq!(
err,
Err(TryFromValueError::WrongType {
actual: ValueType::Tensor(DataType::Float),
expected: ValueType::Tensor(DataType::Int32),
})
);
let err: Result<NdTensor<f32, 3>, _> = output.clone().try_into();
assert_eq!(
err,
Err(TryFromValueError::WrongRank {
actual: 2,
expected: 3
})
);
}
#[test]
fn test_tensor_view_from_value() {
let original = NdTensor::from([[1., 2.], [3., 4.]]);
let output: Value = original.clone().into();
let mat_dyn: TensorView<f32> = (&output).try_into().unwrap();
assert_eq!(mat_dyn, original);
let mat: NdTensorView<f32, 2> = (&output).try_into().unwrap();
assert_eq!(mat, original);
let err: Result<NdTensorView<i32, 2>, _> = (&output).try_into();
assert_eq!(
err,
Err(TryFromValueError::WrongType {
actual: ValueType::Tensor(DataType::Float),
expected: ValueType::Tensor(DataType::Int32),
})
);
let err: Result<NdTensorView<f32, 3>, _> = (&output).try_into();
assert_eq!(
err,
Err(TryFromValueError::WrongRank {
actual: 2,
expected: 3
})
);
}
}