redstone-ml 0.0.0

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
use redstone_ml::*;

#[test]
fn test_broadcast() {
    let tensor = NdArray::new([1, 2, 3]);
    let tensor = tensor.broadcast_to(&[3, 3]);
    assert_eq!(tensor.shape(), [3, 3]);
    assert_eq!(tensor, NdArray::new([[1, 2, 3], [1, 2, 3], [1, 2, 3]]));
}

#[test]
fn test_broadcast_scalar_to_higher_dims() {
    let tensor = NdArray::new([42]);
    let tensor = tensor.broadcast_to(&[2, 3]);
    assert_eq!(tensor.shape(), [2, 3]);
    assert_eq!(tensor, NdArray::new([[42, 42, 42], [42, 42, 42]]));
}

#[test]
fn test_broadcast_matrix_to_higher_dims() {
    let tensor = NdArray::new([[1, 2], [3, 4]]);
    let tensor = tensor.broadcast_to(&[3, 2, 2]);
    assert_eq!(tensor.shape(), [3, 2, 2]);
    assert_eq!(tensor, NdArray::new([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]));
}

#[test]
#[should_panic]
fn test_broadcast_incompatible_shapes1() {
    let tensor = NdArray::new([1, 2, 3]);
    tensor.broadcast_to(&[3, 5]);
}

#[test]
#[should_panic]
fn test_broadcast_incompatible_shapes2() {
    let tensor: NdArray<f32> = NdArray::ones([3, 3]);
    tensor.broadcast_to(&[1, 3]);
}

#[test]
#[should_panic]
fn test_broadcast_incompatible_shapes3() {
    let tensor: NdArray<bool> = NdArray::ones([1, 1, 3]);
    tensor.broadcast_to(&[1, 3]);
}


#[test]
fn test_broadcast_identity() {
    let tensor = NdArray::new([1, 2, 3]);
    let tensor = tensor.broadcast_to(&[3]);
    assert_eq!(tensor.shape(), [3]);
    assert_eq!(tensor, NdArray::new([1, 2, 3]));
}

#[test]
fn test_broadcast_unchanged() {
    let tensor = NdArray::new([1, 2, 3]);
    let tensor = tensor.broadcast_to(&[1, 3]);
    assert_eq!(tensor.shape(), [1, 3]);
    assert_eq!(tensor, NdArray::new([[1, 2, 3]]));
}

#[test]
fn test_broadcast_high_dims() {
    let tensor = NdArray::new([[1], [2], [3]]);
    let tensor = tensor.broadcast_to(&[3, 3, 3]);
    assert_eq!(tensor.shape(), [3, 3, 3]);
    assert_eq!(tensor, NdArray::new(
        [[[1, 1, 1], [2, 2, 2], [3, 3, 3]],
            [[1, 1, 1], [2, 2, 2], [3, 3, 3]],
            [[1, 1, 1], [2, 2, 2], [3, 3, 3]]]
    )
    );
}

#[test]
fn test_broadcast_single_dimensional_expansion() {
    let tensor = NdArray::new([1, 2, 3]);
    let tensor = tensor.broadcast_to(&[3, 1, 3]);
    assert_eq!(tensor.shape(), [3, 1, 3]);
    assert_eq!(tensor, NdArray::new([[[1, 2, 3]], [[1, 2, 3]], [[1, 2, 3]]]));
}