use std::fmt;
use std::fmt::Debug;
use num_traits::Zero;
use super::MatMatMulKer;
#[derive(PartialEq, Clone)]
pub enum FusedSpec<TI: Copy + Debug> {
Min(TI),
Max(TI),
AddC,
PerRowMul(Vec<TI>),
PerRowAdd(Vec<TI>),
PerColMul(Vec<TI>),
PerColAdd(Vec<TI>),
AddRowColProducts(Vec<TI>, Vec<TI>),
ScalarMul(TI),
ScalarAdd(TI),
QTowardsEven(TI, usize),
QTowardsPlusInf(TI, usize),
}
impl<TI: Copy + Debug> Debug for FusedSpec<TI> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
FusedSpec::Min(t) => write!(fmt, "Min({:?})", t),
FusedSpec::Max(t) => write!(fmt, "Max({:?})", t),
FusedSpec::AddC => write!(fmt, "AddC"),
FusedSpec::PerRowMul(_) => write!(fmt, "PerRowMul"),
FusedSpec::PerRowAdd(_) => write!(fmt, "PerRowAdd"),
FusedSpec::PerColMul(_) => write!(fmt, "PerColMul"),
FusedSpec::PerColAdd(_) => write!(fmt, "PerColAdd"),
FusedSpec::AddRowColProducts(_, _) => write!(fmt, "AddRowColProducts"),
FusedSpec::ScalarMul(_) => write!(fmt, "ScalarMul"),
FusedSpec::ScalarAdd(_) => write!(fmt, "ScalarAdd"),
FusedSpec::QTowardsEven(_, _) => write!(fmt, "QTowardsEven"),
FusedSpec::QTowardsPlusInf(_, _) => write!(fmt, "QTowardsPlusInf"),
}
}
}
impl<TI: Copy + Debug> std::hash::Hash for FusedSpec<TI> {
fn hash<H>(&self, state: &mut H) where H: std::hash::Hasher {
use FusedSpec::*;
fn h<TI: Copy, H: std::hash::Hasher>(it: &[TI], state: &mut H) {
unsafe {
it.len().hash(state);
let bytes:&[u8] = std::slice::from_raw_parts(it.as_ptr() as _, it.len() * std::mem::size_of::<TI>());
bytes.hash(state)
}
}
std::mem::discriminant(self).hash(state);
match self {
AddC => (),
Min(a) | Max(a) | ScalarMul(a) | ScalarAdd(a) => {
h(&[*a], state)
},
PerRowMul(a)| PerRowAdd(a)| PerColMul(a)| PerColAdd(a) => {
h(&*a, state)
},
AddRowColProducts(a, b) => {
h(&*a, state);
h(&*b, state);
},
QTowardsEven(a,b ) | QTowardsPlusInf(a, b) => {
h(&[*a], state);
h(&[*b], state);
}
}
}
}
#[repr(C, usize)]
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum FusedKerSpec<TI: Copy> {
Done,
Min(TI),
Max(TI),
AddC,
PerRowMul(*const TI),
PerRowAdd(*const TI),
PerColMul(*const TI),
PerColAdd(*const TI),
AddRowColProducts(*const TI, *const TI),
ScalarMul(TI),
ScalarAdd(TI),
QTowardsEven(TI, usize),
QTowardsPlusInf(TI, usize),
}
pub struct ScratchSpaceFusedNonLinear<TI: Copy> {
uspecs: Vec<FusedKerSpec<TI>>,
non_linear_buffers: Vec<Vec<TI>>,
}
impl<TI: Copy> Default for ScratchSpaceFusedNonLinear<TI> {
fn default() -> ScratchSpaceFusedNonLinear<TI> {
ScratchSpaceFusedNonLinear { uspecs: vec![], non_linear_buffers: vec![] }
}
}
impl<TI: Copy> ScratchSpaceFusedNonLinear<TI> {
pub unsafe fn for_tile<TA, TB, TC, K: MatMatMulKer<TA, TB, TC, TI>>(
&mut self,
specs: &[FusedSpec<TI>],
down: usize,
right: usize,
) -> *const FusedKerSpec<TI>
where
TA: Copy,
TB: Copy,
TC: Copy + Debug,
TI: Copy + Debug + Zero,
{
self.uspecs.clear();
for spec in specs {
let s = match spec {
FusedSpec::Min(m) => FusedKerSpec::Min(*m),
FusedSpec::Max(m) => FusedKerSpec::Max(*m),
FusedSpec::AddC => FusedKerSpec::AddC,
FusedSpec::PerRowMul(v) => {
let have = v.len() - down * K::mr();
let ptr = if have < K::mr() {
let mut buf = vec![TI::zero(); K::mr()];
buf[..have].copy_from_slice(&v[down * K::mr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
v.as_ptr().add(down * K::mr())
};
FusedKerSpec::PerRowMul(ptr)
}
FusedSpec::PerRowAdd(v) => {
let have = v.len() - down * K::mr();
let ptr = if have < K::mr() {
let mut buf = vec![TI::zero(); K::mr()];
buf[..have].copy_from_slice(&v[down * K::mr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
v.as_ptr().add(down * K::mr())
};
FusedKerSpec::PerRowAdd(ptr)
}
FusedSpec::PerColMul(v) => {
let have = v.len() - right * K::nr();
let ptr = if have < K::nr() {
let mut buf = vec![TI::zero(); K::nr()];
buf[..have].copy_from_slice(&v[right * K::nr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
v.as_ptr().add(right * K::nr())
};
FusedKerSpec::PerColMul(ptr)
}
FusedSpec::PerColAdd(v) => {
let have = v.len() - right * K::nr();
let ptr = if have < K::nr() {
let mut buf = vec![TI::zero(); K::nr()];
buf[..have].copy_from_slice(&v[right * K::nr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
v.as_ptr().add(right * K::nr())
};
FusedKerSpec::PerColAdd(ptr)
}
FusedSpec::AddRowColProducts(rows, cols) => {
let have = rows.len() - down * K::mr();
let row_ptr = if have < K::mr() {
let mut buf = vec![TI::zero(); K::mr()];
buf[..have].copy_from_slice(&rows[down * K::mr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
rows.as_ptr().add(down * K::mr())
};
let have = cols.len() - right * K::nr();
let col_ptr = if have < K::nr() {
let mut buf = vec![TI::zero(); K::nr()];
buf[..have].copy_from_slice(&cols[right * K::nr()..][..have]);
let ptr = buf.as_ptr();
self.non_linear_buffers.push(buf);
ptr
} else {
cols.as_ptr().add(right * K::nr())
};
FusedKerSpec::AddRowColProducts(row_ptr, col_ptr)
}
FusedSpec::ScalarMul(t) => FusedKerSpec::ScalarMul(*t),
FusedSpec::ScalarAdd(t) => FusedKerSpec::ScalarAdd(*t),
FusedSpec::QTowardsEven(m, s) => FusedKerSpec::QTowardsEven(*m, *s),
FusedSpec::QTowardsPlusInf(m, s) => FusedKerSpec::QTowardsPlusInf(*m, *s),
};
self.uspecs.push(s);
}
self.uspecs.push(FusedKerSpec::Done);
self.uspecs.as_ptr()
}
}
#[cfg(test)]
#[macro_use]
pub mod test {
use super::*;
use crate::frame::mmm::storage::*;
use crate::frame::mmm::*;
use num_traits::{AsPrimitive, Bounded, Zero};
use proptest::prelude::*;
use std::fmt;
use std::ops::{Add, Mul, Sub};
#[test]
fn check_non_linear_enum_size() {
assert_eq!(
std::mem::size_of::<super::FusedKerSpec<f32>>(),
3 * std::mem::size_of::<usize>()
)
}
#[macro_export]
macro_rules! mmm_kernel_fuse_tests {
($cond:expr, $ker:ty, $ta:ty, $tb:ty, $tc:ty, $ti: ty) => {
mod fuse {
#[allow(unused_imports)]
use crate::frame::mmm::fuse::test;
use proptest::prelude::*;
#[test]
fn return_zeros() {
if $cond {
test::return_zeros::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c() {
if $cond {
test::return_c::<$ker, $ta, $tb, $tc, $ti>()
}
}
proptest::proptest! {
#[test]
fn return_c_prop(pb in any::<test::ReturnCProblem<$ker, $ta, $tb, $tc, $ti>>()) {
if $cond {
let got = pb.run();
prop_assert!(got.iter().zip(pb.c.iter()).all(|(g,e)| (*g as f32 - *e as f32).abs() < 1e-7),
"got: {:?}\nexpected: {:?}", pb.run(), pb.c)
}
}
}
#[test]
fn return_c_mul_row() {
if $cond {
test::return_c_mul_row::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_add_row() {
if $cond {
test::return_c_add_row::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_mul_col() {
if $cond {
test::return_c_mul_col::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_add_col() {
if $cond {
test::return_c_add_col::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_add_row_col_product() {
if $cond {
test::return_c_add_row_col_product::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_max() {
if $cond {
test::return_c_max::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_min() {
if $cond {
test::return_c_min::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_scalar_mul() {
if $cond {
test::return_c_scalar_mul::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_c_scalar_add() {
if $cond {
test::return_c_scalar_add::<$ker, $ta, $tb, $tc, $ti>()
}
}
}
};
}
#[macro_export]
macro_rules! qmmm_kernel_fuse_tests {
($cond:expr, $ker:ty, $ta:ty, $tb:ty, $tc:ty, $ti: ty) => {
mod fuseq {
#[allow(unused_imports)]
use crate::frame::mmm::fuse::test;
use crate::frame::mmm::fuse::test::QTowardsPlusInfProblem;
use crate::frame::mmm::kernel::MatMatMulKer;
use num_traits::AsPrimitive;
use proptest::prelude::*;
#[test]
#[ignore]
fn return_q_towards_even() {
if $cond {
test::return_q_towards_even::<$ker, $ta, $tb, $tc, $ti>()
}
}
#[test]
fn return_q_towards_plusinf() {
if $cond {
let len = <$ker>::mr() * <$ker>::nr();
let half_len: $ti = (len / 2).as_();
let v: Vec<$tc> = (0..len)
.map(|f| {
(<usize as AsPrimitive<$ti>>::as_(f) - half_len)
.min(<$tc>::max_value().as_())
.max(<$tc>::min_value().as_())
.as_()
})
.collect();
let pb = QTowardsPlusInfProblem::<$ker, $ta, $tb, $tc, $ti>::new(v);
assert_eq!(pb.run(), pb.reference())
}
}
#[test]
fn return_q_towards_plusinf_1() {
if $cond {
let len = <$ker>::mr() * <$ker>::nr();
let mut v = vec!(0; len - 1);
v.push(2);
let pb = QTowardsPlusInfProblem::<$ker, $ta, $tb, $tc, $ti>::new(v);
assert_eq!(pb.run(), pb.reference())
}
}
proptest::proptest! {
#[test]
fn return_q_towards_plusinf_prop(pb in any::<QTowardsPlusInfProblem<$ker, $ta, $tb, $tc, $ti>>()) {
if $cond {
prop_assert_eq!(pb.run(), pb.reference())
}
}
}
}
};
}
pub fn null_packed_storage<T: Copy>() -> PanelStore<T> {
PanelStore::Packed { ptr: std::ptr::null::<T>() as _ }
}
pub fn mmm_stride_storage<T: Copy>(v: &mut [T], rsc: usize) -> PanelStore<T> {
PanelStore::Strides {
ptr: v.as_mut_ptr(),
row_byte_stride: (std::mem::size_of::<T>() * rsc) as isize,
col_byte_stride: std::mem::size_of::<T>() as isize,
item_size: std::mem::size_of::<T>(),
}
}
pub fn return_zeros<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + Debug + Bounded + Zero + PartialEq,
TI: Copy + Debug,
{
let mut v = vec![TC::max_value(); K::mr() * K::nr()];
let mut c = mmm_stride_storage(&mut v, K::nr());
let err = K::kernel(&MatMatMulKerSpec {
a: &null_packed_storage(),
b: &null_packed_storage(),
c: &mut c,
linear: &LinearSpec::k(0),
non_linear: std::ptr::null(),
});
assert_eq!(err, 0);
let expected = vec!(TC::zero(); v.len());
assert_eq!(v, expected);
}
pub fn fused_ops<K, TA, TB, TC, TI>(c: &[TC], ops: &[FusedKerSpec<TI>]) -> Vec<TC>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + 'static + PartialEq,
TI: Copy + Debug,
usize: AsPrimitive<TC>,
{
assert!(c.len() == K::mr() * K::nr());
let mut v = c.to_vec();
let mut c = mmm_stride_storage(&mut v, K::nr());
let mut ops = ops.to_vec();
ops.insert(0, FusedKerSpec::AddC);
ops.push(FusedKerSpec::Done);
let err = K::kernel(&MatMatMulKerSpec {
a: &null_packed_storage(),
b: &null_packed_storage(),
c: &mut c,
linear: &LinearSpec::k(0),
non_linear: ops.as_ptr(),
});
assert_eq!(err, 0);
v
}
pub fn return_c<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + Debug + 'static + PartialEq,
TI: Copy + Debug,
usize: AsPrimitive<TC>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[]);
assert_eq!(found, v);
}
pub fn return_c_mul_row<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + Debug + 'static + PartialEq,
TI: Copy + Add + Mul<Output = TI> + Zero + Debug + fmt::Display + 'static + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::mr()).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::PerRowMul(bias.as_ptr())]);
let expected = (0..found.len()).map(|ix| {
let row = ix / K::nr();
let ix: TI = ix.as_();
(ix * bias[row]).as_()
}).collect::<Vec<TC>>();
assert_eq!(found, expected);
}
pub fn return_c_add_row<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static,
TI: Copy + Add + Mul + Zero + Debug + fmt::Display + PartialEq + 'static + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::mr()).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::PerRowAdd(bias.as_ptr())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let row = ix / K::nr();
let ix: TI = ix.as_();
a == (ix + bias[row]).as_()
}));
}
pub fn return_c_mul_col<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + 'static + PartialEq,
TI: Copy + Add + Mul<Output = TI> + Zero + Debug + fmt::Display + 'static + AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::nr()).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::PerColMul(bias.as_ptr())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let col = ix % K::nr();
let ix: TI = ix.as_();
a == (ix * bias[col]).as_()
}));
}
pub fn return_c_add_col<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static + Debug,
TI: Copy
+ Add
+ Mul
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>
+ Debug,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let bias: Vec<TI> = (0..K::nr()).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::PerColAdd(bias.as_ptr())]);
let expected = (0..found.len())
.map(|ix| {
let col = ix % K::nr();
let ix: TI = ix.as_();
(ix + bias[col]).as_()
})
.collect::<Vec<TC>>();
assert_eq!(found, expected);
}
pub fn return_c_add_row_col_product<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + Debug + PartialEq + 'static,
TI: Copy
+ Add
+ Mul<Output = TI>
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let rows: Vec<TI> = (0..K::mr()).map(|f| f.as_()).collect();
let cols: Vec<TI> = (0..K::nr()).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(
&*v,
&[FusedKerSpec::AddRowColProducts(rows.as_ptr(), cols.as_ptr())],
);
let expected = (0..found.len())
.map(|ix| {
let row = ix / K::nr();
let col = ix % K::nr();
let ix: TI = ix.as_();
(ix + cols[col] * rows[row]).as_()
})
.collect::<Vec<TC>>();
assert_eq!(found, expected);
}
pub fn return_c_scalar_mul<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static + Debug,
TI: Copy
+ Add
+ Mul<Output = TI>
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::ScalarMul(5.as_())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let ix: TI = ix.as_();
a == (ix * 5.as_()).as_()
}));
}
pub fn return_c_max<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static,
TI: Copy
+ Add
+ Mul<Output = TI>
+ std::cmp::PartialOrd
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::Max(5.as_())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let ix: TI = ix.as_();
a == if ix > 5.as_() { ix.as_() } else { 5.as_() }
}));
}
pub fn return_c_min<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static,
TI: Copy
+ Add
+ Mul<Output = TI>
+ std::cmp::PartialOrd
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::Min(5.as_())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let ix: TI = ix.as_();
a == if ix < 5.as_() { ix.as_() } else { 5.as_() }
}));
}
pub fn return_c_scalar_add<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy + PartialEq + 'static,
TI: Copy
+ Add
+ Mul<Output = TI>
+ Zero
+ Debug
+ fmt::Display
+ PartialEq
+ 'static
+ AsPrimitive<TC>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
{
let len = K::mr() * K::nr();
let v: Vec<TC> = (0..len).map(|f| f.as_()).collect();
let found = fused_ops::<K, TA, TB, TC, TI>(&*v, &[FusedKerSpec::ScalarAdd(5.as_())]);
assert!(found.iter().enumerate().all(|(ix, &a)| {
let ix: TI = ix.as_();
a == (ix + 5.as_()).as_()
}));
}
pub fn return_q_towards_even<K, TA, TB, TC, TI>()
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy,
TB: Copy,
TC: Copy
+ PartialEq
+ 'static
+ Bounded
+ Debug
+ Sub<Output = TC>
+ AsPrimitive<TI>
+ Mul<Output = TC>,
TI: Copy
+ Add
+ Sub<Output = TI>
+ Mul<Output = TI>
+ Debug
+ fmt::Display
+ Ord
+ PartialEq
+ 'static
+ AsPrimitive<TC>
+ AsPrimitive<i64>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
i64: AsPrimitive<TC>,
{
let len = K::mr() * K::nr();
let half_len: TI = (len / 2).as_();
let v: Vec<TC> = (0..len)
.map(|f| {
(<usize as AsPrimitive<TI>>::as_(f) - half_len)
.min(TC::max_value().as_())
.max(TC::min_value().as_())
.as_()
})
.collect();
let found = fused_ops::<K, TA, TB, TC, TI>(
&*v,
&[FusedKerSpec::ScalarMul(2.as_()), FusedKerSpec::QTowardsEven((1 << 30).as_(), 2)],
);
assert!(found.iter().zip(v.iter()).all(|(&found, input)| {
let input: TI = input.as_();
let input: i64 = input.as_();
let input = input >> 1;
let trunc = input.abs();
let nudge = (trunc & 0x3 == 0x3) as i64;
let mut trunc = (trunc + nudge) >> 1;
if input.is_negative() {
trunc = -trunc;
}
trunc.as_() == found
}));
}
#[derive(Debug, new)]
pub struct QTowardsPlusInfProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static,
TI: Copy + Debug,
i64: AsPrimitive<TC>,
{
pub c: Vec<TC>,
pub boo: std::marker::PhantomData<(K, TA, TB, TC, TI)>,
}
impl<K, TA, TB, TC, TI> Arbitrary for QTowardsPlusInfProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static,
TI: Copy + Debug,
i64: AsPrimitive<TC>,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_p: ()) -> Self::Strategy {
let len = K::mr() * K::nr();
proptest::collection::vec((-20i64..20).prop_map(|i| i.as_()), len..=len)
.prop_map(|c| QTowardsPlusInfProblem { c, boo: std::marker::PhantomData })
.boxed()
}
}
impl<K, TA, TB, TC, TI> QTowardsPlusInfProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static + AsPrimitive<TI> + PartialEq,
TI: Copy + Debug + 'static + AsPrimitive<i64>,
usize: AsPrimitive<TC> + AsPrimitive<TI>,
i64: AsPrimitive<TC>,
{
pub fn reference(&self) -> Vec<TC> {
self.c
.iter()
.map(|input| {
let input: TI = input.as_();
let input: i64 = input.as_();
(((input >> 1) + 1) >> 1).as_()
})
.collect()
}
pub fn run(&self) -> Vec<TC> {
fused_ops::<K, TA, TB, TC, TI>(
&*self.c,
&[
FusedKerSpec::ScalarMul(2.as_()),
FusedKerSpec::QTowardsPlusInf((1 << 30).as_(), 2),
],
)
}
}
#[derive(Debug, new)]
pub struct ReturnCProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static,
TI: Copy + Debug,
{
pub c: Vec<TC>,
pub boo: std::marker::PhantomData<(K, TA, TB, TC, TI)>,
}
impl<K, TA, TB, TC, TI> Arbitrary for ReturnCProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static + Arbitrary,
TI: Copy + Debug,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_p: ()) -> Self::Strategy {
let len = K::mr() * K::nr();
proptest::collection::vec(any::<TC>(), len..=len)
.prop_map(|c| ReturnCProblem { c, boo: std::marker::PhantomData })
.boxed()
}
}
impl<K, TA, TB, TC, TI> ReturnCProblem<K, TA, TB, TC, TI>
where
K: MatMatMulKer<TA, TB, TC, TI>,
TA: Copy + Debug,
TB: Copy + Debug,
TC: Copy + Debug + 'static + AsPrimitive<TI> + PartialEq,
TI: Copy + Debug + 'static + AsPrimitive<i64>,
usize: AsPrimitive<TC>,
{
pub fn run(&self) -> Vec<TC> {
fused_ops::<K, TA, TB, TC, TI>(&*self.c, &[])
}
}
}