pub(crate) trait ReferenceShifts: Copy {
fn expected_shr_(self, rhs: Self) -> Self;
fn expected_shl_(self, rhs: Self) -> Self;
}
macro_rules! impl_shifts_unsigned {
($type:ty) => {
impl ReferenceShifts for $type {
#[inline(always)]
fn expected_shr_(self, rhs: Self) -> Self {
if (rhs as usize) >= 8 * std::mem::size_of::<Self>() {
0
} else {
self >> rhs
}
}
#[inline(always)]
fn expected_shl_(self, rhs: Self) -> Self {
if (rhs as usize) >= 8 * std::mem::size_of::<Self>() {
0
} else {
self << rhs
}
}
}
};
}
macro_rules! impl_shifts_signed {
($type:ty) => {
impl ReferenceShifts for $type {
#[inline(always)]
fn expected_shr_(self, rhs: Self) -> Self {
if rhs < 0 || rhs >= ((8 * std::mem::size_of::<Self>()) as $type) {
if self < 0 { -1 } else { 0 }
} else {
self >> rhs
}
}
#[inline(always)]
fn expected_shl_(self, rhs: Self) -> Self {
if rhs < 0 || (rhs >= (8 * std::mem::size_of::<Self>()) as $type) {
0
} else {
self << rhs
}
}
}
};
}
impl_shifts_unsigned!(u8);
impl_shifts_unsigned!(u16);
impl_shifts_unsigned!(u32);
impl_shifts_unsigned!(u64);
impl_shifts_signed!(i8);
impl_shifts_signed!(i16);
impl_shifts_signed!(i32);
impl_shifts_signed!(i64);
pub(crate) trait ReferenceScalarOps: Copy {
fn expected_add_(self, rhs: Self) -> Self;
fn expected_sub_(self, rhs: Self) -> Self;
fn expected_mul_(self, rhs: Self) -> Self;
fn expected_fma_(self, rhs: Self, acc: Self) -> Self;
fn expected_max_(self, rhs: Self) -> Self;
fn expected_min_(self, rhs: Self) -> Self;
}
macro_rules! impl_expected_ops_for_integers {
($type:ty) => {
impl ReferenceScalarOps for $type {
#[inline(always)]
fn expected_add_(self, rhs: Self) -> Self {
self.wrapping_add(rhs)
}
#[inline(always)]
fn expected_sub_(self, rhs: Self) -> Self {
self.wrapping_sub(rhs)
}
#[inline(always)]
fn expected_mul_(self, rhs: Self) -> Self {
self.wrapping_mul(rhs)
}
#[inline(always)]
fn expected_fma_(self, rhs: Self, acc: Self) -> Self {
self.wrapping_mul(rhs).wrapping_add(acc)
}
#[inline(always)]
fn expected_max_(self, rhs: Self) -> Self {
self.max(rhs)
}
#[inline(always)]
fn expected_min_(self, rhs: Self) -> Self {
self.min(rhs)
}
}
};
}
macro_rules! impl_expected_ops_for_floats {
($type:ty) => {
impl ReferenceScalarOps for $type {
#[inline(always)]
fn expected_add_(self, rhs: Self) -> Self {
self + rhs
}
#[inline(always)]
fn expected_sub_(self, rhs: Self) -> Self {
self - rhs
}
#[inline(always)]
fn expected_mul_(self, rhs: Self) -> Self {
self * rhs
}
#[inline(always)]
fn expected_fma_(self, rhs: Self, acc: Self) -> Self {
self.mul_add(rhs, acc)
}
#[inline(always)]
fn expected_max_(self, rhs: Self) -> Self {
self.max(rhs)
}
#[inline(always)]
fn expected_min_(self, rhs: Self) -> Self {
self.min(rhs)
}
}
};
}
impl_expected_ops_for_integers!(u8);
impl_expected_ops_for_integers!(u16);
impl_expected_ops_for_integers!(u32);
impl_expected_ops_for_integers!(u64);
impl_expected_ops_for_integers!(i8);
impl_expected_ops_for_integers!(i16);
impl_expected_ops_for_integers!(i32);
impl_expected_ops_for_integers!(i64);
impl_expected_ops_for_floats!(f32);
impl_expected_ops_for_floats!(f64);
pub(crate) trait ReferenceAbs: Copy {
fn expected_abs_(self) -> Self;
}
macro_rules! impl_abs {
(integer, $T:ty) => {
impl ReferenceAbs for $T {
fn expected_abs_(self) -> Self {
if self == Self::MIN {
self
} else {
self.abs()
}
}
}
};
(integer, $($T:ty),* $(,)?) => {
$(impl_abs!(integer, $T);)*
};
($T:ty) => {
impl ReferenceAbs for $T {
fn expected_abs_(self) -> Self {
self.abs()
}
}
};
($($T:ty),* $(,)?) => {
$(impl_abs!($T);)*
};
}
impl_abs!(integer, i8, i16, i32, i64);
impl_abs!(f32, f64);
pub(crate) trait ReferenceCast<To> {
fn reference_cast(self) -> To;
}
impl ReferenceCast<f32> for half::f16 {
#[cfg(miri)]
fn reference_cast(self) -> f32 {
self.to_f32_const()
}
#[cfg(not(miri))]
#[inline(always)]
fn reference_cast(self) -> f32 {
self.into()
}
}
impl ReferenceCast<half::f16> for f32 {
#[cfg(miri)]
fn reference_cast(self) -> half::f16 {
half::f16::from_f32_const(self)
}
#[cfg(not(miri))]
#[inline(always)]
fn reference_cast(self) -> half::f16 {
half::f16::from_f32(self)
}
}
impl ReferenceCast<f32> for i32 {
fn reference_cast(self) -> f32 {
self as f32
}
}
impl ReferenceCast<u8> for i16 {
fn reference_cast(self) -> u8 {
self as u8
}
}
impl ReferenceCast<i8> for i16 {
fn reference_cast(self) -> i8 {
self as i8
}
}
#[inline(always)]
pub fn cast_f16_to_f32(x: half::f16) -> f32 {
x.reference_cast()
}
#[inline(always)]
pub fn cast_f32_to_f16(x: f32) -> half::f16 {
x.reference_cast()
}
pub(crate) trait TreeReduce {
type Scalar: Copy;
fn tree_reduce<F>(self, op: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar;
}
impl<T> TreeReduce for [T; 1]
where
T: Copy,
{
type Scalar = T;
#[inline(always)]
fn tree_reduce<F>(self, _op: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
{
self[0]
}
}
impl<T> TreeReduce for [T; 2]
where
T: Copy,
{
type Scalar = T;
#[inline(always)]
fn tree_reduce<F>(self, op: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
{
op(self[0], self[1])
}
}
macro_rules! impl_tree_reduce {
($N:literal) => {
impl<T> TreeReduce for [T; $N]
where
T: Copy,
{
type Scalar = T;
#[inline(always)]
fn tree_reduce<F>(self, op: F) -> Self::Scalar
where
F: Fn(Self::Scalar, Self::Scalar) -> Self::Scalar,
{
const N2: usize = { $N / 2 };
let half: [T; N2] = core::array::from_fn(|i| op(self[i], self[N2 + i]));
half.tree_reduce(op)
}
}
};
}
impl_tree_reduce!(4);
impl_tree_reduce!(8);
impl_tree_reduce!(16);