use crate::error::{CnnError, CnnResult};
#[derive(Debug, Clone)]
pub struct Tensor {
data: Vec<f32>,
shape: Vec<usize>,
strides: Vec<usize>,
}
impl Tensor {
pub fn zeros(shape: &[usize]) -> Self {
let numel: usize = shape.iter().product();
let data = vec![0.0; numel];
let strides = Self::compute_strides(shape);
Self {
data,
shape: shape.to_vec(),
strides,
}
}
pub fn ones(shape: &[usize]) -> Self {
let numel: usize = shape.iter().product();
let data = vec![1.0; numel];
let strides = Self::compute_strides(shape);
Self {
data,
shape: shape.to_vec(),
strides,
}
}
pub fn from_data(data: Vec<f32>, shape: &[usize]) -> CnnResult<Self> {
let expected_numel: usize = shape.iter().product();
if data.len() != expected_numel {
return Err(CnnError::invalid_shape(
format!("data length {}", expected_numel),
format!("data length {}", data.len()),
));
}
let strides = Self::compute_strides(shape);
Ok(Self {
data,
shape: shape.to_vec(),
strides,
})
}
pub fn full(shape: &[usize], value: f32) -> Self {
let numel: usize = shape.iter().product();
let data = vec![value; numel];
let strides = Self::compute_strides(shape);
Self {
data,
shape: shape.to_vec(),
strides,
}
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn strides(&self) -> &[usize] {
&self.strides
}
#[inline]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[inline]
pub fn numel(&self) -> usize {
self.data.len()
}
#[inline]
pub fn data(&self) -> &[f32] {
&self.data
}
#[inline]
pub fn data_mut(&mut self) -> &mut [f32] {
&mut self.data
}
#[inline]
pub fn get_4d(&self, n: usize, h: usize, w: usize, c: usize) -> f32 {
debug_assert!(self.shape.len() == 4);
let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
self.data[idx]
}
#[inline]
pub fn set_4d(&mut self, n: usize, h: usize, w: usize, c: usize, value: f32) {
debug_assert!(self.shape.len() == 4);
let idx = n * self.strides[0] + h * self.strides[1] + w * self.strides[2] + c;
self.data[idx] = value;
}
#[inline]
pub fn batch_size(&self) -> usize {
if self.shape.is_empty() {
0
} else {
self.shape[0]
}
}
#[inline]
pub fn height(&self) -> usize {
if self.shape.len() < 2 {
1
} else {
self.shape[1]
}
}
#[inline]
pub fn width(&self) -> usize {
if self.shape.len() < 3 {
1
} else {
self.shape[2]
}
}
#[inline]
pub fn channels(&self) -> usize {
if self.shape.len() < 4 {
1
} else {
self.shape[3]
}
}
pub fn reshape(&self, new_shape: &[usize]) -> CnnResult<Self> {
let new_numel: usize = new_shape.iter().product();
if new_numel != self.numel() {
return Err(CnnError::invalid_shape(
format!("numel {}", self.numel()),
format!("numel {}", new_numel),
));
}
Self::from_data(self.data.clone(), new_shape)
}
pub fn view(&self, new_shape: &[usize]) -> CnnResult<Self> {
self.reshape(new_shape)
}
pub fn slice_batch(&self, start: usize, end: usize) -> CnnResult<Self> {
if self.shape.is_empty() {
return Err(CnnError::invalid_shape("non-empty tensor", "empty tensor"));
}
if start >= end || end > self.shape[0] {
return Err(CnnError::IndexOutOfBounds {
index: end,
size: self.shape[0],
});
}
let batch_stride = self.strides[0];
let start_idx = start * batch_stride;
let end_idx = end * batch_stride;
let mut new_shape = self.shape.clone();
new_shape[0] = end - start;
Self::from_data(self.data[start_idx..end_idx].to_vec(), &new_shape)
}
pub fn map<F>(&self, f: F) -> Self
where
F: Fn(f32) -> f32,
{
let data: Vec<f32> = self.data.iter().map(|&x| f(x)).collect();
Self {
data,
shape: self.shape.clone(),
strides: self.strides.clone(),
}
}
pub fn map_inplace<F>(&mut self, f: F)
where
F: Fn(f32) -> f32,
{
for x in &mut self.data {
*x = f(*x);
}
}
pub fn add(&self, other: &Self) -> CnnResult<Self> {
if self.shape != other.shape {
return Err(CnnError::shape_mismatch(format!(
"add: {:?} vs {:?}",
self.shape, other.shape
)));
}
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a + b)
.collect();
Self::from_data(data, &self.shape)
}
pub fn mul(&self, other: &Self) -> CnnResult<Self> {
if self.shape != other.shape {
return Err(CnnError::shape_mismatch(format!(
"mul: {:?} vs {:?}",
self.shape, other.shape
)));
}
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| a * b)
.collect();
Self::from_data(data, &self.shape)
}
pub fn scale(&self, scalar: f32) -> Self {
self.map(|x| x * scalar)
}
pub fn sum(&self) -> f32 {
self.data.iter().sum()
}
pub fn mean(&self) -> f32 {
self.sum() / self.numel() as f32
}
pub fn max(&self) -> f32 {
self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
}
pub fn min(&self) -> f32 {
self.data.iter().cloned().fold(f32::INFINITY, f32::min)
}
}
impl Default for Tensor {
fn default() -> Self {
Self::zeros(&[])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_zeros() {
let t = Tensor::zeros(&[2, 3, 4, 5]);
assert_eq!(t.shape(), &[2, 3, 4, 5]);
assert_eq!(t.numel(), 2 * 3 * 4 * 5);
assert!(t.data().iter().all(|&x| x == 0.0));
}
#[test]
fn test_tensor_ones() {
let t = Tensor::ones(&[2, 2, 2, 2]);
assert!(t.data().iter().all(|&x| x == 1.0));
}
#[test]
fn test_tensor_strides() {
let t = Tensor::zeros(&[2, 3, 4, 5]);
assert_eq!(t.strides(), &[60, 20, 5, 1]); }
#[test]
fn test_tensor_get_set_4d() {
let mut t = Tensor::zeros(&[2, 3, 4, 5]);
t.set_4d(1, 2, 3, 4, 42.0);
assert_eq!(t.get_4d(1, 2, 3, 4), 42.0);
}
#[test]
fn test_tensor_reshape() {
let t = Tensor::ones(&[2, 3, 4, 5]);
let reshaped = t.reshape(&[6, 4, 5]).unwrap();
assert_eq!(reshaped.shape(), &[6, 4, 5]);
assert_eq!(reshaped.numel(), t.numel());
}
#[test]
fn test_tensor_map() {
let t = Tensor::full(&[2, 2], 2.0);
let squared = t.map(|x| x * x);
assert!(squared.data().iter().all(|&x| x == 4.0));
}
#[test]
fn test_tensor_add() {
let a = Tensor::ones(&[2, 2]);
let b = Tensor::ones(&[2, 2]);
let c = a.add(&b).unwrap();
assert!(c.data().iter().all(|&x| x == 2.0));
}
#[test]
fn test_tensor_sum_mean() {
let t = Tensor::ones(&[2, 3]);
assert_eq!(t.sum(), 6.0);
assert_eq!(t.mean(), 1.0);
}
}