use core::fmt::{Debug, Formatter, Result as FmtResult};
use core::iter::{Product, Sum};
use core::mem::{self, MaybeUninit};
use core::ops::*;
use core::ptr;
use num_traits::Float;
use self::align::Align;
use crate::inner::Repr;
use crate::Mask;
pub mod align {
pub trait Align: Copy {}
macro_rules! align {
($name: ident, $align: expr) => {
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[repr(align($align))]
pub struct $name;
impl Align for $name {}
};
}
align!(Align1, 1);
align!(Align2, 2);
align!(Align4, 4);
align!(Align8, 8);
align!(Align16, 16);
align!(Align32, 32);
align!(Align64, 64);
align!(Align128, 128);
}
pub trait Masked {
type Mask;
}
macro_rules! bin_op_impl {
($tr: ident, $meth: ident, $tr_assign: ident, $meth_assign: ident) => {
impl<A: Align, B: $tr<Output = B> + Repr, const S: usize> $tr for Vector<A, B, S> {
type Output = Self;
#[inline]
fn $meth(self, rhs: Self) -> Self {
unsafe {
let mut data = MaybeUninit::<Self>::uninit();
for i in 0..S {
ptr::write(
data.as_mut_ptr().cast::<B>().add(i),
$tr::$meth(self.data[i], rhs.data[i]),
);
}
data.assume_init()
}
}
}
impl<A: Align, B: $tr<Output = B> + Repr, const S: usize> $tr<B> for Vector<A, B, S> {
type Output = Self;
#[inline]
fn $meth(self, rhs: B) -> Self {
unsafe {
let mut data = MaybeUninit::<Self>::uninit();
for i in 0..S {
ptr::write(
data.as_mut_ptr().cast::<B>().add(i),
$tr::$meth(self.data[i], rhs),
);
}
data.assume_init()
}
}
}
impl<A: Align, B: $tr_assign + Repr, const S: usize> $tr_assign for Vector<A, B, S> {
#[inline]
fn $meth_assign(&mut self, rhs: Self) {
for i in 0..S {
$tr_assign::$meth_assign(&mut self.data[i], rhs.data[i]);
}
}
}
impl<A: Align, B: $tr_assign + Repr, const S: usize> $tr_assign<B> for Vector<A, B, S> {
#[inline]
fn $meth_assign(&mut self, rhs: B) {
for i in 0..S {
$tr_assign::$meth_assign(&mut self.data[i], rhs);
}
}
}
};
}
macro_rules! una_op_impl {
($tr: ident, $meth: ident) => {
impl<A: Align, B: $tr<Output = B> + Repr, const S: usize> $tr for Vector<A, B, S> {
type Output = Self;
#[inline]
fn $meth(self) -> Self {
unsafe {
let mut data = MaybeUninit::<Self>::uninit();
for i in 0..S {
ptr::write(
data.as_mut_ptr().cast::<B>().add(i),
$tr::$meth(self.data[i]),
);
}
data.assume_init()
}
}
}
};
}
macro_rules! cmp_op {
($($(#[ $meta: meta ])* $tr: ident => $op: ident;)*) => {
$(
$(#[ $meta ])*
#[inline]
pub fn $op(self, other: Self) -> <Self as Masked>::Mask
where
B: $tr,
{
let mut data = MaybeUninit::<<Self as Masked>::Mask>::uninit();
unsafe {
for i in 0..S {
ptr::write(
data.as_mut_ptr().cast::<B::Mask>().add(i),
B::Mask::from_bool(self.data[i].$op(&other.data[i])),
);
}
data.assume_init()
}
}
)*
};
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct Vector<A, B, const S: usize>
where
A: Align,
B: Repr,
{
_align: [A; 0],
data: [B; S],
}
impl<A, B, const S: usize> Vector<A, B, S>
where
A: Align,
B: Repr,
{
pub const LANES: usize = S;
#[inline(always)]
fn assert_size() {
assert!(S > 0);
assert!(
isize::MAX as usize > mem::size_of::<Self>(),
"Vector type too huge",
);
assert_eq!(
mem::size_of::<Self>(),
mem::size_of::<[B; S]>(),
"Must not contain paddings/invalid Align parameter",
);
}
#[inline]
pub unsafe fn new_unchecked(input: *const B) -> Self {
Self::assert_size();
Self {
_align: [],
data: ptr::read(input.cast()),
}
}
#[inline]
pub fn new<I>(input: I) -> Self
where
I: AsRef<[B]>,
{
let input = input.as_ref();
assert_eq!(
input.len(),
S,
"Creating vector from the wrong sized slice (expected {}, got {})",
S,
input.len(),
);
unsafe { Self::new_unchecked(input.as_ptr()) }
}
#[inline]
pub fn splat(value: B) -> Self {
Self::assert_size();
Self {
_align: [],
data: [value; S],
}
}
#[inline]
pub fn gather_load<I, Idx>(input: I, idx: Idx) -> Self
where
I: AsRef<[B]>,
Idx: AsRef<[usize]>,
{
Self::assert_size();
let input = input.as_ref();
let idx = idx.as_ref();
assert_eq!(
S,
idx.len(),
"Gathering vector from wrong number of indexes"
);
assert!(idx.iter().all(|&l| l < input.len()), "Gather out of bounds");
let mut data = MaybeUninit::<Self>::uninit();
unsafe {
for i in 0..S {
let idx = *idx.get_unchecked(i);
let input = *input.get_unchecked(idx);
ptr::write(data.as_mut_ptr().cast::<B>().add(i), input);
}
data.assume_init()
}
}
#[inline]
pub fn gather_load_masked<I, Idx, M, MB>(mut self, input: I, idx: Idx, mask: M) -> Self
where
I: AsRef<[B]>,
Idx: AsRef<[usize]>,
M: AsRef<[MB]>,
MB: Mask,
{
let input = input.as_ref();
let idx = idx.as_ref();
let mask = mask.as_ref();
let len = idx.len();
assert_eq!(S, len, "Gathering vector from wrong number of indexes");
assert_eq!(S, mask.len(), "Gathering with wrong sized mask");
for i in 0..S {
unsafe {
if mask.get_unchecked(i).bool() {
let idx = *idx.get_unchecked(i);
self[i] = input[idx];
}
}
}
self
}
#[inline]
pub fn store<O: AsMut<[B]>>(self, mut output: O) {
output.as_mut().copy_from_slice(&self[..])
}
#[inline]
pub fn scatter_store<O, Idx>(self, mut output: O, idx: Idx)
where
O: AsMut<[B]>,
Idx: AsRef<[usize]>,
{
let output = output.as_mut();
let idx = idx.as_ref();
assert_eq!(S, idx.len(), "Scattering vector to wrong number of indexes");
assert!(
idx.iter().all(|&l| l < output.len()),
"Scatter out of bounds"
);
for i in 0..S {
unsafe {
let idx = *idx.get_unchecked(i);
*output.get_unchecked_mut(idx) = self[i];
}
}
}
#[inline]
pub fn scatter_store_masked<O, Idx, M, MB>(self, mut output: O, idx: Idx, mask: M)
where
O: AsMut<[B]>,
Idx: AsRef<[usize]>,
M: AsRef<[MB]>,
MB: Mask,
{
let output = output.as_mut();
let idx = idx.as_ref();
let mask = mask.as_ref();
assert_eq!(S, idx.len(), "Scattering vector to wrong number of indexes");
assert_eq!(S, mask.len(), "Scattering vector with wrong sized mask");
let in_bounds = idx
.iter()
.enumerate()
.all(|(i, &l)| !mask[i].bool() || l < output.len());
assert!(in_bounds, "Scatter out of bounds");
for i in 0..S {
if mask[i].bool() {
unsafe {
let idx = *idx.get_unchecked(i);
*output.get_unchecked_mut(idx) = self[i];
}
}
}
}
#[inline]
pub fn blend<M, MB>(self, other: Self, mask: M) -> Self
where
M: AsRef<[MB]>,
MB: Mask,
{
let mut data = MaybeUninit::<Self>::uninit();
let mask = mask.as_ref();
unsafe {
for i in 0..S {
ptr::write(
data.as_mut_ptr().cast::<B>().add(i),
if mask[i].bool() { other[i] } else { self[i] },
);
}
data.assume_init()
}
}
#[inline]
pub fn maximum(self, other: Self) -> Self
where
B: PartialOrd,
{
let m = self.lt(other);
self.blend(other, m)
}
#[inline]
pub fn minimum(self, other: Self) -> Self
where
B: PartialOrd,
{
let m = self.gt(other);
self.blend(other, m)
}
#[inline]
pub fn horizontal_sum(self) -> B
where
B: Add<Output = B>,
{
#[inline(always)]
fn inner<B: Copy + Add<Output = B>>(d: &[B]) -> B {
if d.len() == 1 {
d[0]
} else {
let mid = d.len() / 2;
inner(&d[..mid]) + inner(&d[mid..])
}
}
inner(&self.data)
}
#[inline]
pub fn horizontal_product(self) -> B
where
B: Mul<Output = B>,
{
#[inline(always)]
fn inner<B: Copy + Mul<Output = B>>(d: &[B]) -> B {
if d.len() == 1 {
d[0]
} else {
let mid = d.len() / 2;
inner(&d[..mid]) * inner(&d[mid..])
}
}
inner(&self.data)
}
cmp_op!(
PartialEq => eq;
PartialOrd => lt;
PartialOrd => gt;
PartialOrd => le;
PartialOrd => ge;
);
}
impl<A, B, const S: usize> Vector<A, B, S>
where
A: Align,
B: Repr + Float,
{
pub fn mul_add(self, a: Self, b: Self) -> Self {
let mut result = Self::splat(B::zero());
for ((res, &s), (&a, &b)) in result
.data
.iter_mut()
.zip(self.data.iter())
.zip(a.data.iter().zip(b.data.iter()))
{
*res = s.mul_add(a, b);
}
result
}
}
impl<A: Align, B: Repr, const S: usize> Masked for Vector<A, B, S> {
type Mask = Vector<A, B::Mask, S>;
}
impl<A: Align, B: Default + Repr, const S: usize> Default for Vector<A, B, S> {
#[inline]
fn default() -> Self {
Self::splat(Default::default())
}
}
impl<A: Align, B: Debug + Repr, const S: usize> Debug for Vector<A, B, S> {
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
fmt.debug_tuple("Vector").field(&self.data).finish()
}
}
impl<A: Align, B: Repr, const S: usize> Deref for Vector<A, B, S> {
type Target = [B; S];
#[inline]
fn deref(&self) -> &[B; S] {
&self.data
}
}
impl<A: Align, B: Repr, const S: usize> DerefMut for Vector<A, B, S> {
#[inline]
fn deref_mut(&mut self) -> &mut [B; S] {
&mut self.data
}
}
impl<A: Align, B: Repr, const S: usize> AsRef<[B]> for Vector<A, B, S> {
#[inline]
fn as_ref(&self) -> &[B] {
&self.data
}
}
impl<A: Align, B: Repr, const S: usize> AsRef<[B; S]> for Vector<A, B, S> {
#[inline]
fn as_ref(&self) -> &[B; S] {
&self.data
}
}
impl<A: Align, B: Repr, const S: usize> AsMut<[B]> for Vector<A, B, S> {
#[inline]
fn as_mut(&mut self) -> &mut [B] {
&mut self.data
}
}
impl<A: Align, B: Repr, const S: usize> AsMut<[B; S]> for Vector<A, B, S> {
#[inline]
fn as_mut(&mut self) -> &mut [B; S] {
&mut self.data
}
}
impl<A: Align, B: Repr, const S: usize> From<[B; S]> for Vector<A, B, S> {
#[inline]
fn from(data: [B; S]) -> Self {
Self::assert_size();
Self { _align: [], data }
}
}
impl<A: Align, B: Repr, const S: usize> From<Vector<A, B, S>> for [B; S] {
#[inline]
fn from(vector: Vector<A, B, S>) -> [B; S] {
vector.data
}
}
impl<I, A, B, const S: usize> Index<I> for Vector<A, B, S>
where
A: Align,
B: Repr,
[B; S]: Index<I>,
{
type Output = <[B; S] as Index<I>>::Output;
#[inline]
fn index(&self, idx: I) -> &Self::Output {
&self.data[idx]
}
}
impl<I, A, B, const S: usize> IndexMut<I> for Vector<A, B, S>
where
A: Align,
B: Repr,
[B; S]: IndexMut<I>,
{
#[inline]
fn index_mut(&mut self, idx: I) -> &mut Self::Output {
&mut self.data[idx]
}
}
impl<A: Align, B: AddAssign + Default + Repr, const S: usize> Sum for Vector<A, B, S> {
#[inline]
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let mut result = Self::default();
for i in iter {
result += i;
}
result
}
}
impl<A: Align, B: MulAssign + Repr, const S: usize> Product for Vector<A, B, S> {
#[inline]
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
let mut result = Self::splat(B::ONE);
for i in iter {
result *= i;
}
result
}
}
bin_op_impl!(Add, add, AddAssign, add_assign);
bin_op_impl!(Sub, sub, SubAssign, sub_assign);
bin_op_impl!(Mul, mul, MulAssign, mul_assign);
bin_op_impl!(Div, div, DivAssign, div_assign);
bin_op_impl!(Rem, rem, RemAssign, rem_assign);
bin_op_impl!(BitAnd, bitand, BitAndAssign, bitand_assign);
bin_op_impl!(BitOr, bitor, BitOrAssign, bitor_assign);
bin_op_impl!(BitXor, bitxor, BitXorAssign, bitxor_assign);
bin_op_impl!(Shl, shl, ShlAssign, shl_assign);
bin_op_impl!(Shr, shr, ShrAssign, shr_assign);
una_op_impl!(Neg, neg);
una_op_impl!(Not, not);
impl<A: Align, B: PartialEq + Repr, const S: usize> PartialEq for Vector<A, B, S> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.data == other.data
}
}
impl<A: Align, B: Eq + Repr, const S: usize> Eq for Vector<A, B, S> {}
impl<A: Align, B: PartialEq + Repr, const S: usize> PartialEq<[B; S]> for Vector<A, B, S> {
#[inline]
fn eq(&self, other: &[B; S]) -> bool {
self.data == *other
}
}
impl<A: Align, B: PartialEq + Repr, const S: usize> PartialEq<Vector<A, B, S>> for [B; S] {
#[inline]
fn eq(&self, other: &Vector<A, B, S>) -> bool {
*self == other.data
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
type V = u16x4;
#[test]
#[should_panic(expected = "Creating vector from the wrong sized slice (expected 4, got 3)")]
fn wrong_size_new() {
V::new([1, 2, 3]);
}
#[test]
fn round_trip() {
let orig = [1, 2, 3, 4];
assert_eq!(<[u16; 4]>::from(u16x4::from(orig)), orig);
}
#[test]
fn shuffle() {
let v1 = V::new([1, 2, 3, 4]);
let v2 = V::gather_load(v1, [3, 1, 2, 0]);
assert_eq!(v2.deref(), &[4, 2, 3, 1]);
let v3 = V::gather_load(v2, [0, 0, 2, 2]);
assert_eq!(v3.deref(), &[4, 4, 3, 3]);
}
#[test]
fn gather() {
let data = (1..=10).collect::<Vec<_>>();
let v = V::gather_load(data, [0, 2, 4, 6]);
assert_eq!(v, [1, 3, 5, 7]);
}
#[test]
fn scatter() {
let v = V::new([1, 2, 3, 4]);
let mut output = [0; 10];
v.scatter_store(&mut output, [1, 3, 5, 7]);
assert_eq!(output, [0, 1, 0, 2, 0, 3, 0, 4, 0, 0]);
}
#[test]
#[should_panic(expected = "Gather out of bounds")]
fn gather_oob() {
V::gather_load([1, 2, 3], [0, 1, 2, 3]);
}
#[test]
#[should_panic(expected = "Gathering vector from wrong number of indexes")]
fn gather_idx_cnt() {
V::gather_load([0, 1, 2, 3, 4], [0, 1]);
}
#[test]
#[should_panic(expected = "Scatter out of bounds")]
fn scatter_oob() {
let mut out = [0; 10];
V::new([1, 2, 3, 4]).scatter_store(&mut out, [0, 1, 2, 15]);
}
#[test]
#[should_panic(expected = "Scattering vector to wrong number of indexes")]
fn scatter_idx_cnt() {
let mut out = [0; 10];
V::new([1, 2, 3, 4]).scatter_store(&mut out, [0, 1, 2]);
}
const T: m32 = m32::TRUE;
const F: m32 = m32::FALSE;
#[test]
fn cmp() {
let v1 = u32x4::new([1, 3, 5, 7]);
let v2 = u32x4::new([2, 3, 4, 5]);
assert_eq!(v1.eq(v2), m32x4::new([F, T, F, F]));
assert_eq!(v1.le(v2), m32x4::new([T, T, F, F]));
assert_eq!(v1.ge(v2), m32x4::new([F, T, T, T]));
}
#[test]
fn blend() {
let v1 = u32x4::new([1, 2, 3, 4]);
let v2 = u32x4::new([5, 6, 7, 8]);
let b1 = v1.blend(v2, m32x4::new([F, T, F, T]));
assert_eq!(b1, u32x4::new([1, 6, 3, 8]));
let b2 = v1.blend(v2, [false, true, false, true]);
assert_eq!(b1, b2);
}
#[test]
fn fma() {
let a = f32x4::new([1.0, 2.0, 3.0, 4.0]);
let b = f32x4::new([5.0, 6.0, 7.0, 8.0]);
let c = f32x4::new([9.0, 10.0, 11.0, 12.0]);
assert_eq!(a.mul_add(b, c), f32x4::new([14.0, 22.0, 32.0, 44.0]));
}
}