hpt_macros/
lib.rs

1//! # Tensor Macros
2//!
3//! This crate provides a set of macros to generate code for tensor operations.
4//! These macros are used to simplify and automate common tasks such as defining
5//! tensor operations, reducing dimensionality, and optimizing numerical computations.
6//!
7//! ## Examples
8//!
9//! Here's an example of using a macro from this crate:
10//!
11//! ```rust
12//! // Example code using a macro from this crate
13//! ```
14
15#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22    __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23    __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_convert::__impl_simd_convert;
30use simd_float_out_binary::{
31    impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
32    impl_simd_binary_out_float_rhs_scalar,
33};
34use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
35use syn::{parse, parse_macro_input, Expr, Token};
36mod binary_float_out;
37mod conv2d;
38mod float_unary;
39mod from_scalar;
40mod into_cuda_scalar;
41mod into_scalar;
42mod into_vec;
43mod kernel_gen_helper;
44mod normal_out;
45mod normal_out_unary;
46mod scalar_convert;
47mod simd_bitwise;
48
49mod simd_cmp;
50mod simd_convert;
51mod simd_eval;
52mod simd_float_out_binary;
53mod simd_float_out_unary;
54mod simd_normal_out;
55mod simd_normal_unary;
56mod type_utils;
57
58use crate::simd_cmp::impl_simd_cmp;
59use crate::simd_normal_out::impl_simd_normal_out;
60use proc_macro2::{TokenStream as TokenStream2, TokenTree};
61use quote::{format_ident, quote};
62use type_utils::TypeInfo;
63
64/// number of registers available for the target architecture
65#[cfg(target_feature = "avx2")]
66const NUM_REG: usize = 16;
67#[cfg(all(
68    any(target_feature = "sse", target_arch = "arm"),
69    not(target_feature = "avx2")
70))]
71const NUM_REG: usize = 8;
72#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
73const NUM_REG: usize = 32;
74
75struct SelectionParser {
76    start: Option<Expr>,
77    end: Option<Expr>,
78    step: Option<Expr>,
79}
80
81struct Selections {
82    selections: Vec<TokenStream>,
83}
84
85impl parse::Parse for SelectionParser {
86    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
87        let mut start: Option<Expr> = None;
88        let mut end: Option<Expr> = None;
89        let mut step: Option<Expr> = None;
90        if input.peek(syn::Lit)
91            || input.peek(syn::Ident)
92            || input.peek(syn::token::Paren)
93            || input.peek(Token![-])
94        {
95            start = Some(input.parse::<Expr>()?);
96        }
97        if input.peek(Token![:]) {
98            input.parse::<Token![:]>()?;
99        } else if input.is_empty() {
100            return Ok(Self { start, end, step });
101        } else {
102            return Err(syn::Error::new(
103                input.span(),
104                "unexpected token, expected `:`, Int or Ident",
105            ));
106        }
107        if input.peek(syn::Lit)
108            || input.peek(syn::Ident)
109            || input.peek(syn::token::Paren)
110            || input.peek(Token![-])
111        {
112            end = Some(input.parse::<Expr>()?);
113        }
114        if input.peek(Token![:]) {
115            input.parse::<Token![:]>()?;
116        }
117        if input.peek(syn::Lit)
118            || input.peek(syn::Ident)
119            || input.peek(syn::token::Paren)
120            || input.peek(Token![-])
121        {
122            step = Some(input.parse::<Expr>()?);
123        }
124        Ok(Self { start, end, step })
125    }
126}
127
128impl parse::Parse for Selections {
129    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
130        let mut selections: Vec<TokenStream> = vec![];
131        let mut tokenstream = TokenStream2::new();
132        while !input.is_empty() {
133            let lookahead = input.lookahead1();
134            if lookahead.peek(Token![,]) {
135                selections.push(tokenstream.into());
136                tokenstream = TokenStream2::new();
137                input.parse::<Token![,]>()?;
138            } else {
139                let t = input.parse::<TokenTree>()?;
140                tokenstream.extend(quote!(#t));
141            }
142        }
143        selections.push(tokenstream.into());
144        Ok(Self { selections })
145    }
146}
147
148/// parse the input and generate the corresponding slice
149#[proc_macro]
150pub fn match_selection(input: TokenStream) -> TokenStream {
151    let res: Selections = parse_macro_input!(input as Selections);
152    let mut slices: Vec<SelectionParser> = vec![];
153    for x in res.selections {
154        slices.push(parse_macro_input!(x as SelectionParser));
155    }
156    let mut ret_stream = TokenStream2::new();
157    let len = slices.len();
158    for (idx, x) in slices.into_iter().enumerate() {
159        match (x.start, x.end, x.step) {
160            (None, None, None) => {
161                ret_stream.extend(quote!(Slice::Full));
162            }
163            (None, None, Some(step)) => {
164                ret_stream.extend(quote!(Slice::StepByFullRange(#step)));
165            }
166            (None, Some(end), None) => {
167                ret_stream.extend(quote!(Slice::RangeTo(#end)));
168            }
169            (None, Some(end), Some(step)) => {
170                ret_stream.extend(quote!(Slice::StepByRangeTo((#end, #step))));
171            }
172            (Some(start), None, None) => {
173                ret_stream.extend(quote!(Slice::From(#start)));
174            }
175            (Some(start), None, Some(step)) => {
176                ret_stream.extend(quote!(Slice::StepByRangeFrom((#start, #step))));
177            }
178            (Some(start), Some(end), None) => {
179                ret_stream.extend(quote!(Slice::Range((#start, #end))));
180            }
181            (Some(start), Some(end), Some(step)) => {
182                ret_stream.extend(quote!(Slice::StepByRangeFromTo((#start, #end, #step))));
183            }
184        }
185        if idx != len - 1 {
186            ret_stream.extend(quote!(,));
187        }
188    }
189    quote!([#ret_stream]).into()
190}
191
192/// implement float out binary trait
193#[proc_macro]
194pub fn float_out_binary(_: TokenStream) -> TokenStream {
195    impl_float_out_binary()
196}
197
198#[cfg(feature = "cuda")]
199/// implement float out binary trait for cuda
200#[proc_macro]
201pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
202    impl_cuda_float_out_binary()
203}
204
205/// implement simd float out binary trait
206#[proc_macro]
207pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
208    impl_simd_binary_out_float()
209}
210
211/// implement simd float out binary trait with rhs scalar
212#[proc_macro]
213pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
214    impl_simd_binary_out_float_rhs_scalar()
215}
216
217/// implement simd float out binary trait with lhs scalar
218#[proc_macro]
219pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
220    impl_simd_binary_out_float_lhs_scalar()
221}
222
223/// implement float out unary trait
224#[proc_macro]
225pub fn float_out_unary(_: TokenStream) -> TokenStream {
226    impl_float_out_unary()
227}
228
229#[cfg(feature = "cuda")]
230/// implement float out unary trait for cuda
231#[proc_macro]
232pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
233    crate::float_unary::impl_cuda_float_out_unary()
234}
235
236/// implement simd float out unary trait
237#[proc_macro]
238pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
239    simd_float_out_unary::impl_float_out_unary()
240}
241
242/// implement simd eval trait
243#[proc_macro]
244pub fn simd_eval(_: TokenStream) -> TokenStream {
245    simd_eval::impl_simd_eval()
246}
247
248/// implement simd bitwise trait
249#[proc_macro]
250pub fn simd_bitwise(_: TokenStream) -> TokenStream {
251    impl_simd_bitwise_out()
252}
253
254/// generate notmal out trait
255#[proc_macro]
256pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
257    __impl_normal_out_binary()
258}
259
260#[cfg(feature = "cuda")]
261/// generate notmal out trait
262#[proc_macro]
263pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
264    crate::normal_out::__impl_cuda_normal_out_binary()
265}
266
267/// gemerate normal out unary trait
268#[proc_macro]
269pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
270    normal_out_unary::__impl_normal_out_unary()
271}
272
273#[cfg(feature = "cuda")]
274/// gemerate normal out unary trait
275#[proc_macro]
276pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
277    normal_out_unary::__impl_normal_out_unary_cuda()
278}
279
280/// gemerate normal out unary trait
281#[proc_macro]
282pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
283    simd_normal_unary::impl_simd_normal_out_unary()
284}
285
286/// implement simd normal out trait
287#[proc_macro]
288pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
289    impl_simd_normal_out()
290}
291
292/// implement simd normal out trait with rhs scalar
293#[proc_macro]
294pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
295    impl_simd_normal_out_with_rhs_scalar()
296}
297
298/// implement simd normal out trait with lhs scalar
299#[proc_macro]
300pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
301    impl_simd_normal_out_with_lhs_scalar()
302}
303
304/// implement simd convert trait
305#[proc_macro]
306pub fn impl_simd_convert(_: TokenStream) -> TokenStream {
307    __impl_simd_convert()
308}
309
310/// implement scalar convert trait
311#[proc_macro]
312pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
313    __impl_scalar_convert()
314}
315
316/// implement from scalar trait
317#[proc_macro]
318pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
319    __impl_from_scalar()
320}
321
322/// implement simd cmp trait
323#[proc_macro]
324pub fn simd_cmp(_: TokenStream) -> TokenStream {
325    impl_simd_cmp()
326}
327
328/// implment into vec trait
329#[proc_macro]
330pub fn impl_into_vec(_: TokenStream) -> TokenStream {
331    into_vec::into_vec()
332}
333
334#[cfg(feature = "cuda")]
335/// implment into cuda scalar trait
336#[proc_macro]
337pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
338    into_cuda_scalar::__impl_into_cuda_scalar().into()
339}
340
341/// implment into scalar trait
342#[proc_macro]
343pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
344    into_scalar::__impl_into_scalar().into()
345}
346
347/// implement bitwise out trait
348#[proc_macro]
349pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
350    let mut ret = proc_macro2::TokenStream::new();
351
352    let types = [
353        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
354    ];
355
356    for lhs in types.iter() {
357        for rhs in types.iter() {
358            let lhs_type = TypeInfo::new(lhs);
359            let rhs_type = TypeInfo::new(rhs);
360            let lhs_dtype = lhs_type.dtype;
361            let rhs_dtype = rhs_type.dtype;
362            let res = if lhs_dtype == rhs_dtype {
363                quote! {
364                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
365                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
366                        #[inline(always)]
367                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
368                            self.__bitand(rhs)
369                        }
370                        #[inline(always)]
371                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
372                            self.__bitor(rhs)
373                        }
374                        #[inline(always)]
375                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
376                            self.__bitxor(rhs)
377                        }
378                        #[inline(always)]
379                        fn _not(self) -> Self::Output {
380                            self.__not()
381                        }
382                        #[inline(always)]
383                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
384                            self.__shl(rhs)
385                        }
386                        #[inline(always)]
387                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
388                            self.__shr(rhs)
389                        }
390                    }
391                }
392            } else {
393                quote! {
394                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
395                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
396                        #[inline(always)]
397                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
398                            let lhs: Self::Output = self.cast();
399                            let rhs: Self::Output = rhs.cast();
400                            lhs.__bitand(rhs)
401                        }
402                        #[inline(always)]
403                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
404                            let lhs: Self::Output = self.cast();
405                            let rhs: Self::Output = rhs.cast();
406                            lhs.__bitor(rhs)
407                        }
408                        #[inline(always)]
409                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
410                            let lhs: Self::Output = self.cast();
411                            let rhs: Self::Output = rhs.cast();
412                            lhs.__bitxor(rhs)
413                        }
414                        #[inline(always)]
415                        fn _not(self) -> Self::Output {
416                            let lhs: Self::Output = self.cast();
417                            lhs.__not()
418                        }
419                        #[inline(always)]
420                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
421                            let lhs: Self::Output = self.cast();
422                            let rhs: Self::Output = rhs.cast();
423                            lhs.__shl(rhs)
424                        }
425                        #[inline(always)]
426                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
427                            let lhs: Self::Output = self.cast();
428                            let rhs: Self::Output = rhs.cast();
429                            lhs.__shr(rhs)
430                        }
431                    }
432                }
433            };
434            ret.extend(res);
435        }
436    }
437
438    ret.into()
439}
440
441/// implement bitwise out trait
442#[proc_macro]
443pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
444    let mut ret = proc_macro2::TokenStream::new();
445
446    let types = [
447        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
448    ];
449
450    for lhs in types.iter() {
451        for rhs in types.iter() {
452            let lhs_type = TypeInfo::new(lhs);
453            let rhs_type = TypeInfo::new(rhs);
454            let lhs_dtype = lhs_type.dtype;
455            let rhs_dtype = rhs_type.dtype;
456            let res = if lhs_dtype == rhs_dtype {
457                quote! {
458                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
459                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
460                        #[inline(always)]
461                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
462                            self.__bitand(rhs)
463                        }
464                        #[inline(always)]
465                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
466                            self.__bitor(rhs)
467                        }
468                        #[inline(always)]
469                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
470                            self.__bitxor(rhs)
471                        }
472                        #[inline(always)]
473                        fn _not(self) -> Self::Output {
474                            self.__not()
475                        }
476                        #[inline(always)]
477                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
478                            self.__shl(rhs)
479                        }
480                        #[inline(always)]
481                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
482                            self.__shr(rhs)
483                        }
484                    }
485                }
486            } else {
487                quote! {
488                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
489                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
490                        #[inline(always)]
491                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
492                            let lhs: Self::Output = self.cast();
493                            let rhs: Self::Output = rhs.cast();
494                            lhs.__bitand(rhs)
495                        }
496                        #[inline(always)]
497                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
498                            let lhs: Self::Output = self.cast();
499                            let rhs: Self::Output = rhs.cast();
500                            lhs.__bitor(rhs)
501                        }
502                        #[inline(always)]
503                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
504                            let lhs: Self::Output = self.cast();
505                            let rhs: Self::Output = rhs.cast();
506                            lhs.__bitxor(rhs)
507                        }
508                        #[inline(always)]
509                        fn _not(self) -> Self::Output {
510                            let lhs: Self::Output = self.cast();
511                            lhs.__not()
512                        }
513                        #[inline(always)]
514                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
515                            let lhs: Self::Output = self.cast();
516                            let rhs: Self::Output = rhs.cast();
517                            lhs.__shl(rhs)
518                        }
519                        #[inline(always)]
520                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
521                            let lhs: Self::Output = self.cast();
522                            let rhs: Self::Output = rhs.cast();
523                            lhs.__shr(rhs)
524                        }
525                    }
526                }
527            };
528            ret.extend(res);
529        }
530    }
531
532    ret.into()
533}
534
535/// implement compare trait
536#[proc_macro]
537pub fn impl_cmp(_: TokenStream) -> TokenStream {
538    let mut ret = proc_macro2::TokenStream::new();
539
540    let types = [
541        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
542        "usize", "bf16",
543    ];
544
545    for lhs in types.iter() {
546        for rhs in types.iter() {
547            let lhs_type = TypeInfo::new(lhs);
548            let rhs_type = TypeInfo::new(rhs);
549            let lhs_dtype = lhs_type.dtype;
550            let rhs_dtype = rhs_type.dtype;
551            let res = if lhs_dtype == rhs_dtype {
552                quote! {
553                    impl Cmp<#rhs_dtype> for #lhs_dtype {
554                        type Output = bool;
555                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
556                            self == rhs
557                        }
558                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
559                            self != rhs
560                        }
561                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
562                            self < rhs
563                        }
564
565                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
566                            self <= rhs
567                        }
568                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
569                            self > rhs
570                        }
571                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
572                            self >= rhs
573                        }
574                    }
575                }
576            } else {
577                quote! {
578                    impl Cmp<#rhs_dtype> for #lhs_dtype {
579                        type Output = bool;
580                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
581                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
582                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
583                            lhs == rhs
584                        }
585                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
586                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
587                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
588                            lhs != rhs
589                        }
590                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
591                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
592                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
593                            lhs < rhs
594                        }
595
596                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
598                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
599                            lhs <= rhs
600                        }
601                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
602                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
603                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
604                            lhs > rhs
605                        }
606                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
607                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
608                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
609                            lhs >= rhs
610                        }
611                    }
612                }
613            };
614            ret.extend(res);
615        }
616    }
617
618    ret.into()
619}
620
621/// implement compare trait
622#[proc_macro]
623pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
624    let mut ret = proc_macro2::TokenStream::new();
625
626    let types = [
627        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
628        "usize", "bf16",
629    ];
630
631    for lhs in types.iter() {
632        for rhs in types.iter() {
633            let lhs_type = TypeInfo::new(lhs);
634            let rhs_type = TypeInfo::new(rhs);
635            let lhs_dtype = lhs_type.dtype;
636            let rhs_dtype = rhs_type.dtype;
637            let res = if lhs_dtype == rhs_dtype {
638                quote! {
639                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
640                        type Output = Scalar<bool>;
641                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
642                            self.__eq(rhs)
643                        }
644                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
645                            self.__ne(rhs)
646                        }
647                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
648                            self.__lt(rhs)
649                        }
650
651                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
652                            self.__le(rhs)
653                        }
654                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
655                            self.__gt(rhs)
656                        }
657                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
658                            self.__ge(rhs)
659                        }
660                    }
661                }
662            } else {
663                quote! {
664                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
665                        type Output = Scalar<bool>;
666                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
667                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
668                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
669                            lhs.__eq(rhs)
670                        }
671                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
672                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
673                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
674                            lhs.__ne(rhs)
675                        }
676                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
677                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
678                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
679                            lhs.__lt(rhs)
680                        }
681
682                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
684                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
685                            lhs.__le(rhs)
686                        }
687                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
688                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
689                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
690                            lhs.__gt(rhs)
691                        }
692                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
693                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
694                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
695                            lhs.__ge(rhs)
696                        }
697                    }
698                }
699            };
700            ret.extend(res);
701        }
702    }
703
704    ret.into()
705}
706
707/// implement eval trait
708#[proc_macro]
709pub fn impl_eval(_: TokenStream) -> TokenStream {
710    let mut ret = proc_macro2::TokenStream::new();
711
712    let types = [
713        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
714        "usize", "bf16",
715    ];
716
717    for lhs in types.iter() {
718        let lhs_type = TypeInfo::new(lhs);
719        let lhs_dtype = lhs_type.dtype;
720
721        let res = quote! {
722            impl Eval for #lhs_dtype {
723                type Output = bool;
724                #[inline(always)]
725                fn _is_nan(&self) -> bool {
726                    self.__is_nan()
727                }
728                #[inline(always)]
729                fn _is_true(&self) -> bool {
730                    self.__is_true()
731                }
732                #[inline(always)]
733                fn _is_inf(&self) -> bool {
734                    self.__is_inf()
735                }
736            }
737        };
738        ret.extend(res);
739    }
740
741    ret.into()
742}
743
744/// generate fast reduce simd helper
745#[proc_macro]
746pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
747    __gen_fast_reduce_simd_helper(input)
748}
749
750/// generate fast layernorm simd helper
751#[proc_macro]
752pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
753    __gen_fast_layernorm_simd_helper(input)
754}
755
756/// generate reduce dim not include simd helper
757#[proc_macro]
758pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
759    __gen_reduce_dim_not_include_simd_helper(input)
760}
761
762/// declare const values
763///
764/// const OW_BLOCK: usize = ?;
765///
766/// const OC_BLOCK: usize = ?;
767#[proc_macro]
768pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
769    conv2d::conv2d_microkernel_declare_const(input)
770}
771
772/// generate conv2d inps
773#[proc_macro]
774pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
775    conv2d::conv2d_microkernel_gen_inps(input)
776}
777
778/// generate conv2d inps
779#[proc_macro]
780pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
781    conv2d::conv2d_microkernel_gen_pad_inps(input)
782}
783
784/// generate pwconv2d inps
785#[proc_macro]
786pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
787    conv2d::pwconv2d_microkernel_gen_pad_inps(input)
788}
789
790/// generate conv2d inps
791#[proc_macro]
792pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
793    conv2d::dwconv2d_microkernel_gen_pad_inps(input)
794}
795
796/// generate conv2d kernels
797#[proc_macro]
798pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
799    conv2d::conv2d_microkernel_gen_kernels(input)
800}
801
802/// generate conv2d repeat results
803#[proc_macro]
804pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
805    conv2d::conv2d_microkernel_gen_results(input)
806}
807
808/// generate conv2d repeat results
809#[proc_macro]
810pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
811    conv2d::dwconv2d_microkernel_gen_results(input)
812}
813
814/// generate maxpool2d kernels
815/// generate conv2d repeat results
816#[proc_macro]
817pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
818    conv2d::maxpool2d_microkernel_gen_results(input)
819}
820
821/// generate save trait
822#[proc_macro_derive(Save, attributes(compress))]
823pub fn impl_save(input: TokenStream) -> TokenStream {
824    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
825    let name = &ast.ident;
826    let meta_name = format_ident!("{}Meta", name);
827
828    let visibility = &ast.vis;
829    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
830    let fields = match &ast.data {
831        syn::Data::Struct(s) => &s.fields,
832        _ => panic!("Save can only be derived for structs"),
833    };
834
835    let mut compressions = vec![];
836    let mut endians = vec![];
837    let mut compress_levels = vec![];
838
839    let meta_fields = fields
840        .iter()
841        .map(|f| {
842            let mut compression_algo = None;
843            let mut endian = None;
844            let mut level = None;
845
846            for attr in &f.attrs {
847                if attr.path().is_ident("compress") {
848                    attr.parse_nested_meta(|meta| {
849                        if meta.path.is_ident("algo") {
850                            let value: syn::LitStr = meta.value()?.parse()?;
851                            let algo = match value.value().as_str().to_lowercase().as_str() {
852                                "gzip" => quote!(Gzip),
853                                "deflate" => quote!(Deflate),
854                                "zlib" => quote!(Zlib),
855                                "none" => quote!(NoCompression),
856                                _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
857                            };
858                            compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
859                        } else if meta.path.is_ident("level") {
860                            let value: syn::LitStr = meta.value()?.parse()?;
861                            let tmp: u32 = value.value().parse().map_err(|e| {
862                                syn::Error::new(value.span(), format!("Invalid level: {}", e))
863                            })?;
864                            level = Some(quote!(#tmp));
865                        } else if meta.path.is_ident("endian") {
866                            let value: syn::LitStr = meta.value()?.parse()?;
867                            let tmp = match value.value().as_str() {
868                                "native" => quote!(Native),
869                                "little" => quote!(Little),
870                                "big" => quote!(Big),
871                                _ => panic!("Unsupported endianness, supported: native, little, big"),
872                            };
873                            endian = Some(quote!(hpt::Endian::#tmp));
874                        }
875                        Ok(())
876                    })
877                    .unwrap();
878                }
879            }
880            compressions.push(compression_algo);
881            endians.push(endian);
882            compress_levels.push(level);
883            let name = &f.ident;
884            let ty = &f.ty;
885            quote! {
886                pub #name: <#ty as Save>::Meta
887            }
888        })
889        .collect::<Vec<_>>();
890
891    let call_save = fields.iter().enumerate().map(|(idx, f)| {
892        let name = &f.ident;
893        let ty = &f.ty;
894        let ident = format_ident!("field_{}", idx);
895        let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
896        let endian = endians[idx].clone().unwrap_or(quote!(endian));
897        let level = compress_levels[idx].clone().unwrap_or(quote!(level));
898        if let Some(name) = name {
899            quote! {
900                let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
901            }
902        } else {
903            quote! {
904                let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
905            }
906        }
907    });
908
909    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
910        let name = &f.ident;
911        let ident = format_ident!("field_{}", idx);
912        quote! {
913            #name: #ident
914        }
915    });
916
917    let expanded = quote! {
918        #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
919        #visibility struct #meta_name #ty_generics #where_clause  {
920            #(#meta_fields,)*
921        }
922        impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
923            type Meta = #meta_name #ty_generics;
924            fn __save(
925                data: &Self,
926                file: &mut std::fs::File,
927                len_so_far: &mut usize,
928                global_cnt: &mut usize,
929                compression_algo: hpt::CompressionAlgo,
930                endian: hpt::Endian,
931                level: u32,
932            ) -> std::io::Result<Self::Meta> {
933                #(#call_save)*
934                Ok(Self::Meta {
935                    #(#construct_fields),*
936                })
937            }
938        }
939    };
940
941    expanded.into()
942}
943
944/// generate load trait
945#[proc_macro_derive(Load)]
946pub fn impl_load(input: TokenStream) -> TokenStream {
947    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
948    let name = &ast.ident;
949    let meta_name = format_ident!("{}Meta", name);
950    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
951
952    let fields = match &ast.data {
953        syn::Data::Struct(s) => &s.fields,
954        _ => panic!("Load can only be derived for structs"),
955    };
956
957    let call_load = fields.iter().enumerate().map(|(idx, f)| {
958        let name = &f.ident;
959        let ident = format_ident!("field_{}", idx);
960        if let Some(name) = name {
961            quote! {
962                let #ident = self.#name.load(file)?;
963            }
964        } else {
965            quote! {
966                let #ident = self.#idx.load(file)?;
967            }
968        }
969    });
970
971    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
972        let name = &f.ident;
973        let ident = format_ident!("field_{}", idx);
974        quote! {
975            #name: #ident
976        }
977    });
978
979    let expanded = quote! {
980        impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
981            type Output = #name #ty_generics;
982            fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
983                use hpt::MetaLoad;
984                #(#call_load)*
985                Ok(#name {
986                    #(#construct_fields),*
987                })
988            }
989        }
990        impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
991            fn load(path: &str) -> std::io::Result<Self> {
992                use hpt::MetaLoad;
993                let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
994                let mut file = std::fs::File::open(path)?;
995                meta.load(&mut file)
996            }
997        }
998    };
999
1000    expanded.into()
1001}
1002
1003/// generate from safetensors trait
1004#[proc_macro_derive(FromSafeTensors, attributes(map))]
1005pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1006    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1007    let struct_name = &ast.ident;
1008    let fields = match &ast.data {
1009        syn::Data::Struct(s) => &s.fields,
1010        _ => panic!("FromSafeTensors can only be derived for structs"),
1011    };
1012    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1013    let mut construct_fields = vec![];
1014    for (_, field) in fields.iter().enumerate() {
1015        let ty = &field.ty;
1016        let name = &field.ident;
1017        let mut value_construct = vec![];
1018        let mut from_construct = vec![];
1019        let mut params = vec![];
1020        let mut vec_len = None;
1021        for attr in &field.attrs {
1022            if attr.path().is_ident("map") {
1023                let mut path = None;
1024                let mut value = None;
1025                let mut tensor_name = None;
1026                let mut inner_type = None;
1027                attr.parse_nested_meta(|meta| {
1028                    if meta.path.is_ident("path") {
1029                        let value: syn::LitStr = meta.value()?.parse()?;
1030                        path = Some(value.value());
1031                    } else if meta.path.is_ident("value") {
1032                        let val: syn::Expr = meta.value()?.parse()?;
1033                        value = Some(val);
1034                    } else if meta.path.is_ident("tensor_name") {
1035                        let value: syn::LitStr = meta.value()?.parse()?;
1036                        tensor_name = Some(value.value());
1037                    } else if meta.path.is_ident("vec_len") {
1038                        let value: syn::LitInt = meta.value()?.parse()?;
1039                        vec_len = Some(value.base10_parse::<usize>().unwrap());
1040                    } else if meta.path.is_ident("inner_type") {
1041                        let value: syn::Ident = meta.value()?.parse()?;
1042                        inner_type = Some(value);
1043                    }
1044                    Ok(())
1045                })
1046                .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1047                params.push((path, value, tensor_name, vec_len, inner_type));
1048            }
1049        }
1050        let param_count = params.len();
1051        for (path, value, tensor_name, vec_len, inner_type) in params {
1052            if let Some(vec_len) = vec_len {
1053                let inner_type = inner_type.expect("inner_type is required for vec");
1054                if let Some(path) = path {
1055                    from_construct.push(quote! {
1056                        #path => {
1057                            let mut vec = vec![];
1058                            for i in 0..#vec_len {
1059                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1060                            }
1061                            vec
1062                        }
1063                    });
1064                } else {
1065                    value_construct.push(quote! {
1066                        {
1067                            let mut vec = vec![];
1068                            for i in 0..#vec_len {
1069                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1070                            }
1071                            vec
1072                        }
1073                    });
1074                }
1075            } else {
1076                match (path, value, tensor_name) {
1077                    (None, None, Some(tensor_name)) => {
1078                        value_construct.push(quote! {
1079                            <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1080                        });
1081                    }
1082                    (None, Some(value), None) => {
1083                        if param_count > 1 {
1084                            panic!("value without path means generic assignment, there can only be one value without path");
1085                        }
1086                        value_construct.push(quote! {
1087                            #value
1088                        });
1089                    }
1090                    (Some(path), None, Some(tensor_name)) => {
1091                        from_construct.push(quote! {
1092                            #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1093                        });
1094                    }
1095                    (Some(path), Some(value), None) => {
1096                        from_construct.push(quote! {
1097                            #path => #value,
1098                        });
1099                    }
1100
1101                    (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1102                        panic!("value and tensor_name cannot be used together");
1103                    }
1104                    (Some(_), None, None) | (None, None, None) => {
1105                        panic!("path and value are not present");
1106                    }
1107                }
1108            }
1109        }
1110        if !value_construct.is_empty() {
1111            construct_fields.push(quote! {
1112                #name: #(#value_construct)*
1113            });
1114        } else if !from_construct.is_empty() {
1115            construct_fields.push(quote! {
1116                #name: match path {
1117                    #(#from_construct)*
1118                    _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1119                }
1120            });
1121        } else {
1122            construct_fields.push(quote! {
1123                #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1124            });
1125        }
1126    }
1127    let expanded = quote! {
1128        impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1129            fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1130                Self {
1131                    #(#construct_fields),*
1132                }
1133            }
1134        }
1135    };
1136    // let syntax_tree = syn::parse2(expanded.clone()).expect(&format!(
1137    //     "failed to parse expanded: {}",
1138    //     expanded.to_string()
1139    // ));
1140    // let formatted = prettyplease::unparse(&syntax_tree);
1141    // println!("{}", formatted);
1142    expanded.into()
1143}