knok 0.1.0

Compile-time linalg graphs for Rust
use knok::prelude::*;
use knok::Error;

#[test]
fn tensor3_from_array_and_vec_share_row_major_layout() {
    let from_array = Tensor3::from_array([[[1.0, 2.0], [3.0, 4.0]]]);
    let from_vec = Tensor3::<f32, 1, 2, 2>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).unwrap();

    assert_eq!(from_array.as_slice(), from_vec.as_slice());
    assert_eq!(from_array.into_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}

#[test]
fn tensor4_from_array_and_vec_share_row_major_layout() {
    let from_array = Tensor4::from_array([[[[1.0], [2.0]], [[3.0], [4.0]]]]);
    let from_vec = Tensor4::<f32, 1, 2, 2, 1>::from_vec(vec![1.0, 2.0, 3.0, 4.0]).unwrap();

    assert_eq!(from_array.as_slice(), from_vec.as_slice());
    assert_eq!(from_array.into_vec(), vec![1.0, 2.0, 3.0, 4.0]);
}

#[test]
fn higher_rank_tensor_from_vec_validates_element_count() {
    let error = Tensor4::<f32, 1, 2, 2, 1>::from_vec(vec![1.0, 2.0, 3.0]).unwrap_err();

    assert!(matches!(
        error,
        Error::Shape {
            expected: &[1, 2, 2, 1],
            ..
        }
    ));
}

#[test]
fn tensor_convenience_constructors_work() {
    let zeros = Tensor2::<f32, 2, 2>::zeros();
    assert_eq!(zeros.as_slice(), &[0.0, 0.0, 0.0, 0.0]);

    let ones = Tensor3::<f32, 1, 2, 2>::ones();
    assert_eq!(ones.as_slice(), &[1.0, 1.0, 1.0, 1.0]);

    let filled = Tensor4::<i32, 1, 1, 2, 2>::filled(7);
    assert_eq!(filled.into_vec(), vec![7, 7, 7, 7]);
}

#[test]
fn tensor_try_from_vec_and_indexing_work() {
    let mut tensor = Tensor3::<f32, 1, 2, 2>::try_from(vec![1.0, 2.0, 3.0, 4.0]).unwrap();

    assert_eq!(tensor.get(0, 1, 0), Some(&3.0));
    assert_eq!(tensor.get(1, 0, 0), None);

    *tensor.get_mut(0, 1, 1).unwrap() = 9.0;
    assert_eq!(tensor.as_slice(), &[1.0, 2.0, 3.0, 9.0]);

    tensor.as_mut_slice()[0] = 5.0;
    assert_eq!(tensor.as_slice(), &[5.0, 2.0, 3.0, 9.0]);
}

#[test]
fn tensor_debug_includes_shape() {
    let tensor = Tensor2::from_array([[1.0, 2.0], [3.0, 4.0]]);

    let debug = format!("{tensor:?}");

    assert!(debug.contains("Tensor2"));
    assert!(debug.contains("shape"));
    assert!(debug.contains("[2, 2]"));
}