use std::fmt;
use num_traits::Num;
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tensor<T: Num> {
data: Vec<T>,
shape: Vec<usize>,
}
impl<T: Num + Clone> Tensor<T> {
pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Shape cannot be empty".to_string(),
});
}
if shape.contains(&0) {
return Err(RealizarError::InvalidShape {
reason: "Shape dimensions cannot be zero".to_string(),
});
}
let expected_size = shape.iter().product();
if data.len() != expected_size {
return Err(RealizarError::DataShapeMismatch {
data_size: data.len(),
shape,
expected: expected_size,
});
}
Ok(Self { data, shape })
}
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub fn ndim(&self) -> usize {
self.shape.len()
}
#[must_use]
pub fn size(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn data(&self) -> &[T] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [T] {
&mut self.data
}
}
impl<T: Num + Clone + fmt::Display> fmt::Display for Tensor<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Tensor(shape={:?}, data=[", self.shape)?;
for (i, val) in self.data.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{val}")?;
}
write!(f, "])")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_tensor() {
let t = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test");
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.ndim(), 2);
assert_eq!(t.size(), 6);
}
#[test]
fn test_empty_shape_error() {
let result = Tensor::from_vec(vec![], vec![1.0, 2.0]);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RealizarError::InvalidShape { .. }
));
}
#[test]
fn test_zero_dimension_error() {
let result = Tensor::<f32>::from_vec(vec![2, 0], vec![]);
assert!(result.is_err());
}
#[test]
fn test_size_mismatch_error() {
let result = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0]);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RealizarError::DataShapeMismatch { .. }
));
}
#[test]
fn test_display() {
let t = Tensor::from_vec(vec![2], vec![1.0, 2.0]).expect("test");
let display = format!("{t}");
assert!(display.contains("shape=[2]"));
assert!(display.contains('1'));
assert!(display.contains('2'));
}
#[test]
fn test_tensor_1d() {
let t = Tensor::from_vec(vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("test");
assert_eq!(t.shape(), &[5]);
assert_eq!(t.ndim(), 1);
assert_eq!(t.size(), 5);
assert_eq!(t.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_tensor_3d() {
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let t = Tensor::from_vec(vec![2, 3, 4], data.clone()).expect("test");
assert_eq!(t.shape(), &[2, 3, 4]);
assert_eq!(t.ndim(), 3);
assert_eq!(t.size(), 24);
assert_eq!(t.data(), data.as_slice());
}
#[test]
fn test_tensor_4d() {
let data: Vec<f64> = (0..120).map(|x| x as f64).collect();
let t = Tensor::from_vec(vec![2, 3, 4, 5], data).expect("test");
assert_eq!(t.shape(), &[2, 3, 4, 5]);
assert_eq!(t.ndim(), 4);
assert_eq!(t.size(), 120);
}
#[test]
fn test_tensor_single_element() {
let t = Tensor::from_vec(vec![1], vec![42.0]).expect("test");
assert_eq!(t.shape(), &[1]);
assert_eq!(t.ndim(), 1);
assert_eq!(t.size(), 1);
assert_eq!(t.data(), &[42.0]);
}
#[test]
fn test_tensor_with_integers() {
let t = Tensor::from_vec(vec![2, 2], vec![1, 2, 3, 4]).expect("test");
assert_eq!(t.shape(), &[2, 2]);
assert_eq!(t.size(), 4);
assert_eq!(t.data(), &[1, 2, 3, 4]);
}
#[test]
fn test_tensor_clone() {
let t1 = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test");
let t2 = t1.clone();
assert_eq!(t1.shape(), t2.shape());
assert_eq!(t1.data(), t2.data());
}
#[test]
fn test_tensor_debug() {
let t = Tensor::from_vec(vec![2], vec![1.0, 2.0]).expect("test");
let debug = format!("{:?}", t);
assert!(debug.contains("Tensor"));
assert!(debug.contains("data"));
assert!(debug.contains("shape"));
}
#[test]
fn test_display_multiple_elements() {
let t = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("test");
let display = format!("{t}");
assert!(display.contains("shape=[3]"));
assert!(display.contains("1"));
assert!(display.contains("2"));
assert!(display.contains("3"));
assert!(display.contains(", "));
}
#[test]
fn test_display_2d() {
let t = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).expect("test");
let display = format!("{t}");
assert!(display.contains("shape=[2, 2]"));
}
#[test]
fn test_zero_in_middle_dimension() {
let result = Tensor::<f32>::from_vec(vec![2, 0, 3], vec![]);
assert!(result.is_err());
match result.unwrap_err() {
RealizarError::InvalidShape { reason } => {
assert!(reason.contains("zero"));
},
_ => panic!("Expected InvalidShape error"),
}
}
#[test]
fn test_shape_with_ones() {
let t = Tensor::from_vec(vec![1, 1, 1], vec![42.0]).expect("test");
assert_eq!(t.shape(), &[1, 1, 1]);
assert_eq!(t.ndim(), 3);
assert_eq!(t.size(), 1);
}
#[test]
fn test_data_size_too_large() {
let result = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert!(result.is_err());
match result.unwrap_err() {
RealizarError::DataShapeMismatch {
data_size,
expected,
..
} => {
assert_eq!(data_size, 5);
assert_eq!(expected, 4);
},
_ => panic!("Expected DataShapeMismatch error"),
}
}
#[test]
fn test_data_size_too_small() {
let result = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0]);
assert!(result.is_err());
match result.unwrap_err() {
RealizarError::DataShapeMismatch {
data_size,
expected,
..
} => {
assert_eq!(data_size, 3);
assert_eq!(expected, 4);
},
_ => panic!("Expected DataShapeMismatch error"),
}
}
#[test]
fn test_shape_returns_reference() {
let t = Tensor::from_vec(vec![2, 3, 4], vec![0.0f32; 24]).expect("test");
let shape = t.shape();
assert_eq!(shape.len(), 3);
assert_eq!(shape[0], 2);
assert_eq!(shape[1], 3);
assert_eq!(shape[2], 4);
}
#[test]
fn test_serde_roundtrip() {
let original =
Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("test");
let serialized = serde_json::to_string(&original).expect("serialize");
let deserialized: Tensor<f64> = serde_json::from_str(&serialized).expect("deserialize");
assert_eq!(original.shape(), deserialized.shape());
assert_eq!(original.data(), deserialized.data());
}
#[test]
fn test_large_tensor() {
let size = 1000;
let data: Vec<f32> = (0..size).map(|x| x as f32).collect();
let t = Tensor::from_vec(vec![10, 10, 10], data).expect("test");
assert_eq!(t.size(), 1000);
assert_eq!(t.ndim(), 3);
}
}