use crate::error::TensorError;
#[derive(Debug, Clone)]
pub struct Tensor {
shape: Vec<usize>,
strides: Vec<usize>,
data: Vec<f32>,
}
impl Tensor {
pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
let product: usize = shape.iter().product();
if data.len() != product {
return Err(TensorError::DataLengthMismatch {
len: data.len(),
shape: shape.clone(),
product,
});
}
let strides = compute_strides(&shape);
Ok(Self {
shape,
strides,
data,
})
}
pub fn zeros(shape: Vec<usize>) -> Self {
let product: usize = shape.iter().product();
let strides = compute_strides(&shape);
Self {
shape,
strides,
data: vec![0.0; product],
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn data(&self) -> &[f32] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [f32] {
&mut self.data
}
pub fn get(&self, indices: &[usize]) -> f32 {
let offset = self.offset(indices);
self.data[offset]
}
pub fn set(&mut self, indices: &[usize], value: f32) {
let offset = self.offset(indices);
self.data[offset] = value;
}
fn offset(&self, indices: &[usize]) -> usize {
indices
.iter()
.zip(self.strides.iter())
.map(|(&i, &s)| i * s)
.sum()
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
let new_product: usize = new_shape.iter().product();
if new_product != self.data.len() {
return Err(TensorError::ShapeMismatch {
expected: new_shape,
got: self.shape.clone(),
});
}
Self::new(new_shape, self.data.clone())
}
pub fn transpose(&self, perm: &[usize]) -> Self {
let ndim = self.ndim();
let mut new_shape = vec![0usize; ndim];
for (i, &p) in perm.iter().enumerate() {
new_shape[i] = self.shape[p];
}
let new_product: usize = new_shape.iter().product();
let mut new_data = vec![0.0f32; new_product];
let new_strides = compute_strides(&new_shape);
let mut old_indices = vec![0usize; ndim];
for flat in 0..self.data.len() {
let mut rem = flat;
for d in 0..ndim {
old_indices[d] = rem / self.strides[d];
rem %= self.strides[d];
}
let new_offset: usize = perm
.iter()
.enumerate()
.map(|(new_d, &old_d)| old_indices[old_d] * new_strides[new_d])
.sum();
new_data[new_offset] = self.data[flat];
}
Self {
shape: new_shape,
strides: new_strides,
data: new_data,
}
}
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}