use etensor_core::shape::Shape;
use etensor_core::dtypes::DType;
use etensor_core::errors::EtensorError;
use etensor_core::dispatch::Dispatcher;
use etensor_core::backends::cpu::alloc::CpuAllocator;
#[test]
fn test_public_shape_instantiation() {
let shape = Shape::new(vec![4, 3, 224, 224]);
assert_eq!(shape.rank(), 4);
assert_eq!(shape.num_elements(), 4 * 3 * 224 * 224);
assert!(shape.is_contiguous());
}
#[test]
fn test_public_zero_copy_transpose_api() {
let t = CpuAllocator::zeros(Shape::new(vec![128, 64]), DType::F32, false).unwrap();
let transposed = t.transpose();
assert_eq!(transposed.shape.dims, vec![64, 128]);
assert!(!transposed.shape.is_contiguous());
assert_ne!(t.id, transposed.id);
}
#[test]
fn test_engine_shape_mismatch_rejection() {
let a = CpuAllocator::zeros(Shape::new(vec![2, 3]), DType::F32, false).unwrap();
let b = CpuAllocator::zeros(Shape::new(vec![3, 2]), DType::F32, false).unwrap();
let result = Dispatcher::add(&a, &b);
assert!(result.is_err());
if let Err(EtensorError::ShapeMismatch { expected, got }) = result {
assert_eq!(expected, vec![2, 3]);
assert_eq!(got, vec![3, 2]);
} else {
panic!("The engine allowed an invalid shape operation or returned the wrong error type!");
}
}
#[test]
fn test_public_tensor_creation_api() {
let tensor = CpuAllocator::zeros(Shape::new(vec![2, 2]), DType::F32, true)
.unwrap()
.with_name("layer_1_weights");
assert!(tensor.device.is_cpu());
assert_eq!(tensor.dtype, DType::F32);
assert_eq!(tensor.name.unwrap(), "layer_1_weights");
assert!(tensor.requires_grad);
}
#[cfg(feature = "cuda-native")]
#[test]
fn test_cross_device_mismatch_rejection() {
use etensor_core::buffer::Buffer;
use etensor_core::tensor::Tensor;
let t_cpu = CpuAllocator::zeros(Shape::new(vec![10, 10]), DType::F32, false).unwrap();
let shape_gpu = Shape::new(vec![10, 10]);
let t_gpu = Tensor::new(Buffer::CudaNative, shape_gpu, Device::CudaNative(0), DType::F32, false);
let result = Dispatcher::add(&t_cpu, &t_gpu);
assert!(result.is_err());
if let Err(EtensorError::DeviceMismatch { expected, got }) = result {
assert_eq!(expected, "cpu");
assert_eq!(got, "cuda_native:0");
} else {
panic!("Expected DeviceMismatch error!");
}
}