1#![deny(missing_docs)]
16#[cfg(feature = "cuda")]
17use crate::binary_float_out::impl_cuda_float_out_binary;
18use binary_float_out::impl_float_out_binary;
19use float_unary::impl_float_out_unary;
20use from_scalar::__impl_from_scalar;
21use kernel_gen_helper::{
22 __gen_fast_layernorm_simd_helper, __gen_fast_reduce_simd_helper,
23 __gen_reduce_dim_not_include_simd_helper,
24};
25use normal_out::__impl_normal_out_binary;
26use proc_macro::TokenStream;
27use scalar_convert::__impl_scalar_convert;
28use simd_bitwise::impl_simd_bitwise_out;
29use simd_float_out_binary::{
30 impl_simd_binary_out_float, impl_simd_binary_out_float_lhs_scalar,
31 impl_simd_binary_out_float_rhs_scalar,
32};
33use simd_normal_out::{impl_simd_normal_out_with_lhs_scalar, impl_simd_normal_out_with_rhs_scalar};
34use syn::{parse, parse_macro_input, Expr, Token};
35mod binary_float_out;
36mod conv2d;
37mod float_unary;
38mod from_scalar;
39mod into_cuda_scalar;
40mod into_scalar;
41mod into_vec;
42mod kernel_gen_helper;
43mod normal_out;
44mod normal_out_unary;
45mod scalar_convert;
46mod simd_bitwise;
47
48mod simd_cmp;
49mod simd_eval;
50mod simd_float_out_binary;
51mod simd_float_out_unary;
52mod simd_normal_out;
53mod simd_normal_unary;
54mod type_utils;
55
56use crate::simd_cmp::impl_simd_cmp;
57use crate::simd_normal_out::impl_simd_normal_out;
58use proc_macro2::{TokenStream as TokenStream2, TokenTree};
59use quote::{format_ident, quote};
60use type_utils::TypeInfo;
61
62#[cfg(target_feature = "avx2")]
64const NUM_REG: usize = 16;
65#[cfg(all(
66 any(target_feature = "sse", target_arch = "arm"),
67 not(target_feature = "avx2")
68))]
69const NUM_REG: usize = 8;
70#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
71const NUM_REG: usize = 32;
72
73struct SelectionParser {
74 start: Option<Expr>,
75 end: Option<Expr>,
76 step: Option<Expr>,
77 skip: bool,
78}
79
80struct Selections {
81 selections: Vec<TokenStream>,
82}
83
84impl parse::Parse for SelectionParser {
85 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
86 let mut start: Option<Expr> = None;
87 let mut end: Option<Expr> = None;
88 let mut step: Option<Expr> = None;
89 if input.peek(Token![..]) {
90 input.parse::<Token![..]>()?;
91 return Ok(Self {
92 start,
93 end,
94 step,
95 skip: true,
96 });
97 }
98 if input.peek(syn::Lit)
99 || input.peek(syn::Ident)
100 || input.peek(syn::token::Paren)
101 || input.peek(Token![-])
102 {
103 start = Some(input.parse::<Expr>()?);
104 }
105 if input.peek(Token![:]) {
106 input.parse::<Token![:]>()?;
107 } else if input.is_empty() {
108 return Ok(Self {
109 start,
110 end,
111 step,
112 skip: false,
113 });
114 } else {
115 return Err(syn::Error::new(
116 input.span(),
117 "unexpected token, expected `:`, Int or Ident",
118 ));
119 }
120 if input.peek(syn::Lit)
121 || input.peek(syn::Ident)
122 || input.peek(syn::token::Paren)
123 || input.peek(Token![-])
124 {
125 end = Some(input.parse::<Expr>()?);
126 }
127 if input.peek(Token![:]) {
128 input.parse::<Token![:]>()?;
129 }
130 if input.peek(syn::Lit)
131 || input.peek(syn::Ident)
132 || input.peek(syn::token::Paren)
133 || input.peek(Token![-])
134 {
135 step = Some(input.parse::<Expr>()?);
136 }
137 Ok(Self {
138 start,
139 end,
140 step,
141 skip: false,
142 })
143 }
144}
145
146impl parse::Parse for Selections {
147 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
148 let mut selections: Vec<TokenStream> = vec![];
149
150 while !input.is_empty() {
151 let mut item_tokens = TokenStream2::new();
152 while !input.is_empty() && !input.peek(Token![,]) {
153 let token = input.parse::<TokenTree>()?;
154 item_tokens.extend(quote!(#token));
155 }
156
157 selections.push(item_tokens.into());
158
159 if input.peek(Token![,]) {
160 input.parse::<Token![,]>()?;
161 }
162 }
163
164 Ok(Self { selections })
165 }
166}
167
168#[proc_macro]
170pub fn select(input: TokenStream) -> TokenStream {
171 let res: Selections = parse_macro_input!(input as Selections);
172 let mut slices: Vec<SelectionParser> = vec![];
173 for x in res.selections {
174 slices.push(parse_macro_input!(x as SelectionParser));
175 }
176 let mut ret_stream = TokenStream2::new();
177 let len = slices.len();
178 let mut skipped = false;
179 for (idx, x) in slices.into_iter().enumerate() {
180 if x.skip {
181 if skipped {
182 return syn::Error::new(
183 proc_macro2::Span::call_site(),
184 "unexpected token, slicing only support `..` once",
185 )
186 .to_compile_error()
187 .into();
188 }
189 ret_stream.extend(quote!((0, 0, 0x7FFFFFFFFFFFFFFF)));
190 skipped = true;
191 if idx != len - 1 {
192 ret_stream.extend(quote!(,));
193 }
194 continue;
195 }
196 match (x.start, x.end, x.step) {
197 (None, None, None) => {
198 ret_stream.extend(quote!(((0, 0x7FFFFFFFFFFFFFFF, 1))));
199 }
200 (None, None, Some(step)) => {
201 ret_stream.extend(quote!((0, 0x7FFFFFFFFFFFFFFF, #step)));
202 }
203 (None, Some(end), None) => {
204 ret_stream.extend(quote!((0, #end, 1)));
205 }
206 (None, Some(end), Some(step)) => {
207 ret_stream.extend(quote!((0, #end, #step)));
208 }
209 (Some(start), None, None) => {
210 ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, 1)));
211 }
212 (Some(start), None, Some(step)) => {
213 ret_stream.extend(quote!((#start, 0x7FFFFFFFFFFFFFFF, #step)));
214 }
215 (Some(start), Some(end), None) => {
216 ret_stream.extend(quote!((#start, #end, 1)));
217 }
218 (Some(start), Some(end), Some(step)) => {
219 ret_stream.extend(quote!((#start, #end, #step)));
220 }
221 }
222 if idx != len - 1 {
223 ret_stream.extend(quote!(,));
224 }
225 }
226 quote!([#ret_stream]).into()
227}
228
229#[proc_macro]
231pub fn float_out_binary(_: TokenStream) -> TokenStream {
232 impl_float_out_binary()
233}
234
235#[cfg(feature = "cuda")]
236#[proc_macro]
238pub fn float_out_binary_cuda(_: TokenStream) -> TokenStream {
239 impl_cuda_float_out_binary()
240}
241
242#[proc_macro]
244pub fn float_out_binary_simd(_: TokenStream) -> TokenStream {
245 impl_simd_binary_out_float()
246}
247
248#[proc_macro]
250pub fn float_out_binary_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
251 impl_simd_binary_out_float_rhs_scalar()
252}
253
254#[proc_macro]
256pub fn float_out_binary_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
257 impl_simd_binary_out_float_lhs_scalar()
258}
259
260#[proc_macro]
262pub fn float_out_unary(_: TokenStream) -> TokenStream {
263 impl_float_out_unary()
264}
265
266#[cfg(feature = "cuda")]
267#[proc_macro]
269pub fn float_out_unary_cuda(_: TokenStream) -> TokenStream {
270 crate::float_unary::impl_cuda_float_out_unary()
271}
272
273#[proc_macro]
275pub fn simd_float_out_unary(_: TokenStream) -> TokenStream {
276 simd_float_out_unary::impl_float_out_unary()
277}
278
279#[proc_macro]
281pub fn simd_eval(_: TokenStream) -> TokenStream {
282 simd_eval::impl_simd_eval()
283}
284
285#[proc_macro]
287pub fn simd_bitwise(_: TokenStream) -> TokenStream {
288 impl_simd_bitwise_out()
289}
290
291#[proc_macro]
293pub fn impl_normal_out_binary(_: TokenStream) -> TokenStream {
294 __impl_normal_out_binary()
295}
296
297#[cfg(feature = "cuda")]
298#[proc_macro]
300pub fn impl_cuda_normal_out_binary(_: TokenStream) -> TokenStream {
301 crate::normal_out::__impl_cuda_normal_out_binary()
302}
303
304#[proc_macro]
306pub fn impl_normal_out_unary(_: TokenStream) -> TokenStream {
307 normal_out_unary::__impl_normal_out_unary()
308}
309
310#[cfg(feature = "cuda")]
311#[proc_macro]
313pub fn impl_normal_out_unary_cuda(_: TokenStream) -> TokenStream {
314 normal_out_unary::__impl_normal_out_unary_cuda()
315}
316
317#[proc_macro]
319pub fn impl_normal_out_unary_simd(_: TokenStream) -> TokenStream {
320 simd_normal_unary::impl_simd_normal_out_unary()
321}
322
323#[proc_macro]
325pub fn impl_normal_out_simd(_: TokenStream) -> TokenStream {
326 impl_simd_normal_out()
327}
328
329#[proc_macro]
331pub fn impl_normal_out_simd_with_rhs_scalar(_: TokenStream) -> TokenStream {
332 impl_simd_normal_out_with_rhs_scalar()
333}
334
335#[proc_macro]
337pub fn impl_normal_out_simd_with_lhs_scalar(_: TokenStream) -> TokenStream {
338 impl_simd_normal_out_with_lhs_scalar()
339}
340
341#[proc_macro]
343pub fn impl_scalar_convert(_: TokenStream) -> TokenStream {
344 __impl_scalar_convert()
345}
346
347#[proc_macro]
349pub fn impl_from_scalar(_: TokenStream) -> TokenStream {
350 __impl_from_scalar()
351}
352
353#[proc_macro]
355pub fn simd_cmp(_: TokenStream) -> TokenStream {
356 impl_simd_cmp()
357}
358
359#[proc_macro]
361pub fn impl_into_vec(_: TokenStream) -> TokenStream {
362 into_vec::into_vec()
363}
364
365#[cfg(feature = "cuda")]
366#[proc_macro]
368pub fn impl_into_cuda_scalar(_: TokenStream) -> TokenStream {
369 into_cuda_scalar::__impl_into_cuda_scalar().into()
370}
371
372#[proc_macro]
374pub fn impl_into_scalar(_: TokenStream) -> TokenStream {
375 into_scalar::__impl_into_scalar().into()
376}
377
378#[proc_macro]
380pub fn impl_bitwise_out(_: TokenStream) -> TokenStream {
381 let mut ret = proc_macro2::TokenStream::new();
382
383 let types = [
384 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
385 ];
386
387 for lhs in types.iter() {
388 for rhs in types.iter() {
389 let lhs_type = TypeInfo::new(lhs);
390 let rhs_type = TypeInfo::new(rhs);
391 let lhs_dtype = lhs_type.dtype;
392 let rhs_dtype = rhs_type.dtype;
393 let res = if lhs_dtype == rhs_dtype {
394 quote! {
395 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
396 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
397 #[inline(always)]
398 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
399 self.__bitand(rhs)
400 }
401 #[inline(always)]
402 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
403 self.__bitor(rhs)
404 }
405 #[inline(always)]
406 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
407 self.__bitxor(rhs)
408 }
409 #[inline(always)]
410 fn _not(self) -> Self::Output {
411 self.__not()
412 }
413 #[inline(always)]
414 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
415 self.__shl(rhs)
416 }
417 #[inline(always)]
418 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
419 self.__shr(rhs)
420 }
421 }
422 }
423 } else {
424 quote! {
425 impl BitWiseOut<#rhs_dtype> for #lhs_dtype {
426 type Output = <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output;
427 #[inline(always)]
428 fn _bitand(self, rhs: #rhs_dtype) -> Self::Output {
429 let lhs: Self::Output = self.cast();
430 let rhs: Self::Output = rhs.cast();
431 lhs.__bitand(rhs)
432 }
433 #[inline(always)]
434 fn _bitor(self, rhs: #rhs_dtype) -> Self::Output {
435 let lhs: Self::Output = self.cast();
436 let rhs: Self::Output = rhs.cast();
437 lhs.__bitor(rhs)
438 }
439 #[inline(always)]
440 fn _bitxor(self, rhs: #rhs_dtype) -> Self::Output {
441 let lhs: Self::Output = self.cast();
442 let rhs: Self::Output = rhs.cast();
443 lhs.__bitxor(rhs)
444 }
445 #[inline(always)]
446 fn _not(self) -> Self::Output {
447 let lhs: Self::Output = self.cast();
448 lhs.__not()
449 }
450 #[inline(always)]
451 fn _shl(self, rhs: #rhs_dtype) -> Self::Output {
452 let lhs: Self::Output = self.cast();
453 let rhs: Self::Output = rhs.cast();
454 lhs.__shl(rhs)
455 }
456 #[inline(always)]
457 fn _shr(self, rhs: #rhs_dtype) -> Self::Output {
458 let lhs: Self::Output = self.cast();
459 let rhs: Self::Output = rhs.cast();
460 lhs.__shr(rhs)
461 }
462 }
463 }
464 };
465 ret.extend(res);
466 }
467 }
468
469 ret.into()
470}
471
472#[proc_macro]
474pub fn impl_cuda_bitwise_out(_: TokenStream) -> TokenStream {
475 let mut ret = proc_macro2::TokenStream::new();
476
477 let types = [
478 "bool", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize", "usize",
479 ];
480
481 for lhs in types.iter() {
482 for rhs in types.iter() {
483 let lhs_type = TypeInfo::new(lhs);
484 let rhs_type = TypeInfo::new(rhs);
485 let lhs_dtype = lhs_type.dtype;
486 let rhs_dtype = rhs_type.dtype;
487 let res = if lhs_dtype == rhs_dtype {
488 quote! {
489 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
490 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
491 #[inline(always)]
492 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
493 self.__bitand(rhs)
494 }
495 #[inline(always)]
496 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
497 self.__bitor(rhs)
498 }
499 #[inline(always)]
500 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
501 self.__bitxor(rhs)
502 }
503 #[inline(always)]
504 fn _not(self) -> Self::Output {
505 self.__not()
506 }
507 #[inline(always)]
508 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
509 self.__shl(rhs)
510 }
511 #[inline(always)]
512 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
513 self.__shr(rhs)
514 }
515 }
516 }
517 } else {
518 quote! {
519 impl BitWiseOut<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
520 type Output = <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output;
521 #[inline(always)]
522 fn _bitand(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
523 let lhs: Self::Output = self.cast();
524 let rhs: Self::Output = rhs.cast();
525 lhs.__bitand(rhs)
526 }
527 #[inline(always)]
528 fn _bitor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
529 let lhs: Self::Output = self.cast();
530 let rhs: Self::Output = rhs.cast();
531 lhs.__bitor(rhs)
532 }
533 #[inline(always)]
534 fn _bitxor(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
535 let lhs: Self::Output = self.cast();
536 let rhs: Self::Output = rhs.cast();
537 lhs.__bitxor(rhs)
538 }
539 #[inline(always)]
540 fn _not(self) -> Self::Output {
541 let lhs: Self::Output = self.cast();
542 lhs.__not()
543 }
544 #[inline(always)]
545 fn _shl(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
546 let lhs: Self::Output = self.cast();
547 let rhs: Self::Output = rhs.cast();
548 lhs.__shl(rhs)
549 }
550 #[inline(always)]
551 fn _shr(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
552 let lhs: Self::Output = self.cast();
553 let rhs: Self::Output = rhs.cast();
554 lhs.__shr(rhs)
555 }
556 }
557 }
558 };
559 ret.extend(res);
560 }
561 }
562
563 ret.into()
564}
565
566#[proc_macro]
568pub fn impl_cmp(_: TokenStream) -> TokenStream {
569 let mut ret = proc_macro2::TokenStream::new();
570
571 let types = [
572 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
573 "usize", "bf16",
574 ];
575
576 for lhs in types.iter() {
577 for rhs in types.iter() {
578 let lhs_type = TypeInfo::new(lhs);
579 let rhs_type = TypeInfo::new(rhs);
580 let lhs_dtype = lhs_type.dtype;
581 let rhs_dtype = rhs_type.dtype;
582 let res = if lhs_dtype == rhs_dtype {
583 quote! {
584 impl Cmp<#rhs_dtype> for #lhs_dtype {
585 type Output = bool;
586 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
587 self == rhs
588 }
589 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
590 self != rhs
591 }
592 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
593 self < rhs
594 }
595
596 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
597 self <= rhs
598 }
599 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
600 self > rhs
601 }
602 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
603 self >= rhs
604 }
605 }
606 }
607 } else {
608 quote! {
609 impl Cmp<#rhs_dtype> for #lhs_dtype {
610 type Output = bool;
611 fn _eq(self, rhs: #rhs_dtype) -> Self::Output {
612 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
613 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
614 lhs == rhs
615 }
616 fn _ne(self, rhs: #rhs_dtype) -> Self::Output {
617 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
618 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
619 lhs != rhs
620 }
621 fn _lt(self, rhs: #rhs_dtype) -> Self::Output {
622 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
623 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
624 lhs < rhs
625 }
626
627 fn _le(self, rhs: #rhs_dtype) -> Self::Output {
628 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
629 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
630 lhs <= rhs
631 }
632 fn _gt(self, rhs: #rhs_dtype) -> Self::Output {
633 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
634 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
635 lhs > rhs
636 }
637 fn _ge(self, rhs: #rhs_dtype) -> Self::Output {
638 let lhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = self.cast();
639 let rhs: <#lhs_dtype as NormalOutPromote<#rhs_dtype>>::Output = rhs.cast();
640 lhs >= rhs
641 }
642 }
643 }
644 };
645 ret.extend(res);
646 }
647 }
648
649 ret.into()
650}
651
652#[proc_macro]
654pub fn impl_cmp_cuda(_: TokenStream) -> TokenStream {
655 let mut ret = proc_macro2::TokenStream::new();
656
657 let types = [
658 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
659 "usize", "bf16",
660 ];
661
662 for lhs in types.iter() {
663 for rhs in types.iter() {
664 let lhs_type = TypeInfo::new(lhs);
665 let rhs_type = TypeInfo::new(rhs);
666 let lhs_dtype = lhs_type.dtype;
667 let rhs_dtype = rhs_type.dtype;
668 let res = if lhs_dtype == rhs_dtype {
669 quote! {
670 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
671 type Output = Scalar<bool>;
672 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
673 self.__eq(rhs)
674 }
675 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
676 self.__ne(rhs)
677 }
678 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
679 self.__lt(rhs)
680 }
681
682 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
683 self.__le(rhs)
684 }
685 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
686 self.__gt(rhs)
687 }
688 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
689 self.__ge(rhs)
690 }
691 }
692 }
693 } else {
694 quote! {
695 impl Cmp<Scalar<#rhs_dtype>> for Scalar<#lhs_dtype> {
696 type Output = Scalar<bool>;
697 fn _eq(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
698 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
699 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
700 lhs.__eq(rhs)
701 }
702 fn _ne(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
703 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
704 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
705 lhs.__ne(rhs)
706 }
707 fn _lt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
708 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
709 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
710 lhs.__lt(rhs)
711 }
712
713 fn _le(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
714 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
715 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
716 lhs.__le(rhs)
717 }
718 fn _gt(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
719 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
720 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
721 lhs.__gt(rhs)
722 }
723 fn _ge(self, rhs: Scalar<#rhs_dtype>) -> Self::Output {
724 let lhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = self.cast();
725 let rhs: <Scalar<#lhs_dtype> as NormalOutPromote<Scalar<#rhs_dtype>>>::Output = rhs.cast();
726 lhs.__ge(rhs)
727 }
728 }
729 }
730 };
731 ret.extend(res);
732 }
733 }
734
735 ret.into()
736}
737
738#[proc_macro]
740pub fn impl_eval(_: TokenStream) -> TokenStream {
741 let mut ret = proc_macro2::TokenStream::new();
742
743 let types = [
744 "bool", "f16", "f32", "f64", "i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "isize",
745 "usize", "bf16",
746 ];
747
748 for lhs in types.iter() {
749 let lhs_type = TypeInfo::new(lhs);
750 let lhs_dtype = lhs_type.dtype;
751
752 let res = quote! {
753 impl Eval for #lhs_dtype {
754 type Output = bool;
755 #[inline(always)]
756 fn _is_nan(&self) -> bool {
757 self.__is_nan()
758 }
759 #[inline(always)]
760 fn _is_true(&self) -> bool {
761 self.__is_true()
762 }
763 #[inline(always)]
764 fn _is_inf(&self) -> bool {
765 self.__is_inf()
766 }
767 }
768 };
769 ret.extend(res);
770 }
771
772 ret.into()
773}
774
775#[proc_macro]
777pub fn gen_fast_reduce_simd_helper(input: TokenStream) -> TokenStream {
778 __gen_fast_reduce_simd_helper(input)
779}
780
781#[proc_macro]
783pub fn gen_fast_layernorm_simd_helper(input: TokenStream) -> TokenStream {
784 __gen_fast_layernorm_simd_helper(input)
785}
786
787#[proc_macro]
789pub fn gen_reduce_dim_not_include_simd_helper(input: TokenStream) -> TokenStream {
790 __gen_reduce_dim_not_include_simd_helper(input)
791}
792
793#[proc_macro]
799pub fn conv2d_microkernel_declare_const(input: TokenStream) -> TokenStream {
800 conv2d::conv2d_microkernel_declare_const(input)
801}
802
803#[proc_macro]
805pub fn conv2d_microkernel_gen_inps(input: TokenStream) -> TokenStream {
806 conv2d::conv2d_microkernel_gen_inps(input)
807}
808
809#[proc_macro]
811pub fn conv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
812 conv2d::conv2d_microkernel_gen_pad_inps(input)
813}
814
815#[proc_macro]
817pub fn pwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
818 conv2d::pwconv2d_microkernel_gen_pad_inps(input)
819}
820
821#[proc_macro]
823pub fn dwconv2d_microkernel_gen_pad_inps(input: TokenStream) -> TokenStream {
824 conv2d::dwconv2d_microkernel_gen_pad_inps(input)
825}
826
827#[proc_macro]
829pub fn conv2d_microkernel_gen_kernels(input: TokenStream) -> TokenStream {
830 conv2d::conv2d_microkernel_gen_kernels(input)
831}
832
833#[proc_macro]
835pub fn conv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
836 conv2d::conv2d_microkernel_gen_results(input)
837}
838
839#[proc_macro]
841pub fn dwconv2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
842 conv2d::dwconv2d_microkernel_gen_results(input)
843}
844
845#[proc_macro]
848pub fn maxpool2d_microkernel_gen_results(input: TokenStream) -> TokenStream {
849 conv2d::maxpool2d_microkernel_gen_results(input)
850}
851
852#[proc_macro_derive(Save, attributes(compress))]
854pub fn impl_save(input: TokenStream) -> TokenStream {
855 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
856 let name = &ast.ident;
857 let meta_name = format_ident!("{}Meta", name);
858
859 let visibility = &ast.vis;
860 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
861 let fields = match &ast.data {
862 syn::Data::Struct(s) => &s.fields,
863 _ => panic!("Save can only be derived for structs"),
864 };
865
866 let mut compressions = vec![];
867 let mut endians = vec![];
868 let mut compress_levels = vec![];
869
870 let meta_fields = fields
871 .iter()
872 .map(|f| {
873 let mut compression_algo = None;
874 let mut endian = None;
875 let mut level = None;
876
877 for attr in &f.attrs {
878 if attr.path().is_ident("compress") {
879 attr.parse_nested_meta(|meta| {
880 if meta.path.is_ident("algo") {
881 let value: syn::LitStr = meta.value()?.parse()?;
882 let algo = match value.value().as_str().to_lowercase().as_str() {
883 "gzip" => quote!(Gzip),
884 "deflate" => quote!(Deflate),
885 "zlib" => quote!(Zlib),
886 "none" => quote!(NoCompression),
887 _ => panic!("Unsupported compression algorithm, supported: gzip, deflate, zlib, none"),
888 };
889 compression_algo = Some(quote!(hpt::CompressionAlgo::#algo));
890 } else if meta.path.is_ident("level") {
891 let value: syn::LitStr = meta.value()?.parse()?;
892 let tmp: u32 = value.value().parse().map_err(|e| {
893 syn::Error::new(value.span(), format!("Invalid level: {}", e))
894 })?;
895 level = Some(quote!(#tmp));
896 } else if meta.path.is_ident("endian") {
897 let value: syn::LitStr = meta.value()?.parse()?;
898 let tmp = match value.value().as_str() {
899 "native" => quote!(Native),
900 "little" => quote!(Little),
901 "big" => quote!(Big),
902 _ => panic!("Unsupported endianness, supported: native, little, big"),
903 };
904 endian = Some(quote!(hpt::Endian::#tmp));
905 }
906 Ok(())
907 })
908 .unwrap();
909 }
910 }
911 compressions.push(compression_algo);
912 endians.push(endian);
913 compress_levels.push(level);
914 let name = &f.ident;
915 let ty = &f.ty;
916 quote! {
917 pub #name: <#ty as Save>::Meta
918 }
919 })
920 .collect::<Vec<_>>();
921
922 let call_save = fields.iter().enumerate().map(|(idx, f)| {
923 let name = &f.ident;
924 let ty = &f.ty;
925 let ident = format_ident!("field_{}", idx);
926 let compression_algo = compressions[idx].clone().unwrap_or(quote!(compression_algo));
927 let endian = endians[idx].clone().unwrap_or(quote!(endian));
928 let level = compress_levels[idx].clone().unwrap_or(quote!(level));
929 if let Some(name) = name {
930 quote! {
931 let #ident = <#ty as Save>::__save(&data.#name, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
932 }
933 } else {
934 quote! {
935 let #ident = <#ty as Save>::__save(&data.#idx, file, len_so_far, global_cnt, #compression_algo, #endian, #level)?;
936 }
937 }
938 });
939
940 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
941 let name = &f.ident;
942 let ident = format_ident!("field_{}", idx);
943 quote! {
944 #name: #ident
945 }
946 });
947
948 let expanded = quote! {
949 #[derive(hpt::serde::Deserialize, hpt::serde::Serialize)]
950 #[serde(crate = "hpt::serde")]
951 #visibility struct #meta_name #ty_generics #where_clause {
952 #(#meta_fields,)*
953 }
954 impl #impl_generics hpt::Save for #name #ty_generics #where_clause {
955 type Meta = #meta_name #ty_generics;
956 fn __save(
957 data: &Self,
958 file: &mut std::fs::File,
959 len_so_far: &mut usize,
960 global_cnt: &mut usize,
961 compression_algo: hpt::CompressionAlgo,
962 endian: hpt::Endian,
963 level: u32,
964 ) -> std::io::Result<Self::Meta> {
965 #(#call_save)*
966 Ok(Self::Meta {
967 #(#construct_fields),*
968 })
969 }
970 }
971 };
972
973 expanded.into()
974}
975
976#[proc_macro_derive(Load)]
978pub fn impl_load(input: TokenStream) -> TokenStream {
979 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
980 let name = &ast.ident;
981 let meta_name = format_ident!("{}Meta", name);
982 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
983
984 let fields = match &ast.data {
985 syn::Data::Struct(s) => &s.fields,
986 _ => panic!("Load can only be derived for structs"),
987 };
988
989 let call_load = fields.iter().enumerate().map(|(idx, f)| {
990 let name = &f.ident;
991 let ident = format_ident!("field_{}", idx);
992 if let Some(name) = name {
993 quote! {
994 let #ident = self.#name.load(file)?;
995 }
996 } else {
997 quote! {
998 let #ident = self.#idx.load(file)?;
999 }
1000 }
1001 });
1002
1003 let construct_fields = fields.iter().enumerate().map(|(idx, f)| {
1004 let name = &f.ident;
1005 let ident = format_ident!("field_{}", idx);
1006 quote! {
1007 #name: #ident
1008 }
1009 });
1010
1011 let expanded = quote! {
1012 impl #impl_generics hpt::MetaLoad for #meta_name #ty_generics #where_clause {
1013 type Output = #name #ty_generics;
1014 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
1015 use hpt::MetaLoad;
1016 #(#call_load)*
1017 Ok(#name {
1018 #(#construct_fields),*
1019 })
1020 }
1021 }
1022 impl #impl_generics hpt::Load for #name #ty_generics #where_clause {
1023 fn load(path: &str) -> std::io::Result<Self> {
1024 use hpt::MetaLoad;
1025 let meta = hpt::parse_header_compressed::<Self>(path).expect(format!("failed to parse header for {}", stringify!(#name)).as_str());
1026 let mut file = std::fs::File::open(path)?;
1027 meta.load(&mut file)
1028 }
1029 }
1030 };
1031
1032 expanded.into()
1033}
1034
1035#[proc_macro_derive(FromSafeTensors, attributes(map))]
1037pub fn impl_from_safetensors(input: TokenStream) -> TokenStream {
1038 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
1039 let struct_name = &ast.ident;
1040 let fields = match &ast.data {
1041 syn::Data::Struct(s) => &s.fields,
1042 _ => panic!("FromSafeTensors can only be derived for structs"),
1043 };
1044 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
1045 let mut construct_fields = vec![];
1046 for (_, field) in fields.iter().enumerate() {
1047 let ty = &field.ty;
1048 let name = &field.ident;
1049 let mut value_construct = vec![];
1050 let mut from_construct = vec![];
1051 let mut params = vec![];
1052 let mut vec_len = None;
1053 for attr in &field.attrs {
1054 if attr.path().is_ident("map") {
1055 let mut path = None;
1056 let mut value = None;
1057 let mut tensor_name = None;
1058 let mut inner_type = None;
1059 attr.parse_nested_meta(|meta| {
1060 if meta.path.is_ident("path") {
1061 let value: syn::LitStr = meta.value()?.parse()?;
1062 path = Some(value.value());
1063 } else if meta.path.is_ident("value") {
1064 let val: syn::Expr = meta.value()?.parse()?;
1065 value = Some(val);
1066 } else if meta.path.is_ident("tensor_name") {
1067 let value: syn::LitStr = meta.value()?.parse()?;
1068 tensor_name = Some(value.value());
1069 } else if meta.path.is_ident("vec_len") {
1070 let value: syn::LitInt = meta.value()?.parse()?;
1071 vec_len = Some(value.base10_parse::<usize>().unwrap());
1072 } else if meta.path.is_ident("inner_type") {
1073 let value: syn::Ident = meta.value()?.parse()?;
1074 inner_type = Some(value);
1075 }
1076 Ok(())
1077 })
1078 .unwrap_or_else(|err| println!("Failed to parse attribute: {}", err));
1079 params.push((path, value, tensor_name, vec_len, inner_type));
1080 }
1081 }
1082 let param_count = params.len();
1083 for (path, value, tensor_name, vec_len, inner_type) in params {
1084 if let Some(vec_len) = vec_len {
1085 let inner_type = inner_type.expect("inner_type is required for vec");
1086 if let Some(path) = path {
1087 from_construct.push(quote! {
1088 #path => {
1089 let mut vec = vec![];
1090 for i in 0..#vec_len {
1091 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1092 }
1093 vec
1094 }
1095 });
1096 } else {
1097 value_construct.push(quote! {
1098 {
1099 let mut vec = vec![];
1100 for i in 0..#vec_len {
1101 vec.push(<#inner_type as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, i)));
1102 }
1103 vec
1104 }
1105 });
1106 }
1107 } else {
1108 match (path, value, tensor_name) {
1109 (None, None, Some(tensor_name)) => {
1110 value_construct.push(quote! {
1111 <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name)
1112 });
1113 }
1114 (None, Some(value), None) => {
1115 if param_count > 1 {
1116 panic!("value without path means generic assignment, there can only be one value without path");
1117 }
1118 value_construct.push(quote! {
1119 #value
1120 });
1121 }
1122 (Some(path), None, Some(tensor_name)) => {
1123 from_construct.push(quote! {
1124 #path => <#ty as FromSafeTensors>::from_safe_tensors(data, #tensor_name),
1125 });
1126 }
1127 (Some(path), Some(value), None) => {
1128 from_construct.push(quote! {
1129 #path => #value,
1130 });
1131 }
1132
1133 (None, Some(_), Some(_)) | (Some(_), Some(_), Some(_)) => {
1134 panic!("value and tensor_name cannot be used together");
1135 }
1136 (Some(_), None, None) | (None, None, None) => {
1137 panic!("path and value are not present");
1138 }
1139 }
1140 }
1141 }
1142 if !value_construct.is_empty() {
1143 construct_fields.push(quote! {
1144 #name: #(#value_construct)*
1145 });
1146 } else if !from_construct.is_empty() {
1147 construct_fields.push(quote! {
1148 #name: match path {
1149 #(#from_construct)*
1150 _ => panic!("unknown field for field {} in struct {}: `path: {}`", stringify!(#name), stringify!(#struct_name), path),
1151 }
1152 });
1153 } else {
1154 construct_fields.push(quote! {
1155 #name: <#ty as FromSafeTensors>::from_safe_tensors(data, &format!("{}.{}", path, stringify!(#name)))
1156 });
1157 }
1158 }
1159 let expanded = quote! {
1160 impl #impl_generics FromSafeTensors for #struct_name #ty_generics #where_clause {
1161 fn from_safe_tensors(data: &SafeTensors, path: &str) -> Self {
1162 Self {
1163 #(#construct_fields),*
1164 }
1165 }
1166 }
1167 };
1168 expanded.into()
1175}