hpt-macros 0.1.2

An internal library for generating helper functions for hpt
Documentation
use quote::quote;
use quote::ToTokens;
use std::fmt::Display;

pub fn type_simd_lanes(list: &str) -> u8 {
    #[cfg(all(
        any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
        not(target_feature = "avx2")
    ))]
    match list.to_lowercase().as_str() {
        "bool" => 16,
        "i8" => 16,
        "u8" => 16,
        "i16" => 8,
        "u16" => 8,
        "i32" => 4,
        "u32" => 4,
        "i64" => 2,
        "u64" => 2,
        "bf16" => 8,
        "f16" => 8,
        "f32" => 4,
        "f64" => 2,
        #[cfg(target_pointer_width = "64")]
        "isize" => 2,
        #[cfg(target_pointer_width = "64")]
        "usize" => 2,
        #[cfg(target_pointer_width = "32")]
        "isize" => 4,
        #[cfg(target_pointer_width = "32")]
        "usize" => 4,
        #[cfg(target_pointer_width = "64")]
        "complex32" => 2,
        #[cfg(target_pointer_width = "64")]
        "complex64" => 1,
        #[cfg(target_pointer_width = "32")]
        "complex32" => 4,
        #[cfg(target_pointer_width = "32")]
        "complex64" => 2,
        _ => 0,
    }
    #[cfg(target_feature = "avx2")]
    match list.to_lowercase().as_str() {
        "bool" => 32,
        "i8" => 32,
        "u8" => 32,
        "i16" => 16,
        "u16" => 16,
        "i32" => 8,
        "u32" => 8,
        "i64" => 4,
        "u64" => 4,
        "bf16" => 16,
        "f16" => 16,
        "f32" => 8,
        "f64" => 4,
        #[cfg(target_pointer_width = "64")]
        "isize" => 4,
        #[cfg(target_pointer_width = "64")]
        "usize" => 4,
        #[cfg(target_pointer_width = "32")]
        "isize" => 8,
        #[cfg(target_pointer_width = "32")]
        "usize" => 8,
        #[cfg(target_pointer_width = "64")]
        "complex32" => 4,
        #[cfg(target_pointer_width = "64")]
        "complex64" => 2,
        #[cfg(target_pointer_width = "32")]
        "complex32" => 8,
        #[cfg(target_pointer_width = "32")]
        "complex64" => 4,
        _ => 0,
    }
    #[cfg(target_feature = "avx512f")]
    match list.to_lowercase().as_str() {
        "bool" => 64,
        "i8" => 64,
        "u8" => 64,
        "i16" => 32,
        "u16" => 32,
        "i32" => 16,
        "u32" => 16,
        "i64" => 8,
        "u64" => 8,
        "bf16" => 32,
        "f16" => 32,
        "f32" => 16,
        "f64" => 8,
        #[cfg(target_pointer_width = "64")]
        "isize" => 8,
        #[cfg(target_pointer_width = "64")]
        "usize" => 8,
        #[cfg(target_pointer_width = "32")]
        "isize" => 16,
        #[cfg(target_pointer_width = "32")]
        "usize" => 16,
        #[cfg(target_pointer_width = "64")]
        "complex32" => 8,
        #[cfg(target_pointer_width = "64")]
        "complex64" => 4,
        #[cfg(target_pointer_width = "32")]
        "complex32" => 16,
        #[cfg(target_pointer_width = "32")]
        "complex64" => 8,
        _ => 0,
    }
}

#[derive(Copy, Clone, PartialEq, Eq)]
pub(crate) enum Type {
    Bool,
    I8,
    U8,
    I16,
    U16,
    I32,
    U32,
    I64,
    U64,
    BF16,
    F16,
    F32,
    F64,
    C32,
    C64,
    Isize,
    Usize,
    Complex32,
    Complex64,
}

impl Type {
    pub fn is_float(&self) -> bool {
        matches!(
            self,
            Type::BF16 | Type::F16 | Type::F32 | Type::F64 | Type::C32 | Type::C64
        )
    }
    pub fn is_bool(&self) -> bool {
        matches!(self, Type::Bool)
    }
    pub fn is_f16(&self) -> bool {
        matches!(self, Type::F16)
    }
    pub fn is_bf16(&self) -> bool {
        matches!(self, Type::BF16)
    }
    pub fn is_cplx(&self) -> bool {
        matches!(
            self,
            Type::C32 | Type::C64 | Type::Complex32 | Type::Complex64
        )
    }
    pub fn is_cplx32(&self) -> bool {
        matches!(self, Type::C32 | Type::Complex32)
    }
    pub fn is_cplx64(&self) -> bool {
        matches!(self, Type::C64 | Type::Complex64)
    }
}

