1#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22 __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23 __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30 impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31 impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66 any(target_feature = "sse", target_arch = "arm"),
67 not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74 start: Option<Expr>,
75 end: Option<Expr>,
76 step: Option<Expr>,
77 skip: bool,
78}
79
80struct Selections {
81 selections: Vec<TokenStream>,
82}
83
84impl parse::Parse for SelectionParser {
85 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
86 let mut start: Option<Expr> = None;
87 let mut end: Option<Expr> = None;
88 let mut step: Option<Expr> = None;
89 if input.peek(Token![..]) {
90 input.parse::<Token![..]>()?;
91 return Ok(Self {
92 start,
93 end,
94 step,
95 skip: true,
96 });
97 }
98 if input.peek(syn::Lit)
99 || input.peek(syn::Ident)
100 || input.peek(syn::token::Paren)
101 || input.peek(Token![-])
102 {
103 start = Some(input.parse::<Expr>()?);
104 }
105 if input.peek(Token![:]) {
106 input.parse::<Token![:]>()?;
107 } else if input.is_empty() {
108 return Ok(Self {
109 start,
110 end,
111 step,
112 skip: false,
113 });
114 } else {
115 return Err(syn::Error::new(
116 input.span(),
117 "unexpected token, expected `:`, Int or Ident",
118 ));
119 }
120 if input.peek(syn::Lit)
121 || input.peek(syn::Ident)
122 || input.peek(syn::token::Paren)
123 || input.peek(Token![-])
124 {
125 end = Some(input.parse::<Expr>()?);
126 }
127 if input.peek(Token![:]) {
128 input.parse::<Token![:]>()?;
129 }
130 if input.peek(syn::Lit)
131 || input.peek(syn::Ident)
132 || input.peek(syn::token::Paren)
133 || input.peek(Token![-])
134 {
135 step = Some(input.parse::<Expr>()?);
136 }
137 Ok(Self {
138 start,
139 end,
140 step,
141 skip: false,
142 })
143 }
144}
145
146impl parse::Parse for Selections {
147 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
148 let mut selections: Vec<TokenStream> = vec![];
149
150 while !input.is_empty() {
151 let mut item_tokens = TokenStream2::new();
152 while !input.is_empty() && !input.peek(Token![,]) {
153 let token = input.parse::<TokenTree>()?;
154 item_tokens.extend(quote!(#token));
155 }
156
157 selections.push(item_tokens.into());
158
159 if input.peek(Token![,]) {
160 input.parse::<Token![,]>()?;
161 }
162 }
163
164 Ok(Self { selections })
165 }
166}
167
168#[proc_macro]
170pub fn select(input: TokenStream) -> TokenStream {
171 let res: Selections = parse_macro_input!(input as Selections);
172 let mut slices: Vec<SelectionParser> = vec![];
173 for x in res.selections {
174 slices.push(parse_macro_input!(x as SelectionParser));
175 }
176 let mut ret_stream = TokenStream2::new();
177 let len = slices.len();
178 let mut skipped = false;
179 for (idx, x) in slices.into_iter().enumerate() {
180 if x.skip {
181 if skipped {
182 return syn::Error::new(
183 proc_macro2::Span::call_site(),
184 "unexpected token, slicing only support `..` once",
185 )
186 .to_compile_error()
187 .into();
188 }
189 ret_stream.extend(quote!((0, 0, 0x7FFFFFFFFFFFFFFF)));
190 skipped = true;
191 if idx != len - 1 {
192 ret_stream.extend(quote!(,));
193 }
194 continue;
195 }
196 match (x.start, x.end, x.step) {
197 (None, None, None) => {
198 ret_stream.extend(quote!(((0, 0x7FFFFFFFFFFFFFFF, 1))));
199 }
200 (None, None, Some(step)) => {
201 ret_stream.extend(quote!((0, 0x7FFFFFFFFFFFFFFF, #step)));
202 }
203 (None, Some(end), None) => {
204 ret_stream.extend(quote!((0, #end, 1)));
205 }
206 (None, Some(end), Some(step)) => {
207 ret_stream.extend(quote!((0, #end, #step)));
208 }
209 (Some(start), None, None) => {
210 ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, 1)));
211 }
212 (Some(start), None, Some(step)) => {
213 ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, #step)));
214 }
215 (Some(start), Some(end), None) => {
216 ret_stream.extend(quote!((#start, #end, 1)));
217 }
218 (Some(start), Some(end), Some(step)) => {
219 ret_stream.extend(quote!((#start, #end, #step)));
220 }
221 }
222 if idx != len - 1 {
223 ret_stream.extend(quote!(,));
224 }
225 }
226 quote!([#ret_stream]).into()
227}
228
229#[proc_macro]
231pub fn float_out_binary(_: TokenStream) -> TokenStream {
232 impl_float_out_binary()
233}
234
235#[cfg(feature = "cuda")]
236#[proc_macro]
238pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
239 impl_cuda_float_out_binary()
240}
241
242#[proc_macro]
244pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
245 impl_simd_binary_out_float()
246}
247
248#[proc_macro]
250pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
251 impl_simd_binary_out_float_rhs_scalar()
252}
253
254#[proc_macro]
256pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
257 impl_simd_binary_out_float_lhs_scalar()
258}
259
260#[proc_macro]
262pub fn float_out_unary(_: TokenStream) -> TokenStream {
263 impl_float_out_unary()
264}
265
266#[cfg(feature = "cuda")]
267#[proc_macro]
269pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
270 crate::float_unary::impl_cuda_float_out_unary()
271}
272
273#[proc_macro]
275pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
276 simd_float_out_unary::impl_float_out_unary()
277}
278
279#[proc_macro]
281pub fn simd_eval(_: TokenStream) -> TokenStream {
282 simd_eval::impl_simd_eval()
283}
284
285#[proc_macro]
287pub fn simd_bitwise(_: TokenStream) -> TokenStream {
288 impl_simd_bitwise_out()
289}
290
291#[proc_macro]
293pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
294 __impl_normal_out_binary()
295}
296
297#[cfg(feature = "cuda")]
298#[proc_macro]
300pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
301 crate::normal_out::__impl_cuda_normal_out_binary()
302}
303
304#[proc_macro]
306pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
307 normal_out_unary::__impl_normal_out_unary()
308}
309
310#[cfg(feature = "cuda")]
311#[proc_macro]
313pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
314 normal_out_unary::__impl_normal_out_unary_cuda()
315}
316
317#[proc_macro]
319pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
320 simd_normal_unary::impl_simd_normal_out_unary()
321}
322
323#[proc_macro]
325pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
326 impl_simd_normal_out()
327}
328
329#[proc_macro]
331pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
332 impl_simd_normal_out_with_rhs_scalar()
333}
334
335#[proc_macro]
337pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
338 impl_simd_normal_out_with_lhs_scalar()
339}
340
341#[proc_macro]
343pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
344 __impl_scalar_convert()
345}
346
347#[proc_macro]
349pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
350 __impl_from_scalar()
351}
352
353#[proc_macro]
355pub fn simd_cmp(_: TokenStream) -> TokenStream {
356 impl_simd_cmp()
357}
358
359#[proc_macro]
361pub fn impl_into_vec(_: TokenStream) -> TokenStream {
362 into_vec::into_vec()
363}
364
365#[cfg(feature = "cuda")]
366#[proc_macro]
368pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
369 into_cuda_scalar::__impl_into_cuda_scalar().into()
370}
371
372#[proc_macro]
374pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
375 into_scalar::__impl_into_scalar().into()
376}
377
378#[proc_macro]
380pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
381 let mut ret = proc_macro2::TokenStream::new();
382
383 let types = [
384 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
385 ];
386
387 for lhs in types.iter() {
388 for rhs in types.iter() {
389 let lhs_type = TypeInfo::new(lhs);
390 let rhs_type = TypeInfo::new(rhs);
391 let lhs_dtype = lhs_type.dtype;
392 let rhs_dtype = rhs_type.dtype;
393 let res = if lhs_dtype == rhs_dtype {
394 quote! {
395 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
396 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
397 #[inline(always)]
398 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
399 self.__bitand(rhs)
400 }
401 #[inline(always)]
402 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
403 self.__bitor(rhs)
404 }
405 #[inline(always)]
406 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
407 self.__bitxor(rhs)
408 }
409 #[inline(always)]
410 fn _not(self) -> Self::Output {
411 self.__not()
412 }
413 #[inline(always)]
414 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
415 self.__shl(rhs)
416 }
417 #[inline(always)]
418 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419 self.__shr(rhs)
420 }
421 }
422 }
423 } else {
424 quote! {
425 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
426 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
427 #[inline(always)]
428 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
429 let lhs: Self::Output = self.cast();
430 let rhs: Self::Output = rhs.cast();
431 lhs.__bitand(rhs)
432 }
433 #[inline(always)]
434 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
435 let lhs: Self::Output = self.cast();
436 let rhs: Self::Output = rhs.cast();
437 lhs.__bitor(rhs)
438 }
439 #[inline(always)]
440 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
441 let lhs: Self::Output = self.cast();
442 let rhs: Self::Output = rhs.cast();
443 lhs.__bitxor(rhs)
444 }
445 #[inline(always)]
446 fn _not(self) -> Self::Output {
447 let lhs: Self::Output = self.cast();
448 lhs.__not()
449 }
450 #[inline(always)]
451 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
452 let lhs: Self::Output = self.cast();
453 let rhs: Self::Output = rhs.cast();
454 lhs.__shl(rhs)
455 }
456 #[inline(always)]
457 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
458 let lhs: Self::Output = self.cast();
459 let rhs: Self::Output = rhs.cast();
460 lhs.__shr(rhs)
461 }
462 }
463 }
464 };
465 ret.extend(res);
466 }
467 }
468
469 ret.into()
470}
471
472#[proc_macro]
474pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
475 let mut ret = proc_macro2::TokenStream::new();
476
477 let types = [
478 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
479 ];
480
481 for lhs in types.iter() {
482 for rhs in types.iter() {
483 let lhs_type = TypeInfo::new(lhs);
484 let rhs_type = TypeInfo::new(rhs);
485 let lhs_dtype = lhs_type.dtype;
486 let rhs_dtype = rhs_type.dtype;
487 let res = if lhs_dtype == rhs_dtype {
488 quote! {
489 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
490 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
491 #[inline(always)]
492 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
493 self.__bitand(rhs)
494 }
495 #[inline(always)]
496 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
497 self.__bitor(rhs)
498 }
499 #[inline(always)]
500 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
501 self.__bitxor(rhs)
502 }
503 #[inline(always)]
504 fn _not(self) -> Self::Output {
505 self.__not()
506 }
507 #[inline(always)]
508 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
509 self.__shl(rhs)
510 }
511 #[inline(always)]
512 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513 self.__shr(rhs)
514 }
515 }
516 }
517 } else {
518 quote! {
519 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
520 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
521 #[inline(always)]
522 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
523 let lhs: Self::Output = self.cast();
524 let rhs: Self::Output = rhs.cast();
525 lhs.__bitand(rhs)
526 }
527 #[inline(always)]
528 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
529 let lhs: Self::Output = self.cast();
530 let rhs: Self::Output = rhs.cast();
531 lhs.__bitor(rhs)
532 }
533 #[inline(always)]
534 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
535 let lhs: Self::Output = self.cast();
536 let rhs: Self::Output = rhs.cast();
537 lhs.__bitxor(rhs)
538 }
539 #[inline(always)]
540 fn _not(self) -> Self::Output {
541 let lhs: Self::Output = self.cast();
542 lhs.__not()
543 }
544 #[inline(always)]
545 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
546 let lhs: Self::Output = self.cast();
547 let rhs: Self::Output = rhs.cast();
548 lhs.__shl(rhs)
549 }
550 #[inline(always)]
551 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
552 let lhs: Self::Output = self.cast();
553 let rhs: Self::Output = rhs.cast();
554 lhs.__shr(rhs)
555 }
556 }
557 }
558 };
559 ret.extend(res);
560 }
561 }
562
563 ret.into()
564}
565
566#[proc_macro]
568pub fn impl_cmp(_: TokenStream) -> TokenStream {
569 let mut ret = proc_macro2::TokenStream::new();
570
571 let types = [
572 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
573 "usize", "bf16",
574 ];
575
576 for lhs in types.iter() {
577 for rhs in types.iter() {
578 let lhs_type = TypeInfo::new(lhs);
579 let rhs_type = TypeInfo::new(rhs);
580 let lhs_dtype = lhs_type.dtype;
581 let rhs_dtype = rhs_type.dtype;
582 let res = if lhs_dtype == rhs_dtype {
583 quote! {
584 impl Cmp<#rhs_dtype> for #lhs_dtype {
585 type Output = bool;
586 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
587 self == rhs
588 }
589 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
590 self != rhs
591 }
592 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
593 self < rhs
594 }
595
596 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597 self <= rhs
598 }
599 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
600 self > rhs
601 }
602 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
603 self >= rhs
604 }
605 }
606 }
607 } else {
608 quote! {
609 impl Cmp<#rhs_dtype> for #lhs_dtype {
610 type Output = bool;
611 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
612 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
613 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
614 lhs == rhs
615 }
616 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
617 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
618 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
619 lhs != rhs
620 }
621 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
622 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
623 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
624 lhs < rhs
625 }
626
627 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
628 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
629 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
630 lhs <= rhs
631 }
632 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
633 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
634 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
635 lhs > rhs
636 }
637 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
638 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
639 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
640 lhs >= rhs
641 }
642 }
643 }
644 };
645 ret.extend(res);
646 }
647 }
648
649 ret.into()
650}
651
652#[proc_macro]
654pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
655 let mut ret = proc_macro2::TokenStream::new();
656
657 let types = [
658 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
659 "usize", "bf16",
660 ];
661
662 for lhs in types.iter() {
663 for rhs in types.iter() {
664 let lhs_type = TypeInfo::new(lhs);
665 let rhs_type = TypeInfo::new(rhs);
666 let lhs_dtype = lhs_type.dtype;
667 let rhs_dtype = rhs_type.dtype;
668 let res = if lhs_dtype == rhs_dtype {
669 quote! {
670 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
671 type Output = Scalar<bool>;
672 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
673 self.__eq(rhs)
674 }
675 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
676 self.__ne(rhs)
677 }
678 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
679 self.__lt(rhs)
680 }
681
682 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683 self.__le(rhs)
684 }
685 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
686 self.__gt(rhs)
687 }
688 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
689 self.__ge(rhs)
690 }
691 }
692 }
693 } else {
694 quote! {
695 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
696 type Output = Scalar<bool>;
697 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
698 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
699 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
700 lhs.__eq(rhs)
701 }
702 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
703 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
704 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
705 lhs.__ne(rhs)
706 }
707 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
708 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
709 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
710 lhs.__lt(rhs)
711 }
712
713 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
714 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
715 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
716 lhs.__le(rhs)
717 }
718 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
719 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
720 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
721 lhs.__gt(rhs)
722 }
723 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
724 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
725 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
726 lhs.__ge(rhs)
727 }
728 }
729 }
730 };
731 ret.extend(res);
732 }
733 }
734
735 ret.into()
736}
737
738#[proc_macro]
740pub fn impl_eval(_: TokenStream) -> TokenStream {
741 let mut ret = proc_macro2::TokenStream::new();
742
743 let types = [
744 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
745 "usize", "bf16",
746 ];
747
748 for lhs in types.iter() {
749 let lhs_type = TypeInfo::new(lhs);
750 let lhs_dtype = lhs_type.dtype;
751
752 let res = quote! {
753 impl Eval for #lhs_dtype {
754 type Output = bool;
755 #[inline(always)]
756 fn _is_nan(&self) -> bool {
757 self.__is_nan()
758 }
759 #[inline(always)]
760 fn _is_true(&self) -> bool {
761 self.__is_true()
762 }
763 #[inline(always)]
764 fn _is_inf(&self) -> bool {
765 self.__is_inf()
766 }
767 }
768 };
769 ret.extend(res);
770 }
771
772 ret.into()
773}
774
775#[proc_macro]
777pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
778 __gen_fast_reduce_simd_helper(input)
779}
780
781#[proc_macro]
783pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
784 __gen_fast_layernorm_simd_helper(input)
785}
786
787#[proc_macro]
789pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
790 __gen_reduce_dim_not_include_simd_helper(input)
791}
792
793#[proc_macro]
799pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
800 conv2d::conv2d_microkernel_declare_const(input)
801}
802
803#[proc_macro]
805pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
806 conv2d::conv2d_microkernel_gen_inps(input)
807}
808
809#[proc_macro]
811pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
812 conv2d::conv2d_microkernel_gen_pad_inps(input)
813}
814
815#[proc_macro]
817pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
818 conv2d::pwconv2d_microkernel_gen_pad_inps(input)
819}
820
821#[proc_macro]
823pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
824 conv2d::dwconv2d_microkernel_gen_pad_inps(input)
825}
826
827#[proc_macro]
829pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
830 conv2d::conv2d_microkernel_gen_kernels(input)
831}
832
833#[proc_macro]
835pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
836 conv2d::conv2d_microkernel_gen_results(input)
837}
838
839#[proc_macro]
841pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
842 conv2d::dwconv2d_microkernel_gen_results(input)
843}
844
845#[proc_macro]
848pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
849 conv2d::maxpool2d_microkernel_gen_results(input)
850}
851
852#[proc_macro_derive(Save, attributes(compress))]
854pub fn impl_save(input: TokenStream) -> TokenStream {
855 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
856 let name = &ast.ident;
857 let meta_name = format_ident!("{}Meta", name);
858
859 let visibility = &ast.vis;
860 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
861 let fields = match &ast.data {
862 syn::Data::Struct(s) => &s.fields,
863 _ => panic!("Save can only be derived for structs"),
864 };
865
866 let mut compressions = vec![];
867 let mut compress_levels = vec![];
868
869 let meta_fields = fields
870 .iter()
871 .map(|f| {
872 let mut compression_algo = None;
873 let mut level = None;
874
875 for attr in &f.attrs {
876 if attr.path().is_ident("compress") {
877 attr.parse_nested_meta(|meta| {
878 if meta.path.is_ident("algo") {
879 let value: syn::LitStr = meta.value()?.parse()?;
880 let algo = match value.value().as_str().to_lowercase().as_str() {
881 "gzip" => quote!(Gzip),
882 "deflate" => quote!(Deflate),
883 "zlib" => quote!(Zlib),
884 "none" => quote!(NoCompression),
885 _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
886 };
887 compression_algo = Some(quote!(hpt::save_load::CompressionAlgo::#algo));
888 } else if meta.path.is_ident("level") {
889 let value: syn::LitStr = meta.value()?.parse()?;
890 let tmp: u32 = value.value().parse().map_err(|e| {
891 syn::Error::new(value.span(), format!("Invalid level: {}", e))
892 })?;
893 level = Some(quote!(#tmp));
894 }
895 Ok(())
896 })
897 .unwrap();
898 }
899 }
900 compressions.push(compression_algo);
901 compress_levels.push(level);
902 let name = &f.ident;
903 let ty = &f.ty;
904 quote! {
905 pub #name: <#ty as Save>::Meta
906 }
907 })
908 .collect::<Vec<_>>();
909
910 let call_save = fields.iter().enumerate().map(|(idx, f)| {
911 let name = &f.ident;
912 let ty = &f.ty;
913 let ident = format_ident!("field_{}", idx);
914 let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
915 let level = compress_levels[idx].clone().unwrap_or(quote!(level));
916 if let Some(name) = name {
917 quote! {
918 let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #level)?;
919 }
920 } else {
921 quote! {
922 let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #level)?;
923 }
924 }
925 });
926
927 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
928 let name = &f.ident;
929 let ident = format_ident!("field_{}", idx);
930 quote! {
931 #name: #ident
932 }
933 });
934
935 let expanded = quote! {
936 #[derive(hpt::re_exports::serde::Deserialize, hpt::re_exports::serde::Serialize)]
937 #[serde(crate = "hpt::re_exports::serde")]
938 #visibility struct #meta_name #ty_generics #where_clause {
939 #(#meta_fields,)*
940 }
941 impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
942 type Meta = #meta_name #ty_generics;
943 fn __save(
944 data: &Self,
945 file: &mut std::fs::File,
946 len_so_far: &mut usize,
947 global_cnt: &mut usize,
948 compression_algo: hpt::save_load::CompressionAlgo,
949 level: u32,
950 ) -> std::io::Result<Self::Meta> {
951 #(#call_save)*
952 Ok(Self::Meta {
953 #(#construct_fields),*
954 })
955 }
956 }
957 };
958
959 expanded.into()
960}
961
962#[proc_macro_derive(Load)]
964pub fn impl_load(input: TokenStream) -> TokenStream {
965 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
966 let name = &ast.ident;
967 let meta_name = format_ident!("{}Meta", name);
968 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
969
970 let fields = match &ast.data {
971 syn::Data::Struct(s) => &s.fields,
972 _ => panic!("Load can only be derived for structs"),
973 };
974
975 let call_load = fields.iter().enumerate().map(|(idx, f)| {
976 let name = &f.ident;
977 let ident = format_ident!("field_{}", idx);
978 if let Some(name) = name {
979 quote! {
980 let #ident = self.#name.load(file)?;
981 }
982 } else {
983 quote! {
984 let #ident = self.#idx.load(file)?;
985 }
986 }
987 });
988
989 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
990 let name = &f.ident;
991 let ident = format_ident!("field_{}", idx);
992 quote! {
993 #name: #ident
994 }
995 });
996
997 let expanded = quote! {
998 impl #impl_generics hpt::save_load::MetaLoad for #meta_name #ty_generics #where_clause {
999 type Output = #name #ty_generics;
1000 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
1001 use hpt::save_load::MetaLoad;
1002 #(#call_load)*
1003 Ok(#name {
1004 #(#construct_fields),*
1005 })
1006 }
1007 }
1008 impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
1009 fn load<P: Into<std::path::PathBuf>>(path: P) -> std::io::Result<Self> {
1010 use hpt::save_load::MetaLoad;
1011 let path: std::path::PathBuf = path.into();
1012 let meta = hpt::save_load::parse_header_compressed::<Self, _>(&path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
1013 let mut file = std::fs::File::open(path)?;
1014 meta.load(&mut file)
1015 }
1016 }
1017 };
1018
1019 expanded.into()
1020}
1021
1022#[proc_macro_derive(FromSafeTensors, attributes(map))]
1024pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1025 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1026 let struct_name = &ast.ident;
1027 let fields = match &ast.data {
1028 syn::Data::Struct(s) => &s.fields,
1029 _ => panic!("FromSafeTensors can only be derived for structs"),
1030 };
1031 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1032 let mut construct_fields = vec![];
1033 for (_, field) in fields.iter().enumerate() {
1034 let ty = &field.ty;
1035 let name = &field.ident;
1036 let mut value_construct = vec![];
1037 let mut from_construct = vec![];
1038 let mut params = vec![];
1039 let mut vec_len = None;
1040 for attr in &field.attrs {
1041 if attr.path().is_ident("map") {
1042 let mut path = None;
1043 let mut value = None;
1044 let mut tensor_name = None;
1045 let mut inner_type = None;
1046 attr.parse_nested_meta(|meta| {
1047 if meta.path.is_ident("path") {
1048 let value: syn::LitStr = meta.value()?.parse()?;
1049 path = Some(value.value());
1050 } else if meta.path.is_ident("value") {
1051 let val: syn::Expr = meta.value()?.parse()?;
1052 value = Some(val);
1053 } else if meta.path.is_ident("tensor_name") {
1054 let value: syn::LitStr = meta.value()?.parse()?;
1055 tensor_name = Some(value.value());
1056 } else if meta.path.is_ident("vec_len") {
1057 let value: syn::LitInt = meta.value()?.parse()?;
1058 vec_len = Some(value.base10_parse::<usize>().unwrap());
1059 } else if meta.path.is_ident("inner_type") {
1060 let value: syn::Ident = meta.value()?.parse()?;
1061 inner_type = Some(value);
1062 }
1063 Ok(())
1064 })
1065 .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1066 params.push((path, value, tensor_name, vec_len, inner_type));
1067 }
1068 }
1069 let param_count = params.len();
1070 for (path, value, tensor_name, vec_len, inner_type) in params {
1071 if let Some(vec_len) = vec_len {
1072 let inner_type = inner_type.expect("inner_type is required for vec");
1073 if let Some(path) = path {
1074 from_construct.push(quote! {
1075 #path => {
1076 let mut vec = vec![];
1077 for i in 0..#vec_len {
1078 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1079 }
1080 vec
1081 }
1082 });
1083 } else {
1084 value_construct.push(quote! {
1085 {
1086 let mut vec = vec![];
1087 for i in 0..#vec_len {
1088 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1089 }
1090 vec
1091 }
1092 });
1093 }
1094 } else {
1095 match (path, value, tensor_name) {
1096 (None, None, Some(tensor_name)) => {
1097 value_construct.push(quote! {
1098 <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1099 });
1100 }
1101 (None, Some(value), None) => {
1102 if param_count > 1 {
1103 panic!("value without path means generic assignment, there can only be one value without path");
1104 }
1105 value_construct.push(quote! {
1106 #value
1107 });
1108 }
1109 (Some(path), None, Some(tensor_name)) => {
1110 from_construct.push(quote! {
1111 #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1112 });
1113 }
1114 (Some(path), Some(value), None) => {
1115 from_construct.push(quote! {
1116 #path => #value,
1117 });
1118 }
1119
1120 (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1121 panic!("value and tensor_name cannot be used together");
1122 }
1123 (Some(_), None, None) | (None, None, None) => {
1124 panic!("path and value are not present");
1125 }
1126 }
1127 }
1128 }
1129 if !value_construct.is_empty() {
1130 construct_fields.push(quote! {
1131 #name: #(#value_construct)*
1132 });
1133 } else if !from_construct.is_empty() {
1134 construct_fields.push(quote! {
1135 #name: match path {
1136 #(#from_construct)*
1137 _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1138 }
1139 });
1140 } else {
1141 construct_fields.push(quote! {
1142 #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1143 });
1144 }
1145 }
1146 let expanded = quote! {
1147 impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1148 fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1149 Self {
1150 #(#construct_fields),*
1151 }
1152 }
1153 }
1154 };
1155 expanded.into()
1162}