etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! Integration tests for the ETensor Autograd Engine.
//! 
//! These tests verify that the thread-local Tape, the Reverse-Mode execution engine,
//! and the gradient accumulation maps all work together seamlessly across multiple
//! layers of mathematical history using the native Dispatcher.

use etensor_core::tensor::Tensor;
use etensor_core::buffer::Buffer;
use etensor_core::shape::Shape;
use etensor_core::device::Device;
use etensor_core::dtypes::DType;
use etensor_core::autograd::engine::backward;
use etensor_core::autograd::tape;
use etensor_core::dispatch::Dispatcher;

// =====================================================================
// TEST HELPERS
// =====================================================================

fn make_scalar(val: f32, requires_grad: bool) -> Tensor {
    Tensor::new(
        Buffer::from_f32_vec(vec![val]),
        Shape::new(vec![1]),
        Device::Cpu,
        DType::F32,
        requires_grad,
    )
}

fn make_matrix(data: Vec<f32>, dims: Vec<usize>, requires_grad: bool) -> Tensor {
    Tensor::new(
        Buffer::from_f32_vec(data),
        Shape::new(dims),
        Device::Cpu,
        DType::F32,
        requires_grad,
    )
}

// =====================================================================
// INTEGRATION TESTS
// =====================================================================

#[test]
fn test_nested_autograd_history() {
    // We want to evaluate the function: f(x, y, z) = (x + y) * z
    let _ = tape::take(); // Clear tape before test

    // 1. Initialize Inputs
    let x = make_scalar(2.0, true).with_name("x");
    let y = make_scalar(3.0, true).with_name("y");
    let z = make_scalar(4.0, true).with_name("z");

    // 2. Real Forward Pass (The Dispatcher intercepts and records automatically!)
    let a = &x + &y;
    let out = &a * &z;

    // 3. Trigger Backward Pass
    let grads = backward(&out).expect("Backward pass failed!");

    // 4. Verify the Chain Rule
    let dz = grads.get(&z.id).unwrap().as_f32_slice().unwrap()[0];
    assert_eq!(dz, 5.0, "Gradient for z is incorrect");

    let dx = grads.get(&x.id).unwrap().as_f32_slice().unwrap()[0];
    assert_eq!(dx, 4.0, "Gradient for x is incorrect");

    let dy = grads.get(&y.id).unwrap().as_f32_slice().unwrap()[0];
    assert_eq!(dy, 4.0, "Gradient for y is incorrect");
}

#[test]
fn test_gradient_accumulation_chain_rule() {
    // w = x * x (Derivative is 2x)
    let _ = tape::take();

    let x = make_scalar(3.0, true);
    
    // Dispatcher automatically builds the two-branch history
    let w = &x * &x;

    let grads = backward(&w).unwrap();

    // The gradient map MUST sum the branches together (3.0 + 3.0 = 6.0)
    let dx = grads.get(&x.id).unwrap().as_f32_slice().unwrap()[0];
    assert_eq!(dx, 6.0, "Gradient accumulation failed on identical variables!");
}

#[test]
fn test_mini_neural_network_backward() {
    // We will simulate a single Dense Layer: Loss = sum(relu(X @ W))
    let _ = tape::take();

    // Input X (1x2). Usually inputs don't require gradients in inference.
    let x = make_matrix(vec![1.0, 2.0], vec![1, 2], false);

    // Weights W (2x2). Requires gradients so the optimizer can update them!
    // [0.5, -0.5]
    // [1.0,  2.0]
    let w = make_matrix(vec![0.5, -0.5, 1.0, 2.0], vec![2, 2], true);

    // Forward Pass
    let h = Dispatcher::matmul(&x, &w).unwrap(); // [2.5, 3.5]
    let act = Dispatcher::relu(&h).unwrap();     // [2.5, 3.5] (ReLU lets positives through)
    let loss = Dispatcher::sum_all(&act).unwrap(); // 6.0

    // Backward Pass
    let grads = backward(&loss).unwrap();

    // Analytical calculus check for dW:
    // d_loss = 1.0
    // d_act = [1.0, 1.0] (Broadcast from sum)
    // d_h = [1.0, 1.0] (ReLU derivative is 1.0 for positive inputs)
    // dW = X^T @ d_h
    // [1.0] @ [1.0, 1.0] = [1.0, 1.0]
    // [2.0]                [2.0, 2.0]
    
    let dw = grads.get(&w.id).unwrap().as_f32_slice().unwrap();
    assert_eq!(dw, &[1.0, 1.0, 2.0, 2.0], "Matrix Chain Rule failed!");

    // Ensure X gradient was NOT computed because requires_grad = false
    assert!(grads.get(&x.id).is_none(), "Computed gradients for an input that didn't request them!");
}