tract-core 0.23.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation

    use super::*;
    use proptest::collection::vec;
    use proptest::prelude::*;
    use proptest::test_runner::{TestCaseResult, TestRunner};
    use tract_data::itertools::Itertools;

    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
        let shape = shape.to_vec();
        let len = shape.iter().product::<usize>();
        vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
            .prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
            .boxed()
    }

    fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
        let e = e.clone();
        let inputs_axes = e
            .iter_all_axes()
            .filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
            .cloned()
            .collect_vec();
        let dims = vec![2usize..6; inputs_axes.len()];
        dims.prop_map(move |dims| {
            let a: Vec<usize> = e
                .axes(InOut::In(0))
                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
                .collect_vec();
            let b: Vec<usize> = e
                .axes(InOut::In(1))
                .map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
                .collect_vec();
            (a, b)
        })
        .boxed()
    }

    fn test_expr(expr: &str) -> TestCaseResult {
        let expr = expr.to_string();
        let mut runner = TestRunner::default();
        let axes: AxesMapping = expr.parse().unwrap();
        fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
            let axis = axes.axis((InOut::In(input), position)).unwrap();
            axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
        }
        fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
            let axis = axes.axis((InOut::In(input), position)).unwrap();
            axis.outputs[0].len() == 0
        }
        let cases = full_shapes(&axes)
            .prop_flat_map(|(a, b)| {
                (
                    a.iter()
                        .enumerate()
                        .map(|(ix, d)| {
                            if is_k(&axes, 0, ix) {
                                prop_oneof![Just(*d)].boxed()
                            } else if is_disapearing_axis(&axes, 0, ix) {
                                Just(1).boxed()
                            } else {
                                prop_oneof![Just(1usize), Just(*d)].boxed()
                            }
                        })
                        .collect_vec(),
                    b.iter()
                        .enumerate()
                        .map(|(ix, d)| {
                            if is_k(&axes, 1, ix) {
                                prop_oneof![Just(*d)].boxed()
                            } else if is_disapearing_axis(&axes, 1, ix) {
                                Just(1).boxed()
                            } else {
                                prop_oneof![Just(1usize), Just(*d)].boxed()
                            }
                        })
                        .collect_vec(),
                )
            })
            .prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
            .prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
        runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
        Ok(())
    }

    #[derive(Debug, Clone, PartialEq)]
    struct EinSumProblem {
        expr: String,
        a: Tensor,
        b: Tensor,
    }

    impl EinSumProblem {
        fn check(&self) -> TractResult<()> {
            let mut model = TypedModel::default();
            let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
            let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
            let einsum = model
                .wire_node(
                    "einsum",
                    EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
                    &[sa, sb],
                )
                .unwrap();
            model.set_output_outlets(&einsum).unwrap();
            let a = self.a.clone().into_tvalue().into_sharable();
            let b = self.b.clone().into_tvalue().into_sharable();
            let inputs = tvec!(a, b);
            let reference =
                TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
            rewrite_einsums_as_matmul(&mut model)?;
            assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
            let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
            reference.close_enough(&test, true).unwrap();
            Ok(())
        }
    }

    #[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
    #[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
    #[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
    #[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
    #[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
    #[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
    #[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
    #[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
    #[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
    #[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
    #[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }

    #[test]
    fn k_kn_mn_0() -> TractResult<()> {
        EinSumProblem {
            expr: "k,kn->mn".to_string(),
            a: tensor1(&[0f32, 0f32]),
            b: tensor2(&[[0f32, 0.], [0., 0.]]),
        }
        .check()
    }

    #[test]
    fn mk_k_mn_0() -> TractResult<()> {
        EinSumProblem {
            expr: "mk,k->mn".to_string(),
            a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[2]).unwrap(),
        }
        .check()
    }

    #[test]
    fn mk_k_mn_1() -> TractResult<()> {
        EinSumProblem {
            expr: "mk,k->mn".to_string(),
            a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[2]).unwrap(),
        }
        .check()
    }

    #[test]
    fn mk_kn_nm_0() -> TractResult<()> {
        EinSumProblem {
            expr: "mk,kn->mn".to_string(),
            a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
        }
        .check()
    }

    #[test]
    fn amk_akn_amn_0() -> TractResult<()> {
        EinSumProblem {
            expr: "amk,akn->amn".to_string(),
            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
        }
        .check()
    }

    #[test]
    fn amk_akn_amn_1() -> TractResult<()> {
        EinSumProblem {
            expr: "amk,akn->amn".to_string(),
            a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
        }
        .check()
    }

    #[test]
    fn amk_akn_amn_2() -> TractResult<()> {
        EinSumProblem {
            expr: "amk,akn->amn".to_string(),
            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
        }
        .check()
    }

    #[test]
    fn amk_akn_amn_3() -> TractResult<()> {
        EinSumProblem {
            expr: "amk,akn->amn".to_string(),
            a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
            b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
        }
        .check()
    }

    #[test]
    fn km_anbck_bmn_0() -> TractResult<()> {
        EinSumProblem {
            expr: "km,anbck->bmn".to_string(),
            a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
            b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
        }
        .check()
    }

    #[test]
    fn q() -> TractResult<()> {
        let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
        let op = EinSum {
            axes: "mk,kn,m,,,,,,->mn".parse()?,
            operating_dt: i32::datum_type(),
            q_params: Some(DatumType::QI8(qp)),
        };
        let mut model = TypedModelPatch::default();
        let inputs = [
            model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
            model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
            model.add_source("bias", i32::datum_type().fact([3]))?,
            model.add_const("a0", tensor0(qp.zp_scale().0))?,
            model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
            model.add_const("b0", tensor0(qp.zp_scale().0))?,
            model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
            model.add_const("c0", tensor0(qp.zp_scale().0))?,
            model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
        ];
        let wire = model.wire_node("einsum", op.clone(), &inputs)?;
        model.set_output_outlets(&wire)?;
        rewrite_einsums_as_matmul(&mut model)?;
        assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
        Ok(())
    }