impl ToTokens for Type {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        let token = match self {
            Type::Bool => quote!(bool),
            Type::I8 => quote!(i8),
            Type::U8 => quote!(u8),
            Type::I16 => quote!(i16),
            Type::U16 => quote!(u16),
            Type::I32 => quote!(i32),
            Type::U32 => quote!(u32),
            Type::I64 => quote!(i64),
            Type::U64 => quote!(u64),
            Type::BF16 => quote!(bf16),
            Type::F16 => quote!(f16),
            Type::F32 => quote!(f32),
            Type::F64 => quote!(f64),
            Type::C32 => quote!(c32),
            Type::C64 => quote!(c64),
            Type::Isize => quote!(isize),
            Type::Usize => quote!(usize),
            Type::Complex32 => quote!(Complex32),
            Type::Complex64 => quote!(Complex64),
        };
        tokens.extend(token);
    }
}

impl Display for Type {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let str = match self {
            Type::Bool => "bool".to_string(),
            Type::I8 => "i8".to_string(),
            Type::U8 => "u8".to_string(),
            Type::I16 => "i16".to_string(),
            Type::U16 => "u16".to_string(),
            Type::I32 => "i32".to_string(),
            Type::U32 => "u32".to_string(),
            Type::I64 => "i64".to_string(),
            Type::U64 => "u64".to_string(),
            Type::BF16 => "bf16".to_string(),
            Type::F16 => "f16".to_string(),
            Type::F32 => "f32".to_string(),
            Type::F64 => "f64".to_string(),
            Type::C32 => "c32".to_string(),
            Type::C64 => "c64".to_string(),
            Type::Isize => "isize".to_string(),
            Type::Usize => "usize".to_string(),
            Type::Complex32 => "complex32".to_string(),
            Type::Complex64 => "complex64".to_string(),
        };
        write!(f, "{}", str)
    }
}

#[derive(Copy, Clone)]
pub(crate) enum SimdType {
    Bool,
    I8,
    U8,
    I16,
    U16,
    I32,
    U32,
    I64,
    U64,
    BF16,
    F16,
    F32,
    F64,
    C32,
    C64,
    Isize,
    Usize,
    Complex32,
    Complex64,
}

impl From<&str> for SimdType {
    fn from(name: &str) -> Self {
        match name.to_lowercase().as_str() {
            "bool" => SimdType::Bool,
            "i8" => SimdType::I8,
            "u8" => SimdType::U8,
            "i16" => SimdType::I16,
            "u16" => SimdType::U16,
            "i32" => SimdType::I32,
            "u32" => SimdType::U32,
            "i64" => SimdType::I64,
            "u64" => SimdType::U64,
            "bf16" => SimdType::BF16,
            "f16" => SimdType::F16,
            "f32" => SimdType::F32,
            "f64" => SimdType::F64,
            "c32" => SimdType::C32,
            "c64" => SimdType::C64,
            "isize" => SimdType::Isize,
            "usize" => SimdType::Usize,
            "complex32" => SimdType::Complex32,
            "complex64" => SimdType::Complex64,
            _ => unreachable!("Invalid type"),
        }
    }
}

