hpt_macros/
lib.rs

1//! # Tensor Macros
2//!
3//! This crate provides a set of macros to generate code for tensor operations.
4//! These macros are used to simplify and automate common tasks such as defining
5//! tensor operations, reducing dimensionality, and optimizing numerical computations.
6//!
7//! ## Examples
8//!
9//! Here's an example of using a macro from this crate:
10//!
11//! ```rust
12//! // Example code using a macro from this crate
13//! ```
14
15#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22    __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23    __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30    impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31    impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62/// number of registers available for the target architecture
63#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66    any(target_feature = "sse", target_arch = "arm"),
67    not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74    start: Option<Expr>,
75    end: Option<Expr>,
76    step: Option<Expr>,
77    skip: bool,
78}
79
80struct Selections {
81    selections: Vec<TokenStream>,
82}
83
84impl parse::Parse for SelectionParser {
85    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
86        let mut start: Option<Expr> = None;
87        let mut end: Option<Expr> = None;
88        let mut step: Option<Expr> = None;
89        if input.peek(Token![..]) {
90            input.parse::<Token![..]>()?;
91            return Ok(Self {
92                start,
93                end,
94                step,
95                skip: true,
96            });
97        }
98        if input.peek(syn::Lit)
99            || input.peek(syn::Ident)
100            || input.peek(syn::token::Paren)
101            || input.peek(Token![-])
102        {
103            start = Some(input.parse::<Expr>()?);
104        }
105        if input.peek(Token![:]) {
106            input.parse::<Token![:]>()?;
107        } else if input.is_empty() {
108            return Ok(Self {
109                start,
110                end,
111                step,
112                skip: false,
113            });
114        } else {
115            return Err(syn::Error::new(
116                input.span(),
117                "unexpected token, expected `:`, Int or Ident",
118            ));
119        }
120        if input.peek(syn::Lit)
121            || input.peek(syn::Ident)
122            || input.peek(syn::token::Paren)
123            || input.peek(Token![-])
124        {
125            end = Some(input.parse::<Expr>()?);
126        }
127        if input.peek(Token![:]) {
128            input.parse::<Token![:]>()?;
129        }
130        if input.peek(syn::Lit)
131            || input.peek(syn::Ident)
132            || input.peek(syn::token::Paren)
133            || input.peek(Token![-])
134        {
135            step = Some(input.parse::<Expr>()?);
136        }
137        Ok(Self {
138            start,
139            end,
140            step,
141            skip: false,
142        })
143    }
144}
145
146impl parse::Parse for Selections {
147    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
148        let mut selections: Vec<TokenStream> = vec![];
149
150        while !input.is_empty() {
151            let mut item_tokens = TokenStream2::new();
152            while !input.is_empty() && !input.peek(Token![,]) {
153                let token = input.parse::<TokenTree>()?;
154                item_tokens.extend(quote!(#token));
155            }
156
157            selections.push(item_tokens.into());
158
159            if input.peek(Token![,]) {
160                input.parse::<Token![,]>()?;
161            }
162        }
163
164        Ok(Self { selections })
165    }
166}
167
168/// parse the input and generate the corresponding slice
169#[proc_macro]
170pub fn select(input: TokenStream) -> TokenStream {
171    let res: Selections = parse_macro_input!(input as Selections);
172    let mut slices: Vec<SelectionParser> = vec![];
173    for x in res.selections {
174        slices.push(parse_macro_input!(x as SelectionParser));
175    }
176    let mut ret_stream = TokenStream2::new();
177    let len = slices.len();
178    let mut skipped = false;
179    for (idx, x) in slices.into_iter().enumerate() {
180        if x.skip {
181            if skipped {
182                return syn::Error::new(
183                    proc_macro2::Span::call_site(),
184                    "unexpected token, slicing only support `..` once",
185                )
186                .to_compile_error()
187                .into();
188            }
189            ret_stream.extend(quote!((0, 0, 0x7FFFFFFFFFFFFFFF)));
190            skipped = true;
191            if idx != len - 1 {
192                ret_stream.extend(quote!(,));
193            }
194            continue;
195        }
196        match (x.start, x.end, x.step) {
197            (None, None, None) => {
198                ret_stream.extend(quote!(((0, 0x7FFFFFFFFFFFFFFF, 1))));
199            }
200            (None, None, Some(step)) => {
201                ret_stream.extend(quote!((0, 0x7FFFFFFFFFFFFFFF, #step)));
202            }
203            (None, Some(end), None) => {
204                ret_stream.extend(quote!((0, #end, 1)));
205            }
206            (None, Some(end), Some(step)) => {
207                ret_stream.extend(quote!((0, #end, #step)));
208            }
209            (Some(start), None, None) => {
210                ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, 1)));
211            }
212            (Some(start), None, Some(step)) => {
213                ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, #step)));
214            }
215            (Some(start), Some(end), None) => {
216                ret_stream.extend(quote!((#start, #end, 1)));
217            }
218            (Some(start), Some(end), Some(step)) => {
219                ret_stream.extend(quote!((#start, #end, #step)));
220            }
221        }
222        if idx != len - 1 {
223            ret_stream.extend(quote!(,));
224        }
225    }
226    quote!([#ret_stream]).into()
227}
228
229/// implement float out binary trait
230#[proc_macro]
231pub fn float_out_binary(_: TokenStream) -> TokenStream {
232    impl_float_out_binary()
233}
234
235#[cfg(feature = "cuda")]
236/// implement float out binary trait for cuda
237#[proc_macro]
238pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
239    impl_cuda_float_out_binary()
240}
241
242/// implement simd float out binary trait
243#[proc_macro]
244pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
245    impl_simd_binary_out_float()
246}
247
248/// implement simd float out binary trait with rhs scalar
249#[proc_macro]
250pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
251    impl_simd_binary_out_float_rhs_scalar()
252}
253
254/// implement simd float out binary trait with lhs scalar
255#[proc_macro]
256pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
257    impl_simd_binary_out_float_lhs_scalar()
258}
259
260/// implement float out unary trait
261#[proc_macro]
262pub fn float_out_unary(_: TokenStream) -> TokenStream {
263    impl_float_out_unary()
264}
265
266#[cfg(feature = "cuda")]
267/// implement float out unary trait for cuda
268#[proc_macro]
269pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
270    crate::float_unary::impl_cuda_float_out_unary()
271}
272
273/// implement simd float out unary trait
274#[proc_macro]
275pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
276    simd_float_out_unary::impl_float_out_unary()
277}
278
279/// implement simd eval trait
280#[proc_macro]
281pub fn simd_eval(_: TokenStream) -> TokenStream {
282    simd_eval::impl_simd_eval()
283}
284
285/// implement simd bitwise trait
286#[proc_macro]
287pub fn simd_bitwise(_: TokenStream) -> TokenStream {
288    impl_simd_bitwise_out()
289}
290
291/// generate notmal out trait
292#[proc_macro]
293pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
294    __impl_normal_out_binary()
295}
296
297#[cfg(feature = "cuda")]
298/// generate notmal out trait
299#[proc_macro]
300pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
301    crate::normal_out::__impl_cuda_normal_out_binary()
302}
303
304/// gemerate normal out unary trait
305#[proc_macro]
306pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
307    normal_out_unary::__impl_normal_out_unary()
308}
309
310#[cfg(feature = "cuda")]
311/// gemerate normal out unary trait
312#[proc_macro]
313pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
314    normal_out_unary::__impl_normal_out_unary_cuda()
315}
316
317/// gemerate normal out unary trait
318#[proc_macro]
319pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
320    simd_normal_unary::impl_simd_normal_out_unary()
321}
322
323/// implement simd normal out trait
324#[proc_macro]
325pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
326    impl_simd_normal_out()
327}
328
329/// implement simd normal out trait with rhs scalar
330#[proc_macro]
331pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
332    impl_simd_normal_out_with_rhs_scalar()
333}
334
335/// implement simd normal out trait with lhs scalar
336#[proc_macro]
337pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
338    impl_simd_normal_out_with_lhs_scalar()
339}
340
341/// implement scalar convert trait
342#[proc_macro]
343pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
344    __impl_scalar_convert()
345}
346
347/// implement from scalar trait
348#[proc_macro]
349pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
350    __impl_from_scalar()
351}
352
353/// implement simd cmp trait
354#[proc_macro]
355pub fn simd_cmp(_: TokenStream) -> TokenStream {
356    impl_simd_cmp()
357}
358
359/// implment into vec trait
360#[proc_macro]
361pub fn impl_into_vec(_: TokenStream) -> TokenStream {
362    into_vec::into_vec()
363}
364
365#[cfg(feature = "cuda")]
366/// implment into cuda scalar trait
367#[proc_macro]
368pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
369    into_cuda_scalar::__impl_into_cuda_scalar().into()
370}
371
372/// implment into scalar trait
373#[proc_macro]
374pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
375    into_scalar::__impl_into_scalar().into()
376}
377
378/// implement bitwise out trait
379#[proc_macro]
380pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
381    let mut ret = proc_macro2::TokenStream::new();
382
383    let types = [
384        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
385    ];
386
387    for lhs in types.iter() {
388        for rhs in types.iter() {
389            let lhs_type = TypeInfo::new(lhs);
390            let rhs_type = TypeInfo::new(rhs);
391            let lhs_dtype = lhs_type.dtype;
392            let rhs_dtype = rhs_type.dtype;
393            let res = if lhs_dtype == rhs_dtype {
394                quote! {
395                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
396                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
397                        #[inline(always)]
398                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
399                            self.__bitand(rhs)
400                        }
401                        #[inline(always)]
402                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
403                            self.__bitor(rhs)
404                        }
405                        #[inline(always)]
406                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
407                            self.__bitxor(rhs)
408                        }
409                        #[inline(always)]
410                        fn _not(self) -> Self::Output {
411                            self.__not()
412                        }
413                        #[inline(always)]
414                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
415                            self.__shl(rhs)
416                        }
417                        #[inline(always)]
418                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419                            self.__shr(rhs)
420                        }
421                    }
422                }
423            } else {
424                quote! {
425                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
426                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
427                        #[inline(always)]
428                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
429                            let lhs: Self::Output = self.cast();
430                            let rhs: Self::Output = rhs.cast();
431                            lhs.__bitand(rhs)
432                        }
433                        #[inline(always)]
434                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
435                            let lhs: Self::Output = self.cast();
436                            let rhs: Self::Output = rhs.cast();
437                            lhs.__bitor(rhs)
438                        }
439                        #[inline(always)]
440                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
441                            let lhs: Self::Output = self.cast();
442                            let rhs: Self::Output = rhs.cast();
443                            lhs.__bitxor(rhs)
444                        }
445                        #[inline(always)]
446                        fn _not(self) -> Self::Output {
447                            let lhs: Self::Output = self.cast();
448                            lhs.__not()
449                        }
450                        #[inline(always)]
451                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
452                            let lhs: Self::Output = self.cast();
453                            let rhs: Self::Output = rhs.cast();
454                            lhs.__shl(rhs)
455                        }
456                        #[inline(always)]
457                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
458                            let lhs: Self::Output = self.cast();
459                            let rhs: Self::Output = rhs.cast();
460                            lhs.__shr(rhs)
461                        }
462                    }
463                }
464            };
465            ret.extend(res);
466        }
467    }
468
469    ret.into()
470}
471
472/// implement bitwise out trait
473#[proc_macro]
474pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
475    let mut ret = proc_macro2::TokenStream::new();
476
477    let types = [
478        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
479    ];
480
481    for lhs in types.iter() {
482        for rhs in types.iter() {
483            let lhs_type = TypeInfo::new(lhs);
484            let rhs_type = TypeInfo::new(rhs);
485            let lhs_dtype = lhs_type.dtype;
486            let rhs_dtype = rhs_type.dtype;
487            let res = if lhs_dtype == rhs_dtype {
488                quote! {
489                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
490                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
491                        #[inline(always)]
492                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
493                            self.__bitand(rhs)
494                        }
495                        #[inline(always)]
496                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
497                            self.__bitor(rhs)
498                        }
499                        #[inline(always)]
500                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
501                            self.__bitxor(rhs)
502                        }
503                        #[inline(always)]
504                        fn _not(self) -> Self::Output {
505                            self.__not()
506                        }
507                        #[inline(always)]
508                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
509                            self.__shl(rhs)
510                        }
511                        #[inline(always)]
512                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513                            self.__shr(rhs)
514                        }
515                    }
516                }
517            } else {
518                quote! {
519                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
520                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
521                        #[inline(always)]
522                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
523                            let lhs: Self::Output = self.cast();
524                            let rhs: Self::Output = rhs.cast();
525                            lhs.__bitand(rhs)
526                        }
527                        #[inline(always)]
528                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
529                            let lhs: Self::Output = self.cast();
530                            let rhs: Self::Output = rhs.cast();
531                            lhs.__bitor(rhs)
532                        }
533                        #[inline(always)]
534                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
535                            let lhs: Self::Output = self.cast();
536                            let rhs: Self::Output = rhs.cast();
537                            lhs.__bitxor(rhs)
538                        }
539                        #[inline(always)]
540                        fn _not(self) -> Self::Output {
541                            let lhs: Self::Output = self.cast();
542                            lhs.__not()
543                        }
544                        #[inline(always)]
545                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
546                            let lhs: Self::Output = self.cast();
547                            let rhs: Self::Output = rhs.cast();
548                            lhs.__shl(rhs)
549                        }
550                        #[inline(always)]
551                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
552                            let lhs: Self::Output = self.cast();
553                            let rhs: Self::Output = rhs.cast();
554                            lhs.__shr(rhs)
555                        }
556                    }
557                }
558            };
559            ret.extend(res);
560        }
561    }
562
563    ret.into()
564}
565
566/// implement compare trait
567#[proc_macro]
568pub fn impl_cmp(_: TokenStream) -> TokenStream {
569    let mut ret = proc_macro2::TokenStream::new();
570
571    let types = [
572        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
573        "usize", "bf16",
574    ];
575
576    for lhs in types.iter() {
577        for rhs in types.iter() {
578            let lhs_type = TypeInfo::new(lhs);
579            let rhs_type = TypeInfo::new(rhs);
580            let lhs_dtype = lhs_type.dtype;
581            let rhs_dtype = rhs_type.dtype;
582            let res = if lhs_dtype == rhs_dtype {
583                quote! {
584                    impl Cmp<#rhs_dtype> for #lhs_dtype {
585                        type Output = bool;
586                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
587                            self == rhs
588                        }
589                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
590                            self != rhs
591                        }
592                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
593                            self < rhs
594                        }
595
596                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597                            self <= rhs
598                        }
599                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
600                            self > rhs
601                        }
602                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
603                            self >= rhs
604                        }
605                    }
606                }
607            } else {
608                quote! {
609                    impl Cmp<#rhs_dtype> for #lhs_dtype {
610                        type Output = bool;
611                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
612                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
613                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
614                            lhs == rhs
615                        }
616                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
617                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
618                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
619                            lhs != rhs
620                        }
621                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
622                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
623                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
624                            lhs < rhs
625                        }
626
627                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
628                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
629                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
630                            lhs <= rhs
631                        }
632                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
633                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
634                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
635                            lhs > rhs
636                        }
637                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
638                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
639                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
640                            lhs >= rhs
641                        }
642                    }
643                }
644            };
645            ret.extend(res);
646        }
647    }
648
649    ret.into()
650}
651
652/// implement compare trait
653#[proc_macro]
654pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
655    let mut ret = proc_macro2::TokenStream::new();
656
657    let types = [
658        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
659        "usize", "bf16",
660    ];
661
662    for lhs in types.iter() {
663        for rhs in types.iter() {
664            let lhs_type = TypeInfo::new(lhs);
665            let rhs_type = TypeInfo::new(rhs);
666            let lhs_dtype = lhs_type.dtype;
667            let rhs_dtype = rhs_type.dtype;
668            let res = if lhs_dtype == rhs_dtype {
669                quote! {
670                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
671                        type Output = Scalar<bool>;
672                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
673                            self.__eq(rhs)
674                        }
675                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
676                            self.__ne(rhs)
677                        }
678                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
679                            self.__lt(rhs)
680                        }
681
682                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683                            self.__le(rhs)
684                        }
685                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
686                            self.__gt(rhs)
687                        }
688                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
689                            self.__ge(rhs)
690                        }
691                    }
692                }
693            } else {
694                quote! {
695                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
696                        type Output = Scalar<bool>;
697                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
698                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
699                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
700                            lhs.__eq(rhs)
701                        }
702                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
703                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
704                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
705                            lhs.__ne(rhs)
706                        }
707                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
708                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
709                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
710                            lhs.__lt(rhs)
711                        }
712
713                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
714                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
715                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
716                            lhs.__le(rhs)
717                        }
718                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
719                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
720                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
721                            lhs.__gt(rhs)
722                        }
723                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
724                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
725                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
726                            lhs.__ge(rhs)
727                        }
728                    }
729                }
730            };
731            ret.extend(res);
732        }
733    }
734
735    ret.into()
736}
737
738/// implement eval trait
739#[proc_macro]
740pub fn impl_eval(_: TokenStream) -> TokenStream {
741    let mut ret = proc_macro2::TokenStream::new();
742
743    let types = [
744        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
745        "usize", "bf16",
746    ];
747
748    for lhs in types.iter() {
749        let lhs_type = TypeInfo::new(lhs);
750        let lhs_dtype = lhs_type.dtype;
751
752        let res = quote! {
753            impl Eval for #lhs_dtype {
754                type Output = bool;
755                #[inline(always)]
756                fn _is_nan(&self) -> bool {
757                    self.__is_nan()
758                }
759                #[inline(always)]
760                fn _is_true(&self) -> bool {
761                    self.__is_true()
762                }
763                #[inline(always)]
764                fn _is_inf(&self) -> bool {
765                    self.__is_inf()
766                }
767            }
768        };
769        ret.extend(res);
770    }
771
772    ret.into()
773}
774
775/// generate fast reduce simd helper
776#[proc_macro]
777pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
778    __gen_fast_reduce_simd_helper(input)
779}
780
781/// generate fast layernorm simd helper
782#[proc_macro]
783pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
784    __gen_fast_layernorm_simd_helper(input)
785}
786
787/// generate reduce dim not include simd helper
788#[proc_macro]
789pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
790    __gen_reduce_dim_not_include_simd_helper(input)
791}
792
793/// declare const values
794///
795/// const OW_BLOCK: usize = ?;
796///
797/// const OC_BLOCK: usize = ?;
798#[proc_macro]
799pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
800    conv2d::conv2d_microkernel_declare_const(input)
801}
802
803/// generate conv2d inps
804#[proc_macro]
805pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
806    conv2d::conv2d_microkernel_gen_inps(input)
807}
808
809/// generate conv2d inps
810#[proc_macro]
811pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
812    conv2d::conv2d_microkernel_gen_pad_inps(input)
813}
814
815/// generate pwconv2d inps
816#[proc_macro]
817pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
818    conv2d::pwconv2d_microkernel_gen_pad_inps(input)
819}
820
821/// generate conv2d inps
822#[proc_macro]
823pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
824    conv2d::dwconv2d_microkernel_gen_pad_inps(input)
825}
826
827/// generate conv2d kernels
828#[proc_macro]
829pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
830    conv2d::conv2d_microkernel_gen_kernels(input)
831}
832
833/// generate conv2d repeat results
834#[proc_macro]
835pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
836    conv2d::conv2d_microkernel_gen_results(input)
837}
838
839/// generate conv2d repeat results
840#[proc_macro]
841pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
842    conv2d::dwconv2d_microkernel_gen_results(input)
843}
844
845/// generate maxpool2d kernels
846/// generate conv2d repeat results
847#[proc_macro]
848pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
849    conv2d::maxpool2d_microkernel_gen_results(input)
850}
851
852/// generate save trait
853#[proc_macro_derive(Save, attributes(compress))]
854pub fn impl_save(input: TokenStream) -> TokenStream {
855    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
856    let name = &ast.ident;
857    let meta_name = format_ident!("{}Meta", name);
858
859    let visibility = &ast.vis;
860    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
861    let fields = match &ast.data {
862        syn::Data::Struct(s) => &s.fields,
863        _ => panic!("Save can only be derived for structs"),
864    };
865
866    let mut compressions = vec![];
867    let mut compress_levels = vec![];
868
869    let meta_fields = fields
870        .iter()
871        .map(|f| {
872            let mut compression_algo = None;
873            let mut level = None;
874
875            for attr in &f.attrs {
876                if attr.path().is_ident("compress") {
877                    attr.parse_nested_meta(|meta| {
878                        if meta.path.is_ident("algo") {
879                            let value: syn::LitStr = meta.value()?.parse()?;
880                            let algo = match value.value().as_str().to_lowercase().as_str() {
881                                "gzip" => quote!(Gzip),
882                                "deflate" => quote!(Deflate),
883                                "zlib" => quote!(Zlib),
884                                "none" => quote!(NoCompression),
885                                _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
886                            };
887                            compression_algo = Some(quote!(hpt::save_load::CompressionAlgo::#algo));
888                        } else if meta.path.is_ident("level") {
889                            let value: syn::LitStr = meta.value()?.parse()?;
890                            let tmp: u32 = value.value().parse().map_err(|e| {
891                                syn::Error::new(value.span(), format!("Invalid level: {}", e))
892                            })?;
893                            level = Some(quote!(#tmp));
894                        }
895                        Ok(())
896                    })
897                    .unwrap();
898                }
899            }
900            compressions.push(compression_algo);
901            compress_levels.push(level);
902            let name = &f.ident;
903            let ty = &f.ty;
904            quote! {
905                pub #name: <#ty as Save>::Meta
906            }
907        })
908        .collect::<Vec<_>>();
909
910    let call_save = fields.iter().enumerate().map(|(idx, f)| {
911        let name = &f.ident;
912        let ty = &f.ty;
913        let ident = format_ident!("field_{}", idx);
914        let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
915        let level = compress_levels[idx].clone().unwrap_or(quote!(level));
916        if let Some(name) = name {
917            quote! {
918                let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #level)?;
919            }
920        } else {
921            quote! {
922                let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #level)?;
923            }
924        }
925    });
926
927    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
928        let name = &f.ident;
929        let ident = format_ident!("field_{}", idx);
930        quote! {
931            #name: #ident
932        }
933    });
934
935    let expanded = quote! {
936        #[derive(hpt::re_exports::serde::Deserialize, hpt::re_exports::serde::Serialize)]
937        #[serde(crate = "hpt::re_exports::serde")]
938        #visibility struct #meta_name #ty_generics #where_clause  {
939            #(#meta_fields,)*
940        }
941        impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
942            type Meta = #meta_name #ty_generics;
943            fn __save(
944                data: &Self,
945                file: &mut std::fs::File,
946                len_so_far: &mut usize,
947                global_cnt: &mut usize,
948                compression_algo: hpt::save_load::CompressionAlgo,
949                level: u32,
950            ) -> std::io::Result<Self::Meta> {
951                #(#call_save)*
952                Ok(Self::Meta {
953                    #(#construct_fields),*
954                })
955            }
956        }
957    };
958
959    expanded.into()
960}
961
962/// generate load trait
963#[proc_macro_derive(Load)]
964pub fn impl_load(input: TokenStream) -> TokenStream {
965    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
966    let name = &ast.ident;
967    let meta_name = format_ident!("{}Meta", name);
968    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
969
970    let fields = match &ast.data {
971        syn::Data::Struct(s) => &s.fields,
972        _ => panic!("Load can only be derived for structs"),
973    };
974
975    let call_load = fields.iter().enumerate().map(|(idx, f)| {
976        let name = &f.ident;
977        let ident = format_ident!("field_{}", idx);
978        if let Some(name) = name {
979            quote! {
980                let #ident = self.#name.load(file)?;
981            }
982        } else {
983            quote! {
984                let #ident = self.#idx.load(file)?;
985            }
986        }
987    });
988
989    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
990        let name = &f.ident;
991        let ident = format_ident!("field_{}", idx);
992        quote! {
993            #name: #ident
994        }
995    });
996
997    let expanded = quote! {
998        impl #impl_generics hpt::save_load::MetaLoad for #meta_name #ty_generics #where_clause {
999            type Output = #name #ty_generics;
1000            fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
1001                use hpt::save_load::MetaLoad;
1002                #(#call_load)*
1003                Ok(#name {
1004                    #(#construct_fields),*
1005                })
1006            }
1007        }
1008        impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
1009            fn load<P: Into<std::path::PathBuf>>(path: P) -> std::io::Result<Self> {
1010                use hpt::save_load::MetaLoad;
1011                let path: std::path::PathBuf = path.into();
1012                let meta = hpt::save_load::parse_header_compressed::<Self, _>(&path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
1013                let mut file = std::fs::File::open(path)?;
1014                meta.load(&mut file)
1015            }
1016        }
1017    };
1018
1019    expanded.into()
1020}
1021
1022/// generate from safetensors trait
1023#[proc_macro_derive(FromSafeTensors, attributes(map))]
1024pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1025    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1026    let struct_name = &ast.ident;
1027    let fields = match &ast.data {
1028        syn::Data::Struct(s) => &s.fields,
1029        _ => panic!("FromSafeTensors can only be derived for structs"),
1030    };
1031    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1032    let mut construct_fields = vec![];
1033    for (_, field) in fields.iter().enumerate() {
1034        let ty = &field.ty;
1035        let name = &field.ident;
1036        let mut value_construct = vec![];
1037        let mut from_construct = vec![];
1038        let mut params = vec![];
1039        let mut vec_len = None;
1040        for attr in &field.attrs {
1041            if attr.path().is_ident("map") {
1042                let mut path = None;
1043                let mut value = None;
1044                let mut tensor_name = None;
1045                let mut inner_type = None;
1046                attr.parse_nested_meta(|meta| {
1047                    if meta.path.is_ident("path") {
1048                        let value: syn::LitStr = meta.value()?.parse()?;
1049                        path = Some(value.value());
1050                    } else if meta.path.is_ident("value") {
1051                        let val: syn::Expr = meta.value()?.parse()?;
1052                        value = Some(val);
1053                    } else if meta.path.is_ident("tensor_name") {
1054                        let value: syn::LitStr = meta.value()?.parse()?;
1055                        tensor_name = Some(value.value());
1056                    } else if meta.path.is_ident("vec_len") {
1057                        let value: syn::LitInt = meta.value()?.parse()?;
1058                        vec_len = Some(value.base10_parse::<usize>().unwrap());
1059                    } else if meta.path.is_ident("inner_type") {
1060                        let value: syn::Ident = meta.value()?.parse()?;
1061                        inner_type = Some(value);
1062                    }
1063                    Ok(())
1064                })
1065                .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1066                params.push((path, value, tensor_name, vec_len, inner_type));
1067            }
1068        }
1069        let param_count = params.len();
1070        for (path, value, tensor_name, vec_len, inner_type) in params {
1071            if let Some(vec_len) = vec_len {
1072                let inner_type = inner_type.expect("inner_type is required for vec");
1073                if let Some(path) = path {
1074                    from_construct.push(quote! {
1075                        #path => {
1076                            let mut vec = vec![];
1077                            for i in 0..#vec_len {
1078                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1079                            }
1080                            vec
1081                        }
1082                    });
1083                } else {
1084                    value_construct.push(quote! {
1085                        {
1086                            let mut vec = vec![];
1087                            for i in 0..#vec_len {
1088                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1089                            }
1090                            vec
1091                        }
1092                    });
1093                }
1094            } else {
1095                match (path, value, tensor_name) {
1096                    (None, None, Some(tensor_name)) => {
1097                        value_construct.push(quote! {
1098                            <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1099                        });
1100                    }
1101                    (None, Some(value), None) => {
1102                        if param_count > 1 {
1103                            panic!("value without path means generic assignment, there can only be one value without path");
1104                        }
1105                        value_construct.push(quote! {
1106                            #value
1107                        });
1108                    }
1109                    (Some(path), None, Some(tensor_name)) => {
1110                        from_construct.push(quote! {
1111                            #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1112                        });
1113                    }
1114                    (Some(path), Some(value), None) => {
1115                        from_construct.push(quote! {
1116                            #path => #value,
1117                        });
1118                    }
1119
1120                    (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1121                        panic!("value and tensor_name cannot be used together");
1122                    }
1123                    (Some(_), None, None) | (None, None, None) => {
1124                        panic!("path and value are not present");
1125                    }
1126                }
1127            }
1128        }
1129        if !value_construct.is_empty() {
1130            construct_fields.push(quote! {
1131                #name: #(#value_construct)*
1132            });
1133        } else if !from_construct.is_empty() {
1134            construct_fields.push(quote! {
1135                #name: match path {
1136                    #(#from_construct)*
1137                    _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1138                }
1139            });
1140        } else {
1141            construct_fields.push(quote! {
1142                #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1143            });
1144        }
1145    }
1146    let expanded = quote! {
1147        impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1148            fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1149                Self {
1150                    #(#construct_fields),*
1151                }
1152            }
1153        }
1154    };
1155    // let syntax_tree = syn::parse2(expanded.clone()).expect(&format!(
1156    //     "failed to parse expanded: {}",
1157    //     expanded.to_string()
1158    // ));
1159    // let formatted = prettyplease::unparse(&syntax_tree);
1160    // println!("{}", formatted);
1161    expanded.into()
1162}