redstone-ml 0.0.0

High-performance Machine Learning, Auto-Differentiation and Tensor Algebra crate for Rust
Documentation
use num::NumCast;
use redstone_ml::*;
use paste::paste;


#[test]
#[should_panic]
fn test_diagonal_invalid_axis1() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 4]);
    a.diagonal_along(3, 0);
}

#[test]
#[should_panic]
fn test_diagonal_invalid_axis2() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 4]);
    a.diagonal_along(0, 3);
}

#[test]
#[should_panic]
fn test_diagonal_invalid_axes() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 4]);
    a.diagonal_along(0, 0);
}

#[test]
#[should_panic]
fn test_dot_invalid_dimensions1() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 4]);

    let b = NdArray::arange(0, 12);
    a.dot(b);
}

#[test]
#[should_panic]
fn test_dot_invalid_dimensions2() {
    let a = NdArray::arange(0, 12);
    let b = NdArray::arange(0, 12);
    let b = b.reshape([3, 4]);

    a.dot(b);
}

#[test]
#[should_panic]
fn test_dot_invalid_shapes() {
    let a = NdArray::arange(0, 12);
    let b = NdArray::arange(0, 11);
    a.dot(b);
}

#[test]
#[should_panic]
fn test_matvec_matrix_not_2d() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 2, 2]);

    let b = NdArray::arange(0, 2);

    a.matmul(&b);
}

#[test]
#[should_panic]
fn test_matvec_vector_not_1d() {
    let a = NdArray::arange(0, 6);
    let a = a.reshape([2, 3]);

    let b = NdArray::arange(0, 6);
    let b = b.reshape([2, 3]);

    a.matmul(&b);
}

#[test]
#[should_panic]
fn test_matvec_inner_dim_mismatch() {
    let a = NdArray::arange(0, 6);
    let a = a.reshape([2, 3]);

    let b = NdArray::arange(0, 4);

    a.matmul(&b);
}

#[test]
#[should_panic]
fn test_matmat_matrix1_not_2d() {
    let a = NdArray::arange(0, 12);
    let a = a.reshape([3, 2, 2]);

    let b = NdArray::arange(0, 6);
    let b = b.reshape([2, 3]);

    a.matmul(&b);
}

#[test]
#[should_panic]
fn test_matmat_matrix2_not_2d() {
    let a = NdArray::arange(0, 6);
    let a = a.reshape([2, 3]);

    let b = NdArray::arange(0, 12);
    let b = b.reshape([2, 2, 3]);

    a.matmul(&b);
}

#[test]
#[should_panic]
fn test_matmat_inner_dim_mismatch() {
    let a = NdArray::arange(0, 6);
    let a = a.reshape([2, 3]);

    let b = NdArray::arange(0, 8);
    let b = b.reshape([4, 2]);
    a.matmul(&b);
}

test_for_all_numeric_dtypes!(
    test_dot_basic, {
        let a = NdArray::new([5, 8, 1, 2]).astype::<T>();
        let b = NdArray::new([4, 2, 7, 3]).astype::<T>();

        let expected = NdArray::scalar(20 + 16 + 7 + 6).astype::<T>();
        assert_eq!(a.dot(&b), expected);
        assert_eq!(b.dot(&a), expected);
        assert_eq!(a.dot(b), expected);
    }
);

test_for_common_numeric_dtypes!(
    test_dot, {
        for n in (1..30).step_by(3) {
            let a = (NdArray::arange(0, n) * 6).astype::<T>();
            let b = NdArray::arange(5, 5 + n).astype::<T>();

            let expected = NdArray::scalar((n - 1) * n * (2 * n + 14)).astype::<T>();
            assert_eq!(a.dot(&b), expected);
            assert_eq!(b.dot(&a), expected);
            assert_eq!(a.matmul(&b), expected);
            assert_eq!(a.dot(b), expected);
        }
    }
);

test_for_common_numeric_dtypes!(
    test_matvec, {
        for m in (1..28).step_by(9) {
            for n in (1..30).step_by(3) {
                let a = (NdArray::arange(0, m * n) * 6)
                    .reshape([m, n])
                    .astype::<T>();

                let b = NdArray::arange(5, 5 + n)
                    .astype::<T>();

                let result = a.matmul(&b);
                let expected = einsum([&a, &b], (["ij", "j"], "i"));
                assert_eq!(result, expected);
            }
        }
    }
);


test_for_common_numeric_dtypes!(
    test_matvec_strided_views, {
        for m in (1..28).step_by(9) {
            for n in (1..30).step_by(3) {
                let a = (NdArray::arange(0, m * (n + 2) * 2) * 6)
                    .reshape([m, n + 2, 2])
                    .astype::<T>();
                let a = a.slice(s![.., 2.., 0]);
                assert_eq!(a.has_uniform_stride(), None);

                let b = NdArray::arange(0, n * 2)
                    .reshape([n, 2])
                    .astype::<T>();
                let b = b.slice_along(Axis(1), 0);
                assert!(!b.is_contiguous());

                let expected = einsum([&a, &b], (["ij", "j"], "i"));
                assert_eq!(a.matmul(&b), expected);
                assert_eq!(a.matmul(b), expected);
            }
        }
    }
);

test_for_common_numeric_dtypes!(
    test_matmat, {
        for m in (1..25).step_by(6) {
            for k in (1..28).step_by(9) {
                for n in (1..28).step_by(9) {
                    let a = (NdArray::arange(0, m * k) * 3)
                        .reshape([m, k])
                        .astype::<T>();

                    let b = (NdArray::arange(0, k * n) * 7)
                        .reshape([k, n])
                        .astype::<T>();

                    let expected = einsum([&a, &b], (["ik", "kj"], "ij"));
                    assert_eq!(a.matmul(&b), expected);
                    assert_eq!(a.matmul(b), expected);
                }
            }
        }
    }
);

