tensor-macros 0.2.0

A compile time optimised tensor library
Documentation
#![feature(try_from)]

#[macro_use]
extern crate tensor_macros;
use tensor_macros::traits::*;

use std::convert::TryFrom;

tensor!(T2345: 2 x 3 x 4 x 5);
#[test]
fn tensor_dims() {
    assert_eq!(T2345::<u8>::SIZE, 2 * 3 * 4 * 5);
    assert_eq!(T2345::<u8>::NDIM, 4);
}

tensor!(M23: 2 x 3);
#[test]
fn matrix_dims() {
    assert_eq!(M23::<u8>::ROWS, 2);
    assert_eq!(M23::<u8>::COLS, 3);
}

tensor!(V4: 4);
#[test]
fn col_vector_size() {
    assert_eq!(V4::<u8>::COLS, 4);
}

tensor!(V2Row: row 2);
#[test]
fn row_vector_size() {
    assert_eq!(V2Row::<u8>::ROWS, 2);
}

tensor!(T324: 3 x 2 x 4);
#[test]
fn dims() {
    assert_eq!(T324::<u8>::dims(), vec!(3, 2, 4));

    let t324: T324<f64> = Default::default();
    assert_eq!(t324.get_dims(), vec!(3, 2, 4));
}

#[test]
fn try_from_vec() {
    let t324 = T324::<u8>::try_from(vec![
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    ]);

    let exp = T324([
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    ]);

    assert_eq!(t324, Ok(exp));
}

#[test]
fn index() {
    let t324 = T324::<u8>::try_from(vec![
        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    ])
    .unwrap();

    assert_eq!(t324[(0, 0, 0)], 0);
    assert_eq!(t324[(1, 1, 1)], 13);
    assert_eq!(t324[(2, 1, 3)], 23);
    assert_eq!(t324[15], 15);
}

tensor!(T243: 2 x 4 x 3);
tensor!(M43: 4 x 3 x 1);
tensor!(V2: 2 x 1);

dot!(T243: 2 x 4 x 3 * M43: 4 x 3 x 1 => V2: 2 x 1);

#[test]
fn dot_test() {
    let l = T243::<f64>([
        0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
        17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
    ]);
    let r = M43([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
    assert_eq!(l * r, V2([506.0, 1298.0]));
}

#[test]
fn debug() {
    let t = T2345::try_from((0u8..120).collect::<Vec<u8>>()).unwrap();
    let output = "0\t1\t2\t3\t4\t
5\t6\t7\t8\t9\t
10\t11\t12\t13\t14\t
15\t16\t17\t18\t19\t

20\t21\t22\t23\t24\t
25\t26\t27\t28\t29\t
30\t31\t32\t33\t34\t
35\t36\t37\t38\t39\t

40\t41\t42\t43\t44\t
45\t46\t47\t48\t49\t
50\t51\t52\t53\t54\t
55\t56\t57\t58\t59\t


60\t61\t62\t63\t64\t
65\t66\t67\t68\t69\t
70\t71\t72\t73\t74\t
75\t76\t77\t78\t79\t

80\t81\t82\t83\t84\t
85\t86\t87\t88\t89\t
90\t91\t92\t93\t94\t
95\t96\t97\t98\t99\t

100\t101\t102\t103\t104\t
105\t106\t107\t108\t109\t
110\t111\t112\t113\t114\t
115\t116\t117\t118\t119\t


";

    assert_eq!(format!("{:?}", t), output);
}

#[test]
fn cwise() {
    let l = T243::try_from((0u64..24).collect::<Vec<u64>>()).unwrap();
    let r = T243::from(5);

    assert_eq!(
        l.cwise_mul(l.cwise_mul(r)),
        T243::try_from((0u64..24).map(|x| x * x * 5).collect::<Vec<u64>>()).unwrap()
    );
}

transpose!(T243: 2 x 4 x 3 => T243T);

#[test]
fn index_transpose() {
    let t = T243::try_from((0u8..24).collect::<Vec<u8>>()).unwrap();
    assert_eq!(t[(1, 2, 2)], t.transpose()[(2, 2, 1)]);
}

#[test]
fn debug_transpose() {
    let t = T243::try_from((0u8..24).collect::<Vec<u8>>()).unwrap();
    let output = "0\t1\t2\t
3\t4\t5\t
6\t7\t8\t
9\t10\t11\t

12\t13\t14\t
15\t16\t17\t
18\t19\t20\t
21\t22\t23\t

";

    assert_eq!(format!("{:?}", t), output);

    let u = t.transpose();

    let output = "0\t12\t
3\t15\t
6\t18\t
9\t21\t

1\t13\t
4\t16\t
7\t19\t
10\t22\t

2\t14\t
5\t17\t
8\t20\t
11\t23\t

";

    assert_eq!(format!("{:?}", u), output);
}

transpose!(M23: 2 x 3 => M32);
tensor!(M33: 3 x 3);
dot!(M32: 3 x 2 * M23: 2 x 3 => M33: 3 x 3);

#[test]
fn transpose_dot() {
    let t = M23::try_from((0u8..6).collect::<Vec<u8>>()).unwrap();
    let u = t.transpose();
    let v = M33([9, 12, 15, 12, 17, 22, 15, 22, 29]);
    assert_eq!(u * t, v);
}