use crate::error::RusTorchError;
use crate::tensor::memory::optimization::MemoryOptimization;
use crate::tensor::operations::zero_copy::{TensorIterOps, ZeroCopyOps};
use crate::tensor::Tensor;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_try_from_vec_success() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let result = Tensor::try_from_vec(data, shape);
assert!(result.is_ok());
let tensor = result.unwrap();
assert_eq!(tensor.shape(), &[2, 2]);
}
#[test]
fn test_try_from_vec_shape_mismatch() {
let data = vec![1.0f32, 2.0, 3.0]; let shape = vec![2, 2];
let result = Tensor::try_from_vec(data, shape);
assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::ShapeMismatch { expected, actual } => {
assert_eq!(expected, vec![4]); assert_eq!(actual, vec![3]); }
_ => panic!("Expected ShapeMismatch error"),
}
}
#[test]
fn test_try_from_vec_empty_shape() {
let data = vec![1.0f32];
let shape = vec![];
let result = Tensor::try_from_vec(data, shape);
assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::TensorOp { message, .. } => {
assert!(message.contains("Shape cannot be empty"));
}
_ => panic!("Expected TensorOp error"),
}
}
#[test]
fn test_try_from_vec_zero_dimension() {
let data: Vec<f32> = vec![];
let shape = vec![0, 2];
let result = Tensor::try_from_vec(data, shape);
assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::TensorOp { message, .. } => {
assert!(message.contains("zero dimension"));
}
_ => panic!("Expected TensorOp error"),
}
}
#[test]
fn test_try_view_success() {
let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
let result = tensor.try_view(&[4]);
assert!(result.is_ok());
let reshaped = result.unwrap();
assert_eq!(reshaped.shape(), &[4]);
}
#[test]
fn test_try_view_shape_mismatch() {
let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
let result = tensor.try_view(&[3]); assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::ShapeMismatch { .. } => {
}
_ => panic!("Expected ShapeMismatch error"),
}
}
#[test]
fn test_try_zeros_success() {
let result = Tensor::<f32>::try_zeros(&[2, 3]);
assert!(result.is_ok());
let tensor = result.unwrap();
assert_eq!(tensor.shape(), &[2, 3]);
for &value in tensor.as_slice().unwrap() {
assert_eq!(value, 0.0);
}
}
#[test]
fn test_try_zeros_empty_shape() {
let result = Tensor::<f32>::try_zeros(&[]);
assert!(result.is_err());
}
#[test]
fn test_try_zeros_zero_dimension() {
let result = Tensor::<f32>::try_zeros(&[2, 0, 3]);
assert!(result.is_err());
}
#[test]
fn test_try_ones_success() {
let result = Tensor::<f32>::try_ones(&[2, 2]);
assert!(result.is_ok());
let tensor = result.unwrap();
assert_eq!(tensor.shape(), &[2, 2]);
for &value in tensor.as_slice().unwrap() {
assert_eq!(value, 1.0);
}
}
#[test]
fn test_try_zeros_too_large() {
let result = Tensor::<f32>::try_zeros(&[2048, 2048]);
match result {
Ok(tensor) => {
assert_eq!(tensor.shape(), &[2048, 2048]);
assert_eq!(tensor.numel(), 2048 * 2048);
}
Err(_) => {
}
}
}
#[test]
fn test_try_zeros_reasonable_size() {
let result = Tensor::<f32>::try_zeros(&[100, 100]);
assert!(result.is_ok());
let tensor = result.unwrap();
assert_eq!(tensor.shape(), &[100, 100]);
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_auto_device_selection() {
let small_tensor = Tensor::<f32>::zeros_auto(&[2, 2]);
assert_eq!(small_tensor.device_type(), "cpu");
assert!(!small_tensor.is_on_gpu());
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_try_to_gpu_no_gpu_available() {
let tensor = Tensor::<f32>::zeros(&[2, 2]);
let result = tensor.try_to_gpu();
assert!(result.is_err());
match result.unwrap_err() {
RusTorchError::Device { device, message } => {
assert_eq!(device, "GPU");
assert!(message.contains("No GPU devices available"));
}
_ => panic!("Expected Device error"),
}
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_from_vec_auto() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let tensor = Tensor::from_vec_auto(data, shape);
assert_eq!(tensor.shape(), &[2, 2]);
assert_eq!(tensor.device_type(), "cpu");
}
#[test]
#[cfg(not(target_arch = "wasm32"))]
fn test_ones_auto() {
let tensor = Tensor::<f32>::ones_auto(&[3, 3]);
assert_eq!(tensor.shape(), &[3, 3]);
assert_eq!(tensor.device_type(), "cpu");
for &value in tensor.as_slice().unwrap() {
assert_eq!(value, 1.0);
}
}
#[test]
fn test_memory_info() {
let tensor = Tensor::<f32>::zeros(&[100, 100]);
let info = tensor.memory_info();
assert_eq!(info.total_elements, 10000);
assert_eq!(info.element_size, 4); assert_eq!(info.total_bytes, 40000); assert!(info.is_contiguous);
assert!(info.alignment >= 1);
assert!(!info.is_on_gpu);
assert_eq!(info.device, "cpu");
}
#[test]
fn test_can_optimize_memory() {
let small_tensor = Tensor::<f32>::zeros(&[2, 2]);
let large_tensor = Tensor::<f32>::zeros(&[100, 100]);
assert!(!small_tensor.can_optimize_memory());
large_tensor.can_optimize_memory(); }
#[test]
fn test_optimize_memory() {
let tensor = Tensor::<f32>::zeros(&[100, 100]);
let optimized = tensor.optimize_memory();
assert_eq!(tensor.shape(), optimized.shape());
assert_eq!(tensor.numel(), optimized.numel());
}
#[test]
fn test_try_optimize_memory() {
let tensor = Tensor::<f32>::zeros(&[100, 100]);
let result = tensor.try_optimize_memory();
assert!(result.is_ok());
let optimized = result.unwrap();
assert_eq!(tensor.shape(), optimized.shape());
}
#[test]
fn test_try_optimize_memory_too_large() {
let tensor = Tensor::<f32>::zeros(&[10, 10]);
let result = tensor.try_optimize_memory();
assert!(result.is_ok());
}
#[test]
fn test_inplace_add() {
let mut tensor1 = Tensor::<f32>::zeros(&[2, 2]);
let tensor2 = Tensor::<f32>::ones(&[2, 2]);
let result = tensor1.inplace_add(&tensor2);
assert!(result.is_ok());
for &value in tensor1.as_slice().unwrap() {
assert_eq!(value, 1.0);
}
}
#[test]
fn test_inplace_add_shape_mismatch() {
let mut tensor1 = Tensor::<f32>::zeros(&[2, 2]);
let tensor2 = Tensor::<f32>::ones(&[3, 3]);
let result = tensor1.inplace_add(&tensor2);
assert!(result.is_err());
match result.unwrap_err() {
crate::error::RusTorchError::ShapeMismatch { .. } => {
}
_ => panic!("Expected ShapeMismatch error"),
}
}
#[test]
fn test_inplace_mul_scalar() {
let mut tensor = Tensor::<f32>::ones(&[2, 2]);
tensor.inplace_mul_scalar(2.0);
for &value in tensor.as_slice().unwrap() {
assert_eq!(value, 2.0);
}
}
#[test]
fn test_inplace_apply() {
let mut tensor = Tensor::<f32>::ones(&[2, 2]);
let result = tensor.inplace_apply(|x| x * 2.0);
assert!(result.is_ok());
for &value in tensor.as_slice().unwrap() {
assert_eq!(value, 2.0);
}
}
#[test]
fn test_slice_view() {
let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let slice = tensor.slice_view(&[0..1, 0..2]);
assert!(slice.is_ok());
let sliced = slice.unwrap();
assert_eq!(sliced.shape(), &[1, 2]);
}
#[test]
fn test_slice_view_invalid_range() {
let tensor = Tensor::<f32>::zeros(&[2, 2]);
let result = tensor.slice_view(&[0..3, 0..2]);
assert!(result.is_err());
}
#[test]
fn test_shares_memory_with() {
let tensor1 = Tensor::<f32>::zeros(&[2, 2]);
let tensor2 = Tensor::<f32>::zeros(&[2, 2]);
let tensor3 = tensor1.clone();
assert!(!tensor1.shares_memory_with(&tensor2));
tensor1.shares_memory_with(&tensor3); }
#[test]
fn test_detach() {
let tensor1 = Tensor::<f32>::ones(&[2, 2]);
let tensor2 = tensor1.detach();
assert_eq!(tensor1.shape(), tensor2.shape());
assert_eq!(tensor1.numel(), tensor2.numel());
for (a, b) in tensor1.iter().zip(tensor2.iter()) {
assert_eq!(a, b);
}
}
#[test]
fn test_iter_and_iter_mut() {
let mut tensor = Tensor::<f32>::ones(&[2, 2]);
let sum: f32 = tensor.iter().sum();
assert_eq!(sum, 4.0);
for value in tensor.iter_mut() {
*value *= 2.0;
}
let sum: f32 = tensor.iter().sum();
assert_eq!(sum, 8.0);
}
}