use std::cmp::Ordering;
use half::f16;
use crate::{BitMask, Const, SIMDMask, SupportedLaneCount, arch, reference::ReferenceCast};
pub(crate) trait USizeConvertTo<T> {
fn test_convert(self) -> T;
}
macro_rules! use_as_conversion {
($type:ty) => {
impl USizeConvertTo<$type> for usize {
fn test_convert(self) -> $type {
self as $type
}
}
};
}
impl USizeConvertTo<f16> for usize {
fn test_convert(self) -> f16 {
crate::cast_f32_to_f16(self as f32)
}
}
use_as_conversion!(f32);
use_as_conversion!(f64);
use_as_conversion!(i8);
use_as_conversion!(i16);
use_as_conversion!(i32);
use_as_conversion!(i64);
use_as_conversion!(u8);
use_as_conversion!(u16);
use_as_conversion!(u32);
use_as_conversion!(u64);
pub(crate) fn iota_iter<'a, T, I>(x: I)
where
T: 'a,
usize: USizeConvertTo<T>,
I: Iterator<Item = &'a mut T>,
{
x.enumerate().for_each(|(i, d)| {
*d = i.test_convert();
});
}
pub(crate) fn iota_slice<T>(x: &mut [T])
where
usize: USizeConvertTo<T>,
{
iota_iter(x.iter_mut());
}
pub(crate) fn promote_to_array<const N: usize, A>(x: BitMask<N, A>) -> [bool; N]
where
A: arch::Sealed,
Const<N>: SupportedLaneCount,
BitMask<N, A>: SIMDMask<Arch = A>,
{
core::array::from_fn(|i| x.get_unchecked(i))
}
pub(crate) trait ScalarTraits: Copy + std::cmp::PartialEq + std::fmt::Debug {
fn splat_test_values() -> impl Iterator<Item = Self>;
fn exact_eq(self, other: Self) -> bool {
self == other
}
}
impl ScalarTraits for f16 {
fn splat_test_values() -> impl Iterator<Item = Self> {
let v: Vec<f32> = vec![-127.567, -1.0, -0.0, 0.0, 1.25, 10.5];
v.into_iter().map(|x| x.reference_cast())
}
fn exact_eq(self, other: Self) -> bool {
let x: f32 = self.reference_cast();
let y: f32 = other.reference_cast();
x.total_cmp(&y) == Ordering::Equal
}
}
impl ScalarTraits for f32 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[
f32::MIN,
f32::MAX,
0.0,
-0.0,
12.0,
f32::MIN_POSITIVE,
f32::INFINITY,
f32::NEG_INFINITY,
]
.into_iter()
}
fn exact_eq(self, other: Self) -> bool {
if self.is_nan() && other.is_nan() {
return true;
}
self.total_cmp(&other) == Ordering::Equal
}
}
impl ScalarTraits for f64 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[
f64::MIN,
f64::MAX,
0.0,
-0.0,
12.0,
f64::MIN_POSITIVE,
f64::INFINITY,
f64::NEG_INFINITY,
]
.into_iter()
}
fn exact_eq(self, other: Self) -> bool {
#[cfg(miri)]
if self.is_nan() && other.is_nan() {
return true;
}
self.total_cmp(&other) == Ordering::Equal
}
}
impl ScalarTraits for bool {
fn splat_test_values() -> impl Iterator<Item = Self> {
[false, true].into_iter()
}
}
impl ScalarTraits for i8 {
fn splat_test_values() -> impl Iterator<Item = Self> {
-128..=127
}
}
impl ScalarTraits for i16 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[i16::MIN, -10, -1, 0, 1, 10, i16::MAX].into_iter()
}
}
impl ScalarTraits for i32 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[i32::MIN, -10, -1, 0, 1, 10, i32::MAX].into_iter()
}
}
impl ScalarTraits for i64 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[i64::MIN, -10, -1, 0, 1, 10, i64::MAX].into_iter()
}
}
impl ScalarTraits for u8 {
fn splat_test_values() -> impl Iterator<Item = Self> {
0..=255
}
}
impl ScalarTraits for u16 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[0, 1, 10, 255, 9128, u16::MAX].into_iter()
}
}
impl ScalarTraits for u32 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[0, 1, 10, 255, 9128, u32::MAX].into_iter()
}
}
impl ScalarTraits for u64 {
fn splat_test_values() -> impl Iterator<Item = Self> {
[0, 1, 10, 255, 9128, u64::MAX].into_iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_eq_f16() {
let zero: f16 = 0.0f32.reference_cast();
let minus_zero: f16 = -0.0f32.reference_cast();
assert!(zero.exact_eq(zero));
assert!(!zero.exact_eq(minus_zero));
assert!(!minus_zero.exact_eq(zero));
}
#[test]
fn test_exact_eq_f32() {
assert_eq!(0.0f32, -0.0f32);
assert!(!(0.0f32.exact_eq(-0.0f32)));
assert!(f32::NAN.exact_eq(f32::NAN));
}
}