tract-linalg 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::frame::mmm::fuse::FusedKerSpec;
use crate::frame::mmm::storage::*;
use crate::frame::mmm::tests::display_error;
use crate::frame::mmm::tests::store::mmm_stride_storage;
use crate::frame::mmm::*;
use num_traits::{AsPrimitive, Bounded};
use proptest::prelude::*;
use tract_data::internal::*;

#[macro_export]
macro_rules! mmm_kernel_fuse_tests {
    ($ker:expr, $tc:ty, $ti: ty) => {
        mod fuse {
            use num_traits::Zero;
            #[allow(unused_imports)]
            use tract_data::prelude::f16;
            use tract_data::prelude::tensor0;
            use $crate::frame::mmm::MatMatMulKer;
            use $crate::frame::mmm::tests::fuse as test;
            #[allow(unused_imports)]
            use $crate::frame::mmm::tests::fuse::*;

            #[test]
            fn return_zeros() {
                test::return_zeros::<_, $tc, $ti>($ker)
            }

            #[test]
            fn store_non_contiguous() {
                test::store_non_contiguous::<_, $tc, $ti>($ker)
            }
            proptest::proptest! {
                #[test]
                fn return_c_prop(c in tile::<_, $ti>($ker)) {
                    test::return_c::<_, $ti>($ker, &c)
                }
            }

            fn fmin<T: PartialOrd>(a: T, b: T) -> T {
                if a < b { a } else { b }
            }

            fn fmax<T: PartialOrd>(a: T, b: T) -> T {
                if a > b { a } else { b }
            }

            macro_rules! bin {
                ($FKS:ident, $geo:expr, $f:expr, $extra_cond:expr) => {
                    paste! {
                        #[test]
                        fn [<$FKS:snake>]() {
                            if ($ker).is_supported_here() && $extra_cond {
                                test::$geo::<_, $ti>($ker, $crate::mmm::FusedKerSpec::$FKS, $f);
                            }
                        }
                    }
                };
            }

            bin!(PerColMin, per_col, fmin, true);
            bin!(PerColMax, per_col, fmax, true);
            bin!(PerColAdd, per_col, |a, b| a + b, true);
            bin!(PerColMul, per_col, |a, b| a * b, true);
            bin!(PerColSub, per_col, |a, b| a - b, true);
            bin!(PerColSubF, per_col, |a, b| b - a, true);

            bin!(PerRowMin, per_row, fmin, true);
            bin!(PerRowMax, per_row, fmax, true);
            bin!(PerRowAdd, per_row, |a, b| a + b, true);
            bin!(PerRowMul, per_row, |a, b| a * b, true);
            bin!(PerRowSub, per_row, |a, b| a - b, true);
            bin!(PerRowSubF, per_row, |a, b| b - a, true);

            bin!(ScalarMin, scalar, fmin, true);
            bin!(ScalarMax, scalar, fmax, true);
            bin!(ScalarAdd, scalar, |a, b| a + b, true);
            bin!(ScalarMul, scalar, |a, b| a * b, true);
            bin!(ScalarSub, scalar, |a, b| a - b, true);
            bin!(ScalarSubF, scalar, |a, b| b - a, true);

            bin!(
                LeakyRelu,
                scalar,
                |a, b| if b > <$ti>::zero() { b } else { a * b },
                ($ker).can_fuse(&$crate::mmm::FusedSpec::LeakyRelu(&tensor0(<$ti>::from(1_u8))))
            );

            #[test]
            fn return_c_add_row_col_product() {
                test::return_c_add_row_col_product::<_, $ti>($ker)
            }

            #[test]
            fn return_c_plus_d() {
                test::return_c_plus_d::<_, $ti, $ti>($ker)
            }

            #[test]
            fn return_c_clear() {
                test::return_c_clear::<_, $ti>($ker)
            }
        }
    };
}

use crate::LADatum;
pub fn return_zeros<K, TC, TI>(ker: &K)
where
    K: MatMatMulKer<Acc = TI>,
    TC: LADatum,
    TI: LADatum + Bounded + PartialEq,
{
    if !ker.is_supported_here() {
        return;
    }
    let v = vec![TC::max_value(); ker.mr() * ker.nr()];
    let c = mmm_stride_storage(&v, ker.nr());
    let non_linear = tvec![FusedKerSpec::Clear, FusedKerSpec::Store(c), FusedKerSpec::Done];
    let err = ker.kernel(&non_linear);
    assert_eq!(err, 0);
    let expected = vec![TC::zero(); v.len()];
    display_error(&v, &expected, ker.mr(), ker.nr());
    assert_eq!(v, expected);
}

