hpt_macros/
lib.rs

1//! # Tensor Macros
2//!
3//! This crate provides a set of macros to generate code for tensor operations.
4//! These macros are used to simplify and automate common tasks such as defining
5//! tensor operations, reducing dimensionality, and optimizing numerical computations.
6//!
7//! ## Examples
8//!
9//! Here's an example of using a macro from this crate:
10//!
11//! ```rust
12//! // Example code using a macro from this crate
13//! ```
14
15#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22    __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23    __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30    impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31    impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62/// number of registers available for the target architecture
63#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66    any(target_feature = "sse", target_arch = "arm"),
67    not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74    start: Option<Expr>,
75    end: Option<Expr>,
76    step: Option<Expr>,
77    skip: bool,
78}
79
80struct Selections {
81    selections: Vec<TokenStream>,
82}
83
84impl parse::Parse for SelectionParser {
85    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
86        let mut start: Option<Expr> = None;
87        let mut end: Option<Expr> = None;
88        let mut step: Option<Expr> = None;
89        if input.peek(Token![..]) {
90            input.parse::<Token![..]>()?;
91            return Ok(Self {
92                start,
93                end,
94                step,
95                skip: true,
96            });
97        }
98        if input.peek(syn::Lit)
99            || input.peek(syn::Ident)
100            || input.peek(syn::token::Paren)
101            || input.peek(Token![-])
102        {
103            start = Some(input.parse::<Expr>()?);
104        }
105        if input.peek(Token![:]) {
106            input.parse::<Token![:]>()?;
107        } else if input.is_empty() {
108            return Ok(Self {
109                start,
110                end,
111                step,
112                skip: false,
113            });
114        } else {
115            return Err(syn::Error::new(
116                input.span(),
117                "unexpected token, expected `:`, Int or Ident",
118            ));
119        }
120        if input.peek(syn::Lit)
121            || input.peek(syn::Ident)
122            || input.peek(syn::token::Paren)
123            || input.peek(Token![-])
124        {
125            end = Some(input.parse::<Expr>()?);
126        }
127        if input.peek(Token![:]) {
128            input.parse::<Token![:]>()?;
129        }
130        if input.peek(syn::Lit)
131            || input.peek(syn::Ident)
132            || input.peek(syn::token::Paren)
133            || input.peek(Token![-])
134        {
135            step = Some(input.parse::<Expr>()?);
136        }
137        Ok(Self {
138            start,
139            end,
140            step,
141            skip: false,
142        })
143    }
144}
145
146impl parse::Parse for Selections {
147    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
148        let mut selections: Vec<TokenStream> = vec![];
149
150        while !input.is_empty() {
151            let mut item_tokens = TokenStream2::new();
152            while !input.is_empty() && !input.peek(Token![,]) {
153                let token = input.parse::<TokenTree>()?;
154                item_tokens.extend(quote!(#token));
155            }
156
157            selections.push(item_tokens.into());
158
159            if input.peek(Token![,]) {
160                input.parse::<Token![,]>()?;
161            }
162        }
163
164        Ok(Self { selections })
165    }
166}
167
168/// parse the input and generate the corresponding slice
169#[proc_macro]
170pub fn select(input: TokenStream) -> TokenStream {
171    let res: Selections = parse_macro_input!(input as Selections);
172    let mut slices: Vec<SelectionParser> = vec![];
173    for x in res.selections {
174        slices.push(parse_macro_input!(x as SelectionParser));
175    }
176    let mut ret_stream = TokenStream2::new();
177    let len = slices.len();
178    let mut skipped = false;
179    for (idx, x) in slices.into_iter().enumerate() {
180        if x.skip {
181            if skipped {
182                return syn::Error::new(
183                    proc_macro2::Span::call_site(),
184                    "unexpected token, slicing only support `..` once",
185                )
186                .to_compile_error()
187                .into();
188            }
189            ret_stream.extend(quote!((0, 0, 0x7FFFFFFFFFFFFFFF)));
190            skipped = true;
191            if idx != len - 1 {
192                ret_stream.extend(quote!(,));
193            }
194            continue;
195        }
196        match (x.start, x.end, x.step) {
197            (None, None, None) => {
198                ret_stream.extend(quote!(((0, 0x7FFFFFFFFFFFFFFF, 1))));
199            }
200            (None, None, Some(step)) => {
201                ret_stream.extend(quote!((0, 0x7FFFFFFFFFFFFFFF, #step)));
202            }
203            (None, Some(end), None) => {
204                ret_stream.extend(quote!((0, #end, 1)));
205            }
206            (None, Some(end), Some(step)) => {
207                ret_stream.extend(quote!((0, #end, #step)));
208            }
209            (Some(start), None, None) => {
210                ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, 1)));
211            }
212            (Some(start), None, Some(step)) => {
213                ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, #step)));
214            }
215            (Some(start), Some(end), None) => {
216                ret_stream.extend(quote!((#start, #end, 1)));
217            }
218            (Some(start), Some(end), Some(step)) => {
219                ret_stream.extend(quote!((#start, #end, #step)));
220            }
221        }
222        if idx != len - 1 {
223            ret_stream.extend(quote!(,));
224        }
225    }
226    quote!([#ret_stream]).into()
227}
228
229/// implement float out binary trait
230#[proc_macro]
231pub fn float_out_binary(_: TokenStream) -> TokenStream {
232    impl_float_out_binary()
233}
234
235#[cfg(feature = "cuda")]
236/// implement float out binary trait for cuda
237#[proc_macro]
238pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
239    impl_cuda_float_out_binary()
240}
241
242/// implement simd float out binary trait
243#[proc_macro]
244pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
245    impl_simd_binary_out_float()
246}
247
248/// implement simd float out binary trait with rhs scalar
249#[proc_macro]
250pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
251    impl_simd_binary_out_float_rhs_scalar()
252}
253
254/// implement simd float out binary trait with lhs scalar
255#[proc_macro]
256pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
257    impl_simd_binary_out_float_lhs_scalar()
258}
259
260/// implement float out unary trait
261#[proc_macro]
262pub fn float_out_unary(_: TokenStream) -> TokenStream {
263    impl_float_out_unary()
264}
265
266#[cfg(feature = "cuda")]
267/// implement float out unary trait for cuda
268#[proc_macro]
269pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
270    crate::float_unary::impl_cuda_float_out_unary()
271}
272
273/// implement simd float out unary trait
274#[proc_macro]
275pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
276    simd_float_out_unary::impl_float_out_unary()
277}
278
279/// implement simd eval trait
280#[proc_macro]
281pub fn simd_eval(_: TokenStream) -> TokenStream {
282    simd_eval::impl_simd_eval()
283}
284
285/// implement simd bitwise trait
286#[proc_macro]
287pub fn simd_bitwise(_: TokenStream) -> TokenStream {
288    impl_simd_bitwise_out()
289}
290
291/// generate notmal out trait
292#[proc_macro]
293pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
294    __impl_normal_out_binary()
295}
296
297#[cfg(feature = "cuda")]
298/// generate notmal out trait
299#[proc_macro]
300pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
301    crate::normal_out::__impl_cuda_normal_out_binary()
302}
303
304/// gemerate normal out unary trait
305#[proc_macro]
306pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
307    normal_out_unary::__impl_normal_out_unary()
308}
309
310#[cfg(feature = "cuda")]
311/// gemerate normal out unary trait
312#[proc_macro]
313pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
314    normal_out_unary::__impl_normal_out_unary_cuda()
315}
316
317/// gemerate normal out unary trait
318#[proc_macro]
319pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
320    simd_normal_unary::impl_simd_normal_out_unary()
321}
322
323/// implement simd normal out trait
324#[proc_macro]
325pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
326    impl_simd_normal_out()
327}
328
329/// implement simd normal out trait with rhs scalar
330#[proc_macro]
331pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
332    impl_simd_normal_out_with_rhs_scalar()
333}
334
335/// implement simd normal out trait with lhs scalar
336#[proc_macro]
337pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
338    impl_simd_normal_out_with_lhs_scalar()
339}
340
341/// implement scalar convert trait
342#[proc_macro]
343pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
344    __impl_scalar_convert()
345}
346
347/// implement from scalar trait
348#[proc_macro]
349pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
350    __impl_from_scalar()
351}
352
353/// implement simd cmp trait
354#[proc_macro]
355pub fn simd_cmp(_: TokenStream) -> TokenStream {
356    impl_simd_cmp()
357}
358
359/// implment into vec trait
360#[proc_macro]
361pub fn impl_into_vec(_: TokenStream) -> TokenStream {
362    into_vec::into_vec()
363}
364
365#[cfg(feature = "cuda")]
366/// implment into cuda scalar trait
367#[proc_macro]
368pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
369    into_cuda_scalar::__impl_into_cuda_scalar().into()
370}
371
372/// implment into scalar trait
373#[proc_macro]
374pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
375    into_scalar::__impl_into_scalar().into()
376}
377
378/// implement bitwise out trait
379#[proc_macro]
380pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
381    let mut ret = proc_macro2::TokenStream::new();
382
383    let types = [
384        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
385    ];
386
387    for lhs in types.iter() {
388        for rhs in types.iter() {
389            let lhs_type = TypeInfo::new(lhs);
390            let rhs_type = TypeInfo::new(rhs);
391            let lhs_dtype = lhs_type.dtype;
392            let rhs_dtype = rhs_type.dtype;
393            let res = if lhs_dtype == rhs_dtype {
394                quote! {
395                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
396                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
397                        #[inline(always)]
398                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
399                            self.__bitand(rhs)
400                        }
401                        #[inline(always)]
402                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
403                            self.__bitor(rhs)
404                        }
405                        #[inline(always)]
406                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
407                            self.__bitxor(rhs)
408                        }
409                        #[inline(always)]
410                        fn _not(self) -> Self::Output {
411                            self.__not()
412                        }
413                        #[inline(always)]
414                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
415                            self.__shl(rhs)
416                        }
417                        #[inline(always)]
418                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419                            self.__shr(rhs)
420                        }
421                    }
422                }
423            } else {
424                quote! {
425                    impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
426                        type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
427                        #[inline(always)]
428                        fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
429                            let lhs: Self::Output = self.cast();
430                            let rhs: Self::Output = rhs.cast();
431                            lhs.__bitand(rhs)
432                        }
433                        #[inline(always)]
434                        fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
435                            let lhs: Self::Output = self.cast();
436                            let rhs: Self::Output = rhs.cast();
437                            lhs.__bitor(rhs)
438                        }
439                        #[inline(always)]
440                        fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
441                            let lhs: Self::Output = self.cast();
442                            let rhs: Self::Output = rhs.cast();
443                            lhs.__bitxor(rhs)
444                        }
445                        #[inline(always)]
446                        fn _not(self) -> Self::Output {
447                            let lhs: Self::Output = self.cast();
448                            lhs.__not()
449                        }
450                        #[inline(always)]
451                        fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
452                            let lhs: Self::Output = self.cast();
453                            let rhs: Self::Output = rhs.cast();
454                            lhs.__shl(rhs)
455                        }
456                        #[inline(always)]
457                        fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
458                            let lhs: Self::Output = self.cast();
459                            let rhs: Self::Output = rhs.cast();
460                            lhs.__shr(rhs)
461                        }
462                    }
463                }
464            };
465            ret.extend(res);
466        }
467    }
468
469    ret.into()
470}
471
472/// implement bitwise out trait
473#[proc_macro]
474pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
475    let mut ret = proc_macro2::TokenStream::new();
476
477    let types = [
478        "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
479    ];
480
481    for lhs in types.iter() {
482        for rhs in types.iter() {
483            let lhs_type = TypeInfo::new(lhs);
484            let rhs_type = TypeInfo::new(rhs);
485            let lhs_dtype = lhs_type.dtype;
486            let rhs_dtype = rhs_type.dtype;
487            let res = if lhs_dtype == rhs_dtype {
488                quote! {
489                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
490                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
491                        #[inline(always)]
492                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
493                            self.__bitand(rhs)
494                        }
495                        #[inline(always)]
496                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
497                            self.__bitor(rhs)
498                        }
499                        #[inline(always)]
500                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
501                            self.__bitxor(rhs)
502                        }
503                        #[inline(always)]
504                        fn _not(self) -> Self::Output {
505                            self.__not()
506                        }
507                        #[inline(always)]
508                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
509                            self.__shl(rhs)
510                        }
511                        #[inline(always)]
512                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513                            self.__shr(rhs)
514                        }
515                    }
516                }
517            } else {
518                quote! {
519                    impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
520                        type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
521                        #[inline(always)]
522                        fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
523                            let lhs: Self::Output = self.cast();
524                            let rhs: Self::Output = rhs.cast();
525                            lhs.__bitand(rhs)
526                        }
527                        #[inline(always)]
528                        fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
529                            let lhs: Self::Output = self.cast();
530                            let rhs: Self::Output = rhs.cast();
531                            lhs.__bitor(rhs)
532                        }
533                        #[inline(always)]
534                        fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
535                            let lhs: Self::Output = self.cast();
536                            let rhs: Self::Output = rhs.cast();
537                            lhs.__bitxor(rhs)
538                        }
539                        #[inline(always)]
540                        fn _not(self) -> Self::Output {
541                            let lhs: Self::Output = self.cast();
542                            lhs.__not()
543                        }
544                        #[inline(always)]
545                        fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
546                            let lhs: Self::Output = self.cast();
547                            let rhs: Self::Output = rhs.cast();
548                            lhs.__shl(rhs)
549                        }
550                        #[inline(always)]
551                        fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
552                            let lhs: Self::Output = self.cast();
553                            let rhs: Self::Output = rhs.cast();
554                            lhs.__shr(rhs)
555                        }
556                    }
557                }
558            };
559            ret.extend(res);
560        }
561    }
562
563    ret.into()
564}
565
566/// implement compare trait
567#[proc_macro]
568pub fn impl_cmp(_: TokenStream) -> TokenStream {
569    let mut ret = proc_macro2::TokenStream::new();
570
571    let types = [
572        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
573        "usize", "bf16",
574    ];
575
576    for lhs in types.iter() {
577        for rhs in types.iter() {
578            let lhs_type = TypeInfo::new(lhs);
579            let rhs_type = TypeInfo::new(rhs);
580            let lhs_dtype = lhs_type.dtype;
581            let rhs_dtype = rhs_type.dtype;
582            let res = if lhs_dtype == rhs_dtype {
583                quote! {
584                    impl Cmp<#rhs_dtype> for #lhs_dtype {
585                        type Output = bool;
586                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
587                            self == rhs
588                        }
589                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
590                            self != rhs
591                        }
592                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
593                            self < rhs
594                        }
595
596                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597                            self <= rhs
598                        }
599                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
600                            self > rhs
601                        }
602                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
603                            self >= rhs
604                        }
605                    }
606                }
607            } else {
608                quote! {
609                    impl Cmp<#rhs_dtype> for #lhs_dtype {
610                        type Output = bool;
611                        fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
612                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
613                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
614                            lhs == rhs
615                        }
616                        fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
617                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
618                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
619                            lhs != rhs
620                        }
621                        fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
622                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
623                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
624                            lhs < rhs
625                        }
626
627                        fn _le(self, rhs: #rhs_dtype) -> Self::Output {
628                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
629                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
630                            lhs <= rhs
631                        }
632                        fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
633                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
634                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
635                            lhs > rhs
636                        }
637                        fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
638                            let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
639                            let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
640                            lhs >= rhs
641                        }
642                    }
643                }
644            };
645            ret.extend(res);
646        }
647    }
648
649    ret.into()
650}
651
652/// implement compare trait
653#[proc_macro]
654pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
655    let mut ret = proc_macro2::TokenStream::new();
656
657    let types = [
658        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
659        "usize", "bf16",
660    ];
661
662    for lhs in types.iter() {
663        for rhs in types.iter() {
664            let lhs_type = TypeInfo::new(lhs);
665            let rhs_type = TypeInfo::new(rhs);
666            let lhs_dtype = lhs_type.dtype;
667            let rhs_dtype = rhs_type.dtype;
668            let res = if lhs_dtype == rhs_dtype {
669                quote! {
670                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
671                        type Output = Scalar<bool>;
672                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
673                            self.__eq(rhs)
674                        }
675                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
676                            self.__ne(rhs)
677                        }
678                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
679                            self.__lt(rhs)
680                        }
681
682                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683                            self.__le(rhs)
684                        }
685                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
686                            self.__gt(rhs)
687                        }
688                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
689                            self.__ge(rhs)
690                        }
691                    }
692                }
693            } else {
694                quote! {
695                    impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
696                        type Output = Scalar<bool>;
697                        fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
698                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
699                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
700                            lhs.__eq(rhs)
701                        }
702                        fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
703                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
704                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
705                            lhs.__ne(rhs)
706                        }
707                        fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
708                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
709                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
710                            lhs.__lt(rhs)
711                        }
712
713                        fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
714                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
715                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
716                            lhs.__le(rhs)
717                        }
718                        fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
719                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
720                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
721                            lhs.__gt(rhs)
722                        }
723                        fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
724                            let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
725                            let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
726                            lhs.__ge(rhs)
727                        }
728                    }
729                }
730            };
731            ret.extend(res);
732        }
733    }
734
735    ret.into()
736}
737
738/// implement eval trait
739#[proc_macro]
740pub fn impl_eval(_: TokenStream) -> TokenStream {
741    let mut ret = proc_macro2::TokenStream::new();
742
743    let types = [
744        "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
745        "usize", "bf16",
746    ];
747
748    for lhs in types.iter() {
749        let lhs_type = TypeInfo::new(lhs);
750        let lhs_dtype = lhs_type.dtype;
751
752        let res = quote! {
753            impl Eval for #lhs_dtype {
754                type Output = bool;
755                #[inline(always)]
756                fn _is_nan(&self) -> bool {
757                    self.__is_nan()
758                }
759                #[inline(always)]
760                fn _is_true(&self) -> bool {
761                    self.__is_true()
762                }
763                #[inline(always)]
764                fn _is_inf(&self) -> bool {
765                    self.__is_inf()
766                }
767            }
768        };
769        ret.extend(res);
770    }
771
772    ret.into()
773}
774
775/// generate fast reduce simd helper
776#[proc_macro]
777pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
778    __gen_fast_reduce_simd_helper(input)
779}
780
781/// generate fast layernorm simd helper
782#[proc_macro]
783pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
784    __gen_fast_layernorm_simd_helper(input)
785}
786
787/// generate reduce dim not include simd helper
788#[proc_macro]
789pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
790    __gen_reduce_dim_not_include_simd_helper(input)
791}
792
793/// declare const values
794///
795/// const OW_BLOCK: usize = ?;
796///
797/// const OC_BLOCK: usize = ?;
798#[proc_macro]
799pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
800    conv2d::conv2d_microkernel_declare_const(input)
801}
802
803/// generate conv2d inps
804#[proc_macro]
805pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
806    conv2d::conv2d_microkernel_gen_inps(input)
807}
808
809/// generate conv2d inps
810#[proc_macro]
811pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
812    conv2d::conv2d_microkernel_gen_pad_inps(input)
813}
814
815/// generate pwconv2d inps
816#[proc_macro]
817pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
818    conv2d::pwconv2d_microkernel_gen_pad_inps(input)
819}
820
821/// generate conv2d inps
822#[proc_macro]
823pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
824    conv2d::dwconv2d_microkernel_gen_pad_inps(input)
825}
826
827/// generate conv2d kernels
828#[proc_macro]
829pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
830    conv2d::conv2d_microkernel_gen_kernels(input)
831}
832
833/// generate conv2d repeat results
834#[proc_macro]
835pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
836    conv2d::conv2d_microkernel_gen_results(input)
837}
838
839/// generate conv2d repeat results
840#[proc_macro]
841pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
842    conv2d::dwconv2d_microkernel_gen_results(input)
843}
844
845/// generate maxpool2d kernels
846/// generate conv2d repeat results
847#[proc_macro]
848pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
849    conv2d::maxpool2d_microkernel_gen_results(input)
850}
851
852/// generate save trait
853#[proc_macro_derive(Save, attributes(compress))]
854pub fn impl_save(input: TokenStream) -> TokenStream {
855    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
856    let name = &ast.ident;
857    let meta_name = format_ident!("{}Meta", name);
858
859    let visibility = &ast.vis;
860    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
861    let fields = match &ast.data {
862        syn::Data::Struct(s) => &s.fields,
863        _ => panic!("Save can only be derived for structs"),
864    };
865
866    let mut compressions = vec![];
867    let mut endians = vec![];
868    let mut compress_levels = vec![];
869
870    let meta_fields = fields
871        .iter()
872        .map(|f| {
873            let mut compression_algo = None;
874            let mut endian = None;
875            let mut level = None;
876
877            for attr in &f.attrs {
878                if attr.path().is_ident("compress") {
879                    attr.parse_nested_meta(|meta| {
880                        if meta.path.is_ident("algo") {
881                            let value: syn::LitStr = meta.value()?.parse()?;
882                            let algo = match value.value().as_str().to_lowercase().as_str() {
883                                "gzip" => quote!(Gzip),
884                                "deflate" => quote!(Deflate),
885                                "zlib" => quote!(Zlib),
886                                "none" => quote!(NoCompression),
887                                _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
888                            };
889                            compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
890                        } else if meta.path.is_ident("level") {
891                            let value: syn::LitStr = meta.value()?.parse()?;
892                            let tmp: u32 = value.value().parse().map_err(|e| {
893                                syn::Error::new(value.span(), format!("Invalid level: {}", e))
894                            })?;
895                            level = Some(quote!(#tmp));
896                        } else if meta.path.is_ident("endian") {
897                            let value: syn::LitStr = meta.value()?.parse()?;
898                            let tmp = match value.value().as_str() {
899                                "native" => quote!(Native),
900                                "little" => quote!(Little),
901                                "big" => quote!(Big),
902                                _ => panic!("Unsupported endianness, supported: native, little, big"),
903                            };
904                            endian = Some(quote!(hpt::Endian::#tmp));
905                        }
906                        Ok(())
907                    })
908                    .unwrap();
909                }
910            }
911            compressions.push(compression_algo);
912            endians.push(endian);
913            compress_levels.push(level);
914            let name = &f.ident;
915            let ty = &f.ty;
916            quote! {
917                pub #name: <#ty as Save>::Meta
918            }
919        })
920        .collect::<Vec<_>>();
921
922    let call_save = fields.iter().enumerate().map(|(idx, f)| {
923        let name = &f.ident;
924        let ty = &f.ty;
925        let ident = format_ident!("field_{}", idx);
926        let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
927        let endian = endians[idx].clone().unwrap_or(quote!(endian));
928        let level = compress_levels[idx].clone().unwrap_or(quote!(level));
929        if let Some(name) = name {
930            quote! {
931                let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
932            }
933        } else {
934            quote! {
935                let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
936            }
937        }
938    });
939
940    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
941        let name = &f.ident;
942        let ident = format_ident!("field_{}", idx);
943        quote! {
944            #name: #ident
945        }
946    });
947
948    let expanded = quote! {
949        #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
950        #[serde(crate = "hpt::serde")]
951        #visibility struct #meta_name #ty_generics #where_clause  {
952            #(#meta_fields,)*
953        }
954        impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
955            type Meta = #meta_name #ty_generics;
956            fn __save(
957                data: &Self,
958                file: &mut std::fs::File,
959                len_so_far: &mut usize,
960                global_cnt: &mut usize,
961                compression_algo: hpt::CompressionAlgo,
962                endian: hpt::Endian,
963                level: u32,
964            ) -> std::io::Result<Self::Meta> {
965                #(#call_save)*
966                Ok(Self::Meta {
967                    #(#construct_fields),*
968                })
969            }
970        }
971    };
972
973    expanded.into()
974}
975
976/// generate load trait
977#[proc_macro_derive(Load)]
978pub fn impl_load(input: TokenStream) -> TokenStream {
979    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
980    let name = &ast.ident;
981    let meta_name = format_ident!("{}Meta", name);
982    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
983
984    let fields = match &ast.data {
985        syn::Data::Struct(s) => &s.fields,
986        _ => panic!("Load can only be derived for structs"),
987    };
988
989    let call_load = fields.iter().enumerate().map(|(idx, f)| {
990        let name = &f.ident;
991        let ident = format_ident!("field_{}", idx);
992        if let Some(name) = name {
993            quote! {
994                let #ident = self.#name.load(file)?;
995            }
996        } else {
997            quote! {
998                let #ident = self.#idx.load(file)?;
999            }
1000        }
1001    });
1002
1003    let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
1004        let name = &f.ident;
1005        let ident = format_ident!("field_{}", idx);
1006        quote! {
1007            #name: #ident
1008        }
1009    });
1010
1011    let expanded = quote! {
1012        impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
1013            type Output = #name #ty_generics;
1014            fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
1015                use hpt::MetaLoad;
1016                #(#call_load)*
1017                Ok(#name {
1018                    #(#construct_fields),*
1019                })
1020            }
1021        }
1022        impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
1023            fn load(path: &str) -> std::io::Result<Self> {
1024                use hpt::MetaLoad;
1025                let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
1026                let mut file = std::fs::File::open(path)?;
1027                meta.load(&mut file)
1028            }
1029        }
1030    };
1031
1032    expanded.into()
1033}
1034
1035/// generate from safetensors trait
1036#[proc_macro_derive(FromSafeTensors, attributes(map))]
1037pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1038    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1039    let struct_name = &ast.ident;
1040    let fields = match &ast.data {
1041        syn::Data::Struct(s) => &s.fields,
1042        _ => panic!("FromSafeTensors can only be derived for structs"),
1043    };
1044    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1045    let mut construct_fields = vec![];
1046    for (_, field) in fields.iter().enumerate() {
1047        let ty = &field.ty;
1048        let name = &field.ident;
1049        let mut value_construct = vec![];
1050        let mut from_construct = vec![];
1051        let mut params = vec![];
1052        let mut vec_len = None;
1053        for attr in &field.attrs {
1054            if attr.path().is_ident("map") {
1055                let mut path = None;
1056                let mut value = None;
1057                let mut tensor_name = None;
1058                let mut inner_type = None;
1059                attr.parse_nested_meta(|meta| {
1060                    if meta.path.is_ident("path") {
1061                        let value: syn::LitStr = meta.value()?.parse()?;
1062                        path = Some(value.value());
1063                    } else if meta.path.is_ident("value") {
1064                        let val: syn::Expr = meta.value()?.parse()?;
1065                        value = Some(val);
1066                    } else if meta.path.is_ident("tensor_name") {
1067                        let value: syn::LitStr = meta.value()?.parse()?;
1068                        tensor_name = Some(value.value());
1069                    } else if meta.path.is_ident("vec_len") {
1070                        let value: syn::LitInt = meta.value()?.parse()?;
1071                        vec_len = Some(value.base10_parse::<usize>().unwrap());
1072                    } else if meta.path.is_ident("inner_type") {
1073                        let value: syn::Ident = meta.value()?.parse()?;
1074                        inner_type = Some(value);
1075                    }
1076                    Ok(())
1077                })
1078                .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1079                params.push((path, value, tensor_name, vec_len, inner_type));
1080            }
1081        }
1082        let param_count = params.len();
1083        for (path, value, tensor_name, vec_len, inner_type) in params {
1084            if let Some(vec_len) = vec_len {
1085                let inner_type = inner_type.expect("inner_type is required for vec");
1086                if let Some(path) = path {
1087                    from_construct.push(quote! {
1088                        #path => {
1089                            let mut vec = vec![];
1090                            for i in 0..#vec_len {
1091                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1092                            }
1093                            vec
1094                        }
1095                    });
1096                } else {
1097                    value_construct.push(quote! {
1098                        {
1099                            let mut vec = vec![];
1100                            for i in 0..#vec_len {
1101                                vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1102                            }
1103                            vec
1104                        }
1105                    });
1106                }
1107            } else {
1108                match (path, value, tensor_name) {
1109                    (None, None, Some(tensor_name)) => {
1110                        value_construct.push(quote! {
1111                            <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1112                        });
1113                    }
1114                    (None, Some(value), None) => {
1115                        if param_count > 1 {
1116                            panic!("value without path means generic assignment, there can only be one value without path");
1117                        }
1118                        value_construct.push(quote! {
1119                            #value
1120                        });
1121                    }
1122                    (Some(path), None, Some(tensor_name)) => {
1123                        from_construct.push(quote! {
1124                            #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1125                        });
1126                    }
1127                    (Some(path), Some(value), None) => {
1128                        from_construct.push(quote! {
1129                            #path => #value,
1130                        });
1131                    }
1132
1133                    (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1134                        panic!("value and tensor_name cannot be used together");
1135                    }
1136                    (Some(_), None, None) | (None, None, None) => {
1137                        panic!("path and value are not present");
1138                    }
1139                }
1140            }
1141        }
1142        if !value_construct.is_empty() {
1143            construct_fields.push(quote! {
1144                #name: #(#value_construct)*
1145            });
1146        } else if !from_construct.is_empty() {
1147            construct_fields.push(quote! {
1148                #name: match path {
1149                    #(#from_construct)*
1150                    _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1151                }
1152            });
1153        } else {
1154            construct_fields.push(quote! {
1155                #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1156            });
1157        }
1158    }
1159    let expanded = quote! {
1160        impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1161            fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1162                Self {
1163                    #(#construct_fields),*
1164                }
1165            }
1166        }
1167    };
1168    // let syntax_tree = syn::parse2(expanded.clone()).expect(&format!(
1169    //     "failed to parse expanded: {}",
1170    //     expanded.to_string()
1171    // ));
1172    // let formatted = prettyplease::unparse(&syntax_tree);
1173    // println!("{}", formatted);
1174    expanded.into()
1175}