use super::{bool32x4, common_types::ConstUnionHack128bit, macros::*, num_traits::*};
#[cfg(not(feature = "scalar-math"))]
use crate::wide::{CmpEq, CmpGt, CmpLt};
use auto_ops_det::impl_op_ex;
#[cfg(not(target_arch = "spirv"))]
use core::fmt;
use core::ops;
use core::ops::{Add, Div, Mul, Rem, Sub};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "scalar-math")]
type Inner = [i32; 4];
#[cfg(not(feature = "scalar-math"))]
type Inner = crate::wide::i32x4;
#[allow(non_camel_case_types)]
#[cfg_attr(
all(
feature = "std",
not(feature = "libm_force"),
not(feature = "scalar-math")
),
repr(transparent)
)]
#[cfg_attr(
any(
feature = "libm_force",
feature = "scalar-math",
all(feature = "libm_fallback", not(feature = "std"))
),
repr(align(16))
)]
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct i32x4(Inner);
macro_rules! define_const {
( $( $const_name:ident ),* ) => {
$(
pub const $const_name: Self = Self::const_splat(core::i32::$const_name);
)*
};
}
impl i32x4 {
define_const!(MIN, MAX);
#[inline]
fn map(self, f: impl Fn(i32) -> i32) -> Self {
let arr: &[i32; 4] = self.as_ref();
Self::from([f(arr[0]), f(arr[1]), f(arr[2]), f(arr[3])])
}
#[inline]
fn zip_map(self, rhs: Self, f: impl Fn(i32, i32) -> i32) -> Self {
let arr: &[i32; 4] = self.as_ref();
let rhs: &[i32; 4] = rhs.as_ref();
Self::from([
f(arr[0], rhs[0]),
f(arr[1], rhs[1]),
f(arr[2], rhs[2]),
f(arr[3], rhs[3]),
])
}
#[inline]
pub const fn const_splat(val: i32) -> Self {
unsafe { ConstUnionHack128bit { i32a4: [val; 4] }.i32x4 }
}
}
impl NumConstEx for i32x4 {
const ZERO: Self = Self::const_splat(0_i32);
const ONE: Self = Self::const_splat(1_i32);
const TWO: Self = Self::const_splat(2_i32);
}
impl From<[i32; 4]> for i32x4 {
#[inline]
fn from(vals: [i32; 4]) -> Self {
Self(add_into!(vals))
}
}
impl From<i32x4> for [i32; 4] {
#[inline]
fn from(val: i32x4) -> [i32; 4] {
add_into!(val.0)
}
}
impl From<i32> for i32x4 {
#[inline]
fn from(val: i32) -> Self {
Self::from([val; 4])
}
}
impl Num for i32x4 {
type Element = i32;
type Bool = bool32x4;
#[inline]
fn lanes() -> usize {
4
}
#[inline]
fn splat(val: Self::Element) -> Self {
Self::from(val)
}
#[inline]
fn extract(&self, i: usize) -> Self::Element {
self.as_ref()[i]
}
#[inline]
unsafe fn extract_unchecked(&self, i: usize) -> Self::Element {
*self.as_ref().get_unchecked(i)
}
#[inline]
fn replace(&mut self, i: usize, val: Self::Element) {
self.as_mut()[i] = val;
}
#[inline]
unsafe fn replace_unchecked(&mut self, i: usize, val: Self::Element) {
*self.as_mut().get_unchecked_mut(i) = val;
}
#[inline]
fn select(self, cond: Self::Bool, other: Self) -> Self {
#[cfg(not(feature = "scalar-math"))]
{
Self(cond.cast_i32x4().0.blend(self.0, other.0))
}
#[cfg(feature = "scalar-math")]
{
let mut arr = [0_i32; 4];
for (i, xi) in arr.iter_mut().enumerate() {
*xi = if cond.0[i] == u32::MAX {
self.0[i]
} else {
other.0[i]
};
}
Self(arr)
}
}
}
#[cfg(all(
feature = "std",
not(feature = "libm_force"),
not(feature = "scalar-math")
))]
mod impl_i32_ops {
use super::*;
impl_op_ex!(-|a: &i32x4| -> i32x4 { i32x4(-a.0) });
impl_op_ex!(+ |a: &i32x4, b: &i32x4| -> i32x4 { i32x4(a.0 + b.0) });
impl_op_ex!(-|a: &i32x4, b: &i32x4| -> i32x4 { i32x4(a.0 - b.0) });
impl_op_ex!(*|a: &i32x4, b: &i32x4| -> i32x4 { i32x4(a.0 * b.0) });
impl_op_ex!(/ |a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x / y) });
impl_op_ex!(% |a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x % y) });
impl_op_ex!(+= |a: &mut i32x4, b: &i32x4| { a.0 += b.0 });
impl_op_ex!(-= |a: &mut i32x4, b: &i32x4| { a.0 -= b.0 });
impl_op_ex!(/= |a: &mut i32x4, b: &i32x4| { *a = *a / *b });
impl_op_ex!(*= |a: &mut i32x4, b: &i32x4| { a.0 *= b.0 });
impl_op_ex!(%= |a: &mut i32x4, b: &i32x4| { *a = *a % *b });
}
#[cfg(any(
feature = "libm_force",
feature = "scalar-math",
all(feature = "libm_fallback", not(feature = "std"))
))]
mod impl_i32_ops {
use super::*;
impl_op_ex!(-|a: &i32x4| -> i32x4 { i32x4::map(*a, |x| -x) });
impl_op_ex!(+ |a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x + y) });
impl_op_ex!(-|a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x - y) });
impl_op_ex!(*|a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x * y) });
impl_op_ex!(/ |a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x / y) });
impl_op_ex!(% |a: &i32x4, b: &i32x4| -> i32x4 { i32x4::zip_map(*a, *b, |x, y| x % y) });
impl_op_ex!(+= |a: &mut i32x4, b: &i32x4| { *a = *a + *b });
impl_op_ex!(-= |a: &mut i32x4, b: &i32x4| { *a = *a - *b });
impl_op_ex!(/= |a: &mut i32x4, b: &i32x4| { *a = *a / *b });
impl_op_ex!(*= |a: &mut i32x4, b: &i32x4| { *a = *a * *b });
impl_op_ex!(%= |a: &mut i32x4, b: &i32x4| { *a = *a % *b });
}
#[cfg(not(feature = "scalar-math"))]
impl PartialOrdEx for i32x4 {
#[inline]
fn gt(self, other: Self) -> Self::Bool {
Self(self.0.cmp_gt(other.0)).cast_bool32x4()
}
#[inline]
fn lt(self, other: Self) -> Self::Bool {
Self(self.0.cmp_lt(other.0)).cast_bool32x4()
}
#[inline]
fn ge(self, other: Self) -> Self::Bool {
Self(self.0.cmp_gt(other.0) | self.0.cmp_eq(other.0)).cast_bool32x4()
}
#[inline]
fn le(self, other: Self) -> Self::Bool {
Self(self.0.cmp_lt(other.0) | self.0.cmp_eq(other.0)).cast_bool32x4()
}
#[inline]
fn eq(self, other: Self) -> Self::Bool {
Self(self.0.cmp_eq(other.0)).cast_bool32x4()
}
#[inline]
fn ne(self, other: Self) -> Self::Bool {
Self(!self.0.cmp_eq(other.0)).cast_bool32x4()
}
#[inline]
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
#[inline]
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
#[inline]
fn clamp(self, min: Self, max: Self) -> Self {
self.min(max).max(min)
}
}
#[cfg(feature = "scalar-math")]
impl PartialOrdEx for i32x4 {
#[inline]
fn gt(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] > other.0[i];
}
res.into()
}
#[inline]
fn lt(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] < other.0[i];
}
res.into()
}
#[inline]
fn ge(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] >= other.0[i];
}
res.into()
}
#[inline]
fn le(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] <= other.0[i];
}
res.into()
}
#[inline]
fn eq(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] == other.0[i];
}
res.into()
}
#[inline]
fn ne(self, other: Self) -> Self::Bool {
let mut res = [true; 4];
for (i, xi) in res.iter_mut().enumerate() {
*xi = self.0[i] != other.0[i];
}
res.into()
}
#[inline]
fn max(self, other: Self) -> Self {
Self(self.0.max(other.0))
}
#[inline]
fn min(self, other: Self) -> Self {
Self(self.0.min(other.0))
}
#[inline]
fn clamp(self, min: Self, max: Self) -> Self {
self.min(max).max(min)
}
}
impl Signed for i32x4 {
#[inline]
fn absf(self) -> Self {
#[cfg(not(feature = "scalar-math"))]
{
Self(self.0.abs())
}
#[cfg(feature = "scalar-math")]
{
Self(self.0.map(|x| x.abs()))
}
}
#[inline]
fn signumf(self) -> Self {
self.map(|x| x.signum())
}
}
impl_wide_partial_eq!(i32x4);
impl NumEx for i32x4 {}
impl SignedEx for i32x4 {}
#[cfg(not(target_arch = "spirv"))]
impl fmt::Display for i32x4 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[cfg(not(feature = "scalar-math"))]
{
write!(f, "{}", self.0)
}
#[cfg(feature = "scalar-math")]
{
write!(
f,
"({}, {}, {}, {})",
self.0[0], self.0[1], self.0[2], self.0[3]
)
}
}
}
impl_wide_interge_shift_ops!(i32x4, i8, i16, i32, u8, u16, u32);
impl_wide_bit_ops!(i32x4);
impl IntegerBitOps for i32x4 {}
impl_wide_scalar_ops!(i32x4, i32);
impl_default!(i32x4);
impl AsRef<[i32; 4]> for i32x4 {
#[inline]
fn as_ref(&self) -> &[i32; 4] {
as_array_ref!(self.0)
}
}
impl AsMut<[i32; 4]> for i32x4 {
#[inline]
fn as_mut(&mut self) -> &mut [i32; 4] {
as_array_mut!(self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_i32x4_signum() {
let a1 = i32x4::ONE;
let a2 = i32x4::ZERO;
let b1 = -a1;
let b2 = -a2;
assert_eq!(a1.signumf(), a1);
assert_eq!(b1.signumf(), -a1);
assert_eq!(a2.signumf(), a2);
assert_eq!(b2.signumf(), a2);
}
}