pub fn store_non_contiguous<K, TC, TI>(ker: &K)
where
    K: MatMatMulKer<Acc = TI>,
    TC: LADatum,
    TI: LADatum + Bounded + PartialEq,
{
    if !ker.is_supported_here() {
        return;
    }
    let v = vec![TC::max_value(); ker.mr() * 5 * ker.nr() * 3];
    let c = OutputStoreKer {
        ptr: v.as_ptr() as _,
        row_byte_stride: (std::mem::size_of::<TC>() * 3 * ker.nr() * 5) as isize,
        col_byte_stride: std::mem::size_of::<TC>() as isize * 3,
        item_size: std::mem::size_of::<TC>(),
    };
    let non_linear = tvec![FusedKerSpec::Clear, FusedKerSpec::Store(c), FusedKerSpec::Done];
    let err = ker.kernel(&non_linear);
    assert_eq!(err, 0);
    let mut expected = vec![TC::max_value(); v.len()];
    for c in 0..ker.nr() {
        for r in 0..ker.mr() {
            expected[c * 3 + r * 3 * 5 * ker.nr()] = TC::zero();
        }
    }
    assert_eq!(v, expected);
}

pub fn fused_ops<K, TI, E>(ker: &K, c: &[TI], ops: &[FusedKerSpec<TI>], expect: E)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    E: Fn(usize, usize, TI) -> TI,
{
    if !ker.is_supported_here() {
        return;
    }
    assert!(c.len() == ker.mr() * ker.nr());
    let v = c.to_vec();
    let c = mmm_stride_storage(&v, ker.nr());
    let mut ops = ops.to_vec();
    ops.insert(0, FusedKerSpec::AddUnicast(c));
    ops.insert(0, FusedKerSpec::Clear);
    ops.push(FusedKerSpec::Store(c));
    ops.push(FusedKerSpec::Done);
    let expected =
        (0..v.len()).map(|ix| expect(ix / ker.nr(), ix % ker.nr(), v[ix])).collect::<Vec<TI>>();
    let err = ker.kernel(&ops);
    assert_eq!(err, 0);
    display_error(&v, &expected, ker.mr(), ker.nr());
    assert_eq!(v, expected);
}

pub fn return_c<K, TI>(ker: &K, v: &[TI])
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    usize: AsPrimitive<TI>,
{
    fused_ops::<K, TI, _>(ker, v, &[], |_, _, c| c + 1.as_() - 1.as_())
}

pub fn return_c_plus_d<K, TI, TD>(ker: &K)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    TD: LADatum + AsPrimitive<TI>,
    usize: AsPrimitive<TI> + AsPrimitive<TD>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len).map(|f| f.as_()).collect();
    let d: Vec<TD> = (0..len).map(|f| ((3 * f) % 7).as_()).collect();
    fused_ops::<K, TI, _>(
        ker,
        &v,
        &[FusedKerSpec::AddUnicast(mmm_stride_storage(&d, ker.nr()))],
        |row, col, c| c + d[row * ker.nr() + col].as_(),
    );
}

pub fn per_col<K, TI>(ker: &K, op: impl Fn(*const TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    usize: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len).map(|f| f.as_()).collect();
    let bias: Vec<TI> = (0..ker.nr()).map(|f| (f + 1).as_()).collect();
    fused_ops::<K, TI, _>(ker, &v, &[op(bias.as_ptr())], |_, col, c| f(bias[col], c))
}

pub fn per_row<K, TI>(ker: &K, op: impl Fn(*const TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    usize: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len).map(|f| f.as_()).collect();
    let bias: Vec<TI> = (0..ker.mr()).map(|f| (f + 1).as_()).collect();
    fused_ops::<K, TI, _>(ker, &v, &[op(bias.as_ptr())], |row, _, c| f(bias[row], c))
}

pub fn scalar<K, TI>(ker: &K, op: impl Fn(TI) -> FusedKerSpec<TI>, f: impl Fn(TI, TI) -> TI)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    isize: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len as isize).map(|f| (f - len as isize / 2).as_()).collect();
    let five: TI = 5.as_();
    fused_ops::<K, TI, _>(ker, &v, &[op(five)], |_, _, c| f(five, c))
}

pub fn return_c_add_row_col_product<K, TI>(ker: &K)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    usize: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len).map(|f| (f + 1).as_()).collect();
    let rows: Vec<TI> = (0..ker.mr()).map(|f| (f + 3).as_()).collect();
    let cols: Vec<TI> = (0..ker.nr()).map(|f| (f + 2).as_()).collect();
    fused_ops::<K, TI, _>(
        ker,
        &v,
        &[FusedKerSpec::AddRowColProducts(rows.as_ptr(), cols.as_ptr())],
        |row, col, c| c + cols[col] * rows[row],
    )
}

pub fn return_c_clear<K, TI>(ker: &K)
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    usize: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    let v: Vec<TI> = (0..len).map(|f| f.as_()).collect();
    fused_ops::<K, TI, _>(ker, &v, &[FusedKerSpec::Clear], |_, _, _| 0.as_())
}

pub fn tile<K, TI>(ker: &K) -> BoxedStrategy<Vec<TI>>
where
    K: MatMatMulKer<Acc = TI>,
    TI: LADatum,
    i8: AsPrimitive<TI>,
{
    let len = ker.mr() * ker.nr();
    proptest::collection::vec(any::<i8>().prop_map(|c| c.as_()), len..=len).boxed()
}