hpt-macros 0.1.2

An internal library for generating helper functions for hpt
Documentation
use proc_macro::TokenStream;
use quote::quote;
use syn::parse_macro_input;
use syn::Ident;

use crate::NUM_REG;

pub fn __gen_fast_reduce_simd_helper(stream: TokenStream) -> TokenStream {
    let input = parse_macro_input!(stream as Ident);

    #[cfg(target_feature = "avx2")]
    let num_registers = 16;
    #[cfg(all(
        any(target_feature = "sse", target_arch = "arm"),
        not(target_feature = "avx2")
    ))]
    let num_registers = 8;
    #[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
    let num_registers = 32;

    let mut body = proc_macro2::TokenStream::new();
    for i in 0..num_registers as isize {
        let i = i + 1;
        let elements = 1..=i;
        let arr = quote! {
            [ #(#elements),* ]
        };
        let i_u32 = i as u32;
        body.extend(
            quote! {
            #i_u32 => {
                gen_kernel!(1, #i, inp_ptr, res_ptr, vec_size, outer_loop_size, vec_preop, vec_cumulate, inp_strides, inp_shape, prg, vec_post, #arr);
            }
        }
        );
    }
    let ret = quote! {
        match #input {
            #body
            0 => {}
            _ => unreachable!()
        }
    };
    ret.into()
}

pub fn __gen_fast_layernorm_simd_helper(stream: TokenStream) -> TokenStream {
    let input = parse_macro_input!(stream as Ident);

    #[cfg(target_feature = "avx2")]
    let num_registers = 16;
    #[cfg(all(
        any(target_feature = "sse", target_arch = "arm"),
        not(target_feature = "avx2")
    ))]
    let num_registers = 8;
    #[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
    let num_registers = 32;

    let mut body = proc_macro2::TokenStream::new();
    for i in 0..num_registers as isize {
        let i = i + 1;
        let elements = 1..=i;
        let arr = quote! {
            [ #(#elements),* ]
        };
        let i_u32 = i as u32;
        body.extend(
            quote! {
            #i_u32 => {
                gen_kernel!(1, #i, inp_ptr, res_ptr, vec_size, outer_loop_size, inp_strides, inp_shape, prg, #arr);
            }
        }
        );
    }
    let ret = quote! {
        match #input {
            #body
            0 => {}
            _ => unreachable!()
        }
    };
    ret.into()
}

pub fn __gen_reduce_dim_not_include_simd_helper(stream: TokenStream) -> TokenStream {
    let input = parse_macro_input!(stream as Ident);

    let mut body = proc_macro2::TokenStream::new();
    for i in 0..NUM_REG as isize {
        let i = i + 1;
        let elements = 1..=i;
        let arr = quote! {
            [ #(#elements),* ]
        };
        let i_u32 = i as u32;
        body.extend(quote! {
            #i_u32 => {
                gen_kernel3!(
                    1,
                    #i,
                    outer_loop_size,
                    inp_ptr,
                    res_ptr,
                    <O as TypeCommon>::Vec::SIZE as isize,
                    intermediate_size,
                    vec_preop,
                    vec_cumulate,
                    inp_strides,
                    inp_shape,
                    prg1,
                    prg2,
                    shape_len,
                    inner_loop_size,
                    vec_post,
                    #arr
                );
            }
        });
    }
    let ret = quote! {
        match #input {
            #body,
            0 => {}
            _ => unreachable!()
        }
    };
    ret.into()
}