use crate::block_quant::PackedBlockQuantFormat;
use crate::mmm::tests::display_error;
use crate::mmm::{AsInputValue, FusedKerSpec, FusedSpec, MatMatMul, MatMatMulKer, OutputStoreKer};
use crate::pack::PackedFormat;
use proptest::collection::vec;
use proptest::prelude::*;
use std::fmt::Debug;
use tract_data::internal::*;
#[macro_export]
macro_rules! mmm_packed_packed_tests {
($ker:expr, $packing_id:ident : $packing: expr) => {
mod $packing_id {
use super::*;
#[allow(unused_imports)]
use proptest::prelude::*;
#[allow(unused_imports)]
use tract_data::prelude::f16;
use tract_data::prelude::*;
use tract_itertools::Itertools;
use $crate::frame::mmm::kernel::MatMatMulKer;
#[allow(unused_imports)]
use $crate::frame::mmm::tests::packed_packed::*;
mod fuse {
use super::*;
proptest::proptest! {
#[test]
fn prop(pb in arbitrary_problem(false, $ker, $packing)) {
pb.check().unwrap()
}
}
fn t(a: impl Into<Vec<f32>>, b: impl Into<Vec<f32>>) -> TractResult<()> {
PackedPackedProblem::kernel($ker, $packing, a, b).check()
}
#[test]
fn packed_packed_1() -> TractResult<()> {
t(vec![1f32; $ker.mr()], vec![1f32; $ker.nr()])
}
#[test]
fn packed_packed_2() -> TractResult<()> {
t(vec![1f32; $ker.mr() * 2], vec![1f32; $ker.nr() * 2])
}
#[test]
fn packed_packed_13() -> TractResult<()> {
t(vec![1f32; $ker.mr() * 13], vec![1f32; $ker.nr() * 13])
}
#[test]
fn packed_packed_a_scale() -> TractResult<()> {
t((1..=$ker.mr() as i64).map(|x| x as f32).collect_vec(), vec![1f32; $ker.nr()])
}
#[test]
fn packed_packed_a_scale_times_2() -> TractResult<()> {
t(
(1..=2 * $ker.mr() as i64).map(|x| x as f32).collect_vec(),
vec![1f32; $ker.nr() * 2],
)
}
#[test]
fn packed_packed_empty() -> TractResult<()> {
t(vec![0f32; 0], vec![0f32; 0])
}
#[test]
fn packed_packed_bug_1() -> TractResult<()> {
t(vec![0f32; $ker.mr()], vec![0f32; $ker.nr()])
}
#[test]
fn packed_packed_bug_2() -> TractResult<()> {
let mut a = vec![0f32; $ker.mr()];
a[0] = 1.;
let mut b = vec![0f32; $ker.nr()];
b[0] = 1.;
t(a, b)
}
#[test]
fn packed_packed_bug_3() -> TractResult<()> {
if $ker.mr() >= 4 {
let mut a = vec![0f32; 2 * $ker.mr()];
let mut b = vec![0f32; 2 * $ker.nr()];
a[2] = -0.7548828f32;
a[3] = 0.23547363f32;
b[2 * $ker.nr() - 1] = 0.93603516;
t(a, b)?;
}
Ok(())
}
#[test]
fn packed_packed_bug_4() -> TractResult<()> {
if $ker.mr() > 16 {
let mut a = vec![0f32; $ker.mr()];
let mut b = vec![0f32; $ker.nr()];
a[16] = 1.;
b[0] = 1.;
t(a, b)?;
}
Ok(())
}
}
mod frame {
use super::*;
proptest::proptest! {
#[test]
fn prop(pb in arbitrary_problem(true, $ker, $packing)) {
pb.check().unwrap()
}
}
fn t(
m: usize,
n: usize,
a: impl Into<Vec<f32>>,
b: impl Into<Vec<f32>>,
) -> TractResult<()> {
PackedPackedProblem::frame($ker, $packing, m, n, a, b).check()
}
fn ti(
m: usize,
n: usize,
a: impl Into<Vec<i32>>,
b: impl Into<Vec<i32>>,
) -> TractResult<()> {
let a = a.into().into_iter().map(|i| i as f32).collect_vec();
let b = b.into().into_iter().map(|i| i as f32).collect_vec();
t(m, n, a, b)
}
#[test]
fn trivial_1x2() -> TractResult<()> {
ti(1, 2, [0], [0, 0])
}
#[test]
fn packed_packed_empty() -> TractResult<()> {
t($ker.mr(), $ker.nr(), [], [])
}
#[test]
fn packed_packed_empty_2() -> TractResult<()> {
t(2 * $ker.mr(), 2 * $ker.nr(), [], [])
}
#[test]
fn mat_mul_1() -> TractResult<()> {
ti(3, 2, [-3, 3, 5, -5, 6, 0, -6, -5, 0, 0, 9, 7], [-8, 5, 5, -3, 5, 7, -8, -1])
}
#[test]
fn mat_mul_2() -> TractResult<()> {
ti(1, 3, [122, 82], [0, 0, 37, 0, 0, 57])
}
}
}
};
}
#[derive(Debug, new)]
pub struct PackedPackedProblem<K>
where
K: MatMatMulKer,
{
pub frame_test: Option<(usize, usize)>,
pub ker: K,
pub packing: usize,
pub a: Vec<f32>,
pub b: Vec<f32>,
}
pub fn arbitrary_problem<K: MatMatMulKer>(
frame_test: bool,
ker: &K,
packing: usize,
) -> BoxedStrategy<PackedPackedProblem<K>> {
let (mr, nr) = (ker.mr(), ker.nr());
let item_range = if ker.internal_type().is_integer() { (-5f32)..5f32 } else { (-1f32)..1f32 };
let (m_range, n_range) =
if frame_test { (1usize..3 * mr, 1usize..3 * nr) } else { (mr..mr + 1, nr..nr + 1) };
let ker = ker.clone();
(m_range, 0usize..40, n_range)
.prop_flat_map(move |(m, k, n)| {
(
vec(item_range.clone(), k * m..=k * m),
vec(item_range.clone(), k * n..=k * n),
Just((m, n)),
)
})
.prop_map(move |(mut a, mut b, mn)| {
a.reverse();
b.reverse();
PackedPackedProblem {
frame_test: Some(mn).filter(|_| frame_test),
ker: ker.clone(),
packing,
a,
b,
}
})
.boxed()
}
impl<K: MatMatMulKer> PackedPackedProblem<K> {
pub fn kernel(
ker: &K,
packing: usize,
a: impl Into<Vec<f32>>,
b: impl Into<Vec<f32>>,
) -> PackedPackedProblem<K> {
PackedPackedProblem {
frame_test: None,
ker: ker.clone(),
packing,
a: a.into(),
b: b.into(),
}
}
pub fn frame(
ker: &K,
packing: usize,
m: usize,
n: usize,
a: impl Into<Vec<f32>>,
b: impl Into<Vec<f32>>,
) -> PackedPackedProblem<K> {
PackedPackedProblem {
frame_test: Some((m, n)),
ker: ker.clone(),
packing,
a: a.into(),
b: b.into(),
}
}
pub fn mkn(&self) -> (usize, usize, usize) {
let (m, n) = self.frame_test.unwrap_or((self.ker.mr(), self.ker.nr()));
assert!(m != 0 && n != 0);
let k = self.a.len() / m;
assert_eq!(self.b.len() / n, k);
(m, k, n)
}
pub fn padded_inputs(&self) -> TractResult<(Tensor, Tensor)> {
let (pack_a, pack_b) = &self.ker.packings()[self.packing];
assert!(pack_b.k_alignment() == 1);
let (m, k, n) = self.mkn();
let k_aligned = k.next_multiple_of(pack_a.k_alignment());
let mut a = Tensor::zero::<f32>(&[m, k_aligned])?;
for row in 0..m {
for col in 0..k {
a.try_as_plain_mut()?.to_array_view_mut()?[[row, col]] = self.a[col + k * row];
}
}
if let Some(pf) = pack_a.downcast_ref::<PackedFormat>() {
a = a.cast_to_dt(pf.dt)?.into_owned();
}
let mut b = Tensor::zero::<f32>(&[k_aligned, n])?;
for row in 0..k {
for col in 0..n {
b.try_as_plain_mut()?.to_array_view_mut()?[[row, col]] = self.b[col + n * row];
}
}
if let Some(pf) = pack_b.downcast_ref::<PackedFormat>() {
b = b.cast_to_dt(pf.dt)?.into_owned();
}
Ok((a, b))
}
pub fn reference(&self) -> TractResult<Tensor> {
let (m, k, n) = self.mkn();
let pack_a = &self.ker.packings()[self.packing].0;
let (mut a, b) = self.padded_inputs()?;
let k_aligned = k.next_multiple_of(pack_a.k_alignment());
if let Some(pbqf) = pack_a.downcast_ref::<PackedBlockQuantFormat>() {
a = pbqf.simulate_precision_loss(a, 1)?;
};
let mut c = Tensor::zero::<K::Acc>(&[m, n])?;
let a = a.cast_to::<K::Acc>()?;
let a = a.try_as_plain()?.as_slice::<K::Acc>()?;
let b = b.cast_to::<K::Acc>()?;
let b = b.try_as_plain()?.as_slice::<K::Acc>()?;
let mut c_plain = c.try_as_plain_mut()?;
let mut view = c_plain.to_array_view_mut::<K::Acc>()?.into_dimensionality()?;
for ix_m in 0..m {
for ix_n in 0..n {
for ix_k in 0..k {
let a = a[ix_k + k_aligned * ix_m];
let b = b[ix_n + n * ix_k];
view[(ix_m, ix_n)] += a * b;
}
}
}
Ok(c)
}
pub fn run(&self) -> TractResult<Tensor> {
let (m, k, n) = self.mkn();
let (pack_a, pack_b) = &self.ker.packings()[self.packing];
assert!(pack_b.k_alignment() == 1);
let k_aligned = k.next_multiple_of(pack_a.k_alignment());
let (a, b) = self.padded_inputs()?;
let pa = pack_a.prepare_one(&a, 1, 0)?;
let pb = pack_b.prepare_one(&b, 0, 1)?;
let mut v = unsafe { Tensor::uninitialized_dt(self.ker.internal_type(), &[m, n])? };
let item_size = self.ker.internal_type().size_of();
if self.frame_test.is_some() {
unsafe {
let c = self.ker.c_view(Some(0), Some(1)).wrap(&v.view_mut());
let ops = tvec!(
FusedSpec::AddMatMul {
a: AsInputValue::Borrowed(&*pa),
b: AsInputValue::Borrowed(&*pb),
packing: self.packing
},
FusedSpec::Store(c)
);
self.ker.run(m, n, &ops)?;
}
} else {
let c = OutputStoreKer {
ptr: v.as_bytes_mut().as_mut_ptr(),
row_byte_stride: (item_size * self.ker.nr()) as isize,
col_byte_stride: item_size as isize,
item_size,
};
let non_linear_ops = tvec!(
FusedKerSpec::Clear,
FusedKerSpec::AddMatMul {
k: k_aligned,
pa: pa.panel_bytes(0, None)?,
pb: pb.panel_bytes(0, None)?,
packing: self.packing
},
FusedKerSpec::Store(c),
FusedKerSpec::Done
);
let err = self.ker.kernel(&non_linear_ops);
assert_eq!(err, 0);
}
Ok(v)
}
pub fn check(&self) -> TractResult<()> {
if !self.ker.is_supported_here() {
return Ok(());
}
let expected = self.reference()?;
let found = self.run()?;
let app = if K::Acc::datum_type() == f16::datum_type() {
Approximation::SuperApproximate
} else {
Approximation::Approximate
};
let result = found.close_enough(&expected, app);
if result.is_err() {
let exp = expected.try_as_plain()?.as_slice::<K::Acc>()?;
let found = found.try_as_plain()?.as_slice::<K::Acc>()?;
let (m, _, n) = self.mkn();
display_error(found, exp, m, n);
}
result
}
}