etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! Integration tests for the ETensor Core API.
//! 
//! These tests consume the crate strictly from the outside, ensuring that 
//! the public-facing types, structs, and error boundaries function as intended.

use etensor_core::shape::Shape;
use etensor_core::dtypes::DType;
use etensor_core::errors::EtensorError;
use etensor_core::dispatch::Dispatcher;
use etensor_core::backends::cpu::alloc::CpuAllocator;

// =====================================================================
// SHAPE & GEOMETRY INTEGRATION TESTS
// =====================================================================

#[test]
fn test_public_shape_instantiation() {
    // Ensure downstream users can build and query shapes safely
    let shape = Shape::new(vec![4, 3, 224, 224]); // Standard image batch [B, C, H, W]
    
    assert_eq!(shape.rank(), 4);
    assert_eq!(shape.num_elements(), 4 * 3 * 224 * 224);
    assert!(shape.is_contiguous());
}

#[test]
fn test_public_zero_copy_transpose_api() {
    // Ensure the transpose logic is accessible and maintains layout tracking externally
    let t = CpuAllocator::zeros(Shape::new(vec![128, 64]), DType::F32, false).unwrap();
    
    let transposed = t.transpose();
    
    // Geometry logic flipped
    assert_eq!(transposed.shape.dims, vec![64, 128]);
    assert!(!transposed.shape.is_contiguous());
    
    // In an autograd engine, the transposed VIEW is a new node in the graph, 
    // so it receives a NEW unique ID. Proof of 0-copy is that it doesn't crash 
    // and points to the same underlying Arc buffer.
    assert_ne!(t.id, transposed.id);
}

#[test]
fn test_engine_shape_mismatch_rejection() {
    let a = CpuAllocator::zeros(Shape::new(vec![2, 3]), DType::F32, false).unwrap();
    let b = CpuAllocator::zeros(Shape::new(vec![3, 2]), DType::F32, false).unwrap();

    // The real Dispatcher automatically runs gatekeeping before touching physical RAM
    let result = Dispatcher::add(&a, &b);
    
    assert!(result.is_err());
    
    // Verify that the error safely mapped the exact mathematical mismatch
    if let Err(EtensorError::ShapeMismatch { expected, got }) = result {
        assert_eq!(expected, vec![2, 3]);
        assert_eq!(got, vec![3, 2]);
    } else {
        panic!("The engine allowed an invalid shape operation or returned the wrong error type!");
    }
}

// =====================================================================
// TENSOR & DEVICE INTEGRATION TESTS
// =====================================================================

#[test]
fn test_public_tensor_creation_api() {
    // Ensure a user can easily construct a tensor via the allocator and set metadata
    let tensor = CpuAllocator::zeros(Shape::new(vec![2, 2]), DType::F32, true)
        .unwrap()
        .with_name("layer_1_weights");
        
    // Validate that the metadata is correctly exposed to the public API
    assert!(tensor.device.is_cpu());
    assert_eq!(tensor.dtype, DType::F32);
    assert_eq!(tensor.name.unwrap(), "layer_1_weights");
    assert!(tensor.requires_grad);
}

// This test only compiles and runs if you have the cuda-native feature active.
// It verifies that cross-device errors format identically to what Python bindings expect.
#[cfg(feature = "cuda-native")]
#[test]
fn test_cross_device_mismatch_rejection() {
    use etensor_core::buffer::Buffer;
    use etensor_core::tensor::Tensor;

    // 1. Create a CPU Tensor natively
    let t_cpu = CpuAllocator::zeros(Shape::new(vec![10, 10]), DType::F32, false).unwrap();

    // 2. Create a Mock GPU Tensor (Mocking the CUDA physical memory enum)
    let shape_gpu = Shape::new(vec![10, 10]);
    let t_gpu = Tensor::new(Buffer::CudaNative, shape_gpu, Device::CudaNative(0), DType::F32, false);

    // 3. Attempt to process them together using the actual Dispatcher
    let result = Dispatcher::add(&t_cpu, &t_gpu);
    
    assert!(result.is_err());
    
    if let Err(EtensorError::DeviceMismatch { expected, got }) = result {
        assert_eq!(expected, "cpu");
        assert_eq!(got, "cuda_native:0");
    } else {
        panic!("Expected DeviceMismatch error!");
    }
}