use std::marker::PhantomData;
use std::ops::Range;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryLayout {
#[default]
RowMajor,
ColumnMajor,
Tiled {
tile_size: [usize; 2],
},
}
#[derive(Debug)]
pub struct TensorView<T> {
shape: [usize; 4],
strides: [usize; 4],
offset: usize,
layout: MemoryLayout,
ndim: usize,
_marker: PhantomData<T>,
}
impl<T> TensorView<T> {
pub fn new(shape: [usize; 4]) -> Self {
let ndim = Self::compute_ndim(&shape);
let strides = Self::compute_row_major_strides(&shape);
Self {
shape,
strides,
offset: 0,
layout: MemoryLayout::RowMajor,
ndim,
_marker: PhantomData,
}
}
pub fn with_strides(shape: [usize; 4], strides: [usize; 4]) -> Self {
let ndim = Self::compute_ndim(&shape);
Self {
shape,
strides,
offset: 0,
layout: MemoryLayout::RowMajor,
ndim,
_marker: PhantomData,
}
}
pub fn new_1d(len: usize) -> Self {
Self::new([len, 1, 1, 1])
}
pub fn new_2d(rows: usize, cols: usize) -> Self {
Self::new([rows, cols, 1, 1])
}
pub fn new_3d(d0: usize, d1: usize, d2: usize) -> Self {
Self::new([d0, d1, d2, 1])
}
pub fn new_4d(d0: usize, d1: usize, d2: usize, d3: usize) -> Self {
Self::new([d0, d1, d2, d3])
}
pub fn shape(&self) -> &[usize; 4] {
&self.shape
}
pub fn strides(&self) -> &[usize; 4] {
&self.strides
}
pub fn offset(&self) -> usize {
self.offset
}
pub fn layout(&self) -> MemoryLayout {
self.layout
}
pub fn ndim(&self) -> usize {
self.ndim
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn is_contiguous(&self) -> bool {
let expected_strides = Self::compute_row_major_strides(&self.shape);
self.strides == expected_strides
}
pub fn is_empty(&self) -> bool {
self.numel() == 0
}
pub fn dim(&self, dim: usize) -> usize {
self.shape[dim]
}
pub fn stride(&self, dim: usize) -> usize {
self.strides[dim]
}
pub fn slice(&self, range: Range<usize>) -> Self {
assert!(range.end <= self.shape[0], "Slice range out of bounds");
let new_offset = self.offset + range.start * self.strides[0];
let mut new_shape = self.shape;
new_shape[0] = range.end - range.start;
Self {
shape: new_shape,
strides: self.strides,
offset: new_offset,
layout: self.layout,
ndim: self.ndim,
_marker: PhantomData,
}
}
pub fn slice_dim(&self, dim: usize, range: Range<usize>) -> Self {
assert!(dim < 4, "Dimension out of bounds");
assert!(range.end <= self.shape[dim], "Slice range out of bounds");
let new_offset = self.offset + range.start * self.strides[dim];
let mut new_shape = self.shape;
new_shape[dim] = range.end - range.start;
Self {
shape: new_shape,
strides: self.strides,
offset: new_offset,
layout: self.layout,
ndim: self.ndim,
_marker: PhantomData,
}
}
pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
assert!(dim0 < 4 && dim1 < 4, "Dimension out of bounds");
let mut new_shape = self.shape;
let mut new_strides = self.strides;
new_shape.swap(dim0, dim1);
new_strides.swap(dim0, dim1);
Self {
shape: new_shape,
strides: new_strides,
offset: self.offset,
layout: self.layout,
ndim: self.ndim,
_marker: PhantomData,
}
}
pub fn reshape(&self, new_shape: [usize; 4]) -> Option<Self> {
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return None;
}
if !self.is_contiguous() {
return None;
}
Some(Self::new(new_shape))
}
pub fn squeeze(&self) -> Self {
let mut new_shape = [1usize; 4];
let mut new_strides = [1usize; 4];
let mut new_ndim = 0;
for i in 0..4 {
if self.shape[i] > 1 {
new_shape[new_ndim] = self.shape[i];
new_strides[new_ndim] = self.strides[i];
new_ndim += 1;
}
}
if new_ndim == 0 {
new_ndim = 1;
}
Self {
shape: new_shape,
strides: new_strides,
offset: self.offset,
layout: self.layout,
ndim: new_ndim,
_marker: PhantomData,
}
}
pub fn unsqueeze(&self, dim: usize) -> Option<Self> {
if dim > self.ndim || self.ndim >= 4 {
return None;
}
let mut new_shape = [1usize; 4];
let mut new_strides = [1usize; 4];
#[allow(clippy::manual_memcpy)]
for i in 0..dim {
new_shape[i] = self.shape[i];
new_strides[i] = self.strides[i];
}
new_shape[dim] = 1;
new_strides[dim] = if dim < self.ndim { self.strides[dim] * self.shape[dim] } else { 1 };
#[allow(clippy::manual_memcpy)]
for i in dim..self.ndim {
new_shape[i + 1] = self.shape[i];
new_strides[i + 1] = self.strides[i];
}
Some(Self {
shape: new_shape,
strides: new_strides,
offset: self.offset,
layout: self.layout,
ndim: self.ndim + 1,
_marker: PhantomData,
})
}
pub fn with_layout(mut self, layout: MemoryLayout) -> Self {
self.layout = layout;
self
}
pub fn linear_index(&self, indices: [usize; 4]) -> usize {
self.offset
+ indices[0] * self.strides[0]
+ indices[1] * self.strides[1]
+ indices[2] * self.strides[2]
+ indices[3] * self.strides[3]
}
fn compute_row_major_strides(shape: &[usize; 4]) -> [usize; 4] {
let mut strides = [1usize; 4];
strides[3] = 1;
strides[2] = shape[3];
strides[1] = shape[3] * shape[2];
strides[0] = shape[3] * shape[2] * shape[1];
strides
}
fn compute_ndim(shape: &[usize; 4]) -> usize {
for i in (0..4).rev() {
if shape[i] > 1 {
return i + 1;
}
}
1 }
}
impl<T> Clone for TensorView<T> {
fn clone(&self) -> Self {
Self {
shape: self.shape,
strides: self.strides,
offset: self.offset,
layout: self.layout,
ndim: self.ndim,
_marker: PhantomData,
}
}
}
impl<T> Default for TensorView<T> {
fn default() -> Self {
Self::new([1, 1, 1, 1])
}
}
#[cfg(test)]
mod tests;