use super::common::{USizeConvertTo, bytes, iota_slice};
use crate::{
bitmask,
constant::Const,
traits::{ArrayType, BitMaskType, SIMDMask, SIMDVector},
};
fn check(output: &[u8], target: &[u8], offset: usize, message: &dyn std::fmt::Display) {
let iszero = |x: &u8| *x == 0;
assert!(
output[..offset].iter().all(iszero),
"prefix of {:?} up to {} is not zero -- {}",
output,
offset,
message
);
assert_eq!(
&output[offset..offset + target.len()],
target,
"output window from {} not equal to target -- {}",
offset,
message
);
assert!(
output[offset + target.len()..].iter().all(iszero),
"suffix of {:?} starting from {} is not zero -- {}",
output,
offset + target.len(),
message
);
}
struct FullStore(usize);
impl std::fmt::Display for FullStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "full SIMD store at byte offset {}", self.0)
}
}
struct PredicatedStore {
api: &'static str,
keep_first: usize,
sub: usize,
}
impl std::fmt::Display for PredicatedStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"`{}` with `keep_first` = {}, `sub` = {}",
self.api, self.keep_first, self.sub
)
}
}
pub(crate) fn test_store_simd<T, const N: usize, V>(arch: V::Arch)
where
T: Default
+ std::marker::Copy
+ std::cmp::PartialEq
+ std::fmt::Debug
+ std::ops::AddAssign
+ bytemuck::Pod,
usize: USizeConvertTo<T>,
Const<N>: ArrayType<T, Type = [T; N]>,
bitmask::BitMask<N, V::Arch>: SIMDMask<Arch = V::Arch>,
V: SIMDVector<Scalar = T, ConstLanes = Const<N>>,
{
let elsize = std::mem::size_of::<T>();
let mut output = vec![0u8; elsize * 2 * N];
let mut input = [T::default(); N];
iota_slice(input.as_mut_slice());
for i in input.iter_mut() {
*i += 1.test_convert();
}
let v = V::from_array(arch, input);
for i in 0..=N * elsize {
output.fill(0);
unsafe { v.store_simd(output.as_mut_ptr().add(i).cast::<T>()) };
check(&output, bytes(&input), i, &FullStore(i));
}
let mut output = vec![0u8; elsize * (N + 1)];
let base = output.as_mut_ptr();
for keep_first in 0..=N + 5 {
let kept = keep_first.min(N);
let expected = bytes(&input[..kept]);
for sub in 0..(elsize + 1) {
let offset = elsize * (N - kept + 1) - sub;
let ptr = unsafe { base.add(offset).cast::<T>() };
let label = |api| PredicatedStore {
api,
keep_first,
sub,
};
output.fill(0);
unsafe { v.store_simd_first(ptr, keep_first) };
check(&output, expected, offset, &label("store_simd_first"));
output.fill(0);
unsafe { v.store_simd_masked_logical(ptr, V::Mask::keep_first(arch, keep_first)) };
check(
&output,
expected,
offset,
&label("store_simd_masked_logical"),
);
output.fill(0);
unsafe {
v.store_simd_masked(
ptr,
<Const<N> as BitMaskType<V::Arch>>::Type::keep_first(arch, keep_first),
)
};
check(&output, expected, offset, &label("store_simd_masked"));
}
}
}