use crate::tensor::{
ephemeral_view::{EphemeralView, ToEphemeralView},
tensor_shape::TensorShape,
};
use std::ops::{Index, IndexMut};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TensorStorage<Dtype> {
pub vec: Vec<Dtype>,
}
impl<Dtype> Index<usize> for TensorStorage<Dtype>
where
Dtype: Copy,
{
type Output = Dtype;
fn index(&self, index: usize) -> &Self::Output {
&self.vec[index]
}
}
impl<Dtype> IndexMut<usize> for TensorStorage<Dtype>
where
Dtype: Copy,
{
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.vec[index]
}
}
pub trait HasTensorData<Dtype> {
fn data(&self) -> &TensorStorage<Dtype>;
}
pub trait IntoTensorStorage<Dtype> {
fn into_tensor_storage(self) -> TensorStorage<Dtype>;
}
impl<Dtype> IntoTensorStorage<Dtype> for Vec<Dtype> {
fn into_tensor_storage(self) -> TensorStorage<Dtype> {
TensorStorage {
vec: self,
}
}
}
impl<'a, Dtype> ToEphemeralView<'a, Dtype> for TensorStorage<Dtype> {
fn as_shape<S: Into<TensorShape>>(
&'a self,
shape: S,
) -> EphemeralView<'a, Dtype> {
let target_shape: TensorShape = shape.into();
assert_eq!(
target_shape.num_elements(),
self.vec.len(),
"number of elements in target shape mismatch"
);
EphemeralView {
shape: target_shape,
data: &self,
}
}
}
#[cfg(test)]
mod tests {
use crate::tensor::{
ephemeral_view::ToEphemeralView, indexable_tensor::IndexableTensor,
tensor_shape::HasTensorShape, tensor_storage::IntoTensorStorage,
};
#[test]
fn test_to_ephemeral_view() {
{
let storage = vec![1, 2, 3, 4, 5, 6].into_tensor_storage();
let view = storage.as_shape([2, 3]);
let shape = view.shape();
assert_eq!(shape.dims(), vec![2, 3]);
assert_eq!(shape.strides(), vec![3, 1]);
assert_eq!(shape.ndim(), 2);
assert_eq!(view.at([0, 0]), 1);
assert_eq!(view.at([0, 2]), 3);
assert_eq!(view.at([1, 2]), 6);
}
{
let storage = (0..24)
.into_iter()
.collect::<Vec<i32>>()
.into_tensor_storage();
let view = storage.as_shape([2, 3, 4]);
let shape = view.shape();
assert_eq!(shape.dims(), vec![2, 3, 4]);
assert_eq!(shape.strides(), vec![12, 4, 1]);
assert_eq!(shape.ndim(), 3);
assert_eq!(view.at([0, 0, 0]), 0);
assert_eq!(view.at([0, 0, 1]), 1);
assert_eq!(view.at([0, 0, 2]), 2);
assert_eq!(view.at([0, 0, 3]), 3);
assert_eq!(view.at([0, 1, 0]), 4);
assert_eq!(view.at([0, 1, 1]), 5);
assert_eq!(view.at([0, 1, 2]), 6);
assert_eq!(view.at([0, 1, 3]), 7);
assert_eq!(view.at([0, 2, 3]), 11);
assert_eq!(view.at([1, 0, 0]), 12);
assert_eq!(view.at([1, 0, 1]), 13);
assert_eq!(view.at([1, 0, 2]), 14);
assert_eq!(view.at([1, 0, 3]), 15);
assert_eq!(view.at([1, 1, 0]), 16);
assert_eq!(view.at([1, 1, 1]), 17);
assert_eq!(view.at([1, 1, 2]), 18);
assert_eq!(view.at([1, 1, 3]), 19);
assert_eq!(view.at([1, 2, 3]), 23);
}
{
let storage = (0..30)
.into_iter()
.collect::<Vec<i32>>()
.into_tensor_storage();
let view = storage.as_shape(vec![2, 3, 5]);
let shape = view.shape();
assert_eq!(shape.dims(), vec![2, 3, 5]);
assert_eq!(shape.strides(), vec![15, 5, 1]);
assert_eq!(shape.ndim(), 3);
}
}
}