1#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22 __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23 __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30 impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31 impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66 any(target_feature = "sse", target_arch = "arm"),
67 not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74 start: Option<Expr>,
75 end: Option<Expr>,
76 step: Option<Expr>,
77}
78
79struct Selections {
80 selections: Vec<TokenStream>,
81}
82
83impl parse::Parse for SelectionParser {
84 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
85 let mut start: Option<Expr> = None;
86 let mut end: Option<Expr> = None;
87 let mut step: Option<Expr> = None;
88 if input.peek(syn::Lit)
89 || input.peek(syn::Ident)
90 || input.peek(syn::token::Paren)
91 || input.peek(Token![-])
92 {
93 start = Some(input.parse::<Expr>()?);
94 }
95 if input.peek(Token![:]) {
96 input.parse::<Token![:]>()?;
97 } else if input.is_empty() {
98 return Ok(Self { start, end, step });
99 } else {
100 return Err(syn::Error::new(
101 input.span(),
102 "unexpected token, expected `:`, Int or Ident",
103 ));
104 }
105 if input.peek(syn::Lit)
106 || input.peek(syn::Ident)
107 || input.peek(syn::token::Paren)
108 || input.peek(Token![-])
109 {
110 end = Some(input.parse::<Expr>()?);
111 }
112 if input.peek(Token![:]) {
113 input.parse::<Token![:]>()?;
114 }
115 if input.peek(syn::Lit)
116 || input.peek(syn::Ident)
117 || input.peek(syn::token::Paren)
118 || input.peek(Token![-])
119 {
120 step = Some(input.parse::<Expr>()?);
121 }
122 Ok(Self { start, end, step })
123 }
124}
125
126impl parse::Parse for Selections {
127 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
128 let mut selections: Vec<TokenStream> = vec![];
129 let mut tokenstream = TokenStream2::new();
130 while !input.is_empty() {
131 let lookahead = input.lookahead1();
132 if lookahead.peek(Token![,]) {
133 selections.push(tokenstream.into());
134 tokenstream = TokenStream2::new();
135 input.parse::<Token![,]>()?;
136 } else {
137 let t = input.parse::<TokenTree>()?;
138 tokenstream.extend(quote!(#t));
139 }
140 }
141 selections.push(tokenstream.into());
142 Ok(Self { selections })
143 }
144}
145
146#[proc_macro]
148pub fn match_selection(input: TokenStream) -> TokenStream {
149 let res: Selections = parse_macro_input!(input as Selections);
150 let mut slices: Vec<SelectionParser> = vec![];
151 for x in res.selections {
152 slices.push(parse_macro_input!(x as SelectionParser));
153 }
154 let mut ret_stream = TokenStream2::new();
155 let len = slices.len();
156 for (idx, x) in slices.into_iter().enumerate() {
157 match (x.start, x.end, x.step) {
158 (None, None, None) => {
159 ret_stream.extend(quote!(Slice::Full));
160 }
161 (None, None, Some(step)) => {
162 ret_stream.extend(quote!(Slice::StepByFullRange(#step)));
163 }
164 (None, Some(end), None) => {
165 ret_stream.extend(quote!(Slice::RangeTo(#end)));
166 }
167 (None, Some(end), Some(step)) => {
168 ret_stream.extend(quote!(Slice::StepByRangeTo((#end, #step))));
169 }
170 (Some(start), None, None) => {
171 ret_stream.extend(quote!(Slice::From(#start)));
172 }
173 (Some(start), None, Some(step)) => {
174 ret_stream.extend(quote!(Slice::StepByRangeFrom((#start, #step))));
175 }
176 (Some(start), Some(end), None) => {
177 ret_stream.extend(quote!(Slice::Range((#start, #end))));
178 }
179 (Some(start), Some(end), Some(step)) => {
180 ret_stream.extend(quote!(Slice::StepByRangeFromTo((#start, #end, #step))));
181 }
182 }
183 if idx != len - 1 {
184 ret_stream.extend(quote!(,));
185 }
186 }
187 quote!([#ret_stream]).into()
188}
189
190#[proc_macro]
192pub fn float_out_binary(_: TokenStream) -> TokenStream {
193 impl_float_out_binary()
194}
195
196#[cfg(feature = "cuda")]
197#[proc_macro]
199pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
200 impl_cuda_float_out_binary()
201}
202
203#[proc_macro]
205pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
206 impl_simd_binary_out_float()
207}
208
209#[proc_macro]
211pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
212 impl_simd_binary_out_float_rhs_scalar()
213}
214
215#[proc_macro]
217pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
218 impl_simd_binary_out_float_lhs_scalar()
219}
220
221#[proc_macro]
223pub fn float_out_unary(_: TokenStream) -> TokenStream {
224 impl_float_out_unary()
225}
226
227#[cfg(feature = "cuda")]
228#[proc_macro]
230pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
231 crate::float_unary::impl_cuda_float_out_unary()
232}
233
234#[proc_macro]
236pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
237 simd_float_out_unary::impl_float_out_unary()
238}
239
240#[proc_macro]
242pub fn simd_eval(_: TokenStream) -> TokenStream {
243 simd_eval::impl_simd_eval()
244}
245
246#[proc_macro]
248pub fn simd_bitwise(_: TokenStream) -> TokenStream {
249 impl_simd_bitwise_out()
250}
251
252#[proc_macro]
254pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
255 __impl_normal_out_binary()
256}
257
258#[cfg(feature = "cuda")]
259#[proc_macro]
261pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
262 crate::normal_out::__impl_cuda_normal_out_binary()
263}
264
265#[proc_macro]
267pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
268 normal_out_unary::__impl_normal_out_unary()
269}
270
271#[cfg(feature = "cuda")]
272#[proc_macro]
274pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
275 normal_out_unary::__impl_normal_out_unary_cuda()
276}
277
278#[proc_macro]
280pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
281 simd_normal_unary::impl_simd_normal_out_unary()
282}
283
284#[proc_macro]
286pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
287 impl_simd_normal_out()
288}
289
290#[proc_macro]
292pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
293 impl_simd_normal_out_with_rhs_scalar()
294}
295
296#[proc_macro]
298pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
299 impl_simd_normal_out_with_lhs_scalar()
300}
301
302#[proc_macro]
304pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
305 __impl_scalar_convert()
306}
307
308#[proc_macro]
310pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
311 __impl_from_scalar()
312}
313
314#[proc_macro]
316pub fn simd_cmp(_: TokenStream) -> TokenStream {
317 impl_simd_cmp()
318}
319
320#[proc_macro]
322pub fn impl_into_vec(_: TokenStream) -> TokenStream {
323 into_vec::into_vec()
324}
325
326#[cfg(feature = "cuda")]
327#[proc_macro]
329pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
330 into_cuda_scalar::__impl_into_cuda_scalar().into()
331}
332
333#[proc_macro]
335pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
336 into_scalar::__impl_into_scalar().into()
337}
338
339#[proc_macro]
341pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
342 let mut ret = proc_macro2::TokenStream::new();
343
344 let types = [
345 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
346 ];
347
348 for lhs in types.iter() {
349 for rhs in types.iter() {
350 let lhs_type = TypeInfo::new(lhs);
351 let rhs_type = TypeInfo::new(rhs);
352 let lhs_dtype = lhs_type.dtype;
353 let rhs_dtype = rhs_type.dtype;
354 let res = if lhs_dtype == rhs_dtype {
355 quote! {
356 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
357 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
358 #[inline(always)]
359 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
360 self.__bitand(rhs)
361 }
362 #[inline(always)]
363 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
364 self.__bitor(rhs)
365 }
366 #[inline(always)]
367 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
368 self.__bitxor(rhs)
369 }
370 #[inline(always)]
371 fn _not(self) -> Self::Output {
372 self.__not()
373 }
374 #[inline(always)]
375 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
376 self.__shl(rhs)
377 }
378 #[inline(always)]
379 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
380 self.__shr(rhs)
381 }
382 }
383 }
384 } else {
385 quote! {
386 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
387 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
388 #[inline(always)]
389 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
390 let lhs: Self::Output = self.cast();
391 let rhs: Self::Output = rhs.cast();
392 lhs.__bitand(rhs)
393 }
394 #[inline(always)]
395 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
396 let lhs: Self::Output = self.cast();
397 let rhs: Self::Output = rhs.cast();
398 lhs.__bitor(rhs)
399 }
400 #[inline(always)]
401 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
402 let lhs: Self::Output = self.cast();
403 let rhs: Self::Output = rhs.cast();
404 lhs.__bitxor(rhs)
405 }
406 #[inline(always)]
407 fn _not(self) -> Self::Output {
408 let lhs: Self::Output = self.cast();
409 lhs.__not()
410 }
411 #[inline(always)]
412 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
413 let lhs: Self::Output = self.cast();
414 let rhs: Self::Output = rhs.cast();
415 lhs.__shl(rhs)
416 }
417 #[inline(always)]
418 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419 let lhs: Self::Output = self.cast();
420 let rhs: Self::Output = rhs.cast();
421 lhs.__shr(rhs)
422 }
423 }
424 }
425 };
426 ret.extend(res);
427 }
428 }
429
430 ret.into()
431}
432
433#[proc_macro]
435pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
436 let mut ret = proc_macro2::TokenStream::new();
437
438 let types = [
439 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
440 ];
441
442 for lhs in types.iter() {
443 for rhs in types.iter() {
444 let lhs_type = TypeInfo::new(lhs);
445 let rhs_type = TypeInfo::new(rhs);
446 let lhs_dtype = lhs_type.dtype;
447 let rhs_dtype = rhs_type.dtype;
448 let res = if lhs_dtype == rhs_dtype {
449 quote! {
450 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
451 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
452 #[inline(always)]
453 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
454 self.__bitand(rhs)
455 }
456 #[inline(always)]
457 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
458 self.__bitor(rhs)
459 }
460 #[inline(always)]
461 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
462 self.__bitxor(rhs)
463 }
464 #[inline(always)]
465 fn _not(self) -> Self::Output {
466 self.__not()
467 }
468 #[inline(always)]
469 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
470 self.__shl(rhs)
471 }
472 #[inline(always)]
473 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
474 self.__shr(rhs)
475 }
476 }
477 }
478 } else {
479 quote! {
480 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
481 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
482 #[inline(always)]
483 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
484 let lhs: Self::Output = self.cast();
485 let rhs: Self::Output = rhs.cast();
486 lhs.__bitand(rhs)
487 }
488 #[inline(always)]
489 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
490 let lhs: Self::Output = self.cast();
491 let rhs: Self::Output = rhs.cast();
492 lhs.__bitor(rhs)
493 }
494 #[inline(always)]
495 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
496 let lhs: Self::Output = self.cast();
497 let rhs: Self::Output = rhs.cast();
498 lhs.__bitxor(rhs)
499 }
500 #[inline(always)]
501 fn _not(self) -> Self::Output {
502 let lhs: Self::Output = self.cast();
503 lhs.__not()
504 }
505 #[inline(always)]
506 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
507 let lhs: Self::Output = self.cast();
508 let rhs: Self::Output = rhs.cast();
509 lhs.__shl(rhs)
510 }
511 #[inline(always)]
512 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513 let lhs: Self::Output = self.cast();
514 let rhs: Self::Output = rhs.cast();
515 lhs.__shr(rhs)
516 }
517 }
518 }
519 };
520 ret.extend(res);
521 }
522 }
523
524 ret.into()
525}
526
527#[proc_macro]
529pub fn impl_cmp(_: TokenStream) -> TokenStream {
530 let mut ret = proc_macro2::TokenStream::new();
531
532 let types = [
533 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
534 "usize", "bf16",
535 ];
536
537 for lhs in types.iter() {
538 for rhs in types.iter() {
539 let lhs_type = TypeInfo::new(lhs);
540 let rhs_type = TypeInfo::new(rhs);
541 let lhs_dtype = lhs_type.dtype;
542 let rhs_dtype = rhs_type.dtype;
543 let res = if lhs_dtype == rhs_dtype {
544 quote! {
545 impl Cmp<#rhs_dtype> for #lhs_dtype {
546 type Output = bool;
547 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
548 self == rhs
549 }
550 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
551 self != rhs
552 }
553 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
554 self < rhs
555 }
556
557 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
558 self <= rhs
559 }
560 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
561 self > rhs
562 }
563 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
564 self >= rhs
565 }
566 }
567 }
568 } else {
569 quote! {
570 impl Cmp<#rhs_dtype> for #lhs_dtype {
571 type Output = bool;
572 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
573 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
574 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
575 lhs == rhs
576 }
577 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
578 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
579 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
580 lhs != rhs
581 }
582 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
583 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
584 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
585 lhs < rhs
586 }
587
588 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
589 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
590 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
591 lhs <= rhs
592 }
593 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
594 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
595 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
596 lhs > rhs
597 }
598 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
599 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
600 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
601 lhs >= rhs
602 }
603 }
604 }
605 };
606 ret.extend(res);
607 }
608 }
609
610 ret.into()
611}
612
613#[proc_macro]
615pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
616 let mut ret = proc_macro2::TokenStream::new();
617
618 let types = [
619 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
620 "usize", "bf16",
621 ];
622
623 for lhs in types.iter() {
624 for rhs in types.iter() {
625 let lhs_type = TypeInfo::new(lhs);
626 let rhs_type = TypeInfo::new(rhs);
627 let lhs_dtype = lhs_type.dtype;
628 let rhs_dtype = rhs_type.dtype;
629 let res = if lhs_dtype == rhs_dtype {
630 quote! {
631 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
632 type Output = Scalar<bool>;
633 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
634 self.__eq(rhs)
635 }
636 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
637 self.__ne(rhs)
638 }
639 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
640 self.__lt(rhs)
641 }
642
643 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
644 self.__le(rhs)
645 }
646 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
647 self.__gt(rhs)
648 }
649 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
650 self.__ge(rhs)
651 }
652 }
653 }
654 } else {
655 quote! {
656 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
657 type Output = Scalar<bool>;
658 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
659 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
660 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
661 lhs.__eq(rhs)
662 }
663 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
664 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
665 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
666 lhs.__ne(rhs)
667 }
668 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
669 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
670 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
671 lhs.__lt(rhs)
672 }
673
674 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
675 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
676 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
677 lhs.__le(rhs)
678 }
679 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
680 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
681 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
682 lhs.__gt(rhs)
683 }
684 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
685 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
686 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
687 lhs.__ge(rhs)
688 }
689 }
690 }
691 };
692 ret.extend(res);
693 }
694 }
695
696 ret.into()
697}
698
699#[proc_macro]
701pub fn impl_eval(_: TokenStream) -> TokenStream {
702 let mut ret = proc_macro2::TokenStream::new();
703
704 let types = [
705 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
706 "usize", "bf16",
707 ];
708
709 for lhs in types.iter() {
710 let lhs_type = TypeInfo::new(lhs);
711 let lhs_dtype = lhs_type.dtype;
712
713 let res = quote! {
714 impl Eval for #lhs_dtype {
715 type Output = bool;
716 #[inline(always)]
717 fn _is_nan(&self) -> bool {
718 self.__is_nan()
719 }
720 #[inline(always)]
721 fn _is_true(&self) -> bool {
722 self.__is_true()
723 }
724 #[inline(always)]
725 fn _is_inf(&self) -> bool {
726 self.__is_inf()
727 }
728 }
729 };
730 ret.extend(res);
731 }
732
733 ret.into()
734}
735
736#[proc_macro]
738pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
739 __gen_fast_reduce_simd_helper(input)
740}
741
742#[proc_macro]
744pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
745 __gen_fast_layernorm_simd_helper(input)
746}
747
748#[proc_macro]
750pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
751 __gen_reduce_dim_not_include_simd_helper(input)
752}
753
754#[proc_macro]
760pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
761 conv2d::conv2d_microkernel_declare_const(input)
762}
763
764#[proc_macro]
766pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
767 conv2d::conv2d_microkernel_gen_inps(input)
768}
769
770#[proc_macro]
772pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
773 conv2d::conv2d_microkernel_gen_pad_inps(input)
774}
775
776#[proc_macro]
778pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
779 conv2d::pwconv2d_microkernel_gen_pad_inps(input)
780}
781
782#[proc_macro]
784pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
785 conv2d::dwconv2d_microkernel_gen_pad_inps(input)
786}
787
788#[proc_macro]
790pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
791 conv2d::conv2d_microkernel_gen_kernels(input)
792}
793
794#[proc_macro]
796pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
797 conv2d::conv2d_microkernel_gen_results(input)
798}
799
800#[proc_macro]
802pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
803 conv2d::dwconv2d_microkernel_gen_results(input)
804}
805
806#[proc_macro]
809pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
810 conv2d::maxpool2d_microkernel_gen_results(input)
811}
812
813#[proc_macro_derive(Save, attributes(compress))]
815pub fn impl_save(input: TokenStream) -> TokenStream {
816 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
817 let name = &ast.ident;
818 let meta_name = format_ident!("{}Meta", name);
819
820 let visibility = &ast.vis;
821 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
822 let fields = match &ast.data {
823 syn::Data::Struct(s) => &s.fields,
824 _ => panic!("Save can only be derived for structs"),
825 };
826
827 let mut compressions = vec![];
828 let mut endians = vec![];
829 let mut compress_levels = vec![];
830
831 let meta_fields = fields
832 .iter()
833 .map(|f| {
834 let mut compression_algo = None;
835 let mut endian = None;
836 let mut level = None;
837
838 for attr in &f.attrs {
839 if attr.path().is_ident("compress") {
840 attr.parse_nested_meta(|meta| {
841 if meta.path.is_ident("algo") {
842 let value: syn::LitStr = meta.value()?.parse()?;
843 let algo = match value.value().as_str().to_lowercase().as_str() {
844 "gzip" => quote!(Gzip),
845 "deflate" => quote!(Deflate),
846 "zlib" => quote!(Zlib),
847 "none" => quote!(NoCompression),
848 _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
849 };
850 compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
851 } else if meta.path.is_ident("level") {
852 let value: syn::LitStr = meta.value()?.parse()?;
853 let tmp: u32 = value.value().parse().map_err(|e| {
854 syn::Error::new(value.span(), format!("Invalid level: {}", e))
855 })?;
856 level = Some(quote!(#tmp));
857 } else if meta.path.is_ident("endian") {
858 let value: syn::LitStr = meta.value()?.parse()?;
859 let tmp = match value.value().as_str() {
860 "native" => quote!(Native),
861 "little" => quote!(Little),
862 "big" => quote!(Big),
863 _ => panic!("Unsupported endianness, supported: native, little, big"),
864 };
865 endian = Some(quote!(hpt::Endian::#tmp));
866 }
867 Ok(())
868 })
869 .unwrap();
870 }
871 }
872 compressions.push(compression_algo);
873 endians.push(endian);
874 compress_levels.push(level);
875 let name = &f.ident;
876 let ty = &f.ty;
877 quote! {
878 pub #name: <#ty as Save>::Meta
879 }
880 })
881 .collect::<Vec<_>>();
882
883 let call_save = fields.iter().enumerate().map(|(idx, f)| {
884 let name = &f.ident;
885 let ty = &f.ty;
886 let ident = format_ident!("field_{}", idx);
887 let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
888 let endian = endians[idx].clone().unwrap_or(quote!(endian));
889 let level = compress_levels[idx].clone().unwrap_or(quote!(level));
890 if let Some(name) = name {
891 quote! {
892 let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
893 }
894 } else {
895 quote! {
896 let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
897 }
898 }
899 });
900
901 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
902 let name = &f.ident;
903 let ident = format_ident!("field_{}", idx);
904 quote! {
905 #name: #ident
906 }
907 });
908
909 let expanded = quote! {
910 #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
911 #visibility struct #meta_name #ty_generics #where_clause {
912 #(#meta_fields,)*
913 }
914 impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
915 type Meta = #meta_name #ty_generics;
916 fn __save(
917 data: &Self,
918 file: &mut std::fs::File,
919 len_so_far: &mut usize,
920 global_cnt: &mut usize,
921 compression_algo: hpt::CompressionAlgo,
922 endian: hpt::Endian,
923 level: u32,
924 ) -> std::io::Result<Self::Meta> {
925 #(#call_save)*
926 Ok(Self::Meta {
927 #(#construct_fields),*
928 })
929 }
930 }
931 };
932
933 expanded.into()
934}
935
936#[proc_macro_derive(Load)]
938pub fn impl_load(input: TokenStream) -> TokenStream {
939 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
940 let name = &ast.ident;
941 let meta_name = format_ident!("{}Meta", name);
942 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
943
944 let fields = match &ast.data {
945 syn::Data::Struct(s) => &s.fields,
946 _ => panic!("Load can only be derived for structs"),
947 };
948
949 let call_load = fields.iter().enumerate().map(|(idx, f)| {
950 let name = &f.ident;
951 let ident = format_ident!("field_{}", idx);
952 if let Some(name) = name {
953 quote! {
954 let #ident = self.#name.load(file)?;
955 }
956 } else {
957 quote! {
958 let #ident = self.#idx.load(file)?;
959 }
960 }
961 });
962
963 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
964 let name = &f.ident;
965 let ident = format_ident!("field_{}", idx);
966 quote! {
967 #name: #ident
968 }
969 });
970
971 let expanded = quote! {
972 impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
973 type Output = #name #ty_generics;
974 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
975 use hpt::MetaLoad;
976 #(#call_load)*
977 Ok(#name {
978 #(#construct_fields),*
979 })
980 }
981 }
982 impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
983 fn load(path: &str) -> std::io::Result<Self> {
984 use hpt::MetaLoad;
985 let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
986 let mut file = std::fs::File::open(path)?;
987 meta.load(&mut file)
988 }
989 }
990 };
991
992 expanded.into()
993}
994
995#[proc_macro_derive(FromSafeTensors, attributes(map))]
997pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
998 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
999 let struct_name = &ast.ident;
1000 let fields = match &ast.data {
1001 syn::Data::Struct(s) => &s.fields,
1002 _ => panic!("FromSafeTensors can only be derived for structs"),
1003 };
1004 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1005 let mut construct_fields = vec![];
1006 for (_, field) in fields.iter().enumerate() {
1007 let ty = &field.ty;
1008 let name = &field.ident;
1009 let mut value_construct = vec![];
1010 let mut from_construct = vec![];
1011 let mut params = vec![];
1012 let mut vec_len = None;
1013 for attr in &field.attrs {
1014 if attr.path().is_ident("map") {
1015 let mut path = None;
1016 let mut value = None;
1017 let mut tensor_name = None;
1018 let mut inner_type = None;
1019 attr.parse_nested_meta(|meta| {
1020 if meta.path.is_ident("path") {
1021 let value: syn::LitStr = meta.value()?.parse()?;
1022 path = Some(value.value());
1023 } else if meta.path.is_ident("value") {
1024 let val: syn::Expr = meta.value()?.parse()?;
1025 value = Some(val);
1026 } else if meta.path.is_ident("tensor_name") {
1027 let value: syn::LitStr = meta.value()?.parse()?;
1028 tensor_name = Some(value.value());
1029 } else if meta.path.is_ident("vec_len") {
1030 let value: syn::LitInt = meta.value()?.parse()?;
1031 vec_len = Some(value.base10_parse::<usize>().unwrap());
1032 } else if meta.path.is_ident("inner_type") {
1033 let value: syn::Ident = meta.value()?.parse()?;
1034 inner_type = Some(value);
1035 }
1036 Ok(())
1037 })
1038 .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1039 params.push((path, value, tensor_name, vec_len, inner_type));
1040 }
1041 }
1042 let param_count = params.len();
1043 for (path, value, tensor_name, vec_len, inner_type) in params {
1044 if let Some(vec_len) = vec_len {
1045 let inner_type = inner_type.expect("inner_type is required for vec");
1046 if let Some(path) = path {
1047 from_construct.push(quote! {
1048 #path => {
1049 let mut vec = vec![];
1050 for i in 0..#vec_len {
1051 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1052 }
1053 vec
1054 }
1055 });
1056 } else {
1057 value_construct.push(quote! {
1058 {
1059 let mut vec = vec![];
1060 for i in 0..#vec_len {
1061 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1062 }
1063 vec
1064 }
1065 });
1066 }
1067 } else {
1068 match (path, value, tensor_name) {
1069 (None, None, Some(tensor_name)) => {
1070 value_construct.push(quote! {
1071 <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1072 });
1073 }
1074 (None, Some(value), None) => {
1075 if param_count > 1 {
1076 panic!("value without path means generic assignment, there can only be one value without path");
1077 }
1078 value_construct.push(quote! {
1079 #value
1080 });
1081 }
1082 (Some(path), None, Some(tensor_name)) => {
1083 from_construct.push(quote! {
1084 #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1085 });
1086 }
1087 (Some(path), Some(value), None) => {
1088 from_construct.push(quote! {
1089 #path => #value,
1090 });
1091 }
1092
1093 (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1094 panic!("value and tensor_name cannot be used together");
1095 }
1096 (Some(_), None, None) | (None, None, None) => {
1097 panic!("path and value are not present");
1098 }
1099 }
1100 }
1101 }
1102 if !value_construct.is_empty() {
1103 construct_fields.push(quote! {
1104 #name: #(#value_construct)*
1105 });
1106 } else if !from_construct.is_empty() {
1107 construct_fields.push(quote! {
1108 #name: match path {
1109 #(#from_construct)*
1110 _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1111 }
1112 });
1113 } else {
1114 construct_fields.push(quote! {
1115 #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1116 });
1117 }
1118 }
1119 let expanded = quote! {
1120 impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1121 fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1122 Self {
1123 #(#construct_fields),*
1124 }
1125 }
1126 }
1127 };
1128 expanded.into()
1135}