1#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22 __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23 __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_convert::__impl_simd_convert;
30use simd_float_out_binary::{
31 impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
32 impl_simd_binary_out_float_rhs_scalar,
33};
34use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
35use syn::{parse, parse_macro_input, Expr, Token};
36mod binary_float_out;
37mod conv2d;
38mod float_unary;
39mod from_scalar;
40mod into_cuda_scalar;
41mod into_scalar;
42mod into_vec;
43mod kernel_gen_helper;
44mod normal_out;
45mod normal_out_unary;
46mod scalar_convert;
47mod simd_bitwise;
48
49mod simd_cmp;
50mod simd_convert;
51mod simd_eval;
52mod simd_float_out_binary;
53mod simd_float_out_unary;
54mod simd_normal_out;
55mod simd_normal_unary;
56mod type_utils;
57
58use crate::simd_cmp::impl_simd_cmp;
59use crate::simd_normal_out::impl_simd_normal_out;
60use proc_macro2::{TokenStream as TokenStream2, TokenTree};
61use quote::{format_ident, quote};
62use type_utils::TypeInfo;
63
64#[cfg(target_feature = "avx2")]
66const NUM_REG: usize = 16;
67#[cfg(all(
68 any(target_feature = "sse", target_arch = "arm"),
69 not(target_feature = "avx2")
70))]
71const NUM_REG: usize = 8;
72#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
73const NUM_REG: usize = 32;
74
75struct SelectionParser {
76 start: Option<Expr>,
77 end: Option<Expr>,
78 step: Option<Expr>,
79}
80
81struct Selections {
82 selections: Vec<TokenStream>,
83}
84
85impl parse::Parse for SelectionParser {
86 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
87 let mut start: Option<Expr> = None;
88 let mut end: Option<Expr> = None;
89 let mut step: Option<Expr> = None;
90 if input.peek(syn::Lit)
91 || input.peek(syn::Ident)
92 || input.peek(syn::token::Paren)
93 || input.peek(Token![-])
94 {
95 start = Some(input.parse::<Expr>()?);
96 }
97 if input.peek(Token![:]) {
98 input.parse::<Token![:]>()?;
99 } else if input.is_empty() {
100 return Ok(Self { start, end, step });
101 } else {
102 return Err(syn::Error::new(
103 input.span(),
104 "unexpected token, expected `:`, Int or Ident",
105 ));
106 }
107 if input.peek(syn::Lit)
108 || input.peek(syn::Ident)
109 || input.peek(syn::token::Paren)
110 || input.peek(Token![-])
111 {
112 end = Some(input.parse::<Expr>()?);
113 }
114 if input.peek(Token![:]) {
115 input.parse::<Token![:]>()?;
116 }
117 if input.peek(syn::Lit)
118 || input.peek(syn::Ident)
119 || input.peek(syn::token::Paren)
120 || input.peek(Token![-])
121 {
122 step = Some(input.parse::<Expr>()?);
123 }
124 Ok(Self { start, end, step })
125 }
126}
127
128impl parse::Parse for Selections {
129 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
130 let mut selections: Vec<TokenStream> = vec![];
131 let mut tokenstream = TokenStream2::new();
132 while !input.is_empty() {
133 let lookahead = input.lookahead1();
134 if lookahead.peek(Token![,]) {
135 selections.push(tokenstream.into());
136 tokenstream = TokenStream2::new();
137 input.parse::<Token![,]>()?;
138 } else {
139 let t = input.parse::<TokenTree>()?;
140 tokenstream.extend(quote!(#t));
141 }
142 }
143 selections.push(tokenstream.into());
144 Ok(Self { selections })
145 }
146}
147
148#[proc_macro]
150pub fn match_selection(input: TokenStream) -> TokenStream {
151 let res: Selections = parse_macro_input!(input as Selections);
152 let mut slices: Vec<SelectionParser> = vec![];
153 for x in res.selections {
154 slices.push(parse_macro_input!(x as SelectionParser));
155 }
156 let mut ret_stream = TokenStream2::new();
157 let len = slices.len();
158 for (idx, x) in slices.into_iter().enumerate() {
159 match (x.start, x.end, x.step) {
160 (None, None, None) => {
161 ret_stream.extend(quote!(Slice::Full));
162 }
163 (None, None, Some(step)) => {
164 ret_stream.extend(quote!(Slice::StepByFullRange(#step)));
165 }
166 (None, Some(end), None) => {
167 ret_stream.extend(quote!(Slice::RangeTo(#end)));
168 }
169 (None, Some(end), Some(step)) => {
170 ret_stream.extend(quote!(Slice::StepByRangeTo((#end, #step))));
171 }
172 (Some(start), None, None) => {
173 ret_stream.extend(quote!(Slice::From(#start)));
174 }
175 (Some(start), None, Some(step)) => {
176 ret_stream.extend(quote!(Slice::StepByRangeFrom((#start, #step))));
177 }
178 (Some(start), Some(end), None) => {
179 ret_stream.extend(quote!(Slice::Range((#start, #end))));
180 }
181 (Some(start), Some(end), Some(step)) => {
182 ret_stream.extend(quote!(Slice::StepByRangeFromTo((#start, #end, #step))));
183 }
184 }
185 if idx != len - 1 {
186 ret_stream.extend(quote!(,));
187 }
188 }
189 quote!([#ret_stream]).into()
190}
191
192#[proc_macro]
194pub fn float_out_binary(_: TokenStream) -> TokenStream {
195 impl_float_out_binary()
196}
197
198#[cfg(feature = "cuda")]
199#[proc_macro]
201pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
202 impl_cuda_float_out_binary()
203}
204
205#[proc_macro]
207pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
208 impl_simd_binary_out_float()
209}
210
211#[proc_macro]
213pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
214 impl_simd_binary_out_float_rhs_scalar()
215}
216
217#[proc_macro]
219pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
220 impl_simd_binary_out_float_lhs_scalar()
221}
222
223#[proc_macro]
225pub fn float_out_unary(_: TokenStream) -> TokenStream {
226 impl_float_out_unary()
227}
228
229#[cfg(feature = "cuda")]
230#[proc_macro]
232pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
233 crate::float_unary::impl_cuda_float_out_unary()
234}
235
236#[proc_macro]
238pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
239 simd_float_out_unary::impl_float_out_unary()
240}
241
242#[proc_macro]
244pub fn simd_eval(_: TokenStream) -> TokenStream {
245 simd_eval::impl_simd_eval()
246}
247
248#[proc_macro]
250pub fn simd_bitwise(_: TokenStream) -> TokenStream {
251 impl_simd_bitwise_out()
252}
253
254#[proc_macro]
256pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
257 __impl_normal_out_binary()
258}
259
260#[cfg(feature = "cuda")]
261#[proc_macro]
263pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
264 crate::normal_out::__impl_cuda_normal_out_binary()
265}
266
267#[proc_macro]
269pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
270 normal_out_unary::__impl_normal_out_unary()
271}
272
273#[cfg(feature = "cuda")]
274#[proc_macro]
276pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
277 normal_out_unary::__impl_normal_out_unary_cuda()
278}
279
280#[proc_macro]
282pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
283 simd_normal_unary::impl_simd_normal_out_unary()
284}
285
286#[proc_macro]
288pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
289 impl_simd_normal_out()
290}
291
292#[proc_macro]
294pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
295 impl_simd_normal_out_with_rhs_scalar()
296}
297
298#[proc_macro]
300pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
301 impl_simd_normal_out_with_lhs_scalar()
302}
303
304#[proc_macro]
306pub fn impl_simd_convert(_: TokenStream) -> TokenStream {
307 __impl_simd_convert()
308}
309
310#[proc_macro]
312pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
313 __impl_scalar_convert()
314}
315
316#[proc_macro]
318pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
319 __impl_from_scalar()
320}
321
322#[proc_macro]
324pub fn simd_cmp(_: TokenStream) -> TokenStream {
325 impl_simd_cmp()
326}
327
328#[proc_macro]
330pub fn impl_into_vec(_: TokenStream) -> TokenStream {
331 into_vec::into_vec()
332}
333
334#[cfg(feature = "cuda")]
335#[proc_macro]
337pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
338 into_cuda_scalar::__impl_into_cuda_scalar().into()
339}
340
341#[proc_macro]
343pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
344 into_scalar::__impl_into_scalar().into()
345}
346
347#[proc_macro]
349pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
350 let mut ret = proc_macro2::TokenStream::new();
351
352 let types = [
353 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
354 ];
355
356 for lhs in types.iter() {
357 for rhs in types.iter() {
358 let lhs_type = TypeInfo::new(lhs);
359 let rhs_type = TypeInfo::new(rhs);
360 let lhs_dtype = lhs_type.dtype;
361 let rhs_dtype = rhs_type.dtype;
362 let res = if lhs_dtype == rhs_dtype {
363 quote! {
364 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
365 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
366 #[inline(always)]
367 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
368 self.__bitand(rhs)
369 }
370 #[inline(always)]
371 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
372 self.__bitor(rhs)
373 }
374 #[inline(always)]
375 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
376 self.__bitxor(rhs)
377 }
378 #[inline(always)]
379 fn _not(self) -> Self::Output {
380 self.__not()
381 }
382 #[inline(always)]
383 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
384 self.__shl(rhs)
385 }
386 #[inline(always)]
387 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
388 self.__shr(rhs)
389 }
390 }
391 }
392 } else {
393 quote! {
394 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
395 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
396 #[inline(always)]
397 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
398 let lhs: Self::Output = self.cast();
399 let rhs: Self::Output = rhs.cast();
400 lhs.__bitand(rhs)
401 }
402 #[inline(always)]
403 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
404 let lhs: Self::Output = self.cast();
405 let rhs: Self::Output = rhs.cast();
406 lhs.__bitor(rhs)
407 }
408 #[inline(always)]
409 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
410 let lhs: Self::Output = self.cast();
411 let rhs: Self::Output = rhs.cast();
412 lhs.__bitxor(rhs)
413 }
414 #[inline(always)]
415 fn _not(self) -> Self::Output {
416 let lhs: Self::Output = self.cast();
417 lhs.__not()
418 }
419 #[inline(always)]
420 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
421 let lhs: Self::Output = self.cast();
422 let rhs: Self::Output = rhs.cast();
423 lhs.__shl(rhs)
424 }
425 #[inline(always)]
426 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
427 let lhs: Self::Output = self.cast();
428 let rhs: Self::Output = rhs.cast();
429 lhs.__shr(rhs)
430 }
431 }
432 }
433 };
434 ret.extend(res);
435 }
436 }
437
438 ret.into()
439}
440
441#[proc_macro]
443pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
444 let mut ret = proc_macro2::TokenStream::new();
445
446 let types = [
447 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
448 ];
449
450 for lhs in types.iter() {
451 for rhs in types.iter() {
452 let lhs_type = TypeInfo::new(lhs);
453 let rhs_type = TypeInfo::new(rhs);
454 let lhs_dtype = lhs_type.dtype;
455 let rhs_dtype = rhs_type.dtype;
456 let res = if lhs_dtype == rhs_dtype {
457 quote! {
458 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
459 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
460 #[inline(always)]
461 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
462 self.__bitand(rhs)
463 }
464 #[inline(always)]
465 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
466 self.__bitor(rhs)
467 }
468 #[inline(always)]
469 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
470 self.__bitxor(rhs)
471 }
472 #[inline(always)]
473 fn _not(self) -> Self::Output {
474 self.__not()
475 }
476 #[inline(always)]
477 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
478 self.__shl(rhs)
479 }
480 #[inline(always)]
481 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
482 self.__shr(rhs)
483 }
484 }
485 }
486 } else {
487 quote! {
488 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
489 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
490 #[inline(always)]
491 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
492 let lhs: Self::Output = self.cast();
493 let rhs: Self::Output = rhs.cast();
494 lhs.__bitand(rhs)
495 }
496 #[inline(always)]
497 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
498 let lhs: Self::Output = self.cast();
499 let rhs: Self::Output = rhs.cast();
500 lhs.__bitor(rhs)
501 }
502 #[inline(always)]
503 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
504 let lhs: Self::Output = self.cast();
505 let rhs: Self::Output = rhs.cast();
506 lhs.__bitxor(rhs)
507 }
508 #[inline(always)]
509 fn _not(self) -> Self::Output {
510 let lhs: Self::Output = self.cast();
511 lhs.__not()
512 }
513 #[inline(always)]
514 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
515 let lhs: Self::Output = self.cast();
516 let rhs: Self::Output = rhs.cast();
517 lhs.__shl(rhs)
518 }
519 #[inline(always)]
520 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
521 let lhs: Self::Output = self.cast();
522 let rhs: Self::Output = rhs.cast();
523 lhs.__shr(rhs)
524 }
525 }
526 }
527 };
528 ret.extend(res);
529 }
530 }
531
532 ret.into()
533}
534
535#[proc_macro]
537pub fn impl_cmp(_: TokenStream) -> TokenStream {
538 let mut ret = proc_macro2::TokenStream::new();
539
540 let types = [
541 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
542 "usize", "bf16",
543 ];
544
545 for lhs in types.iter() {
546 for rhs in types.iter() {
547 let lhs_type = TypeInfo::new(lhs);
548 let rhs_type = TypeInfo::new(rhs);
549 let lhs_dtype = lhs_type.dtype;
550 let rhs_dtype = rhs_type.dtype;
551 let res = if lhs_dtype == rhs_dtype {
552 quote! {
553 impl Cmp<#rhs_dtype> for #lhs_dtype {
554 type Output = bool;
555 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
556 self == rhs
557 }
558 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
559 self != rhs
560 }
561 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
562 self < rhs
563 }
564
565 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
566 self <= rhs
567 }
568 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
569 self > rhs
570 }
571 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
572 self >= rhs
573 }
574 }
575 }
576 } else {
577 quote! {
578 impl Cmp<#rhs_dtype> for #lhs_dtype {
579 type Output = bool;
580 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
581 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
582 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
583 lhs == rhs
584 }
585 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
586 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
587 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
588 lhs != rhs
589 }
590 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
591 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
592 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
593 lhs < rhs
594 }
595
596 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
598 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
599 lhs <= rhs
600 }
601 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
602 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
603 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
604 lhs > rhs
605 }
606 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
607 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
608 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
609 lhs >= rhs
610 }
611 }
612 }
613 };
614 ret.extend(res);
615 }
616 }
617
618 ret.into()
619}
620
621#[proc_macro]
623pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
624 let mut ret = proc_macro2::TokenStream::new();
625
626 let types = [
627 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
628 "usize", "bf16",
629 ];
630
631 for lhs in types.iter() {
632 for rhs in types.iter() {
633 let lhs_type = TypeInfo::new(lhs);
634 let rhs_type = TypeInfo::new(rhs);
635 let lhs_dtype = lhs_type.dtype;
636 let rhs_dtype = rhs_type.dtype;
637 let res = if lhs_dtype == rhs_dtype {
638 quote! {
639 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
640 type Output = Scalar<bool>;
641 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
642 self.__eq(rhs)
643 }
644 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
645 self.__ne(rhs)
646 }
647 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
648 self.__lt(rhs)
649 }
650
651 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
652 self.__le(rhs)
653 }
654 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
655 self.__gt(rhs)
656 }
657 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
658 self.__ge(rhs)
659 }
660 }
661 }
662 } else {
663 quote! {
664 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
665 type Output = Scalar<bool>;
666 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
667 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
668 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
669 lhs.__eq(rhs)
670 }
671 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
672 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
673 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
674 lhs.__ne(rhs)
675 }
676 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
677 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
678 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
679 lhs.__lt(rhs)
680 }
681
682 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
684 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
685 lhs.__le(rhs)
686 }
687 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
688 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
689 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
690 lhs.__gt(rhs)
691 }
692 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
693 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
694 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
695 lhs.__ge(rhs)
696 }
697 }
698 }
699 };
700 ret.extend(res);
701 }
702 }
703
704 ret.into()
705}
706
707#[proc_macro]
709pub fn impl_eval(_: TokenStream) -> TokenStream {
710 let mut ret = proc_macro2::TokenStream::new();
711
712 let types = [
713 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
714 "usize", "bf16",
715 ];
716
717 for lhs in types.iter() {
718 let lhs_type = TypeInfo::new(lhs);
719 let lhs_dtype = lhs_type.dtype;
720
721 let res = quote! {
722 impl Eval for #lhs_dtype {
723 type Output = bool;
724 #[inline(always)]
725 fn _is_nan(&self) -> bool {
726 self.__is_nan()
727 }
728 #[inline(always)]
729 fn _is_true(&self) -> bool {
730 self.__is_true()
731 }
732 #[inline(always)]
733 fn _is_inf(&self) -> bool {
734 self.__is_inf()
735 }
736 }
737 };
738 ret.extend(res);
739 }
740
741 ret.into()
742}
743
744#[proc_macro]
746pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
747 __gen_fast_reduce_simd_helper(input)
748}
749
750#[proc_macro]
752pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
753 __gen_fast_layernorm_simd_helper(input)
754}
755
756#[proc_macro]
758pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
759 __gen_reduce_dim_not_include_simd_helper(input)
760}
761
762#[proc_macro]
768pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
769 conv2d::conv2d_microkernel_declare_const(input)
770}
771
772#[proc_macro]
774pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
775 conv2d::conv2d_microkernel_gen_inps(input)
776}
777
778#[proc_macro]
780pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
781 conv2d::conv2d_microkernel_gen_pad_inps(input)
782}
783
784#[proc_macro]
786pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
787 conv2d::pwconv2d_microkernel_gen_pad_inps(input)
788}
789
790#[proc_macro]
792pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
793 conv2d::dwconv2d_microkernel_gen_pad_inps(input)
794}
795
796#[proc_macro]
798pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
799 conv2d::conv2d_microkernel_gen_kernels(input)
800}
801
802#[proc_macro]
804pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
805 conv2d::conv2d_microkernel_gen_results(input)
806}
807
808#[proc_macro]
810pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
811 conv2d::dwconv2d_microkernel_gen_results(input)
812}
813
814#[proc_macro]
817pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
818 conv2d::maxpool2d_microkernel_gen_results(input)
819}
820
821#[proc_macro_derive(Save, attributes(compress))]
823pub fn impl_save(input: TokenStream) -> TokenStream {
824 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
825 let name = &ast.ident;
826 let meta_name = format_ident!("{}Meta", name);
827
828 let visibility = &ast.vis;
829 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
830 let fields = match &ast.data {
831 syn::Data::Struct(s) => &s.fields,
832 _ => panic!("Save can only be derived for structs"),
833 };
834
835 let mut compressions = vec![];
836 let mut endians = vec![];
837 let mut compress_levels = vec![];
838
839 let meta_fields = fields
840 .iter()
841 .map(|f| {
842 let mut compression_algo = None;
843 let mut endian = None;
844 let mut level = None;
845
846 for attr in &f.attrs {
847 if attr.path().is_ident("compress") {
848 attr.parse_nested_meta(|meta| {
849 if meta.path.is_ident("algo") {
850 let value: syn::LitStr = meta.value()?.parse()?;
851 let algo = match value.value().as_str().to_lowercase().as_str() {
852 "gzip" => quote!(Gzip),
853 "deflate" => quote!(Deflate),
854 "zlib" => quote!(Zlib),
855 "none" => quote!(NoCompression),
856 _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
857 };
858 compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
859 } else if meta.path.is_ident("level") {
860 let value: syn::LitStr = meta.value()?.parse()?;
861 let tmp: u32 = value.value().parse().map_err(|e| {
862 syn::Error::new(value.span(), format!("Invalid level: {}", e))
863 })?;
864 level = Some(quote!(#tmp));
865 } else if meta.path.is_ident("endian") {
866 let value: syn::LitStr = meta.value()?.parse()?;
867 let tmp = match value.value().as_str() {
868 "native" => quote!(Native),
869 "little" => quote!(Little),
870 "big" => quote!(Big),
871 _ => panic!("Unsupported endianness, supported: native, little, big"),
872 };
873 endian = Some(quote!(hpt::Endian::#tmp));
874 }
875 Ok(())
876 })
877 .unwrap();
878 }
879 }
880 compressions.push(compression_algo);
881 endians.push(endian);
882 compress_levels.push(level);
883 let name = &f.ident;
884 let ty = &f.ty;
885 quote! {
886 pub #name: <#ty as Save>::Meta
887 }
888 })
889 .collect::<Vec<_>>();
890
891 let call_save = fields.iter().enumerate().map(|(idx, f)| {
892 let name = &f.ident;
893 let ty = &f.ty;
894 let ident = format_ident!("field_{}", idx);
895 let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
896 let endian = endians[idx].clone().unwrap_or(quote!(endian));
897 let level = compress_levels[idx].clone().unwrap_or(quote!(level));
898 if let Some(name) = name {
899 quote! {
900 let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
901 }
902 } else {
903 quote! {
904 let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
905 }
906 }
907 });
908
909 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
910 let name = &f.ident;
911 let ident = format_ident!("field_{}", idx);
912 quote! {
913 #name: #ident
914 }
915 });
916
917 let expanded = quote! {
918 #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
919 #visibility struct #meta_name #ty_generics #where_clause {
920 #(#meta_fields,)*
921 }
922 impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
923 type Meta = #meta_name #ty_generics;
924 fn __save(
925 data: &Self,
926 file: &mut std::fs::File,
927 len_so_far: &mut usize,
928 global_cnt: &mut usize,
929 compression_algo: hpt::CompressionAlgo,
930 endian: hpt::Endian,
931 level: u32,
932 ) -> std::io::Result<Self::Meta> {
933 #(#call_save)*
934 Ok(Self::Meta {
935 #(#construct_fields),*
936 })
937 }
938 }
939 };
940
941 expanded.into()
942}
943
944#[proc_macro_derive(Load)]
946pub fn impl_load(input: TokenStream) -> TokenStream {
947 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
948 let name = &ast.ident;
949 let meta_name = format_ident!("{}Meta", name);
950 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
951
952 let fields = match &ast.data {
953 syn::Data::Struct(s) => &s.fields,
954 _ => panic!("Load can only be derived for structs"),
955 };
956
957 let call_load = fields.iter().enumerate().map(|(idx, f)| {
958 let name = &f.ident;
959 let ident = format_ident!("field_{}", idx);
960 if let Some(name) = name {
961 quote! {
962 let #ident = self.#name.load(file)?;
963 }
964 } else {
965 quote! {
966 let #ident = self.#idx.load(file)?;
967 }
968 }
969 });
970
971 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
972 let name = &f.ident;
973 let ident = format_ident!("field_{}", idx);
974 quote! {
975 #name: #ident
976 }
977 });
978
979 let expanded = quote! {
980 impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
981 type Output = #name #ty_generics;
982 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
983 use hpt::MetaLoad;
984 #(#call_load)*
985 Ok(#name {
986 #(#construct_fields),*
987 })
988 }
989 }
990 impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
991 fn load(path: &str) -> std::io::Result<Self> {
992 use hpt::MetaLoad;
993 let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
994 let mut file = std::fs::File::open(path)?;
995 meta.load(&mut file)
996 }
997 }
998 };
999
1000 expanded.into()
1001}
1002
1003#[proc_macro_derive(FromSafeTensors, attributes(map))]
1005pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1006 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1007 let struct_name = &ast.ident;
1008 let fields = match &ast.data {
1009 syn::Data::Struct(s) => &s.fields,
1010 _ => panic!("FromSafeTensors can only be derived for structs"),
1011 };
1012 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1013 let mut construct_fields = vec![];
1014 for (_, field) in fields.iter().enumerate() {
1015 let ty = &field.ty;
1016 let name = &field.ident;
1017 let mut value_construct = vec![];
1018 let mut from_construct = vec![];
1019 let mut params = vec![];
1020 let mut vec_len = None;
1021 for attr in &field.attrs {
1022 if attr.path().is_ident("map") {
1023 let mut path = None;
1024 let mut value = None;
1025 let mut tensor_name = None;
1026 let mut inner_type = None;
1027 attr.parse_nested_meta(|meta| {
1028 if meta.path.is_ident("path") {
1029 let value: syn::LitStr = meta.value()?.parse()?;
1030 path = Some(value.value());
1031 } else if meta.path.is_ident("value") {
1032 let val: syn::Expr = meta.value()?.parse()?;
1033 value = Some(val);
1034 } else if meta.path.is_ident("tensor_name") {
1035 let value: syn::LitStr = meta.value()?.parse()?;
1036 tensor_name = Some(value.value());
1037 } else if meta.path.is_ident("vec_len") {
1038 let value: syn::LitInt = meta.value()?.parse()?;
1039 vec_len = Some(value.base10_parse::<usize>().unwrap());
1040 } else if meta.path.is_ident("inner_type") {
1041 let value: syn::Ident = meta.value()?.parse()?;
1042 inner_type = Some(value);
1043 }
1044 Ok(())
1045 })
1046 .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1047 params.push((path, value, tensor_name, vec_len, inner_type));
1048 }
1049 }
1050 let param_count = params.len();
1051 for (path, value, tensor_name, vec_len, inner_type) in params {
1052 if let Some(vec_len) = vec_len {
1053 let inner_type = inner_type.expect("inner_type is required for vec");
1054 if let Some(path) = path {
1055 from_construct.push(quote! {
1056 #path => {
1057 let mut vec = vec![];
1058 for i in 0..#vec_len {
1059 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1060 }
1061 vec
1062 }
1063 });
1064 } else {
1065 value_construct.push(quote! {
1066 {
1067 let mut vec = vec![];
1068 for i in 0..#vec_len {
1069 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1070 }
1071 vec
1072 }
1073 });
1074 }
1075 } else {
1076 match (path, value, tensor_name) {
1077 (None, None, Some(tensor_name)) => {
1078 value_construct.push(quote! {
1079 <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1080 });
1081 }
1082 (None, Some(value), None) => {
1083 if param_count > 1 {
1084 panic!("value without path means generic assignment, there can only be one value without path");
1085 }
1086 value_construct.push(quote! {
1087 #value
1088 });
1089 }
1090 (Some(path), None, Some(tensor_name)) => {
1091 from_construct.push(quote! {
1092 #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1093 });
1094 }
1095 (Some(path), Some(value), None) => {
1096 from_construct.push(quote! {
1097 #path => #value,
1098 });
1099 }
1100
1101 (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1102 panic!("value and tensor_name cannot be used together");
1103 }
1104 (Some(_), None, None) | (None, None, None) => {
1105 panic!("path and value are not present");
1106 }
1107 }
1108 }
1109 }
1110 if !value_construct.is_empty() {
1111 construct_fields.push(quote! {
1112 #name: #(#value_construct)*
1113 });
1114 } else if !from_construct.is_empty() {
1115 construct_fields.push(quote! {
1116 #name: match path {
1117 #(#from_construct)*
1118 _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1119 }
1120 });
1121 } else {
1122 construct_fields.push(quote! {
1123 #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1124 });
1125 }
1126 }
1127 let expanded = quote! {
1128 impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1129 fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1130 Self {
1131 #(#construct_fields),*
1132 }
1133 }
1134 }
1135 };
1136 expanded.into()
1143}