impl ToTokens for SimdType {
    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
        #[cfg(target_feature = "avx2")]
        let token = match self {
            SimdType::Bool => quote!(boolx32::boolx32),
            SimdType::I8 => quote!(i8x32::i8x32),
            SimdType::U8 => quote!(u8x32::u8x32),
            SimdType::I16 => quote!(i16x16::i16x16),
            SimdType::U16 => quote!(u16x16::u16x16),
            SimdType::I32 => quote!(i32x8::i32x8),
            SimdType::U32 => quote!(u32x8::u32x8),
            SimdType::I64 => quote!(i64x4::i64x4),
            SimdType::U64 => quote!(u64x4::u64x4),
            SimdType::BF16 => quote!(bf16x16::bf16x16),
            SimdType::F16 => quote!(f16x16::f16x16),
            SimdType::F32 => quote!(f32x8::f32x8),
            SimdType::F64 => quote!(f64x4::f64x4),
            SimdType::C32 => quote!(cplx32x4::cplx32x4),
            SimdType::C64 => quote!(cplx64x2::cplx64x2),
            SimdType::Isize => quote!(isizex4::isizex4),
            SimdType::Usize => quote!(usizex4::usizex4),
            SimdType::Complex32 => quote!(cplx32x4::cplx32x4),
            SimdType::Complex64 => quote!(cplx64x2::cplx64x2),
        };
        #[cfg(all(
            any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
            not(target_feature = "avx2")
        ))]
        let token = match self {
            SimdType::Bool => quote!(boolx16::boolx16),
            SimdType::I8 => quote!(i8x16::i8x16),
            SimdType::U8 => quote!(u8x16::u8x16),
            SimdType::I16 => quote!(i16x8::i16x8),
            SimdType::U16 => quote!(u16x8::u16x8),
            SimdType::I32 => quote!(i32x4::i32x4),
            SimdType::U32 => quote!(u32x4::u32x4),
            SimdType::I64 => quote!(i64x2::i64x2),
            SimdType::U64 => quote!(u64x2::u64x2),
            SimdType::BF16 => quote!(bf16x8::bf16x8),
            SimdType::F16 => quote!(f16x8::f16x8),
            SimdType::F32 => quote!(f32x4::f32x4),
            SimdType::F64 => quote!(f64x2::f64x2),
            SimdType::C32 => quote!(cplx32x2::cplx32x2),
            SimdType::C64 => quote!(cplx64x1::cplx64x1),
            SimdType::Isize => quote!(isizex2::isizex2),
            SimdType::Usize => quote!(usizex2::usizex2),
            SimdType::Complex32 => quote!(cplx32x2::cplx32x2),
            SimdType::Complex64 => quote!(cplx64x1::cplx64x1),
        };
        #[cfg(target_feature = "avx512f")]
        let token = match self {
            SimdType::Bool => quote!(boolx64::boolx64),
            SimdType::I8 => quote!(i8x64::i8x64),
            SimdType::U8 => quote!(u8x64::u8x64),
            SimdType::I16 => quote!(i16x32::i16x32),
            SimdType::U16 => quote!(u16x32::u16x32),
            SimdType::I32 => quote!(i32x16::i32x16),
            SimdType::U32 => quote!(u32x16::u32x16),
            SimdType::I64 => quote!(i64x8::i64x8),
            SimdType::U64 => quote!(u64x8::u64x8),
            SimdType::BF16 => quote!(bf16x32::bf16x32),
            SimdType::F16 => quote!(f16x32::f16x32),
            SimdType::F32 => quote!(f32x16::f32x16),
            SimdType::F64 => quote!(f64x8::f64x8),
            SimdType::C32 => quote!(cplx32x8::cplx32x8),
            SimdType::C64 => quote!(cplx64x4::cplx64x4),
            SimdType::Isize => quote!(isizex8::isizex8),
            SimdType::Usize => quote!(usizex8::usizex8),
            SimdType::Complex32 => quote!(cplx32x8::cplx32x8),
            SimdType::Complex64 => quote!(cplx64x4::cplx64x4),
        };
        tokens.extend(token);
    }
}

#[derive(Copy, Clone)]
pub(crate) struct TypeInfo {
    pub(crate) dtype: Type,
}

impl TypeInfo {
    pub(crate) fn new(name: &str) -> Self {
        match name.to_lowercase().as_str() {
            "bool" => Self { dtype: Type::Bool },
            "i8" => Self { dtype: Type::I8 },
            "u8" => Self { dtype: Type::U8 },
            "i16" => Self { dtype: Type::I16 },
            "u16" => Self { dtype: Type::U16 },
            "i32" => Self { dtype: Type::I32 },
            "u32" => Self { dtype: Type::U32 },
            "i64" => Self { dtype: Type::I64 },
            "u64" => Self { dtype: Type::U64 },
            "bf16" => Self { dtype: Type::BF16 },
            "f16" => Self { dtype: Type::F16 },
            "f32" => Self { dtype: Type::F32 },
            "f64" => Self { dtype: Type::F64 },
            "c32" => Self { dtype: Type::C32 },
            "c64" => Self { dtype: Type::C64 },
            "isize" => Self { dtype: Type::Isize },
            "usize" => Self { dtype: Type::Usize },
            "complex32" => Self {
                dtype: Type::Complex32,
            },
            "complex64" => Self {
                dtype: Type::Complex64,
            },
            _ => unreachable!("Invalid type"),
        }
    }
}