npsimd 0.3.0

An ergonomic library for architecture-specific vectorization.
Documentation
//! Primitive SIMD types.

use core::marker::Freeze;

use super::Element;

/// A primitive SIMD vector.
pub unsafe trait PrimitiveVector: Copy + Freeze {
    /// The element type of the vector.
    type Element: PrimitiveElement;

    /// The length of the vector.
    const LEN: usize;
}

/// A SIMD-compatible primitive element.
pub unsafe trait PrimitiveElement: Copy + Freeze {}

unsafe impl PrimitiveElement for u8 {}
unsafe impl PrimitiveElement for u16 {}
unsafe impl PrimitiveElement for u32 {}
unsafe impl PrimitiveElement for u64 {}

unsafe impl PrimitiveElement for i8 {}
unsafe impl PrimitiveElement for i16 {}
unsafe impl PrimitiveElement for i32 {}
unsafe impl PrimitiveElement for i64 {}

unsafe impl PrimitiveElement for f32 {}
unsafe impl PrimitiveElement for f64 {}

/// A macro to define a primitive SIMD vector type.
macro_rules! decl_simd_prim {
    {
        $(
            $(#[doc = $doc:literal])*
            vector $name:ident[$elem:ty; $size:literal];
        )*
    } => {
        $(
            $(#[doc = $doc])*
            #[allow(non_camel_case_types)]
            #[allow(dead_code)] // TODO: Remove
            #[derive(Copy, Clone, Debug)]
            #[repr(simd)]
            pub struct $name([$elem; $size]);

            unsafe impl Element<$size> for $elem {
                type Primitive = $name;
            }

            unsafe impl PrimitiveVector for $name {
                type Element = $elem;
                const LEN: usize = $size;
            }
        )*
    };
}

// Primitive vectors of 8-bit integers.
decl_simd_prim! {
    /// A primitive vector of 1 unsigned 8-bit integer.
    vector u8x1[u8; 1];

    /// A primitive vector of 2 unsigned 8-bit integers.
    vector u8x2[u8; 2];

    /// A primitive vector of 4 unsigned 8-bit integers.
    vector u8x4[u8; 4];

    /// A primitive vector of 8 unsigned 8-bit integers.
    vector u8x8[u8; 8];

    /// A primitive vector of 16 unsigned 8-bit integers.
    vector u8x16[u8; 16];

    /// A primitive vector of 32 unsigned 8-bit integers.
    vector u8x32[u8; 32];

    /// A primitive vector of 64 unsigned 8-bit integers.
    vector u8x64[u8; 64];

    /// A primitive vector of 128 unsigned 8-bit integers.
    vector u8x128[u8; 128];

    /// A primitive vector of 1 signed 8-bit integer.
    vector i8x1[i8; 1];

    /// A primitive vector of 2 signed 8-bit integers.
    vector i8x2[i8; 2];

    /// A primitive vector of 4 signed 8-bit integers.
    vector i8x4[i8; 4];

    /// A primitive vector of 8 signed 8-bit integers.
    vector i8x8[i8; 8];

    /// A primitive vector of 16 signed 8-bit integers.
    vector i8x16[i8; 16];

    /// A primitive vector of 32 signed 8-bit integers.
    vector i8x32[i8; 32];

    /// A primitive vector of 64 signed 8-bit integers.
    vector i8x64[i8; 64];

    /// A primitive vector of 128 signed 8-bit integers.
    vector i8x128[i8; 128];
}

// Primitive vectors of 16-bit integers.
decl_simd_prim! {
    /// A primitive vector of 1 unsigned 16-bit integer.
    vector u16x1[u16; 1];

    /// A primitive vector of 2 unsigned 16-bit integers.
    vector u16x2[u16; 2];

    /// A primitive vector of 4 unsigned 16-bit integers.
    vector u16x4[u16; 4];

    /// A primitive vector of 8 unsigned 16-bit integers.
    vector u16x8[u16; 8];

    /// A primitive vector of 16 unsigned 16-bit integers.
    vector u16x16[u16; 16];

    /// A primitive vector of 32 unsigned 16-bit integers.
    vector u16x32[u16; 32];

    /// A primitive vector of 64 unsigned 16-bit integers.
    vector u16x64[u16; 64];

    /// A primitive vector of 1 signed 16-bit integer.
    vector i16x1[i16; 1];

    /// A primitive vector of 2 signed 16-bit integers.
    vector i16x2[i16; 2];

    /// A primitive vector of 4 signed 16-bit integers.
    vector i16x4[i16; 4];

    /// A primitive vector of 8 signed 16-bit integers.
    vector i16x8[i16; 8];

    /// A primitive vector of 16 signed 16-bit integers.
    vector i16x16[i16; 16];

    /// A primitive vector of 32 signed 16-bit integers.
    vector i16x32[i16; 32];

    /// A primitive vector of 64 signed 16-bit integers.
    vector i16x64[i16; 64];
}

// Primitive vectors of 32-bit integers.
decl_simd_prim! {
    /// A primitive vector of 1 unsigned 32-bit integer.
    vector u32x1[u32; 1];

    /// A primitive vector of 2 unsigned 32-bit integers.
    vector u32x2[u32; 2];

    /// A primitive vector of 4 unsigned 32-bit integers.
    vector u32x4[u32; 4];

    /// A primitive vector of 8 unsigned 32-bit integers.
    vector u32x8[u32; 8];

    /// A primitive vector of 16 unsigned 32-bit integers.
    vector u32x16[u32; 16];

    /// A primitive vector of 32 unsigned 32-bit integers.
    vector u32x32[u32; 32];

    /// A primitive vector of 1 signed 32-bit integer.
    vector i32x1[i32; 1];

    /// A primitive vector of 2 signed 32-bit integers.
    vector i32x2[i32; 2];

    /// A primitive vector of 4 signed 32-bit integers.
    vector i32x4[i32; 4];

    /// A primitive vector of 8 signed 32-bit integers.
    vector i32x8[i32; 8];

    /// A primitive vector of 16 signed 32-bit integers.
    vector i32x16[i32; 16];

    /// A primitive vector of 32 signed 32-bit integers.
    vector i32x32[i32; 32];
}

// Primitive vectors of 64-bit integers.
decl_simd_prim! {
    /// A primitive vector of 1 unsigned 64-bit integer.
    vector u64x1[u64; 1];

    /// A primitive vector of 2 unsigned 64-bit integers.
    vector u64x2[u64; 2];

    /// A primitive vector of 4 unsigned 64-bit integers.
    vector u64x4[u64; 4];

    /// A primitive vector of 8 unsigned 64-bit integers.
    vector u64x8[u64; 8];

    /// A primitive vector of 16 unsigned 64-bit integers.
    vector u64x16[u64; 16];

    /// A primitive vector of 1 signed 64-bit integer.
    vector i64x1[i64; 1];

    /// A primitive vector of 2 signed 64-bit integers.
    vector i64x2[i64; 2];

    /// A primitive vector of 4 signed 64-bit integers.
    vector i64x4[i64; 4];

    /// A primitive vector of 8 signed 64-bit integers.
    vector i64x8[i64; 8];

    /// A primitive vector of 16 signed 64-bit integers.
    vector i64x16[i64; 16];
}

// Primitive vectors of 32-bit floating point values.
decl_simd_prim! {
    /// A primitive vector of 1 32-bit floating point value.
    vector f32x1[f32; 1];

    /// A primitive vector of 2 32-bit floating point values.
    vector f32x2[f32; 2];

    /// A primitive vector of 4 32-bit floating point values.
    vector f32x4[f32; 4];

    /// A primitive vector of 8 32-bit floating point values.
    vector f32x8[f32; 8];

    /// A primitive vector of 16 32-bit floating point values.
    vector f32x16[f32; 16];

    /// A primitive vector of 32 32-bit floating point values.
    vector f32x32[f32; 32];
}

// Primitive vectors of 64-bit floating point values.
decl_simd_prim! {
    /// A primitive vector of 1 64-bit floating point value.
    vector f64x1[f64; 1];

    /// A primitive vector of 2 64-bit floating point values.
    vector f64x2[f64; 2];

    /// A primitive vector of 4 64-bit floating point values.
    vector f64x4[f64; 4];

    /// A primitive vector of 8 64-bit floating point values.
    vector f64x8[f64; 8];

    /// A primitive vector of 16 64-bit floating point values.
    vector f64x16[f64; 16];
}