1#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22 __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23 __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30 impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31 impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66 any(target_feature = "sse", target_arch = "arm"),
67 not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74 start: Option<Expr>,
75 end: Option<Expr>,
76 step: Option<Expr>,
77}
78
79struct Selections {
80 selections: Vec<TokenStream>,
81}
82
83impl parse::Parse for SelectionParser {
84 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
85 let mut start: Option<Expr> = None;
86 let mut end: Option<Expr> = None;
87 let mut step: Option<Expr> = None;
88 if input.peek(syn::Lit)
89 || input.peek(syn::Ident)
90 || input.peek(syn::token::Paren)
91 || input.peek(Token![-])
92 {
93 start = Some(input.parse::<Expr>()?);
94 }
95 if input.peek(Token![:]) {
96 input.parse::<Token![:]>()?;
97 } else if input.is_empty() {
98 return Ok(Self { start, end, step });
99 } else {
100 return Err(syn::Error::new(
101 input.span(),
102 "unexpected token, expected `:`, Int or Ident",
103 ));
104 }
105 if input.peek(syn::Lit)
106 || input.peek(syn::Ident)
107 || input.peek(syn::token::Paren)
108 || input.peek(Token![-])
109 {
110 end = Some(input.parse::<Expr>()?);
111 }
112 if input.peek(Token![:]) {
113 input.parse::<Token![:]>()?;
114 }
115 if input.peek(syn::Lit)
116 || input.peek(syn::Ident)
117 || input.peek(syn::token::Paren)
118 || input.peek(Token![-])
119 {
120 step = Some(input.parse::<Expr>()?);
121 }
122 Ok(Self { start, end, step })
123 }
124}
125
126impl parse::Parse for Selections {
127 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
128 let mut selections: Vec<TokenStream> = vec![];
129 let mut tokenstream = TokenStream2::new();
130 while !input.is_empty() {
131 let lookahead = input.lookahead1();
132 if lookahead.peek(Token![,]) {
133 selections.push(tokenstream.into());
134 tokenstream = TokenStream2::new();
135 input.parse::<Token![,]>()?;
136 } else {
137 let t = input.parse::<TokenTree>()?;
138 tokenstream.extend(quote!(#t));
139 }
140 }
141 selections.push(tokenstream.into());
142 Ok(Self { selections })
143 }
144}
145
146#[proc_macro]
148pub fn match_selection(input: TokenStream) -> TokenStream {
149 let res: Selections = parse_macro_input!(input as Selections);
150 let mut slices: Vec<SelectionParser> = vec![];
151 for x in res.selections {
152 slices.push(parse_macro_input!(x as SelectionParser));
153 }
154 let mut ret_stream = TokenStream2::new();
155 let len = slices.len();
156 for (idx, x) in slices.into_iter().enumerate() {
157 match (x.start, x.end, x.step) {
158 (None, None, None) => {
159 ret_stream.extend(quote!(Slice::Full));
160 }
161 (None, None, Some(step)) => {
162 ret_stream.extend(quote!(Slice::StepByFullRange(#step)));
163 }
164 (None, Some(end), None) => {
165 ret_stream.extend(quote!(Slice::RangeTo(#end)));
166 }
167 (None, Some(end), Some(step)) => {
168 ret_stream.extend(quote!(Slice::StepByRangeTo((#end, #step))));
169 }
170 (Some(start), None, None) => {
171 ret_stream.extend(quote!(Slice::From(#start)));
172 }
173 (Some(start), None, Some(step)) => {
174 ret_stream.extend(quote!(Slice::StepByRangeFrom((#start, #step))));
175 }
176 (Some(start), Some(end), None) => {
177 ret_stream.extend(quote!(Slice::Range((#start, #end))));
178 }
179 (Some(start), Some(end), Some(step)) => {
180 ret_stream.extend(quote!(Slice::StepByRangeFromTo((#start, #end, #step))));
181 }
182 }
183 if idx != len - 1 {
184 ret_stream.extend(quote!(,));
185 }
186 }
187 quote!([#ret_stream]).into()
188}
189
190#[proc_macro]
192pub fn float_out_binary(_: TokenStream) -> TokenStream {
193 impl_float_out_binary()
194}
195
196#[cfg(feature = "cuda")]
197#[proc_macro]
199pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
200 impl_cuda_float_out_binary()
201}
202
203#[proc_macro]
205pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
206 impl_simd_binary_out_float()
207}
208
209#[proc_macro]
211pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
212 impl_simd_binary_out_float_rhs_scalar()
213}
214
215#[proc_macro]
217pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
218 impl_simd_binary_out_float_lhs_scalar()
219}
220
221#[proc_macro]
223pub fn float_out_unary(_: TokenStream) -> TokenStream {
224 impl_float_out_unary()
225}
226
227#[cfg(feature = "cuda")]
228#[proc_macro]
230pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
231 crate::float_unary::impl_cuda_float_out_unary()
232}
233
234#[proc_macro]
236pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
237 simd_float_out_unary::impl_float_out_unary()
238}
239
240#[proc_macro]
242pub fn simd_eval(_: TokenStream) -> TokenStream {
243 simd_eval::impl_simd_eval()
244}
245
246#[proc_macro]
248pub fn simd_bitwise(_: TokenStream) -> TokenStream {
249 impl_simd_bitwise_out()
250}
251
252#[proc_macro]
254pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
255 __impl_normal_out_binary()
256}
257
258#[cfg(feature = "cuda")]
259#[proc_macro]
261pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
262 crate::normal_out::__impl_cuda_normal_out_binary()
263}
264
265#[proc_macro]
267pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
268 normal_out_unary::__impl_normal_out_unary()
269}
270
271#[cfg(feature = "cuda")]
272#[proc_macro]
274pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
275 normal_out_unary::__impl_normal_out_unary_cuda()
276}
277
278#[proc_macro]
280pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
281 simd_normal_unary::impl_simd_normal_out_unary()
282}
283
284#[proc_macro]
286pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
287 impl_simd_normal_out()
288}
289
290#[proc_macro]
292pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
293 impl_simd_normal_out_with_rhs_scalar()
294}
295
296#[proc_macro]
298pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
299 impl_simd_normal_out_with_lhs_scalar()
300}
301
302#[proc_macro]
304pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
305 __impl_scalar_convert()
306}
307
308#[proc_macro]
310pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
311 __impl_from_scalar()
312}
313
314#[proc_macro]
316pub fn simd_cmp(_: TokenStream) -> TokenStream {
317 impl_simd_cmp()
318}
319
320#[proc_macro]
322pub fn impl_into_vec(_: TokenStream) -> TokenStream {
323 into_vec::into_vec()
324}
325
326#[cfg(feature = "cuda")]
327#[proc_macro]
329pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
330 into_cuda_scalar::__impl_into_cuda_scalar().into()
331}
332
333#[proc_macro]
335pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
336 into_scalar::__impl_into_scalar().into()
337}
338
339#[proc_macro]
341pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
342 let mut ret = proc_macro2::TokenStream::new();
343
344 let types = [
345 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
346 ];
347
348 for lhs in types.iter() {
349 for rhs in types.iter() {
350 let lhs_type = TypeInfo::new(lhs);
351 let rhs_type = TypeInfo::new(rhs);
352 let lhs_dtype = lhs_type.dtype;
353 let rhs_dtype = rhs_type.dtype;
354 let res = if lhs_dtype == rhs_dtype {
355 quote! {
356 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
357 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
358 #[inline(always)]
359 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
360 self.__bitand(rhs)
361 }
362 #[inline(always)]
363 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
364 self.__bitor(rhs)
365 }
366 #[inline(always)]
367 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
368 self.__bitxor(rhs)
369 }
370 #[inline(always)]
371 fn _not(self) -> Self::Output {
372 self.__not()
373 }
374 #[inline(always)]
375 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
376 self.__shl(rhs)
377 }
378 #[inline(always)]
379 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
380 self.__shr(rhs)
381 }
382 }
383 }
384 } else {
385 quote! {
386 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
387 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
388 #[inline(always)]
389 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
390 let lhs: Self::Output = self.cast();
391 let rhs: Self::Output = rhs.cast();
392 lhs.__bitand(rhs)
393 }
394 #[inline(always)]
395 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
396 let lhs: Self::Output = self.cast();
397 let rhs: Self::Output = rhs.cast();
398 lhs.__bitor(rhs)
399 }
400 #[inline(always)]
401 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
402 let lhs: Self::Output = self.cast();
403 let rhs: Self::Output = rhs.cast();
404 lhs.__bitxor(rhs)
405 }
406 #[inline(always)]
407 fn _not(self) -> Self::Output {
408 let lhs: Self::Output = self.cast();
409 lhs.__not()
410 }
411 #[inline(always)]
412 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
413 let lhs: Self::Output = self.cast();
414 let rhs: Self::Output = rhs.cast();
415 lhs.__shl(rhs)
416 }
417 #[inline(always)]
418 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419 let lhs: Self::Output = self.cast();
420 let rhs: Self::Output = rhs.cast();
421 lhs.__shr(rhs)
422 }
423 }
424 }
425 };
426 ret.extend(res);
427 }
428 }
429
430 ret.into()
431}
432
433#[proc_macro]
435pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
436 let mut ret = proc_macro2::TokenStream::new();
437
438 let types = [
439 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
440 ];
441
442 for lhs in types.iter() {
443 for rhs in types.iter() {
444 let lhs_type = TypeInfo::new(lhs);
445 let rhs_type = TypeInfo::new(rhs);
446 let lhs_dtype = lhs_type.dtype;
447 let rhs_dtype = rhs_type.dtype;
448 let res = if lhs_dtype == rhs_dtype {
449 quote! {
450 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
451 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
452 #[inline(always)]
453 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
454 self.__bitand(rhs)
455 }
456 #[inline(always)]
457 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
458 self.__bitor(rhs)
459 }
460 #[inline(always)]
461 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
462 self.__bitxor(rhs)
463 }
464 #[inline(always)]
465 fn _not(self) -> Self::Output {
466 self.__not()
467 }
468 #[inline(always)]
469 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
470 self.__shl(rhs)
471 }
472 #[inline(always)]
473 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
474 self.__shr(rhs)
475 }
476 }
477 }
478 } else {
479 quote! {
480 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
481 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
482 #[inline(always)]
483 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
484 let lhs: Self::Output = self.cast();
485 let rhs: Self::Output = rhs.cast();
486 lhs.__bitand(rhs)
487 }
488 #[inline(always)]
489 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
490 let lhs: Self::Output = self.cast();
491 let rhs: Self::Output = rhs.cast();
492 lhs.__bitor(rhs)
493 }
494 #[inline(always)]
495 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
496 let lhs: Self::Output = self.cast();
497 let rhs: Self::Output = rhs.cast();
498 lhs.__bitxor(rhs)
499 }
500 #[inline(always)]
501 fn _not(self) -> Self::Output {
502 let lhs: Self::Output = self.cast();
503 lhs.__not()
504 }
505 #[inline(always)]
506 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
507 let lhs: Self::Output = self.cast();
508 let rhs: Self::Output = rhs.cast();
509 lhs.__shl(rhs)
510 }
511 #[inline(always)]
512 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513 let lhs: Self::Output = self.cast();
514 let rhs: Self::Output = rhs.cast();
515 lhs.__shr(rhs)
516 }
517 }
518 }
519 };
520 ret.extend(res);
521 }
522 }
523
524 ret.into()
525}
526
527#[proc_macro]
529pub fn impl_cmp(_: TokenStream) -> TokenStream {
530 let mut ret = proc_macro2::TokenStream::new();
531
532 let types = [
533 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
534 "usize", "bf16",
535 ];
536
537 for lhs in types.iter() {
538 for rhs in types.iter() {
539 let lhs_type = TypeInfo::new(lhs);
540 let rhs_type = TypeInfo::new(rhs);
541 let lhs_dtype = lhs_type.dtype;
542 let rhs_dtype = rhs_type.dtype;
543 let res = if lhs_dtype == rhs_dtype {
544 quote! {
545 impl Cmp<#rhs_dtype> for #lhs_dtype {
546 type Output = bool;
547 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
548 self == rhs
549 }
550 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
551 self != rhs
552 }
553 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
554 self < rhs
555 }
556
557 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
558 self <= rhs
559 }
560 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
561 self > rhs
562 }
563 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
564 self >= rhs
565 }
566 }
567 }
568 } else {
569 quote! {
570 impl Cmp<#rhs_dtype> for #lhs_dtype {
571 type Output = bool;
572 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
573 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
574 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
575 lhs == rhs
576 }
577 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
578 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
579 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
580 lhs != rhs
581 }
582 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
583 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
584 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
585 lhs < rhs
586 }
587
588 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
589 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
590 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
591 lhs <= rhs
592 }
593 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
594 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
595 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
596 lhs > rhs
597 }
598 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
599 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
600 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
601 lhs >= rhs
602 }
603 }
604 }
605 };
606 ret.extend(res);
607 }
608 }
609
610 ret.into()
611}
612
613#[proc_macro]
615pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
616 let mut ret = proc_macro2::TokenStream::new();
617
618 let types = [
619 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
620 "usize", "bf16",
621 ];
622
623 for lhs in types.iter() {
624 for rhs in types.iter() {
625 let lhs_type = TypeInfo::new(lhs);
626 let rhs_type = TypeInfo::new(rhs);
627 let lhs_dtype = lhs_type.dtype;
628 let rhs_dtype = rhs_type.dtype;
629 let res = if lhs_dtype == rhs_dtype {
630 quote! {
631 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
632 type Output = Scalar<bool>;
633 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
634 self.__eq(rhs)
635 }
636 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
637 self.__ne(rhs)
638 }
639 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
640 self.__lt(rhs)
641 }
642
643 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
644 self.__le(rhs)
645 }
646 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
647 self.__gt(rhs)
648 }
649 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
650 self.__ge(rhs)
651 }
652 }
653 }
654 } else {
655 quote! {
656 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
657 type Output = Scalar<bool>;
658 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
659 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
660 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
661 lhs.__eq(rhs)
662 }
663 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
664 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
665 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
666 lhs.__ne(rhs)
667 }
668 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
669 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
670 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
671 lhs.__lt(rhs)
672 }
673
674 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
675 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
676 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
677 lhs.__le(rhs)
678 }
679 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
680 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
681 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
682 lhs.__gt(rhs)
683 }
684 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
685 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
686 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
687 lhs.__ge(rhs)
688 }
689 }
690 }
691 };
692 ret.extend(res);
693 }
694 }
695
696 ret.into()
697}
698
699#[proc_macro]
701pub fn impl_eval(_: TokenStream) -> TokenStream {
702 let mut ret = proc_macro2::TokenStream::new();
703
704 let types = [
705 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
706 "usize", "bf16",
707 ];
708
709 for lhs in types.iter() {
710 let lhs_type = TypeInfo::new(lhs);
711 let lhs_dtype = lhs_type.dtype;
712
713 let res = quote! {
714 impl Eval for #lhs_dtype {
715 type Output = bool;
716 #[inline(always)]
717 fn _is_nan(&self) -> bool {
718 self.__is_nan()
719 }
720 #[inline(always)]
721 fn _is_true(&self) -> bool {
722 self.__is_true()
723 }
724 #[inline(always)]
725 fn _is_inf(&self) -> bool {
726 self.__is_inf()
727 }
728 }
729 };
730 ret.extend(res);
731 }
732
733 ret.into()
734}
735
736#[proc_macro]
738pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
739 __gen_fast_reduce_simd_helper(input)
740}
741
742#[proc_macro]
744pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
745 __gen_fast_layernorm_simd_helper(input)
746}
747
748#[proc_macro]
750pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
751 __gen_reduce_dim_not_include_simd_helper(input)
752}
753
754#[proc_macro]
760pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
761 conv2d::conv2d_microkernel_declare_const(input)
762}
763
764#[proc_macro]
766pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
767 conv2d::conv2d_microkernel_gen_inps(input)
768}
769
770#[proc_macro]
772pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
773 conv2d::conv2d_microkernel_gen_pad_inps(input)
774}
775
776#[proc_macro]
778pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
779 conv2d::pwconv2d_microkernel_gen_pad_inps(input)
780}
781
782#[proc_macro]
784pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
785 conv2d::dwconv2d_microkernel_gen_pad_inps(input)
786}
787
788#[proc_macro]
790pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
791 conv2d::conv2d_microkernel_gen_kernels(input)
792}
793
794#[proc_macro]
796pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
797 conv2d::conv2d_microkernel_gen_results(input)
798}
799
800#[proc_macro]
802pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
803 conv2d::dwconv2d_microkernel_gen_results(input)
804}
805
806#[proc_macro]
809pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
810 conv2d::maxpool2d_microkernel_gen_results(input)
811}
812
813#[proc_macro_derive(Save, attributes(compress))]
815pub fn impl_save(input: TokenStream) -> TokenStream {
816 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
817 let name = &ast.ident;
818 let meta_name = format_ident!("{}Meta", name);
819
820 let visibility = &ast.vis;
821 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
822 let fields = match &ast.data {
823 syn::Data::Struct(s) => &s.fields,
824 _ => panic!("Save can only be derived for structs"),
825 };
826
827 let mut compressions = vec![];
828 let mut endians = vec![];
829 let mut compress_levels = vec![];
830
831 let meta_fields = fields
832 .iter()
833 .map(|f| {
834 let mut compression_algo = None;
835 let mut endian = None;
836 let mut level = None;
837
838 for attr in &f.attrs {
839 if attr.path().is_ident("compress") {
840 attr.parse_nested_meta(|meta| {
841 if meta.path.is_ident("algo") {
842 let value: syn::LitStr = meta.value()?.parse()?;
843 let algo = match value.value().as_str().to_lowercase().as_str() {
844 "gzip" => quote!(Gzip),
845 "deflate" => quote!(Deflate),
846 "zlib" => quote!(Zlib),
847 "none" => quote!(NoCompression),
848 _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
849 };
850 compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
851 } else if meta.path.is_ident("level") {
852 let value: syn::LitStr = meta.value()?.parse()?;
853 let tmp: u32 = value.value().parse().map_err(|e| {
854 syn::Error::new(value.span(), format!("Invalid level: {}", e))
855 })?;
856 level = Some(quote!(#tmp));
857 } else if meta.path.is_ident("endian") {
858 let value: syn::LitStr = meta.value()?.parse()?;
859 let tmp = match value.value().as_str() {
860 "native" => quote!(Native),
861 "little" => quote!(Little),
862 "big" => quote!(Big),
863 _ => panic!("Unsupported endianness, supported: native, little, big"),
864 };
865 endian = Some(quote!(hpt::Endian::#tmp));
866 }
867 Ok(())
868 })
869 .unwrap();
870 }
871 }
872 compressions.push(compression_algo);
873 endians.push(endian);
874 compress_levels.push(level);
875 let name = &f.ident;
876 let ty = &f.ty;
877 quote! {
878 pub #name: <#ty as Save>::Meta
879 }
880 })
881 .collect::<Vec<_>>();
882
883 let call_save = fields.iter().enumerate().map(|(idx, f)| {
884 let name = &f.ident;
885 let ty = &f.ty;
886 let ident = format_ident!("field_{}", idx);
887 let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
888 let endian = endians[idx].clone().unwrap_or(quote!(endian));
889 let level = compress_levels[idx].clone().unwrap_or(quote!(level));
890 if let Some(name) = name {
891 quote! {
892 let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
893 }
894 } else {
895 quote! {
896 let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
897 }
898 }
899 });
900
901 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
902 let name = &f.ident;
903 let ident = format_ident!("field_{}", idx);
904 quote! {
905 #name: #ident
906 }
907 });
908
909 let expanded = quote! {
910 #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
911 #[serde(crate = "hpt::serde")]
912 #visibility struct #meta_name #ty_generics #where_clause {
913 #(#meta_fields,)*
914 }
915 impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
916 type Meta = #meta_name #ty_generics;
917 fn __save(
918 data: &Self,
919 file: &mut std::fs::File,
920 len_so_far: &mut usize,
921 global_cnt: &mut usize,
922 compression_algo: hpt::CompressionAlgo,
923 endian: hpt::Endian,
924 level: u32,
925 ) -> std::io::Result<Self::Meta> {
926 #(#call_save)*
927 Ok(Self::Meta {
928 #(#construct_fields),*
929 })
930 }
931 }
932 };
933
934 expanded.into()
935}
936
937#[proc_macro_derive(Load)]
939pub fn impl_load(input: TokenStream) -> TokenStream {
940 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
941 let name = &ast.ident;
942 let meta_name = format_ident!("{}Meta", name);
943 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
944
945 let fields = match &ast.data {
946 syn::Data::Struct(s) => &s.fields,
947 _ => panic!("Load can only be derived for structs"),
948 };
949
950 let call_load = fields.iter().enumerate().map(|(idx, f)| {
951 let name = &f.ident;
952 let ident = format_ident!("field_{}", idx);
953 if let Some(name) = name {
954 quote! {
955 let #ident = self.#name.load(file)?;
956 }
957 } else {
958 quote! {
959 let #ident = self.#idx.load(file)?;
960 }
961 }
962 });
963
964 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
965 let name = &f.ident;
966 let ident = format_ident!("field_{}", idx);
967 quote! {
968 #name: #ident
969 }
970 });
971
972 let expanded = quote! {
973 impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
974 type Output = #name #ty_generics;
975 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
976 use hpt::MetaLoad;
977 #(#call_load)*
978 Ok(#name {
979 #(#construct_fields),*
980 })
981 }
982 }
983 impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
984 fn load(path: &str) -> std::io::Result<Self> {
985 use hpt::MetaLoad;
986 let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
987 let mut file = std::fs::File::open(path)?;
988 meta.load(&mut file)
989 }
990 }
991 };
992
993 expanded.into()
994}
995
996#[proc_macro_derive(FromSafeTensors, attributes(map))]
998pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
999 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1000 let struct_name = &ast.ident;
1001 let fields = match &ast.data {
1002 syn::Data::Struct(s) => &s.fields,
1003 _ => panic!("FromSafeTensors can only be derived for structs"),
1004 };
1005 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1006 let mut construct_fields = vec![];
1007 for (_, field) in fields.iter().enumerate() {
1008 let ty = &field.ty;
1009 let name = &field.ident;
1010 let mut value_construct = vec![];
1011 let mut from_construct = vec![];
1012 let mut params = vec![];
1013 let mut vec_len = None;
1014 for attr in &field.attrs {
1015 if attr.path().is_ident("map") {
1016 let mut path = None;
1017 let mut value = None;
1018 let mut tensor_name = None;
1019 let mut inner_type = None;
1020 attr.parse_nested_meta(|meta| {
1021 if meta.path.is_ident("path") {
1022 let value: syn::LitStr = meta.value()?.parse()?;
1023 path = Some(value.value());
1024 } else if meta.path.is_ident("value") {
1025 let val: syn::Expr = meta.value()?.parse()?;
1026 value = Some(val);
1027 } else if meta.path.is_ident("tensor_name") {
1028 let value: syn::LitStr = meta.value()?.parse()?;
1029 tensor_name = Some(value.value());
1030 } else if meta.path.is_ident("vec_len") {
1031 let value: syn::LitInt = meta.value()?.parse()?;
1032 vec_len = Some(value.base10_parse::<usize>().unwrap());
1033 } else if meta.path.is_ident("inner_type") {
1034 let value: syn::Ident = meta.value()?.parse()?;
1035 inner_type = Some(value);
1036 }
1037 Ok(())
1038 })
1039 .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1040 params.push((path, value, tensor_name, vec_len, inner_type));
1041 }
1042 }
1043 let param_count = params.len();
1044 for (path, value, tensor_name, vec_len, inner_type) in params {
1045 if let Some(vec_len) = vec_len {
1046 let inner_type = inner_type.expect("inner_type is required for vec");
1047 if let Some(path) = path {
1048 from_construct.push(quote! {
1049 #path => {
1050 let mut vec = vec![];
1051 for i in 0..#vec_len {
1052 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1053 }
1054 vec
1055 }
1056 });
1057 } else {
1058 value_construct.push(quote! {
1059 {
1060 let mut vec = vec![];
1061 for i in 0..#vec_len {
1062 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1063 }
1064 vec
1065 }
1066 });
1067 }
1068 } else {
1069 match (path, value, tensor_name) {
1070 (None, None, Some(tensor_name)) => {
1071 value_construct.push(quote! {
1072 <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1073 });
1074 }
1075 (None, Some(value), None) => {
1076 if param_count > 1 {
1077 panic!("value without path means generic assignment, there can only be one value without path");
1078 }
1079 value_construct.push(quote! {
1080 #value
1081 });
1082 }
1083 (Some(path), None, Some(tensor_name)) => {
1084 from_construct.push(quote! {
1085 #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1086 });
1087 }
1088 (Some(path), Some(value), None) => {
1089 from_construct.push(quote! {
1090 #path => #value,
1091 });
1092 }
1093
1094 (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1095 panic!("value and tensor_name cannot be used together");
1096 }
1097 (Some(_), None, None) | (None, None, None) => {
1098 panic!("path and value are not present");
1099 }
1100 }
1101 }
1102 }
1103 if !value_construct.is_empty() {
1104 construct_fields.push(quote! {
1105 #name: #(#value_construct)*
1106 });
1107 } else if !from_construct.is_empty() {
1108 construct_fields.push(quote! {
1109 #name: match path {
1110 #(#from_construct)*
1111 _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1112 }
1113 });
1114 } else {
1115 construct_fields.push(quote! {
1116 #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1117 });
1118 }
1119 }
1120 let expanded = quote! {
1121 impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1122 fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1123 Self {
1124 #(#construct_fields),*
1125 }
1126 }
1127 }
1128 };
1129 expanded.into()
1136}