use super::*;
use proptest::collection::vec;
use proptest::prelude::*;
use proptest::test_runner::{TestCaseResult, TestRunner};
use tract_data::itertools::Itertools;
pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
let shape = shape.to_vec();
let len = shape.iter().product::<usize>();
vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len)
.prop_map(move |vec| tensor1(&vec).into_shape(&shape).unwrap())
.boxed()
}
fn full_shapes(e: &AxesMapping) -> BoxedStrategy<(Vec<usize>, Vec<usize>)> {
let e = e.clone();
let inputs_axes = e
.iter_all_axes()
.filter(|axis| axis.inputs[0].len() + axis.inputs[1].len() > 0)
.cloned()
.collect_vec();
let dims = vec![2usize..6; inputs_axes.len()];
dims.prop_map(move |dims| {
let a: Vec<usize> = e
.axes(InOut::In(0))
.map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
.collect_vec();
let b: Vec<usize> = e
.axes(InOut::In(1))
.map(|a| dims[inputs_axes.iter().position(|b| a == b).unwrap()])
.collect_vec();
(a, b)
})
.boxed()
}
fn test_expr(expr: &str) -> TestCaseResult {
let expr = expr.to_string();
let mut runner = TestRunner::default();
let axes: AxesMapping = expr.parse().unwrap();
fn is_k(axes: &AxesMapping, input: usize, position: usize) -> bool {
let axis = axes.axis((InOut::In(input), position)).unwrap();
axis.inputs[1 - input].len() == 1 && axis.outputs[0].len() == 0
}
fn is_disapearing_axis(axes: &AxesMapping, input: usize, position: usize) -> bool {
let axis = axes.axis((InOut::In(input), position)).unwrap();
axis.outputs[0].len() == 0
}
let cases = full_shapes(&axes)
.prop_flat_map(|(a, b)| {
(
a.iter()
.enumerate()
.map(|(ix, d)| {
if is_k(&axes, 0, ix) {
prop_oneof![Just(*d)].boxed()
} else if is_disapearing_axis(&axes, 0, ix) {
Just(1).boxed()
} else {
prop_oneof![Just(1usize), Just(*d)].boxed()
}
})
.collect_vec(),
b.iter()
.enumerate()
.map(|(ix, d)| {
if is_k(&axes, 1, ix) {
prop_oneof![Just(*d)].boxed()
} else if is_disapearing_axis(&axes, 1, ix) {
Just(1).boxed()
} else {
prop_oneof![Just(1usize), Just(*d)].boxed()
}
})
.collect_vec(),
)
})
.prop_flat_map(|(a_shape, b_shape)| (tensor(&a_shape), tensor(&b_shape)))
.prop_map(|(a, b)| EinSumProblem { expr: expr.clone(), a, b });
runner.run(&cases, |pb| pb.check().map_err(|e| TestCaseError::fail(e.to_string())))?;
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
struct EinSumProblem {
expr: String,
a: Tensor,
b: Tensor,
}
impl EinSumProblem {
fn check(&self) -> TractResult<()> {
let mut model = TypedModel::default();
let sa = model.add_source("a", f32::fact(self.a.shape())).unwrap();
let sb = model.add_source("b", f32::fact(self.b.shape())).unwrap();
let einsum = model
.wire_node(
"einsum",
EinSum::new(self.expr.parse().unwrap(), f32::datum_type()),
&[sa, sb],
)
.unwrap();
model.set_output_outlets(&einsum).unwrap();
let a = self.a.clone().into_tvalue().into_sharable();
let b = self.b.clone().into_tvalue().into_sharable();
let inputs = tvec!(a, b);
let reference =
TypedRunnableModel::new(&model).unwrap().run(inputs.clone()).unwrap().remove(0);
rewrite_einsums_as_matmul(&mut model)?;
assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
let test = TypedRunnableModel::new(&model).unwrap().run(inputs).unwrap().remove(0);
reference.close_enough(&test, true).unwrap();
Ok(())
}
}
#[rustfmt::skip] #[test] fn prop_mk_kn_mn() -> TestCaseResult { test_expr("mk,kn->mn") }
#[rustfmt::skip] #[test] fn prop_km_kn_mn() -> TestCaseResult { test_expr("km,kn->mn") }
#[rustfmt::skip] #[test] fn prop_mk_nk_mn() -> TestCaseResult { test_expr("mk,nk->mn") }
#[rustfmt::skip] #[test] fn prop_mk_kn_nm() -> TestCaseResult { test_expr("mk,kn->nm") }
#[rustfmt::skip] #[test] fn prop_k_kn_mn() -> TestCaseResult { test_expr("k,kn->mn") }
#[rustfmt::skip] #[test] fn prop_mk_k_mn() -> TestCaseResult { test_expr("mk,k->mn") }
#[rustfmt::skip] #[test] fn prop_m_n_mn() -> TestCaseResult { test_expr("m,n->mn") }
#[rustfmt::skip] #[test] fn prop_amk_akn_amn() -> TestCaseResult { test_expr("amk,akn->amn") }
#[rustfmt::skip] #[test] fn prop_mk_akn_amn() -> TestCaseResult { test_expr("mk,akn->amn") }
#[rustfmt::skip] #[test] fn prop_btgi_gih_tgh() -> TestCaseResult { test_expr("btgi,gih->tgh") }
#[rustfmt::skip] #[test] fn prop_tgi_gih_btgh() -> TestCaseResult { test_expr("tgi,gih->btgh") }
#[test]
fn k_kn_mn_0() -> TractResult<()> {
EinSumProblem {
expr: "k,kn->mn".to_string(),
a: tensor1(&[0f32, 0f32]),
b: tensor2(&[[0f32, 0.], [0., 0.]]),
}
.check()
}
#[test]
fn mk_k_mn_0() -> TractResult<()> {
EinSumProblem {
expr: "mk,k->mn".to_string(),
a: Tensor::zero::<f32>(&[2, 2]).unwrap(),
b: Tensor::zero::<f32>(&[2]).unwrap(),
}
.check()
}
#[test]
fn mk_k_mn_1() -> TractResult<()> {
EinSumProblem {
expr: "mk,k->mn".to_string(),
a: Tensor::zero::<f32>(&[1, 2]).unwrap(),
b: Tensor::zero::<f32>(&[2]).unwrap(),
}
.check()
}
#[test]
fn mk_kn_nm_0() -> TractResult<()> {
EinSumProblem {
expr: "mk,kn->mn".to_string(),
a: Tensor::zero::<f32>(&[3, 2]).unwrap(),
b: Tensor::zero::<f32>(&[2, 2]).unwrap(),
}
.check()
}
#[test]
fn amk_akn_amn_0() -> TractResult<()> {
EinSumProblem {
expr: "amk,akn->amn".to_string(),
a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
}
.check()
}
#[test]
fn amk_akn_amn_1() -> TractResult<()> {
EinSumProblem {
expr: "amk,akn->amn".to_string(),
a: Tensor::zero::<f32>(&[2, 1, 2]).unwrap(),
b: Tensor::zero::<f32>(&[1, 2, 1]).unwrap(),
}
.check()
}
#[test]
fn amk_akn_amn_2() -> TractResult<()> {
EinSumProblem {
expr: "amk,akn->amn".to_string(),
a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
b: Tensor::zero::<f32>(&[2, 2, 2]).unwrap(),
}
.check()
}
#[test]
fn amk_akn_amn_3() -> TractResult<()> {
EinSumProblem {
expr: "amk,akn->amn".to_string(),
a: Tensor::zero::<f32>(&[1, 1, 2]).unwrap(),
b: Tensor::zero::<f32>(&[2, 2, 1]).unwrap(),
}
.check()
}
#[test]
fn km_anbck_bmn_0() -> TractResult<()> {
EinSumProblem {
expr: "km,anbck->bmn".to_string(),
a: Tensor::zero::<f32>(&[2, 1]).unwrap(),
b: Tensor::zero::<f32>(&[1, 1, 1, 1, 2]).unwrap(),
}
.check()
}
#[test]
fn q() -> TractResult<()> {
let qp = QParams::ZpScale { zero_point: 0, scale: 0.1 };
let op = EinSum {
axes: "mk,kn,m,,,,,,->mn".parse()?,
operating_dt: i32::datum_type(),
q_params: Some(DatumType::QI8(qp)),
};
let mut model = TypedModelPatch::default();
let inputs = [
model.add_source("a", DatumType::QI8(qp).fact([3, 2]))?,
model.add_source("b", DatumType::QI8(qp).fact([2, 4]))?,
model.add_source("bias", i32::datum_type().fact([3]))?,
model.add_const("a0", tensor0(qp.zp_scale().0))?,
model.add_const("a_scale", tensor0(qp.zp_scale().1))?,
model.add_const("b0", tensor0(qp.zp_scale().0))?,
model.add_const("b_scale", tensor0(qp.zp_scale().1))?,
model.add_const("c0", tensor0(qp.zp_scale().0))?,
model.add_const("c_scale", tensor0(qp.zp_scale().1))?,
];
let wire = model.wire_node("einsum", op.clone(), &inputs)?;
model.set_output_outlets(&wire)?;
rewrite_einsums_as_matmul(&mut model)?;
assert!(model.nodes.iter().all(|n| !n.op_is::<EinSum>()));
Ok(())
}