hpt_macros/
lib.rs

1//! # Tensor Macros
2//!
3//! This crate provides a set of macros to generate code for tensor operations.
4//! These macros are used to simplify and automate common tasks such as defining
5//! tensor operations, reducing dimensionality, and optimizing numerical computations.
6//!
7//! ## Examples
8//!
9//! Here's an example of using a macro from this crate:
10//!
11//! ```rust
12//! // Example code using a macro from this crate
13//! ```
14
15#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22    __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23    __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30    impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31    impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62/// number of registers available for the target architecture
63#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66    any(target_feature = "sse", target_arch = "arm"),
67    not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74    start: Option<Expr>,
75    end: Option<Expr>,
76    step: Option<Expr>,
77}
78
79struct Selections {
80    selections: Vec<TokenStream>,
81}
82
83impl parse::Parse for SelectionParser {
84    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
85        let mut start: Option<Expr> = None;
86        let mut end: Option<Expr> = None;
87        let mut step: Option<Expr> = None;
88        if input.peek(syn::Lit)
89            || input.peek(syn::Ident)
90            || input.peek(syn::token::Paren)
91            || input.peek(Token![-])
92        {
93            start = Some(input.parse::<Expr>()?);
94        }
95        if input.peek(Token![:]) {
96            input.parse::<Token![:]>()?;
97        } else if input.is_empty() {
98            return Ok(Self { start, end, step });
99        } else {
100            return Err(syn::Error::new(
101                input.span(),
102                "unexpected token, expected `:`, Int or Ident",
103            ));
104        }
105        if input.peek(syn::Lit)
106            || input.peek(syn::Ident)
107            || input.peek(syn::token::Paren)
108            || input.peek(Token![-])
109        {
110            end = Some(input.parse::<Expr>()?);
111        }
112        if input.peek(Token![:]) {
113            input.parse::<Token![:]>()?;
114        }
115        if input.peek(syn::Lit)
116            || input.peek(syn::Ident)
117            || input.peek(syn::token::Paren)
118            || input.peek(Token![-])
119        {
120            step = Some(input.parse::<Expr>()?);
121        }
122        Ok(Self { start, end, step })
123    }
124}
125
126impl parse::Parse for Selections {
127    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
128        let mut selections: Vec<TokenStream> = vec![];
129        let mut tokenstream = TokenStream2::new();
130        while !input.is_empty() {
131            let lookahead = input.lookahead1();
132            if lookahead.peek(Token![,]) {
133                selections.push(tokenstream.into());
134                tokenstream = TokenStream2::new();
135                input.parse::<Token![,]>()?;
136            } else {
137                let t = input.parse::<TokenTree>()?;
138                tokenstream.extend(quote!(#t));
139            }
140        }
141        selections.push(tokenstream.into());
142        Ok(Self { selections })
143    }
144}
145
146/// parse the input and generate the corresponding slice
147#[proc_macro]
148pub fn match_selection(input: TokenStream) -> TokenStream {
149    let res: Selections = parse_macro_input!(input as Selections);
150    let mut slices: Vec<SelectionParser> = vec![];
151    for x in res.selections {
152        slices.push(parse_macro_input!(x as SelectionParser));
153    }
154    let mut ret_stream = TokenStream2::new();
155    let len = slices.len();
156    for (idx, x) in slices.into_iter().enumerate() {
157        match (x.start, x.end, x.step) {
158            (None, None, None) => {
159                ret_stream.extend(quote!(Slice::Full));
160            }
161            (None, None, Some(step)) => {
162                ret_stream.extend(quote!(Slice::StepByFullRange(#step)));
163            }
164            (None, Some(end), None) => {
165                ret_stream.extend(quote!(Slice::RangeTo(#end)));
166            }
167            (None, Some(end), Some(step)) => {
168                ret_stream.extend(quote!(Slice::StepByRangeTo((#end, #step))));
169            }
170            (Some(start), None, None) => {
171                ret_stream.extend(quote!(Slice::From(#start)));
172            }
173            (Some(start), None, Some(step)) => {
174                ret_stream.extend(quote!(Slice::StepByRangeFrom((#start, #step))));
175            }
176            (Some(start), Some(end), None) => {
177                ret_stream.extend(quote!(Slice::Range((#start, #end))));
178            }
179            (Some(start), Some(end), Some(step)) => {
180                ret_stream.extend(quote!(Slice::StepByRangeFromTo((#start, #end, #step))));
181            }
182        }
183        if idx != len - 1 {
184            ret_stream.extend(quote!(,));
185        }
186    }
187    quote!([#ret_stream]).into()
188}
189
190/// implement float out binary trait
191#[proc_macro]
192pub fn float_out_binary(_: TokenStream) -> TokenStream {
193    impl_float_out_binary()
194}
195
196#[cfg(feature = "cuda")]
197/// implement float out binary trait for cuda
198#[proc_macro]
199pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
200    impl_cuda_float_out_binary()
201}
202
203/// implement simd float out binary trait
204#[proc_macro]
205pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
206    impl_simd_binary_out_float()
207}
208
209/// implement simd float out binary trait with rhs scalar
210#[proc_macro]
211pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
212    impl_simd_binary_out_float_rhs_scalar()
213}
214
215/// implement simd float out binary trait with lhs scalar
216#[proc_macro]
217pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
218    impl_simd_binary_out_float_lhs_scalar()
219}
220
221/// implement float out unary trait
222#[proc_macro]
223pub fn float_out_unary(_: TokenStream) -> TokenStream {
224    impl_float_out_unary()
225}
226
227#[cfg(feature = "cuda")]
228/// implement float out unary trait for cuda
229#[proc_macro]
230pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
231    crate::float_unary::impl_cuda_float_out_unary()
232}
233
234/// implement simd float out unary trait
235#[proc_macro]
236pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
237    simd_float_out_unary::impl_float_out_unary()
238}
239
240/// implement simd eval trait
241#[proc_macro]
242pub fn simd_eval(_: TokenStream) -> TokenStream {
243    simd_eval::impl_simd_eval()
244}
245
246/// implement simd bitwise trait
247#[proc_macro]
248pub fn simd_bitwise(_: TokenStream) -> TokenStream {
249    impl_simd_bitwise_out()
250}
251
252/// generate notmal out trait
253#[proc_macro]
254pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
255    __impl_normal_out_binary()
256}
257
258#[cfg(feature = "cuda")]
259/// generate notmal out trait
260#[proc_macro]
261pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
262    crate::normal_out::__impl_cuda_normal_out_binary()
263}
264
265/// gemerate normal out unary trait
266#[proc_macro]
267pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
268    normal_out_unary::__impl_normal_out_unary()
269}
270
271#[cfg(feature = "cuda")]
272/// gemerate normal out unary trait
273#[proc_macro]
274pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
275    normal_out_unary::__impl_normal_out_unary_cuda()
276}
277
278/// gemerate normal out unary trait
279#[proc_macro]
280pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
281    simd_normal_unary::impl_simd_normal_out_unary()
282}
283
284/// implement simd normal out trait
285#[proc_macro]
286pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
287    impl_simd_normal_out()
288}
289
290/// implement simd normal out trait with rhs scalar
291#[proc_macro]
292pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
293    impl_simd_normal_out_with_rhs_scalar()
294}
295
296/// implement simd normal out trait with lhs scalar
297#[proc_macro]
298pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
299    impl_simd_normal_out_with_lhs_scalar()
300}
301
302/// implement scalar convert trait
303#[proc_macro]
304pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
305    __impl_scalar_convert()
306}
307
308/// implement from scalar trait
309#[proc_macro]
310pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
311    __impl_from_scalar()
312}
313
314/// implement simd cmp trait
315#[proc_macro]
316pub fn simd_cmp(_: TokenStream) -> TokenStream {
317    impl_simd_cmp()
318}
319
320/// implment into vec trait
321#[proc_macro]
322pub fn impl_into_vec(_: TokenStream) -> TokenStream {
323    into_vec::into_vec()
324}
325
326#[cfg(feature = "cuda")]
327/// implment into cuda scalar trait
328#[proc_macro]
329pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
330    into_cuda_scalar::__impl_into_cuda_scalar().into()
331}
332
333/// implment into scalar trait
334#[proc_macro]
335pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
336    into_scalar::__impl_into_scalar().into()
337}
338
339/// implement bitwise out trait
340#[proc_macro]
341pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
342    let mut ret = proc_macro2::TokenStream::new();
343
344    let types = [
345        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
346    ];
347
348    for lhs in types.iter() {
349        for rhs in types.iter() {
350            let lhs_type = TypeInfo::new(lhs);
351            let rhs_type = TypeInfo::new(rhs);
352            let lhs_dtype = lhs_type.dtype;
353            let rhs_dtype = rhs_type.dtype;
354            let res = if lhs_dtype == rhs_dtype {
355                quote! {
356                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
357                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
358                        #[inline(always)]
359                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
360                            self.__bitand(rhs)
361                        }
362                        #[inline(always)]
363                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
364                            self.__bitor(rhs)
365                        }
366                        #[inline(always)]
367                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
368                            self.__bitxor(rhs)
369                        }
370                        #[inline(always)]
371                        fn _not(self) -> Self::Output {
372                            self.__not()
373                        }
374                        #[inline(always)]
375                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
376                            self.__shl(rhs)
377                        }
378                        #[inline(always)]
379                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
380                            self.__shr(rhs)
381                        }
382                    }
383                }
384            } else {
385                quote! {
386                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
387                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
388                        #[inline(always)]
389                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
390                            let lhs: Self::Output = self.cast();
391                            let rhs: Self::Output = rhs.cast();
392                            lhs.__bitand(rhs)
393                        }
394                        #[inline(always)]
395                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
396                            let lhs: Self::Output = self.cast();
397                            let rhs: Self::Output = rhs.cast();
398                            lhs.__bitor(rhs)
399                        }
400                        #[inline(always)]
401                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
402                            let lhs: Self::Output = self.cast();
403                            let rhs: Self::Output = rhs.cast();
404                            lhs.__bitxor(rhs)
405                        }
406                        #[inline(always)]
407                        fn _not(self) -> Self::Output {
408                            let lhs: Self::Output = self.cast();
409                            lhs.__not()
410                        }
411                        #[inline(always)]
412                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
413                            let lhs: Self::Output = self.cast();
414                            let rhs: Self::Output = rhs.cast();
415                            lhs.__shl(rhs)
416                        }
417                        #[inline(always)]
418                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419                            let lhs: Self::Output = self.cast();
420                            let rhs: Self::Output = rhs.cast();
421                            lhs.__shr(rhs)
422                        }
423                    }
424                }
425            };
426            ret.extend(res);
427        }
428    }
429
430    ret.into()
431}
432
433/// implement bitwise out trait
434#[proc_macro]
435pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
436    let mut ret = proc_macro2::TokenStream::new();
437
438    let types = [
439        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
440    ];
441
442    for lhs in types.iter() {
443        for rhs in types.iter() {
444            let lhs_type = TypeInfo::new(lhs);
445            let rhs_type = TypeInfo::new(rhs);
446            let lhs_dtype = lhs_type.dtype;
447            let rhs_dtype = rhs_type.dtype;
448            let res = if lhs_dtype == rhs_dtype {
449                quote! {
450                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
451                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
452                        #[inline(always)]
453                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
454                            self.__bitand(rhs)
455                        }
456                        #[inline(always)]
457                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
458                            self.__bitor(rhs)
459                        }
460                        #[inline(always)]
461                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
462                            self.__bitxor(rhs)
463                        }
464                        #[inline(always)]
465                        fn _not(self) -> Self::Output {
466                            self.__not()
467                        }
468                        #[inline(always)]
469                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
470                            self.__shl(rhs)
471                        }
472                        #[inline(always)]
473                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
474                            self.__shr(rhs)
475                        }
476                    }
477                }
478            } else {
479                quote! {
480                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
481                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
482                        #[inline(always)]
483                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
484                            let lhs: Self::Output = self.cast();
485                            let rhs: Self::Output = rhs.cast();
486                            lhs.__bitand(rhs)
487                        }
488                        #[inline(always)]
489                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
490                            let lhs: Self::Output = self.cast();
491                            let rhs: Self::Output = rhs.cast();
492                            lhs.__bitor(rhs)
493                        }
494                        #[inline(always)]
495                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
496                            let lhs: Self::Output = self.cast();
497                            let rhs: Self::Output = rhs.cast();
498                            lhs.__bitxor(rhs)
499                        }
500                        #[inline(always)]
501                        fn _not(self) -> Self::Output {
502                            let lhs: Self::Output = self.cast();
503                            lhs.__not()
504                        }
505                        #[inline(always)]
506                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
507                            let lhs: Self::Output = self.cast();
508                            let rhs: Self::Output = rhs.cast();
509                            lhs.__shl(rhs)
510                        }
511                        #[inline(always)]
512                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513                            let lhs: Self::Output = self.cast();
514                            let rhs: Self::Output = rhs.cast();
515                            lhs.__shr(rhs)
516                        }
517                    }
518                }
519            };
520            ret.extend(res);
521        }
522    }
523
524    ret.into()
525}
526
527/// implement compare trait
528#[proc_macro]
529pub fn impl_cmp(_: TokenStream) -> TokenStream {
530    let mut ret = proc_macro2::TokenStream::new();
531
532    let types = [
533        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
534        "usize", "bf16",
535    ];
536
537    for lhs in types.iter() {
538        for rhs in types.iter() {
539            let lhs_type = TypeInfo::new(lhs);
540            let rhs_type = TypeInfo::new(rhs);
541            let lhs_dtype = lhs_type.dtype;
542            let rhs_dtype = rhs_type.dtype;
543            let res = if lhs_dtype == rhs_dtype {
544                quote! {
545                    impl Cmp<#rhs_dtype> for #lhs_dtype {
546                        type Output = bool;
547                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
548                            self == rhs
549                        }
550                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
551                            self != rhs
552                        }
553                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
554                            self < rhs
555                        }
556
557                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
558                            self <= rhs
559                        }
560                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
561                            self > rhs
562                        }
563                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
564                            self >= rhs
565                        }
566                    }
567                }
568            } else {
569                quote! {
570                    impl Cmp<#rhs_dtype> for #lhs_dtype {
571                        type Output = bool;
572                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
573                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
574                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
575                            lhs == rhs
576                        }
577                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
578                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
579                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
580                            lhs != rhs
581                        }
582                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
583                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
584                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
585                            lhs < rhs
586                        }
587
588                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
589                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
590                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
591                            lhs <= rhs
592                        }
593                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
594                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
595                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
596                            lhs > rhs
597                        }
598                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
599                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
600                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
601                            lhs >= rhs
602                        }
603                    }
604                }
605            };
606            ret.extend(res);
607        }
608    }
609
610    ret.into()
611}
612
613/// implement compare trait
614#[proc_macro]
615pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
616    let mut ret = proc_macro2::TokenStream::new();
617
618    let types = [
619        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
620        "usize", "bf16",
621    ];
622
623    for lhs in types.iter() {
624        for rhs in types.iter() {
625            let lhs_type = TypeInfo::new(lhs);
626            let rhs_type = TypeInfo::new(rhs);
627            let lhs_dtype = lhs_type.dtype;
628            let rhs_dtype = rhs_type.dtype;
629            let res = if lhs_dtype == rhs_dtype {
630                quote! {
631                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
632                        type Output = Scalar<bool>;
633                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
634                            self.__eq(rhs)
635                        }
636                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
637                            self.__ne(rhs)
638                        }
639                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
640                            self.__lt(rhs)
641                        }
642
643                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
644                            self.__le(rhs)
645                        }
646                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
647                            self.__gt(rhs)
648                        }
649                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
650                            self.__ge(rhs)
651                        }
652                    }
653                }
654            } else {
655                quote! {
656                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
657                        type Output = Scalar<bool>;
658                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
659                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
660                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
661                            lhs.__eq(rhs)
662                        }
663                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
664                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
665                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
666                            lhs.__ne(rhs)
667                        }
668                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
669                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
670                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
671                            lhs.__lt(rhs)
672                        }
673
674                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
675                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
676                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
677                            lhs.__le(rhs)
678                        }
679                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
680                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
681                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
682                            lhs.__gt(rhs)
683                        }
684                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
685                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
686                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
687                            lhs.__ge(rhs)
688                        }
689                    }
690                }
691            };
692            ret.extend(res);
693        }
694    }
695
696    ret.into()
697}
698
699/// implement eval trait
700#[proc_macro]
701pub fn impl_eval(_: TokenStream) -> TokenStream {
702    let mut ret = proc_macro2::TokenStream::new();
703
704    let types = [
705        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
706        "usize", "bf16",
707    ];
708
709    for lhs in types.iter() {
710        let lhs_type = TypeInfo::new(lhs);
711        let lhs_dtype = lhs_type.dtype;
712
713        let res = quote! {
714            impl Eval for #lhs_dtype {
715                type Output = bool;
716                #[inline(always)]
717                fn _is_nan(&self) -> bool {
718                    self.__is_nan()
719                }
720                #[inline(always)]
721                fn _is_true(&self) -> bool {
722                    self.__is_true()
723                }
724                #[inline(always)]
725                fn _is_inf(&self) -> bool {
726                    self.__is_inf()
727                }
728            }
729        };
730        ret.extend(res);
731    }
732
733    ret.into()
734}
735
736/// generate fast reduce simd helper
737#[proc_macro]
738pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
739    __gen_fast_reduce_simd_helper(input)
740}
741
742/// generate fast layernorm simd helper
743#[proc_macro]
744pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
745    __gen_fast_layernorm_simd_helper(input)
746}
747
748/// generate reduce dim not include simd helper
749#[proc_macro]
750pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
751    __gen_reduce_dim_not_include_simd_helper(input)
752}
753
754/// declare const values
755///
756/// const OW_BLOCK: usize = ?;
757///
758/// const OC_BLOCK: usize = ?;
759#[proc_macro]
760pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
761    conv2d::conv2d_microkernel_declare_const(input)
762}
763
764/// generate conv2d inps
765#[proc_macro]
766pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
767    conv2d::conv2d_microkernel_gen_inps(input)
768}
769
770/// generate conv2d inps
771#[proc_macro]
772pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
773    conv2d::conv2d_microkernel_gen_pad_inps(input)
774}
775
776/// generate pwconv2d inps
777#[proc_macro]
778pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
779    conv2d::pwconv2d_microkernel_gen_pad_inps(input)
780}
781
782/// generate conv2d inps
783#[proc_macro]
784pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
785    conv2d::dwconv2d_microkernel_gen_pad_inps(input)
786}
787
788/// generate conv2d kernels
789#[proc_macro]
790pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
791    conv2d::conv2d_microkernel_gen_kernels(input)
792}
793
794/// generate conv2d repeat results
795#[proc_macro]
796pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
797    conv2d::conv2d_microkernel_gen_results(input)
798}
799
800/// generate conv2d repeat results
801#[proc_macro]
802pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
803    conv2d::dwconv2d_microkernel_gen_results(input)
804}
805
806/// generate maxpool2d kernels
807/// generate conv2d repeat results
808#[proc_macro]
809pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
810    conv2d::maxpool2d_microkernel_gen_results(input)
811}
812
813/// generate save trait
814#[proc_macro_derive(Save, attributes(compress))]
815pub fn impl_save(input: TokenStream) -> TokenStream {
816    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
817    let name = &ast.ident;
818    let meta_name = format_ident!("{}Meta", name);
819
820    let visibility = &ast.vis;
821    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
822    let fields = match &ast.data {
823        syn::Data::Struct(s) => &s.fields,
824        _ => panic!("Save can only be derived for structs"),
825    };
826
827    let mut compressions = vec![];
828    let mut endians = vec![];
829    let mut compress_levels = vec![];
830
831    let meta_fields = fields
832        .iter()
833        .map(|f| {
834            let mut compression_algo = None;
835            let mut endian = None;
836            let mut level = None;
837
838            for attr in &f.attrs {
839                if attr.path().is_ident("compress") {
840                    attr.parse_nested_meta(|meta| {
841                        if meta.path.is_ident("algo") {
842                            let value: syn::LitStr = meta.value()?.parse()?;
843                            let algo = match value.value().as_str().to_lowercase().as_str() {
844                                "gzip" => quote!(Gzip),
845                                "deflate" => quote!(Deflate),
846                                "zlib" => quote!(Zlib),
847                                "none" => quote!(NoCompression),
848                                _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
849                            };
850                            compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
851                        } else if meta.path.is_ident("level") {
852                            let value: syn::LitStr = meta.value()?.parse()?;
853                            let tmp: u32 = value.value().parse().map_err(|e| {
854                                syn::Error::new(value.span(), format!("Invalid level: {}", e))
855                            })?;
856                            level = Some(quote!(#tmp));
857                        } else if meta.path.is_ident("endian") {
858                            let value: syn::LitStr = meta.value()?.parse()?;
859                            let tmp = match value.value().as_str() {
860                                "native" => quote!(Native),
861                                "little" => quote!(Little),
862                                "big" => quote!(Big),
863                                _ => panic!("Unsupported endianness, supported: native, little, big"),
864                            };
865                            endian = Some(quote!(hpt::Endian::#tmp));
866                        }
867                        Ok(())
868                    })
869                    .unwrap();
870                }
871            }
872            compressions.push(compression_algo);
873            endians.push(endian);
874            compress_levels.push(level);
875            let name = &f.ident;
876            let ty = &f.ty;
877            quote! {
878                pub #name: <#ty as Save>::Meta
879            }
880        })
881        .collect::<Vec<_>>();
882
883    let call_save = fields.iter().enumerate().map(|(idx, f)| {
884        let name = &f.ident;
885        let ty = &f.ty;
886        let ident = format_ident!("field_{}", idx);
887        let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
888        let endian = endians[idx].clone().unwrap_or(quote!(endian));
889        let level = compress_levels[idx].clone().unwrap_or(quote!(level));
890        if let Some(name) = name {
891            quote! {
892                let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
893            }
894        } else {
895            quote! {
896                let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
897            }
898        }
899    });
900
901    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
902        let name = &f.ident;
903        let ident = format_ident!("field_{}", idx);
904        quote! {
905            #name: #ident
906        }
907    });
908
909    let expanded = quote! {
910        #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
911        #visibility struct #meta_name #ty_generics #where_clause  {
912            #(#meta_fields,)*
913        }
914        impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
915            type Meta = #meta_name #ty_generics;
916            fn __save(
917                data: &Self,
918                file: &mut std::fs::File,
919                len_so_far: &mut usize,
920                global_cnt: &mut usize,
921                compression_algo: hpt::CompressionAlgo,
922                endian: hpt::Endian,
923                level: u32,
924            ) -> std::io::Result<Self::Meta> {
925                #(#call_save)*
926                Ok(Self::Meta {
927                    #(#construct_fields),*
928                })
929            }
930        }
931    };
932
933    expanded.into()
934}
935
936/// generate load trait
937#[proc_macro_derive(Load)]
938pub fn impl_load(input: TokenStream) -> TokenStream {
939    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
940    let name = &ast.ident;
941    let meta_name = format_ident!("{}Meta", name);
942    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
943
944    let fields = match &ast.data {
945        syn::Data::Struct(s) => &s.fields,
946        _ => panic!("Load can only be derived for structs"),
947    };
948
949    let call_load = fields.iter().enumerate().map(|(idx, f)| {
950        let name = &f.ident;
951        let ident = format_ident!("field_{}", idx);
952        if let Some(name) = name {
953            quote! {
954                let #ident = self.#name.load(file)?;
955            }
956        } else {
957            quote! {
958                let #ident = self.#idx.load(file)?;
959            }
960        }
961    });
962
963    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
964        let name = &f.ident;
965        let ident = format_ident!("field_{}", idx);
966        quote! {
967            #name: #ident
968        }
969    });
970
971    let expanded = quote! {
972        impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
973            type Output = #name #ty_generics;
974            fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
975                use hpt::MetaLoad;
976                #(#call_load)*
977                Ok(#name {
978                    #(#construct_fields),*
979                })
980            }
981        }
982        impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
983            fn load(path: &str) -> std::io::Result<Self> {
984                use hpt::MetaLoad;
985                let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
986                let mut file = std::fs::File::open(path)?;
987                meta.load(&mut file)
988            }
989        }
990    };
991
992    expanded.into()
993}
994
995/// generate from safetensors trait
996#[proc_macro_derive(FromSafeTensors, attributes(map))]
997pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
998    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
999    let struct_name = &ast.ident;
1000    let fields = match &ast.data {
1001        syn::Data::Struct(s) => &s.fields,
1002        _ => panic!("FromSafeTensors can only be derived for structs"),
1003    };
1004    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1005    let mut construct_fields = vec![];
1006    for (_, field) in fields.iter().enumerate() {
1007        let ty = &field.ty;
1008        let name = &field.ident;
1009        let mut value_construct = vec![];
1010        let mut from_construct = vec![];
1011        let mut params = vec![];
1012        let mut vec_len = None;
1013        for attr in &field.attrs {
1014            if attr.path().is_ident("map") {
1015                let mut path = None;
1016                let mut value = None;
1017                let mut tensor_name = None;
1018                let mut inner_type = None;
1019                attr.parse_nested_meta(|meta| {
1020                    if meta.path.is_ident("path") {
1021                        let value: syn::LitStr = meta.value()?.parse()?;
1022                        path = Some(value.value());
1023                    } else if meta.path.is_ident("value") {
1024                        let val: syn::Expr = meta.value()?.parse()?;
1025                        value = Some(val);
1026                    } else if meta.path.is_ident("tensor_name") {
1027                        let value: syn::LitStr = meta.value()?.parse()?;
1028                        tensor_name = Some(value.value());
1029                    } else if meta.path.is_ident("vec_len") {
1030                        let value: syn::LitInt = meta.value()?.parse()?;
1031                        vec_len = Some(value.base10_parse::<usize>().unwrap());
1032                    } else if meta.path.is_ident("inner_type") {
1033                        let value: syn::Ident = meta.value()?.parse()?;
1034                        inner_type = Some(value);
1035                    }
1036                    Ok(())
1037                })
1038                .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1039                params.push((path, value, tensor_name, vec_len, inner_type));
1040            }
1041        }
1042        let param_count = params.len();
1043        for (path, value, tensor_name, vec_len, inner_type) in params {
1044            if let Some(vec_len) = vec_len {
1045                let inner_type = inner_type.expect("inner_type is required for vec");
1046                if let Some(path) = path {
1047                    from_construct.push(quote! {
1048                        #path => {
1049                            let mut vec = vec![];
1050                            for i in 0..#vec_len {
1051                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1052                            }
1053                            vec
1054                        }
1055                    });
1056                } else {
1057                    value_construct.push(quote! {
1058                        {
1059                            let mut vec = vec![];
1060                            for i in 0..#vec_len {
1061                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1062                            }
1063                            vec
1064                        }
1065                    });
1066                }
1067            } else {
1068                match (path, value, tensor_name) {
1069                    (None, None, Some(tensor_name)) => {
1070                        value_construct.push(quote! {
1071                            <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1072                        });
1073                    }
1074                    (None, Some(value), None) => {
1075                        if param_count > 1 {
1076                            panic!("value without path means generic assignment, there can only be one value without path");
1077                        }
1078                        value_construct.push(quote! {
1079                            #value
1080                        });
1081                    }
1082                    (Some(path), None, Some(tensor_name)) => {
1083                        from_construct.push(quote! {
1084                            #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1085                        });
1086                    }
1087                    (Some(path), Some(value), None) => {
1088                        from_construct.push(quote! {
1089                            #path => #value,
1090                        });
1091                    }
1092
1093                    (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1094                        panic!("value and tensor_name cannot be used together");
1095                    }
1096                    (Some(_), None, None) | (None, None, None) => {
1097                        panic!("path and value are not present");
1098                    }
1099                }
1100            }
1101        }
1102        if !value_construct.is_empty() {
1103            construct_fields.push(quote! {
1104                #name: #(#value_construct)*
1105            });
1106        } else if !from_construct.is_empty() {
1107            construct_fields.push(quote! {
1108                #name: match path {
1109                    #(#from_construct)*
1110                    _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1111                }
1112            });
1113        } else {
1114            construct_fields.push(quote! {
1115                #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1116            });
1117        }
1118    }
1119    let expanded = quote! {
1120        impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1121            fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1122                Self {
1123                    #(#construct_fields),*
1124                }
1125            }
1126        }
1127    };
1128    // let syntax_tree = syn::parse2(expanded.clone()).expect(&format!(
1129    //     "failed to parse expanded: {}",
1130    //     expanded.to_string()
1131    // ));
1132    // let formatted = prettyplease::unparse(&syntax_tree);
1133    // println!("{}", formatted);
1134    expanded.into()
1135}