test_for_common_numeric_dtypes!(
    test_batched_matmat, {
        let high = <T as NumCast>::from(10).unwrap();

        for b in (1..37).step_by(9) {
            for m in (1..28).step_by(9) {
                for k in (1..21).step_by(5) {
                    for n in (1..25).step_by(6) {
                        let lhs = NdArray::<T>::randint([b, m, k], T::default(), high);
                        let rhs = NdArray::<T>::randint([b, k, n], T::default(), high);

                        let expected = einsum([&lhs, &rhs], (["bik", "bkj"], "bij"));
                        assert_almost_eq!(lhs.bmm(&rhs), expected, 0.1);
                    }
                }
            }
        }
    }
);

test_for_common_numeric_dtypes!(
    test_matmat_strided_views, {
        for m in (1..37).step_by(9) {
            for k in (1..21).step_by(4) {
                for n in (1..25).step_by(6) {
                    let a = (NdArray::arange(0, m * (k + 3) * 2) * 6)
                        .reshape([m, k + 3, 2])
                        .astype::<T>();
                    let a = a.slice(s![.., 2..k + 2, 0]);
                    assert_eq!(a.shape(), &[m, k]);
                    assert_eq!(a.has_uniform_stride(), None);

                    let b = (NdArray::arange(0, k * (n + 2) * 2) * 7)
                        .reshape([k, n + 2, 2])
                        .astype::<T>();
                    let b = b.slice(s![.., ..n, 0]);
                    assert_eq!(b.shape(), &[k, n]);
                    assert!(!b.is_contiguous());

                    let result = a.matmul(&b);
                    let expected = einsum([&a, &b], (["ik", "kj"], "ij"));
                    assert_eq!(result, expected);
                }
            }
        }
    }
);

test_for_common_numeric_dtypes!(
    test_dot_mem_overlap, {
        for n in (1..31).step_by(5) {
            let a = NdArray::arange(0, n).astype::<T>();
            let b = (&a).view();

            let expected = NdArray::scalar((n - 1) * n * (2 * n - 1)).astype::<T>();
            assert_eq!(a.dot(&b) * NdArray::scalar(6).astype::<T>(), expected);
            assert_eq!(b.dot(&a), a.dot(&b));
            assert_eq!(a.dot(&b), a.dot(b));
        }
    }
);

test_for_all_numeric_dtypes!(
    test_diagonal, {
        let a = NdArray::arange(0, 12).astype::<T>();
        let a = a.reshape([3, 4]);

        let expected = NdArray::new([0, 5, 10]).astype::<T>();
        assert_eq!(a.diagonal(), expected);
        assert!(a.diagonal().is_view());

        assert_eq!(a.diagonal_along(0, 1), expected);
        assert_eq!(a.diagonal_along(-2, 1), expected);
        assert_eq!(a.offset_diagonal(0), expected);

        let expected = NdArray::new([1, 6, 11]).astype::<T>();
        assert_eq!(a.offset_diagonal(1), expected);

        let expected = NdArray::new([4, 9]).astype::<T>();
        assert_eq!(a.offset_diagonal(-1), expected);
    }
);

test_for_all_numeric_dtypes!(
    test_diagonal_3d, {
        let a = NdArray::arange(0, 8).astype::<T>();
        let a = a.reshape([2, 2, 2]);

        let expected = NdArray::new([[0, 6], [1, 7]]).astype::<T>();
        assert_eq!(a.diagonal(), expected);

        let expected = NdArray::new([[2], [3]]).astype::<T>();
        assert_eq!(a.offset_diagonal(1), expected);

        let expected = NdArray::new([[4], [5]]).astype::<T>();
        assert_eq!(a.offset_diagonal(-1), expected);

        let expected = NdArray::new([[0, 3], [4, 7]]).astype::<T>();
        assert_eq!(a.diagonal_along(1, 2), expected);
        assert_eq!(a.diagonal_along(2, 1), expected);

        let expected = NdArray::new([[1], [3]]).astype::<T>();
        assert_eq!(a.offset_diagonal_along(1, Axis(0), Axis(2)), expected);
    }
);

test_for_all_numeric_dtypes!(
    test_trace, {
        let a = NdArray::arange(0, 12).astype::<T>();
        let a = a.reshape([3, 4]);

        let expected = NdArray::scalar(15).astype::<T>();
        assert_eq!(a.trace(), expected);

        assert_eq!(a.trace_along(0, 1), expected);
        assert_eq!(a.trace_along(-2, 1), expected);
        assert_eq!(a.offset_trace(0), expected);

        let expected = NdArray::scalar(18).astype::<T>();
        assert_eq!(a.offset_trace(1), expected);

        let expected = NdArray::scalar(13).astype::<T>();
        assert_eq!(a.offset_trace(-1), expected);
    }
);


test_for_all_numeric_dtypes!(
    test_trace_3d, {
        let a = NdArray::arange(0, 8).astype::<T>();
        let a = a.reshape([2, 2, 2]);

        let expected = NdArray::new([6, 8]).astype::<T>();
        assert_eq!(a.trace(), expected);

        let expected = NdArray::new([2, 3]).astype::<T>();
        assert_eq!(a.offset_trace(1), expected);

        let expected = NdArray::new([4, 5]).astype::<T>();
        assert_eq!(a.offset_trace(-1), expected);

        let expected = NdArray::new([3, 11]).astype::<T>();
        assert_eq!(a.trace_along(1, 2), expected);
        assert_eq!(a.trace_along(2, 1), expected);

        let expected = NdArray::new([1, 3]).astype::<T>();
        assert_eq!(a.offset_trace_along(1, Axis(0), Axis(2)), expected);
    }
);