hpt-macros 0.1.2

An internal library for generating helper functions for hpt
Documentation
use crate::type_utils::type_simd_lanes;
use crate::TokenStream2;
use crate::TypeInfo;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;

pub fn impl_simd_cmp() -> TokenStream {
    let mut ret = proc_macro2::TokenStream::new();

    let types = [
        (format!("boolx{}", type_simd_lanes("bool")), "bool"),
        (format!("bf16x{}", type_simd_lanes("bf16")), "bf16"),
        (format!("f16x{}", type_simd_lanes("f16")), "f16"),
        (format!("f32x{}", type_simd_lanes("f32")), "f32"),
        (format!("f64x{}", type_simd_lanes("f64")), "f64"),
        (format!("i8x{}", type_simd_lanes("i8")), "i8"),
        (format!("i16x{}", type_simd_lanes("i16")), "i16"),
        (format!("i32x{}", type_simd_lanes("i32")), "i32"),
        (format!("i64x{}", type_simd_lanes("i64")), "i64"),
        (format!("u8x{}", type_simd_lanes("u8")), "u8"),
        (format!("u16x{}", type_simd_lanes("u16")), "u16"),
        (format!("u32x{}", type_simd_lanes("u32")), "u32"),
        (format!("u64x{}", type_simd_lanes("u64")), "u64"),
        (format!("isizex{}", type_simd_lanes("isize")), "isize"),
        (format!("usizex{}", type_simd_lanes("usize")), "usize"),
    ];

    for (lhs_simd_ty, lhs) in types.iter() {
        for (rhs_simd_ty, rhs) in types.iter() {
            let lhs_simd_ty = Ident::new(&lhs_simd_ty, proc_macro2::Span::call_site());
            let rhs_simd_ty = Ident::new(&rhs_simd_ty, proc_macro2::Span::call_site());
            let lhs_type = TypeInfo::new(lhs);
            let rhs_type = TypeInfo::new(rhs);
            let lhs_lanes = type_simd_lanes(lhs);
            let rhs_lanes = type_simd_lanes(rhs);
            if lhs_lanes != rhs_lanes {
                ret.extend(impl_unreachable(lhs_simd_ty, rhs_simd_ty));
                continue;
            }
            let res = if lhs_type.dtype == rhs_type.dtype {
                quote! {
                    impl SimdCmp<#rhs_simd_ty::#rhs_simd_ty> for #lhs_simd_ty::#lhs_simd_ty {
                        type Output = <#lhs_simd_ty::#lhs_simd_ty as SimdCmpPromote<#rhs_simd_ty::#rhs_simd_ty>>::Output;
                        fn _eq(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_eq(rhs)
                        }
                        fn _ne(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_ne(rhs)
                        }
                        fn _lt(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_lt(rhs)
                        }
                        fn _le(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_le(rhs)
                        }
                        fn _gt(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_gt(rhs)
                        }
                        fn _ge(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            self.simd_ge(rhs)
                        }
                    }
                }
            } else {
                quote! {
                    impl SimdCmp<#rhs_simd_ty::#rhs_simd_ty> for #lhs_simd_ty::#lhs_simd_ty {
                        type Output = <#lhs_simd_ty::#lhs_simd_ty as SimdCmpPromote<#rhs_simd_ty::#rhs_simd_ty>>::Output;
                        fn _eq(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_eq(rhs)
                        }
                        fn _ne(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_ne(rhs)
                        }
                        fn _lt(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_lt(rhs)
                        }
                        fn _le(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_le(rhs)
                        }
                        fn _gt(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_gt(rhs)
                        }
                        fn _ge(self, rhs: #rhs_simd_ty::#rhs_simd_ty) -> Self::Output {
                            let lhs: Self::Output = self.into_vec();
                            let rhs: Self::Output = rhs.into_vec();
                            lhs.simd_ge(rhs)
                        }
                    }
                }
            };
            ret.extend(res);
        }
    }

    ret.into()
}

fn impl_unreachable(lhs_simd: Ident, rhs_simd: Ident) -> TokenStream2 {
    quote! {
        impl SimdCmp<#rhs_simd::#rhs_simd> for #lhs_simd::#lhs_simd {
            type Output = <#lhs_simd::#lhs_simd as SimdCmpPromote<#rhs_simd::#rhs_simd>>::Output;
            fn _eq(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }
            fn _ne(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }
            fn _lt(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }

            fn _le(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }
            fn _gt(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }
            fn _ge(self, rhs: #rhs_simd::#rhs_simd) -> Self::Output {
                unreachable!()
            }
        }
    }
}