use alloc::{vec, vec::Vec};
use core::fmt;
#[derive(Clone, Debug, PartialEq)]
struct TensorData<T> {
data: Vec<T>,
}
#[derive(Clone, PartialEq)]
pub struct Tensor1<T, const D0: usize> {
storage: TensorData<T>,
}
#[derive(Clone, PartialEq)]
pub struct Tensor2<T, const D0: usize, const D1: usize> {
storage: TensorData<T>,
}
#[derive(Clone, PartialEq)]
pub struct Tensor3<T, const D0: usize, const D1: usize, const D2: usize> {
storage: TensorData<T>,
}
#[derive(Clone, PartialEq)]
pub struct Tensor4<T, const D0: usize, const D1: usize, const D2: usize, const D3: usize> {
storage: TensorData<T>,
}
impl<T> TensorData<T> {
fn from_vec<const R: usize>(data: Vec<T>, shape: &'static [usize; R]) -> crate::Result<Self> {
let expected_len: usize = shape.iter().product();
if data.len() != expected_len {
return Err(crate::Error::Shape {
expected: shape,
actual: vec![data.len()],
});
}
Ok(Self { data })
}
fn as_slice(&self) -> &[T] {
&self.data
}
fn into_vec(self) -> Vec<T> {
self.data
}
fn filled<const R: usize>(value: T, shape: &'static [usize; R]) -> Self
where
T: Clone,
{
let len: usize = shape.iter().product();
Self {
data: vec![value; len],
}
}
}
impl<T, const D0: usize> Tensor1<T, D0> {
pub const SHAPE: &'static [usize] = &[D0];
pub fn from_vec(data: Vec<T>) -> crate::Result<Self> {
Ok(Self {
storage: TensorData::from_vec(data, &[D0])?,
})
}
pub fn from_array(data: [T; D0]) -> Self {
Self {
storage: TensorData { data: data.into() },
}
}
pub fn filled(value: T) -> Self
where
T: Clone,
{
Self {
storage: TensorData::filled(value, &[D0]),
}
}
pub fn as_slice(&self) -> &[T] {
self.storage.as_slice()
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.storage.data
}
pub fn into_vec(self) -> Vec<T> {
self.storage.into_vec()
}
pub fn get(&self, index: usize) -> Option<&T> {
self.storage.data.get(index)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
self.storage.data.get_mut(index)
}
}
impl<T, const D0: usize, const D1: usize> Tensor2<T, D0, D1> {
pub const SHAPE: &'static [usize] = &[D0, D1];
pub fn from_vec(data: Vec<T>) -> crate::Result<Self> {
Ok(Self {
storage: TensorData::from_vec(data, &[D0, D1])?,
})
}
pub fn from_array(data: [[T; D1]; D0]) -> Self {
Self {
storage: TensorData {
data: data.into_iter().flat_map(IntoIterator::into_iter).collect(),
},
}
}
pub fn filled(value: T) -> Self
where
T: Clone,
{
Self {
storage: TensorData::filled(value, &[D0, D1]),
}
}
pub fn as_slice(&self) -> &[T] {
self.storage.as_slice()
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.storage.data
}
pub fn into_vec(self) -> Vec<T> {
self.storage.into_vec()
}
pub fn get(&self, row: usize, col: usize) -> Option<&T> {
(row < D0 && col < D1).then(|| &self.storage.data[row * D1 + col])
}
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
(row < D0 && col < D1).then(|| &mut self.storage.data[row * D1 + col])
}
}
impl<T, const D0: usize, const D1: usize, const D2: usize> Tensor3<T, D0, D1, D2> {
pub const SHAPE: &'static [usize] = &[D0, D1, D2];
pub fn from_vec(data: Vec<T>) -> crate::Result<Self> {
Ok(Self {
storage: TensorData::from_vec(data, &[D0, D1, D2])?,
})
}
pub fn from_array(data: [[[T; D2]; D1]; D0]) -> Self {
Self {
storage: TensorData {
data: data
.into_iter()
.flat_map(IntoIterator::into_iter)
.flat_map(IntoIterator::into_iter)
.collect(),
},
}
}
pub fn filled(value: T) -> Self
where
T: Clone,
{
Self {
storage: TensorData::filled(value, &[D0, D1, D2]),
}
}
pub fn as_slice(&self) -> &[T] {
self.storage.as_slice()
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.storage.data
}
pub fn into_vec(self) -> Vec<T> {
self.storage.into_vec()
}
pub fn get(&self, d0: usize, d1: usize, d2: usize) -> Option<&T> {
(d0 < D0 && d1 < D1 && d2 < D2).then(|| &self.storage.data[(d0 * D1 + d1) * D2 + d2])
}
pub fn get_mut(&mut self, d0: usize, d1: usize, d2: usize) -> Option<&mut T> {
(d0 < D0 && d1 < D1 && d2 < D2).then(|| &mut self.storage.data[(d0 * D1 + d1) * D2 + d2])
}
}
impl<T, const D0: usize, const D1: usize, const D2: usize, const D3: usize>
Tensor4<T, D0, D1, D2, D3>
{
pub const SHAPE: &'static [usize] = &[D0, D1, D2, D3];
pub fn from_vec(data: Vec<T>) -> crate::Result<Self> {
Ok(Self {
storage: TensorData::from_vec(data, &[D0, D1, D2, D3])?,
})
}
pub fn from_array(data: [[[[T; D3]; D2]; D1]; D0]) -> Self {
Self {
storage: TensorData {
data: data
.into_iter()
.flat_map(IntoIterator::into_iter)
.flat_map(IntoIterator::into_iter)
.flat_map(IntoIterator::into_iter)
.collect(),
},
}
}
pub fn filled(value: T) -> Self
where
T: Clone,
{
Self {
storage: TensorData::filled(value, &[D0, D1, D2, D3]),
}
}
pub fn as_slice(&self) -> &[T] {
self.storage.as_slice()
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.storage.data
}
pub fn into_vec(self) -> Vec<T> {
self.storage.into_vec()
}
pub fn get(&self, d0: usize, d1: usize, d2: usize, d3: usize) -> Option<&T> {
(d0 < D0 && d1 < D1 && d2 < D2 && d3 < D3)
.then(|| &self.storage.data[((d0 * D1 + d1) * D2 + d2) * D3 + d3])
}
pub fn get_mut(&mut self, d0: usize, d1: usize, d2: usize, d3: usize) -> Option<&mut T> {
(d0 < D0 && d1 < D1 && d2 < D2 && d3 < D3)
.then(|| &mut self.storage.data[((d0 * D1 + d1) * D2 + d2) * D3 + d3])
}
}
impl<const D0: usize> Tensor1<f32, D0> {
pub fn zeros() -> Self {
Self::filled(0.0)
}
pub fn ones() -> Self {
Self::filled(1.0)
}
}
impl<const D0: usize, const D1: usize> Tensor2<f32, D0, D1> {
pub fn zeros() -> Self {
Self::filled(0.0)
}
pub fn ones() -> Self {
Self::filled(1.0)
}
}
impl<const D0: usize, const D1: usize, const D2: usize> Tensor3<f32, D0, D1, D2> {
pub fn zeros() -> Self {
Self::filled(0.0)
}
pub fn ones() -> Self {
Self::filled(1.0)
}
}
impl<const D0: usize, const D1: usize, const D2: usize, const D3: usize>
Tensor4<f32, D0, D1, D2, D3>
{
pub fn zeros() -> Self {
Self::filled(0.0)
}
pub fn ones() -> Self {
Self::filled(1.0)
}
}
impl<T, const D0: usize> TryFrom<Vec<T>> for Tensor1<T, D0> {
type Error = crate::Error;
fn try_from(data: Vec<T>) -> crate::Result<Self> {
Self::from_vec(data)
}
}
impl<T, const D0: usize, const D1: usize> TryFrom<Vec<T>> for Tensor2<T, D0, D1> {
type Error = crate::Error;
fn try_from(data: Vec<T>) -> crate::Result<Self> {
Self::from_vec(data)
}
}
impl<T, const D0: usize, const D1: usize, const D2: usize> TryFrom<Vec<T>>
for Tensor3<T, D0, D1, D2>
{
type Error = crate::Error;
fn try_from(data: Vec<T>) -> crate::Result<Self> {
Self::from_vec(data)
}
}
impl<T, const D0: usize, const D1: usize, const D2: usize, const D3: usize> TryFrom<Vec<T>>
for Tensor4<T, D0, D1, D2, D3>
{
type Error = crate::Error;
fn try_from(data: Vec<T>) -> crate::Result<Self> {
Self::from_vec(data)
}
}
impl<T: fmt::Debug, const D0: usize> fmt::Debug for Tensor1<T, D0> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Tensor1")
.field("shape", &Self::SHAPE)
.field("data", &self.storage.data)
.finish()
}
}
impl<T: fmt::Debug, const D0: usize, const D1: usize> fmt::Debug for Tensor2<T, D0, D1> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Tensor2")
.field("shape", &Self::SHAPE)
.field("data", &self.storage.data)
.finish()
}
}
impl<T: fmt::Debug, const D0: usize, const D1: usize, const D2: usize> fmt::Debug
for Tensor3<T, D0, D1, D2>
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Tensor3")
.field("shape", &Self::SHAPE)
.field("data", &self.storage.data)
.finish()
}
}
impl<T: fmt::Debug, const D0: usize, const D1: usize, const D2: usize, const D3: usize> fmt::Debug
for Tensor4<T, D0, D1, D2, D3>
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("Tensor4")
.field("shape", &Self::SHAPE)
.field("data", &self.storage.data)
.finish()
}
}