use torsh_core::dtype::TensorElement;
use torsh_core::error::{Result, TorshError};
use torsh_core::shape::Shape;
#[derive(Debug)]
pub struct TensorView<'a, T: TensorElement> {
data: &'a [T],
shape: Shape,
strides: Vec<usize>,
offset: usize,
}
#[derive(Debug)]
pub struct TensorViewMut<'a, T: TensorElement> {
data: &'a mut [T],
shape: Shape,
strides: Vec<usize>,
offset: usize,
}
impl<'a, T: TensorElement> TensorView<'a, T> {
pub fn new(data: &'a [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
#[inline]
pub fn shape(&self) -> &Shape {
&self.shape
}
#[inline]
pub fn strides(&self) -> &[usize] {
&self.strides
}
#[inline]
pub fn len(&self) -> usize {
self.shape.numel()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn data(&self) -> &[T] {
&self.data[self.offset..]
}
pub fn is_contiguous(&self) -> bool {
if self.shape.dims().is_empty() {
return true;
}
let dims = self.shape.dims();
let mut expected_stride = 1;
for i in (0..dims.len()).rev() {
if self.strides[i] != expected_stride {
return false;
}
expected_stride *= dims[i];
}
true
}
pub fn get(&self, index: usize) -> Result<&T> {
if index >= self.len() {
return Err(TorshError::IndexError {
index,
size: self.len(),
});
}
Ok(&self.data[self.offset + index])
}
pub fn get_at(&self, indices: &[usize]) -> Result<&T> {
if indices.len() != self.shape.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Expected {} indices, got {}",
self.shape.ndim(),
indices.len()
)));
}
let flat_index = self.compute_flat_index(indices)?;
Ok(&self.data[self.offset + flat_index])
}
fn compute_flat_index(&self, indices: &[usize]) -> Result<usize> {
let dims = self.shape.dims();
let mut flat_index = 0;
for (i, &idx) in indices.iter().enumerate() {
if idx >= dims[i] {
return Err(TorshError::IndexError {
index: idx,
size: dims[i],
});
}
flat_index += idx * self.strides[i];
}
Ok(flat_index)
}
pub fn iter(&self) -> TensorViewIter<'a, T> {
TensorViewIter {
data: self.data,
offset: self.offset,
len: self.len(),
current: 0,
}
}
pub fn to_vec(&self) -> Vec<T>
where
T: Copy,
{
self.data[self.offset..self.offset + self.len()].to_vec()
}
}
impl<'a, T: TensorElement> TensorViewMut<'a, T> {
pub fn new(data: &'a mut [T], shape: Shape, strides: Vec<usize>, offset: usize) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
#[inline]
pub fn shape(&self) -> &Shape {
&self.shape
}
#[inline]
pub fn strides(&self) -> &[usize] {
&self.strides
}
#[inline]
pub fn len(&self) -> usize {
self.shape.numel()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn data(&self) -> &[T] {
&self.data[self.offset..]
}
#[inline]
pub fn data_mut(&mut self) -> &mut [T] {
let len = self.len();
&mut self.data[self.offset..self.offset + len]
}
pub fn is_contiguous(&self) -> bool {
if self.shape.dims().is_empty() {
return true;
}
let dims = self.shape.dims();
let mut expected_stride = 1;
for i in (0..dims.len()).rev() {
if self.strides[i] != expected_stride {
return false;
}
expected_stride *= dims[i];
}
true
}
pub fn get(&self, index: usize) -> Result<&T> {
if index >= self.len() {
return Err(TorshError::IndexError {
index,
size: self.len(),
});
}
Ok(&self.data[self.offset + index])
}
pub fn get_mut(&mut self, index: usize) -> Result<&mut T> {
if index >= self.len() {
return Err(TorshError::IndexError {
index,
size: self.len(),
});
}
Ok(&mut self.data[self.offset + index])
}
pub fn fill(&mut self, value: T)
where
T: Copy,
{
let len = self.len();
self.data[self.offset..self.offset + len].fill(value);
}
pub fn iter(&self) -> TensorViewIter<'_, T> {
TensorViewIter {
data: self.data,
offset: self.offset,
len: self.len(),
current: 0,
}
}
pub fn iter_mut(&mut self) -> TensorViewIterMut<'_, T> {
let len = self.len();
TensorViewIterMut {
data: &mut self.data[self.offset..self.offset + len],
current: 0,
}
}
}
pub struct TensorViewIter<'a, T: TensorElement> {
data: &'a [T],
offset: usize,
len: usize,
current: usize,
}
impl<'a, T: TensorElement> Iterator for TensorViewIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.len {
None
} else {
let item = &self.data[self.offset + self.current];
self.current += 1;
Some(item)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.len - self.current;
(remaining, Some(remaining))
}
}
impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIter<'a, T> {
fn len(&self) -> usize {
self.len - self.current
}
}
pub struct TensorViewIterMut<'a, T: TensorElement> {
data: &'a mut [T],
current: usize,
}
impl<'a, T: TensorElement> Iterator for TensorViewIterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.data.len() {
None
} else {
let item = unsafe {
let ptr = self.data.as_mut_ptr().add(self.current);
&mut *ptr
};
self.current += 1;
Some(item)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.data.len() - self.current;
(remaining, Some(remaining))
}
}
impl<'a, T: TensorElement> ExactSizeIterator for TensorViewIterMut<'a, T> {
fn len(&self) -> usize {
self.data.len() - self.current
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_view_creation() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![2, 2]);
let strides = vec![2, 1];
let view = TensorView::new(&data, shape, strides, 0);
assert_eq!(view.len(), 4);
assert!(!view.is_empty());
assert_eq!(view.shape().dims(), &[2, 2]);
}
#[test]
fn test_tensor_view_contiguous() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![2, 2]);
let strides = vec![2, 1];
let view = TensorView::new(&data, shape, strides, 0);
assert!(view.is_contiguous());
}
#[test]
fn test_tensor_view_get() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let view = TensorView::new(&data, shape, strides, 0);
assert_eq!(*view.get(0).expect("get should succeed"), 1.0);
assert_eq!(*view.get(1).expect("get should succeed"), 2.0);
assert_eq!(*view.get(2).expect("get should succeed"), 3.0);
assert_eq!(*view.get(3).expect("get should succeed"), 4.0);
assert!(view.get(4).is_err());
}
#[test]
fn test_tensor_view_iter() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let view = TensorView::new(&data, shape, strides, 0);
let collected: Vec<_> = view.iter().copied().collect();
assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_tensor_view_mut_creation() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![2, 2]);
let strides = vec![2, 1];
let view = TensorViewMut::new(&mut data, shape, strides, 0);
assert_eq!(view.len(), 4);
assert!(!view.is_empty());
}
#[test]
fn test_tensor_view_mut_fill() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
view.fill(0.0);
assert_eq!(data, vec![0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_tensor_view_mut_get_mut() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
*view.get_mut(0).expect("get_mut should succeed") = 10.0;
*view.get_mut(1).expect("get_mut should succeed") = 20.0;
assert_eq!(data, vec![10.0, 20.0, 3.0, 4.0]);
}
#[test]
fn test_tensor_view_mut_iter_mut() {
let mut data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let mut view = TensorViewMut::new(&mut data, shape, strides, 0);
for elem in view.iter_mut() {
*elem *= 2.0;
}
assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_tensor_view_to_vec() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = Shape::new(vec![4]);
let strides = vec![1];
let view = TensorView::new(&data, shape, strides, 0);
let copied = view.to_vec();
assert_eq!(copied, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_tensor_view_with_offset() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = Shape::new(vec![2]);
let strides = vec![1];
let view = TensorView::new(&data, shape, strides, 2);
assert_eq!(view.len(), 2);
assert_eq!(*view.get(0).expect("get should succeed"), 3.0);
assert_eq!(*view.get(1).expect("get should succeed"), 4.0);
}
}