1use proc_macro::TokenStream;
7use quote::{ToTokens, format_ident, quote};
8use syn::{
9 Attribute, FnArg, GenericParam, Ident, PatType, Signature, Token, Type, TypeParamBound,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote, token,
12};
13
14#[derive(Clone)]
20struct LightFn {
21 attrs: Vec<Attribute>,
22 vis: syn::Visibility,
23 sig: Signature,
24 brace_token: token::Brace,
25 body: proc_macro2::TokenStream,
26}
27
28impl Parse for LightFn {
29 fn parse(input: ParseStream) -> syn::Result<Self> {
30 let attrs = input.call(Attribute::parse_outer)?;
31 let vis: syn::Visibility = input.parse()?;
32 let sig: Signature = input.parse()?;
33 let content;
34 let brace_token = syn::braced!(content in input);
35 let body: proc_macro2::TokenStream = content.parse()?;
36 Ok(LightFn {
37 attrs,
38 vis,
39 sig,
40 brace_token,
41 body,
42 })
43 }
44}
45
46impl ToTokens for LightFn {
47 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
48 for attr in &self.attrs {
49 attr.to_tokens(tokens);
50 }
51 self.vis.to_tokens(tokens);
52 self.sig.to_tokens(tokens);
53 self.brace_token.surround(tokens, |tokens| {
54 self.body.to_tokens(tokens);
55 });
56 }
57}
58
59fn replace_self_in_tokens(
64 tokens: proc_macro2::TokenStream,
65 replacement: &Type,
66) -> proc_macro2::TokenStream {
67 let mut result = proc_macro2::TokenStream::new();
68 for tt in tokens {
69 match tt {
70 proc_macro2::TokenTree::Ident(ref ident) if ident == "Self" => {
71 result.extend(replacement.to_token_stream());
72 }
73 proc_macro2::TokenTree::Group(group) => {
74 let new_stream = replace_self_in_tokens(group.stream(), replacement);
75 let mut new_group = proc_macro2::Group::new(group.delimiter(), new_stream);
76 new_group.set_span(group.span());
77 result.extend(std::iter::once(proc_macro2::TokenTree::Group(new_group)));
78 }
79 other => {
80 result.extend(std::iter::once(other));
81 }
82 }
83 }
84 result
85}
86
87#[derive(Default)]
89struct ArcaneArgs {
90 inline_always: bool,
93 self_type: Option<Type>,
97 stub: bool,
100 nested: bool,
104 import_intrinsics: bool,
106 import_magetypes: bool,
109}
110
111impl Parse for ArcaneArgs {
112 fn parse(input: ParseStream) -> syn::Result<Self> {
113 let mut args = ArcaneArgs::default();
114
115 while !input.is_empty() {
116 let ident: Ident = input.parse()?;
117 match ident.to_string().as_str() {
118 "inline_always" => args.inline_always = true,
119 "stub" => args.stub = true,
120 "nested" => args.nested = true,
121 "import_intrinsics" => args.import_intrinsics = true,
122 "import_magetypes" => args.import_magetypes = true,
123 "_self" => {
124 let _: Token![=] = input.parse()?;
125 args.self_type = Some(input.parse()?);
126 }
127 other => {
128 return Err(syn::Error::new(
129 ident.span(),
130 format!("unknown arcane argument: `{}`", other),
131 ));
132 }
133 }
134 if input.peek(Token![,]) {
136 let _: Token![,] = input.parse()?;
137 }
138 }
139
140 if args.self_type.is_some() {
142 args.nested = true;
143 }
144
145 Ok(args)
146 }
147}
148
149mod generated;
152use generated::{
153 token_to_arch, token_to_features, token_to_magetypes_namespace, trait_to_arch,
154 trait_to_features, trait_to_magetypes_namespace,
155};
156
157enum TokenTypeInfo {
159 Concrete(String),
161 ImplTrait(Vec<String>),
163 Generic(String),
165}
166
167fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
169 match ty {
170 Type::Path(type_path) => {
171 type_path.path.segments.last().map(|seg| {
173 let name = seg.ident.to_string();
174 if token_to_features(&name).is_some() {
176 TokenTypeInfo::Concrete(name)
177 } else {
178 TokenTypeInfo::Generic(name)
180 }
181 })
182 }
183 Type::Reference(type_ref) => {
184 extract_token_type_info(&type_ref.elem)
186 }
187 Type::ImplTrait(impl_trait) => {
188 let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
190 if traits.is_empty() {
191 None
192 } else {
193 Some(TokenTypeInfo::ImplTrait(traits))
194 }
195 }
196 _ => None,
197 }
198}
199
200fn extract_trait_names_from_bounds(
202 bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
203) -> Vec<String> {
204 bounds
205 .iter()
206 .filter_map(|bound| {
207 if let TypeParamBound::Trait(trait_bound) = bound {
208 trait_bound
209 .path
210 .segments
211 .last()
212 .map(|seg| seg.ident.to_string())
213 } else {
214 None
215 }
216 })
217 .collect()
218}
219
220fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
222 for param in &sig.generics.params {
224 if let GenericParam::Type(type_param) = param
225 && type_param.ident == type_name
226 {
227 let traits = extract_trait_names_from_bounds(&type_param.bounds);
228 if !traits.is_empty() {
229 return Some(traits);
230 }
231 }
232 }
233
234 if let Some(where_clause) = &sig.generics.where_clause {
236 for predicate in &where_clause.predicates {
237 if let syn::WherePredicate::Type(pred_type) = predicate
238 && let Type::Path(type_path) = &pred_type.bounded_ty
239 && let Some(seg) = type_path.path.segments.last()
240 && seg.ident == type_name
241 {
242 let traits = extract_trait_names_from_bounds(&pred_type.bounds);
243 if !traits.is_empty() {
244 return Some(traits);
245 }
246 }
247 }
248 }
249
250 None
251}
252
253fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
255 let mut all_features = Vec::new();
256
257 for trait_name in trait_names {
258 if let Some(features) = trait_to_features(trait_name) {
259 for &feature in features {
260 if !all_features.contains(&feature) {
261 all_features.push(feature);
262 }
263 }
264 }
265 }
266
267 if all_features.is_empty() {
268 None
269 } else {
270 Some(all_features)
271 }
272}
273
274const FEATURELESS_TRAIT_NAMES: &[&str] = &["SimdToken", "IntoConcreteToken"];
278
279fn find_featureless_trait(trait_names: &[String]) -> Option<&'static str> {
282 for name in trait_names {
283 for &featureless in FEATURELESS_TRAIT_NAMES {
284 if name == featureless {
285 return Some(featureless);
286 }
287 }
288 }
289 None
290}
291
292fn diagnose_featureless_token(sig: &Signature) -> Option<&'static str> {
295 for arg in &sig.inputs {
296 if let FnArg::Typed(PatType { ty, .. }) = arg
297 && let Some(info) = extract_token_type_info(ty)
298 {
299 match &info {
300 TokenTypeInfo::ImplTrait(names) => {
301 if let Some(name) = find_featureless_trait(names) {
302 return Some(name);
303 }
304 }
305 TokenTypeInfo::Generic(type_name) => {
306 let as_vec = vec![type_name.clone()];
309 if let Some(name) = find_featureless_trait(&as_vec) {
310 return Some(name);
311 }
312 if let Some(bounds) = find_generic_bounds(sig, type_name)
314 && let Some(name) = find_featureless_trait(&bounds)
315 {
316 return Some(name);
317 }
318 }
319 TokenTypeInfo::Concrete(_) => {}
320 }
321 }
322 }
323 None
324}
325
326struct TokenParamInfo {
328 ident: Ident,
330 features: Vec<&'static str>,
332 target_arch: Option<&'static str>,
334 token_type_name: Option<String>,
336 magetypes_namespace: Option<&'static str>,
338}
339
340fn traits_to_magetypes_namespace(trait_names: &[String]) -> Option<&'static str> {
343 for name in trait_names {
344 if let Some(ns) = trait_to_magetypes_namespace(name) {
345 return Some(ns);
346 }
347 }
348 None
349}
350
351fn traits_to_arch(trait_names: &[String]) -> Option<&'static str> {
353 for name in trait_names {
354 if let Some(arch) = trait_to_arch(name) {
355 return Some(arch);
356 }
357 }
358 None
359}
360
361fn find_token_param(sig: &Signature) -> Option<TokenParamInfo> {
363 for arg in &sig.inputs {
364 match arg {
365 FnArg::Receiver(_) => {
366 continue;
372 }
373 FnArg::Typed(PatType { pat, ty, .. }) => {
374 if let Some(info) = extract_token_type_info(ty) {
375 let (features, arch, token_name, mage_ns) = match info {
376 TokenTypeInfo::Concrete(ref name) => {
377 let features = token_to_features(name).map(|f| f.to_vec());
378 let arch = token_to_arch(name);
379 let ns = token_to_magetypes_namespace(name);
380 (features, arch, Some(name.clone()), ns)
381 }
382 TokenTypeInfo::ImplTrait(ref trait_names) => {
383 let ns = traits_to_magetypes_namespace(trait_names);
384 let arch = traits_to_arch(trait_names);
385 (traits_to_features(trait_names), arch, None, ns)
386 }
387 TokenTypeInfo::Generic(type_name) => {
388 let bounds = find_generic_bounds(sig, &type_name);
390 let features = bounds.as_ref().and_then(|t| traits_to_features(t));
391 let ns = bounds
392 .as_ref()
393 .and_then(|t| traits_to_magetypes_namespace(t));
394 let arch = bounds.as_ref().and_then(|t| traits_to_arch(t));
395 (features, arch, None, ns)
396 }
397 };
398
399 if let Some(features) = features {
400 let ident = match pat.as_ref() {
402 syn::Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
403 syn::Pat::Wild(w) => {
404 Some(Ident::new("__archmage_token", w.underscore_token.span))
405 }
406 _ => None,
407 };
408 if let Some(ident) = ident {
409 return Some(TokenParamInfo {
410 ident,
411 features,
412 target_arch: arch,
413 token_type_name: token_name,
414 magetypes_namespace: mage_ns,
415 });
416 }
417 }
418 }
419 }
420 }
421 }
422 None
423}
424
425enum SelfReceiver {
427 Owned,
429 Ref,
431 RefMut,
433}
434
435fn generate_imports(
440 target_arch: Option<&str>,
441 magetypes_namespace: Option<&str>,
442 import_intrinsics: bool,
443 import_magetypes: bool,
444) -> proc_macro2::TokenStream {
445 let mut imports = proc_macro2::TokenStream::new();
446
447 if import_intrinsics && let Some(arch) = target_arch {
448 let arch_ident = format_ident!("{}", arch);
449 imports.extend(quote! {
450 #[allow(unused_imports)]
451 use archmage::intrinsics::#arch_ident::*;
452 });
453 }
455
456 if import_magetypes && let Some(ns) = magetypes_namespace {
457 let ns_ident = format_ident!("{}", ns);
458 imports.extend(quote! {
459 #[allow(unused_imports)]
460 use magetypes::simd::#ns_ident::*;
461 #[allow(unused_imports)]
462 use magetypes::simd::backends::*;
463 });
464 }
465
466 imports
467}
468
469fn arcane_impl(mut input_fn: LightFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
471 let has_self_receiver = input_fn
473 .sig
474 .inputs
475 .first()
476 .map(|arg| matches!(arg, FnArg::Receiver(_)))
477 .unwrap_or(false);
478
479 if has_self_receiver && args.nested && args.self_type.is_none() {
483 let msg = format!(
484 "{} with self receiver in nested mode requires `_self = Type` argument.\n\
485 Example: #[{}(nested, _self = MyType)]\n\
486 Use `_self` (not `self`) in the function body to refer to self.\n\
487 \n\
488 Alternatively, remove `nested` to use sibling expansion (default), \
489 which handles self/Self naturally.",
490 macro_name, macro_name
491 );
492 return syn::Error::new_spanned(&input_fn.sig, msg)
493 .to_compile_error()
494 .into();
495 }
496
497 let TokenParamInfo {
499 ident: _token_ident,
500 features,
501 target_arch,
502 token_type_name,
503 magetypes_namespace,
504 } = match find_token_param(&input_fn.sig) {
505 Some(result) => result,
506 None => {
507 if let Some(trait_name) = diagnose_featureless_token(&input_fn.sig) {
509 let msg = format!(
510 "`{trait_name}` cannot be used as a token bound in #[{macro_name}] \
511 because it doesn't specify any CPU features.\n\
512 \n\
513 #[{macro_name}] needs concrete features to generate #[target_feature]. \
514 Use a concrete token or a feature trait:\n\
515 \n\
516 Concrete tokens: X64V3Token, Desktop64, NeonToken, Arm64V2Token, ...\n\
517 Feature traits: impl HasX64V2, impl HasNeon, impl HasArm64V3, ..."
518 );
519 return syn::Error::new_spanned(&input_fn.sig, msg)
520 .to_compile_error()
521 .into();
522 }
523 let msg = format!(
524 "{} requires a token parameter. Supported forms:\n\
525 - Concrete: `token: X64V3Token`\n\
526 - impl Trait: `token: impl HasX64V2`\n\
527 - Generic: `fn foo<T: HasX64V2>(token: T, ...)`\n\
528 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasNeon, ...)`",
529 macro_name, macro_name
530 );
531 return syn::Error::new_spanned(&input_fn.sig, msg)
532 .to_compile_error()
533 .into();
534 }
535 };
536
537 let body_imports = generate_imports(
539 target_arch,
540 magetypes_namespace,
541 args.import_intrinsics,
542 args.import_magetypes,
543 );
544 if !body_imports.is_empty() {
545 let original_body = &input_fn.body;
546 input_fn.body = quote! {
547 #body_imports
548 #original_body
549 };
550 }
551
552 let target_feature_attrs: Vec<Attribute> = features
554 .iter()
555 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
556 .collect();
557
558 let mut wild_rename_counter = 0u32;
560 for arg in &mut input_fn.sig.inputs {
561 if let FnArg::Typed(pat_type) = arg
562 && matches!(pat_type.pat.as_ref(), syn::Pat::Wild(_))
563 {
564 let ident = format_ident!("__archmage_wild_{}", wild_rename_counter);
565 wild_rename_counter += 1;
566 *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
567 attrs: vec![],
568 by_ref: None,
569 mutability: None,
570 ident,
571 subpat: None,
572 });
573 }
574 }
575
576 let inline_attr: Attribute = if args.inline_always {
578 parse_quote!(#[inline(always)])
579 } else {
580 parse_quote!(#[inline])
581 };
582
583 if target_arch == Some("wasm32") {
587 return arcane_impl_wasm_safe(
588 input_fn,
589 &args,
590 token_type_name,
591 target_feature_attrs,
592 inline_attr,
593 );
594 }
595
596 if args.nested {
597 arcane_impl_nested(
598 input_fn,
599 &args,
600 target_arch,
601 token_type_name,
602 target_feature_attrs,
603 inline_attr,
604 )
605 } else {
606 arcane_impl_sibling(
607 input_fn,
608 &args,
609 target_arch,
610 token_type_name,
611 target_feature_attrs,
612 inline_attr,
613 )
614 }
615}
616
617fn arcane_impl_wasm_safe(
626 input_fn: LightFn,
627 args: &ArcaneArgs,
628 token_type_name: Option<String>,
629 target_feature_attrs: Vec<Attribute>,
630 inline_attr: Attribute,
631) -> TokenStream {
632 let vis = &input_fn.vis;
633 let sig = &input_fn.sig;
634 let fn_name = &sig.ident;
635 let attrs = &input_fn.attrs;
636
637 let token_type_str = token_type_name.as_deref().unwrap_or("UnknownToken");
638
639 let body = if args.self_type.is_some() {
643 let original_body = &input_fn.body;
644 quote! {
645 let _self = self;
646 #original_body
647 }
648 } else {
649 input_fn.body.clone()
650 };
651
652 let mut new_attrs = target_feature_attrs;
654 new_attrs.push(inline_attr);
655 for attr in attrs {
656 new_attrs.push(attr.clone());
657 }
658
659 let stub = if args.stub {
660 let stub_args: Vec<proc_macro2::TokenStream> = sig
662 .inputs
663 .iter()
664 .filter_map(|arg| match arg {
665 FnArg::Typed(pat_type) => {
666 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
667 let ident = &pat_ident.ident;
668 Some(quote!(#ident))
669 } else {
670 None
671 }
672 }
673 FnArg::Receiver(_) => None,
674 })
675 .collect();
676
677 quote! {
678 #[cfg(not(target_arch = "wasm32"))]
679 #vis #sig {
680 let _ = (#(#stub_args),*);
681 unreachable!(
682 "BUG: {}() was called but requires {} (target_arch = \"wasm32\"). \
683 {}::summon() returns None on this architecture, so this function \
684 is unreachable in safe code. If you used forge_token_dangerously(), \
685 that is the bug.",
686 stringify!(#fn_name),
687 #token_type_str,
688 #token_type_str,
689 )
690 }
691 }
692 } else {
693 quote! {}
694 };
695
696 let expanded = quote! {
697 #[cfg(target_arch = "wasm32")]
698 #(#new_attrs)*
699 #vis #sig {
700 #body
701 }
702
703 #stub
704 };
705
706 expanded.into()
707}
708
709fn arcane_impl_sibling(
732 input_fn: LightFn,
733 args: &ArcaneArgs,
734 target_arch: Option<&str>,
735 token_type_name: Option<String>,
736 target_feature_attrs: Vec<Attribute>,
737 inline_attr: Attribute,
738) -> TokenStream {
739 let vis = &input_fn.vis;
740 let sig = &input_fn.sig;
741 let fn_name = &sig.ident;
742 let generics = &sig.generics;
743 let where_clause = &generics.where_clause;
744 let inputs = &sig.inputs;
745 let output = &sig.output;
746 let body = &input_fn.body;
747 let attrs = &input_fn.attrs;
748
749 let sibling_name = format_ident!("__arcane_{}", fn_name);
750
751 let has_self_receiver = inputs
753 .first()
754 .map(|arg| matches!(arg, FnArg::Receiver(_)))
755 .unwrap_or(false);
756
757 let sibling_sig_inputs = inputs;
761
762 let sibling_call = if has_self_receiver {
764 let other_args: Vec<proc_macro2::TokenStream> = inputs
766 .iter()
767 .skip(1) .filter_map(|arg| {
769 if let FnArg::Typed(pat_type) = arg
770 && let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref()
771 {
772 let ident = &pat_ident.ident;
773 Some(quote!(#ident))
774 } else {
775 None
776 }
777 })
778 .collect();
779 quote! { self.#sibling_name(#(#other_args),*) }
780 } else {
781 let all_args: Vec<proc_macro2::TokenStream> = inputs
783 .iter()
784 .filter_map(|arg| {
785 if let FnArg::Typed(pat_type) = arg
786 && let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref()
787 {
788 let ident = &pat_ident.ident;
789 Some(quote!(#ident))
790 } else {
791 None
792 }
793 })
794 .collect();
795 quote! { #sibling_name(#(#all_args),*) }
796 };
797
798 let stub_args: Vec<proc_macro2::TokenStream> = inputs
800 .iter()
801 .filter_map(|arg| match arg {
802 FnArg::Typed(pat_type) => {
803 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
804 let ident = &pat_ident.ident;
805 Some(quote!(#ident))
806 } else {
807 None
808 }
809 }
810 FnArg::Receiver(_) => None, })
812 .collect();
813
814 let token_type_str = token_type_name.as_deref().unwrap_or("UnknownToken");
815
816 let expanded = if let Some(arch) = target_arch {
817 let sibling_fn = quote! {
821 #[cfg(target_arch = #arch)]
822 #[doc(hidden)]
823 #(#target_feature_attrs)*
824 #inline_attr
825 fn #sibling_name #generics (#sibling_sig_inputs) #output #where_clause {
826 #body
827 }
828 };
829
830 let wrapper_fn = quote! {
834 #[cfg(target_arch = #arch)]
835 #(#attrs)*
836 #vis #sig {
837 unsafe { #sibling_call }
842 }
843 };
844
845 let stub = if args.stub {
847 quote! {
848 #[cfg(not(target_arch = #arch))]
849 #(#attrs)*
850 #vis #sig {
851 let _ = (#(#stub_args),*);
852 unreachable!(
853 "BUG: {}() was called but requires {} (target_arch = \"{}\"). \
854 {}::summon() returns None on this architecture, so this function \
855 is unreachable in safe code. If you used forge_token_dangerously(), \
856 that is the bug.",
857 stringify!(#fn_name),
858 #token_type_str,
859 #arch,
860 #token_type_str,
861 )
862 }
863 }
864 } else {
865 quote! {}
866 };
867
868 quote! {
869 #sibling_fn
870 #wrapper_fn
871 #stub
872 }
873 } else {
874 let sibling_fn = quote! {
877 #[doc(hidden)]
878 #(#target_feature_attrs)*
879 #inline_attr
880 fn #sibling_name #generics (#sibling_sig_inputs) #output #where_clause {
881 #body
882 }
883 };
884
885 let wrapper_fn = quote! {
886 #(#attrs)*
887 #vis #sig {
888 unsafe { #sibling_call }
890 }
891 };
892
893 quote! {
894 #sibling_fn
895 #wrapper_fn
896 }
897 };
898
899 expanded.into()
900}
901
902fn arcane_impl_nested(
908 input_fn: LightFn,
909 args: &ArcaneArgs,
910 target_arch: Option<&str>,
911 token_type_name: Option<String>,
912 target_feature_attrs: Vec<Attribute>,
913 inline_attr: Attribute,
914) -> TokenStream {
915 let vis = &input_fn.vis;
916 let sig = &input_fn.sig;
917 let fn_name = &sig.ident;
918 let generics = &sig.generics;
919 let where_clause = &generics.where_clause;
920 let inputs = &sig.inputs;
921 let output = &sig.output;
922 let body = &input_fn.body;
923 let attrs = &input_fn.attrs;
924
925 let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
927 FnArg::Receiver(receiver) => {
928 if receiver.reference.is_none() {
929 Some(SelfReceiver::Owned)
930 } else if receiver.mutability.is_some() {
931 Some(SelfReceiver::RefMut)
932 } else {
933 Some(SelfReceiver::Ref)
934 }
935 }
936 _ => None,
937 });
938
939 let inner_params: Vec<proc_macro2::TokenStream> = inputs
943 .iter()
944 .map(|arg| match arg {
945 FnArg::Receiver(_) => {
946 let self_ty = args.self_type.as_ref().unwrap();
948 match self_receiver_kind.as_ref().unwrap() {
949 SelfReceiver::Owned => quote!(_self: #self_ty),
950 SelfReceiver::Ref => quote!(_self: &#self_ty),
951 SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
952 }
953 }
954 FnArg::Typed(pat_type) => {
955 if let Some(ref self_ty) = args.self_type {
956 replace_self_in_tokens(quote!(#pat_type), self_ty)
957 } else {
958 quote!(#pat_type)
959 }
960 }
961 })
962 .collect();
963
964 let inner_args: Vec<proc_macro2::TokenStream> = inputs
966 .iter()
967 .filter_map(|arg| match arg {
968 FnArg::Typed(pat_type) => {
969 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
970 let ident = &pat_ident.ident;
971 Some(quote!(#ident))
972 } else {
973 None
974 }
975 }
976 FnArg::Receiver(_) => Some(quote!(self)), })
978 .collect();
979
980 let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
981
982 let (inner_output, inner_body, inner_where_clause): (
984 proc_macro2::TokenStream,
985 proc_macro2::TokenStream,
986 proc_macro2::TokenStream,
987 ) = if let Some(ref self_ty) = args.self_type {
988 let transformed_output = replace_self_in_tokens(output.to_token_stream(), self_ty);
989 let transformed_body = replace_self_in_tokens(body.clone(), self_ty);
990 let transformed_where = where_clause
991 .as_ref()
992 .map(|wc| replace_self_in_tokens(wc.to_token_stream(), self_ty))
993 .unwrap_or_default();
994 (transformed_output, transformed_body, transformed_where)
995 } else {
996 (
997 output.to_token_stream(),
998 body.clone(),
999 where_clause
1000 .as_ref()
1001 .map(|wc| wc.to_token_stream())
1002 .unwrap_or_default(),
1003 )
1004 };
1005
1006 let token_type_str = token_type_name.as_deref().unwrap_or("UnknownToken");
1007 let expanded = if let Some(arch) = target_arch {
1008 let stub = if args.stub {
1009 quote! {
1010 #[cfg(not(target_arch = #arch))]
1012 #(#attrs)*
1013 #vis #sig {
1014 let _ = (#(#inner_args),*);
1015 unreachable!(
1016 "BUG: {}() was called but requires {} (target_arch = \"{}\"). \
1017 {}::summon() returns None on this architecture, so this function \
1018 is unreachable in safe code. If you used forge_token_dangerously(), \
1019 that is the bug.",
1020 stringify!(#fn_name),
1021 #token_type_str,
1022 #arch,
1023 #token_type_str,
1024 )
1025 }
1026 }
1027 } else {
1028 quote! {}
1029 };
1030
1031 quote! {
1032 #[cfg(target_arch = #arch)]
1034 #(#attrs)*
1035 #vis #sig {
1036 #(#target_feature_attrs)*
1037 #inline_attr
1038 fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #inner_where_clause {
1039 #inner_body
1040 }
1041
1042 unsafe { #inner_fn_name(#(#inner_args),*) }
1044 }
1045
1046 #stub
1047 }
1048 } else {
1049 quote! {
1051 #(#attrs)*
1052 #vis #sig {
1053 #(#target_feature_attrs)*
1054 #inline_attr
1055 fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #inner_where_clause {
1056 #inner_body
1057 }
1058
1059 unsafe { #inner_fn_name(#(#inner_args),*) }
1061 }
1062 }
1063 };
1064
1065 expanded.into()
1066}
1067
1068#[proc_macro_attribute]
1228pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
1229 let args = parse_macro_input!(attr as ArcaneArgs);
1230 let input_fn = parse_macro_input!(item as LightFn);
1231 arcane_impl(input_fn, "arcane", args)
1232}
1233
1234#[proc_macro_attribute]
1238#[doc(hidden)]
1239pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
1240 let args = parse_macro_input!(attr as ArcaneArgs);
1241 let input_fn = parse_macro_input!(item as LightFn);
1242 arcane_impl(input_fn, "simd_fn", args)
1243}
1244
1245#[proc_macro_attribute]
1258pub fn token_target_features_boundary(attr: TokenStream, item: TokenStream) -> TokenStream {
1259 let args = parse_macro_input!(attr as ArcaneArgs);
1260 let input_fn = parse_macro_input!(item as LightFn);
1261 arcane_impl(input_fn, "token_target_features_boundary", args)
1262}
1263
1264#[proc_macro_attribute]
1335pub fn rite(attr: TokenStream, item: TokenStream) -> TokenStream {
1336 let args = parse_macro_input!(attr as RiteArgs);
1337 let input_fn = parse_macro_input!(item as LightFn);
1338 rite_impl(input_fn, args)
1339}
1340
1341#[proc_macro_attribute]
1352pub fn token_target_features(attr: TokenStream, item: TokenStream) -> TokenStream {
1353 let args = parse_macro_input!(attr as RiteArgs);
1354 let input_fn = parse_macro_input!(item as LightFn);
1355 rite_impl(input_fn, args)
1356}
1357
1358#[derive(Default)]
1360struct RiteArgs {
1361 stub: bool,
1364 import_intrinsics: bool,
1366 import_magetypes: bool,
1369}
1370
1371impl Parse for RiteArgs {
1372 fn parse(input: ParseStream) -> syn::Result<Self> {
1373 let mut args = RiteArgs::default();
1374
1375 while !input.is_empty() {
1376 let ident: Ident = input.parse()?;
1377 match ident.to_string().as_str() {
1378 "stub" => args.stub = true,
1379 "import_intrinsics" => args.import_intrinsics = true,
1380 "import_magetypes" => args.import_magetypes = true,
1381 other => {
1382 return Err(syn::Error::new(
1383 ident.span(),
1384 format!(
1385 "unknown rite argument: `{}`. Supported: `stub`, \
1386 `import_intrinsics`, `import_magetypes`.",
1387 other
1388 ),
1389 ));
1390 }
1391 }
1392 if input.peek(Token![,]) {
1393 let _: Token![,] = input.parse()?;
1394 }
1395 }
1396
1397 Ok(args)
1398 }
1399}
1400
1401fn rite_impl(mut input_fn: LightFn, args: RiteArgs) -> TokenStream {
1403 let TokenParamInfo {
1405 features,
1406 target_arch,
1407 magetypes_namespace,
1408 ..
1409 } = match find_token_param(&input_fn.sig) {
1410 Some(result) => result,
1411 None => {
1412 if let Some(trait_name) = diagnose_featureless_token(&input_fn.sig) {
1414 let msg = format!(
1415 "`{trait_name}` cannot be used as a token bound in #[rite] \
1416 because it doesn't specify any CPU features.\n\
1417 \n\
1418 #[rite] needs concrete features to generate #[target_feature]. \
1419 Use a concrete token or a feature trait:\n\
1420 \n\
1421 Concrete tokens: X64V3Token, Desktop64, NeonToken, Arm64V2Token, ...\n\
1422 Feature traits: impl HasX64V2, impl HasNeon, impl HasArm64V3, ..."
1423 );
1424 return syn::Error::new_spanned(&input_fn.sig, msg)
1425 .to_compile_error()
1426 .into();
1427 }
1428 let msg = "rite requires a token parameter. Supported forms:\n\
1429 - Concrete: `token: X64V3Token`\n\
1430 - impl Trait: `token: impl HasX64V2`\n\
1431 - Generic: `fn foo<T: HasX64V2>(token: T, ...)`";
1432 return syn::Error::new_spanned(&input_fn.sig, msg)
1433 .to_compile_error()
1434 .into();
1435 }
1436 };
1437
1438 let target_feature_attrs: Vec<Attribute> = features
1440 .iter()
1441 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
1442 .collect();
1443
1444 let inline_attr: Attribute = parse_quote!(#[inline]);
1446
1447 let mut new_attrs = target_feature_attrs;
1449 new_attrs.push(inline_attr);
1450 new_attrs.append(&mut input_fn.attrs);
1451 input_fn.attrs = new_attrs;
1452
1453 let body_imports = generate_imports(
1455 target_arch,
1456 magetypes_namespace,
1457 args.import_intrinsics,
1458 args.import_magetypes,
1459 );
1460 if !body_imports.is_empty() {
1461 let original_body = &input_fn.body;
1462 input_fn.body = quote! {
1463 #body_imports
1464 #original_body
1465 };
1466 }
1467
1468 if let Some(arch) = target_arch {
1470 let vis = &input_fn.vis;
1471 let sig = &input_fn.sig;
1472 let attrs = &input_fn.attrs;
1473 let body = &input_fn.body;
1474
1475 let stub = if args.stub {
1476 quote! {
1477 #[cfg(not(target_arch = #arch))]
1478 #vis #sig {
1479 unreachable!(concat!(
1480 "This function requires ",
1481 #arch,
1482 " architecture"
1483 ))
1484 }
1485 }
1486 } else {
1487 quote! {}
1488 };
1489
1490 quote! {
1491 #[cfg(target_arch = #arch)]
1492 #(#attrs)*
1493 #vis #sig {
1494 #body
1495 }
1496
1497 #stub
1498 }
1499 .into()
1500 } else {
1501 quote!(#input_fn).into()
1503 }
1504}
1505
1506#[proc_macro_attribute]
1565pub fn magetypes(attr: TokenStream, item: TokenStream) -> TokenStream {
1566 let input_fn = parse_macro_input!(item as LightFn);
1567
1568 let tier_names: Vec<String> = if attr.is_empty() {
1570 DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect()
1571 } else {
1572 let parser = |input: ParseStream| input.parse_terminated(Ident::parse, Token![,]);
1573 let idents = match syn::parse::Parser::parse(parser, attr) {
1574 Ok(p) => p,
1575 Err(e) => return e.to_compile_error().into(),
1576 };
1577 idents.iter().map(|i| i.to_string()).collect()
1578 };
1579
1580 let tiers = match resolve_tiers(&tier_names, input_fn.sig.ident.span()) {
1581 Ok(t) => t,
1582 Err(e) => return e.to_compile_error().into(),
1583 };
1584
1585 magetypes_impl(input_fn, &tiers)
1586}
1587
1588fn magetypes_impl(mut input_fn: LightFn, tiers: &[&TierDescriptor]) -> TokenStream {
1589 input_fn
1592 .attrs
1593 .retain(|attr| !attr.path().is_ident("arcane") && !attr.path().is_ident("rite"));
1594
1595 let fn_name = &input_fn.sig.ident;
1596 let fn_attrs = &input_fn.attrs;
1597
1598 let fn_str = input_fn.to_token_stream().to_string();
1600
1601 let mut variants = Vec::new();
1602
1603 for tier in tiers {
1604 let suffixed_name = format!("{}_{}", fn_name, tier.suffix);
1606
1607 let mut variant_str = fn_str.clone();
1609
1610 variant_str = variant_str.replacen(&fn_name.to_string(), &suffixed_name, 1);
1612
1613 variant_str = variant_str.replace("Token", tier.token_path);
1615
1616 let variant_tokens: proc_macro2::TokenStream = match variant_str.parse() {
1618 Ok(t) => t,
1619 Err(e) => {
1620 return syn::Error::new_spanned(
1621 &input_fn,
1622 format!(
1623 "Failed to parse generated variant `{}`: {}",
1624 suffixed_name, e
1625 ),
1626 )
1627 .to_compile_error()
1628 .into();
1629 }
1630 };
1631
1632 let cfg_guard = match (tier.target_arch, tier.cargo_feature) {
1634 (Some(arch), Some(feature)) => {
1635 quote! { #[cfg(all(target_arch = #arch, feature = #feature))] }
1636 }
1637 (Some(arch), None) => {
1638 quote! { #[cfg(target_arch = #arch)] }
1639 }
1640 (None, Some(feature)) => {
1641 quote! { #[cfg(feature = #feature)] }
1642 }
1643 (None, None) => {
1644 quote! {} }
1646 };
1647
1648 variants.push(if tier.name != "scalar" {
1649 quote! {
1651 #cfg_guard
1652 #[archmage::arcane]
1653 #variant_tokens
1654 }
1655 } else {
1656 quote! {
1657 #cfg_guard
1658 #variant_tokens
1659 }
1660 });
1661 }
1662
1663 let filtered_attrs: Vec<_> = fn_attrs
1665 .iter()
1666 .filter(|a| !a.path().is_ident("magetypes"))
1667 .collect();
1668
1669 let output = quote! {
1670 #(#filtered_attrs)*
1671 #(#variants)*
1672 };
1673
1674 output.into()
1675}
1676
1677struct TierDescriptor {
1687 name: &'static str,
1689 suffix: &'static str,
1691 token_path: &'static str,
1693 as_method: &'static str,
1695 target_arch: Option<&'static str>,
1697 cargo_feature: Option<&'static str>,
1699 priority: u32,
1701}
1702
1703const ALL_TIERS: &[TierDescriptor] = &[
1705 TierDescriptor {
1707 name: "v4x",
1708 suffix: "v4x",
1709 token_path: "archmage::X64V4xToken",
1710 as_method: "as_x64v4x",
1711 target_arch: Some("x86_64"),
1712 cargo_feature: Some("avx512"),
1713 priority: 50,
1714 },
1715 TierDescriptor {
1716 name: "v4",
1717 suffix: "v4",
1718 token_path: "archmage::X64V4Token",
1719 as_method: "as_x64v4",
1720 target_arch: Some("x86_64"),
1721 cargo_feature: Some("avx512"),
1722 priority: 40,
1723 },
1724 TierDescriptor {
1725 name: "v3_crypto",
1726 suffix: "v3_crypto",
1727 token_path: "archmage::X64V3CryptoToken",
1728 as_method: "as_x64v3_crypto",
1729 target_arch: Some("x86_64"),
1730 cargo_feature: None,
1731 priority: 35,
1732 },
1733 TierDescriptor {
1734 name: "v3",
1735 suffix: "v3",
1736 token_path: "archmage::X64V3Token",
1737 as_method: "as_x64v3",
1738 target_arch: Some("x86_64"),
1739 cargo_feature: None,
1740 priority: 30,
1741 },
1742 TierDescriptor {
1743 name: "x64_crypto",
1744 suffix: "x64_crypto",
1745 token_path: "archmage::X64CryptoToken",
1746 as_method: "as_x64_crypto",
1747 target_arch: Some("x86_64"),
1748 cargo_feature: None,
1749 priority: 25,
1750 },
1751 TierDescriptor {
1752 name: "v2",
1753 suffix: "v2",
1754 token_path: "archmage::X64V2Token",
1755 as_method: "as_x64v2",
1756 target_arch: Some("x86_64"),
1757 cargo_feature: None,
1758 priority: 20,
1759 },
1760 TierDescriptor {
1761 name: "v1",
1762 suffix: "v1",
1763 token_path: "archmage::X64V1Token",
1764 as_method: "as_x64v1",
1765 target_arch: Some("x86_64"),
1766 cargo_feature: None,
1767 priority: 10,
1768 },
1769 TierDescriptor {
1771 name: "arm_v3",
1772 suffix: "arm_v3",
1773 token_path: "archmage::Arm64V3Token",
1774 as_method: "as_arm_v3",
1775 target_arch: Some("aarch64"),
1776 cargo_feature: None,
1777 priority: 50,
1778 },
1779 TierDescriptor {
1780 name: "arm_v2",
1781 suffix: "arm_v2",
1782 token_path: "archmage::Arm64V2Token",
1783 as_method: "as_arm_v2",
1784 target_arch: Some("aarch64"),
1785 cargo_feature: None,
1786 priority: 40,
1787 },
1788 TierDescriptor {
1789 name: "neon_aes",
1790 suffix: "neon_aes",
1791 token_path: "archmage::NeonAesToken",
1792 as_method: "as_neon_aes",
1793 target_arch: Some("aarch64"),
1794 cargo_feature: None,
1795 priority: 30,
1796 },
1797 TierDescriptor {
1798 name: "neon_sha3",
1799 suffix: "neon_sha3",
1800 token_path: "archmage::NeonSha3Token",
1801 as_method: "as_neon_sha3",
1802 target_arch: Some("aarch64"),
1803 cargo_feature: None,
1804 priority: 30,
1805 },
1806 TierDescriptor {
1807 name: "neon_crc",
1808 suffix: "neon_crc",
1809 token_path: "archmage::NeonCrcToken",
1810 as_method: "as_neon_crc",
1811 target_arch: Some("aarch64"),
1812 cargo_feature: None,
1813 priority: 30,
1814 },
1815 TierDescriptor {
1816 name: "neon",
1817 suffix: "neon",
1818 token_path: "archmage::NeonToken",
1819 as_method: "as_neon",
1820 target_arch: Some("aarch64"),
1821 cargo_feature: None,
1822 priority: 20,
1823 },
1824 TierDescriptor {
1826 name: "wasm128_relaxed",
1827 suffix: "wasm128_relaxed",
1828 token_path: "archmage::Wasm128RelaxedToken",
1829 as_method: "as_wasm128_relaxed",
1830 target_arch: Some("wasm32"),
1831 cargo_feature: None,
1832 priority: 21,
1833 },
1834 TierDescriptor {
1835 name: "wasm128",
1836 suffix: "wasm128",
1837 token_path: "archmage::Wasm128Token",
1838 as_method: "as_wasm128",
1839 target_arch: Some("wasm32"),
1840 cargo_feature: None,
1841 priority: 20,
1842 },
1843 TierDescriptor {
1845 name: "scalar",
1846 suffix: "scalar",
1847 token_path: "archmage::ScalarToken",
1848 as_method: "as_scalar",
1849 target_arch: None,
1850 cargo_feature: None,
1851 priority: 0,
1852 },
1853];
1854
1855const DEFAULT_TIER_NAMES: &[&str] = &["v4", "v3", "neon", "wasm128", "scalar"];
1857
1858fn find_tier(name: &str) -> Option<&'static TierDescriptor> {
1860 ALL_TIERS.iter().find(|t| t.name == name)
1861}
1862
1863fn resolve_tiers(
1866 tier_names: &[String],
1867 error_span: proc_macro2::Span,
1868) -> syn::Result<Vec<&'static TierDescriptor>> {
1869 let mut tiers = Vec::new();
1870 for name in tier_names {
1871 match find_tier(name) {
1872 Some(tier) => tiers.push(tier),
1873 None => {
1874 let known: Vec<&str> = ALL_TIERS.iter().map(|t| t.name).collect();
1875 return Err(syn::Error::new(
1876 error_span,
1877 format!("unknown tier `{}`. Known tiers: {}", name, known.join(", ")),
1878 ));
1879 }
1880 }
1881 }
1882
1883 if !tiers.iter().any(|t| t.name == "scalar") {
1885 tiers.push(find_tier("scalar").unwrap());
1886 }
1887
1888 tiers.sort_by(|a, b| b.priority.cmp(&a.priority));
1890
1891 Ok(tiers)
1892}
1893
1894struct IncantInput {
1900 func_path: syn::Path,
1902 args: Vec<syn::Expr>,
1904 with_token: Option<syn::Expr>,
1906 tiers: Option<(Vec<String>, proc_macro2::Span)>,
1908}
1909
1910fn suffix_path(path: &syn::Path, suffix: &str) -> syn::Path {
1913 let mut suffixed = path.clone();
1914 if let Some(last) = suffixed.segments.last_mut() {
1915 last.ident = format_ident!("{}_{}", last.ident, suffix);
1916 }
1917 suffixed
1918}
1919
1920impl Parse for IncantInput {
1921 fn parse(input: ParseStream) -> syn::Result<Self> {
1922 let func_path: syn::Path = input.parse()?;
1924
1925 let content;
1927 syn::parenthesized!(content in input);
1928 let args = content
1929 .parse_terminated(syn::Expr::parse, Token![,])?
1930 .into_iter()
1931 .collect();
1932
1933 let with_token = if input.peek(Ident) {
1935 let kw: Ident = input.parse()?;
1936 if kw != "with" {
1937 return Err(syn::Error::new_spanned(kw, "expected `with` keyword"));
1938 }
1939 Some(input.parse()?)
1940 } else {
1941 None
1942 };
1943
1944 let tiers = if input.peek(Token![,]) {
1946 let _: Token![,] = input.parse()?;
1947 let bracket_content;
1948 let bracket = syn::bracketed!(bracket_content in input);
1949 let tier_idents = bracket_content.parse_terminated(Ident::parse, Token![,])?;
1950 let tier_names: Vec<String> = tier_idents.iter().map(|i| i.to_string()).collect();
1951 Some((tier_names, bracket.span.join()))
1952 } else {
1953 None
1954 };
1955
1956 Ok(IncantInput {
1957 func_path,
1958 args,
1959 with_token,
1960 tiers,
1961 })
1962 }
1963}
1964
1965#[proc_macro]
2034pub fn incant(input: TokenStream) -> TokenStream {
2035 let input = parse_macro_input!(input as IncantInput);
2036 incant_impl(input)
2037}
2038
2039#[proc_macro]
2041pub fn simd_route(input: TokenStream) -> TokenStream {
2042 let input = parse_macro_input!(input as IncantInput);
2043 incant_impl(input)
2044}
2045
2046#[proc_macro]
2054pub fn dispatch_variant(input: TokenStream) -> TokenStream {
2055 let input = parse_macro_input!(input as IncantInput);
2056 incant_impl(input)
2057}
2058
2059fn incant_impl(input: IncantInput) -> TokenStream {
2060 let func_path = &input.func_path;
2061 let args = &input.args;
2062
2063 let tier_names: Vec<String> = match &input.tiers {
2065 Some((names, _)) => names.clone(),
2066 None => DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect(),
2067 };
2068 let last_segment_span = func_path
2069 .segments
2070 .last()
2071 .map(|s| s.ident.span())
2072 .unwrap_or_else(proc_macro2::Span::call_site);
2073 let error_span = input
2074 .tiers
2075 .as_ref()
2076 .map(|(_, span)| *span)
2077 .unwrap_or(last_segment_span);
2078
2079 let tiers = match resolve_tiers(&tier_names, error_span) {
2080 Ok(t) => t,
2081 Err(e) => return e.to_compile_error().into(),
2082 };
2083
2084 if let Some(token_expr) = &input.with_token {
2087 gen_incant_passthrough(func_path, args, token_expr, &tiers)
2088 } else {
2089 gen_incant_entry(func_path, args, &tiers)
2090 }
2091}
2092
2093fn gen_incant_passthrough(
2095 func_path: &syn::Path,
2096 args: &[syn::Expr],
2097 token_expr: &syn::Expr,
2098 tiers: &[&TierDescriptor],
2099) -> TokenStream {
2100 let mut dispatch_arms = Vec::new();
2101
2102 let mut arch_groups: Vec<(Option<&str>, Option<&str>, Vec<&TierDescriptor>)> = Vec::new();
2104 for tier in tiers {
2105 if tier.name == "scalar" {
2106 continue; }
2108 let key = (tier.target_arch, tier.cargo_feature);
2109 if let Some(group) = arch_groups.iter_mut().find(|(a, f, _)| (*a, *f) == key) {
2110 group.2.push(tier);
2111 } else {
2112 arch_groups.push((tier.target_arch, tier.cargo_feature, vec![tier]));
2113 }
2114 }
2115
2116 for (target_arch, cargo_feature, group_tiers) in &arch_groups {
2117 let mut tier_checks = Vec::new();
2118 for tier in group_tiers {
2119 let fn_suffixed = suffix_path(func_path, tier.suffix);
2120 let as_method = format_ident!("{}", tier.as_method);
2121 tier_checks.push(quote! {
2122 if let Some(__t) = __incant_token.#as_method() {
2123 break '__incant #fn_suffixed(__t, #(#args),*);
2124 }
2125 });
2126 }
2127
2128 let inner = quote! { #(#tier_checks)* };
2129
2130 let guarded = match (target_arch, cargo_feature) {
2131 (Some(arch), Some(feat)) => quote! {
2132 #[cfg(target_arch = #arch)]
2133 {
2134 #[cfg(feature = #feat)]
2135 { #inner }
2136 }
2137 },
2138 (Some(arch), None) => quote! {
2139 #[cfg(target_arch = #arch)]
2140 { #inner }
2141 },
2142 (None, Some(feat)) => quote! {
2143 #[cfg(feature = #feat)]
2144 { #inner }
2145 },
2146 (None, None) => inner,
2147 };
2148
2149 dispatch_arms.push(guarded);
2150 }
2151
2152 let fn_scalar = suffix_path(func_path, "scalar");
2154 let scalar_arm = if tiers.iter().any(|t| t.name == "scalar") {
2155 quote! {
2156 if let Some(__t) = __incant_token.as_scalar() {
2157 break '__incant #fn_scalar(__t, #(#args),*);
2158 }
2159 unreachable!("Token did not match any known variant")
2160 }
2161 } else {
2162 quote! { unreachable!("Token did not match any known variant") }
2163 };
2164
2165 let expanded = quote! {
2166 '__incant: {
2167 use archmage::IntoConcreteToken;
2168 let __incant_token = #token_expr;
2169 #(#dispatch_arms)*
2170 #scalar_arm
2171 }
2172 };
2173 expanded.into()
2174}
2175
2176fn gen_incant_entry(
2178 func_path: &syn::Path,
2179 args: &[syn::Expr],
2180 tiers: &[&TierDescriptor],
2181) -> TokenStream {
2182 let mut dispatch_arms = Vec::new();
2183
2184 let mut arch_groups: Vec<(Option<&str>, Vec<&TierDescriptor>)> = Vec::new();
2187 for tier in tiers {
2188 if tier.name == "scalar" {
2189 continue;
2190 }
2191 if let Some(group) = arch_groups.iter_mut().find(|(a, _)| *a == tier.target_arch) {
2192 group.1.push(tier);
2193 } else {
2194 arch_groups.push((tier.target_arch, vec![tier]));
2195 }
2196 }
2197
2198 for (target_arch, group_tiers) in &arch_groups {
2199 let mut tier_checks = Vec::new();
2200 for tier in group_tiers {
2201 let fn_suffixed = suffix_path(func_path, tier.suffix);
2202 let token_path: syn::Path = syn::parse_str(tier.token_path).unwrap();
2203
2204 let check = quote! {
2205 if let Some(__t) = #token_path::summon() {
2206 break '__incant #fn_suffixed(__t, #(#args),*);
2207 }
2208 };
2209
2210 if let Some(feat) = tier.cargo_feature {
2211 tier_checks.push(quote! {
2212 #[cfg(feature = #feat)]
2213 { #check }
2214 });
2215 } else {
2216 tier_checks.push(check);
2217 }
2218 }
2219
2220 let inner = quote! { #(#tier_checks)* };
2221
2222 if let Some(arch) = target_arch {
2223 dispatch_arms.push(quote! {
2224 #[cfg(target_arch = #arch)]
2225 { #inner }
2226 });
2227 } else {
2228 dispatch_arms.push(inner);
2229 }
2230 }
2231
2232 let fn_scalar = suffix_path(func_path, "scalar");
2234
2235 let expanded = quote! {
2236 '__incant: {
2237 use archmage::SimdToken;
2238 #(#dispatch_arms)*
2239 #fn_scalar(archmage::ScalarToken, #(#args),*)
2240 }
2241 };
2242 expanded.into()
2243}
2244
2245struct AutoversionArgs {
2251 self_type: Option<Type>,
2253 tiers: Option<Vec<String>>,
2255}
2256
2257impl Parse for AutoversionArgs {
2258 fn parse(input: ParseStream) -> syn::Result<Self> {
2259 let mut self_type = None;
2260 let mut tier_names = Vec::new();
2261
2262 while !input.is_empty() {
2263 let ident: Ident = input.parse()?;
2264 if ident == "_self" {
2265 let _: Token![=] = input.parse()?;
2266 self_type = Some(input.parse()?);
2267 } else {
2268 tier_names.push(ident.to_string());
2270 }
2271 if input.peek(Token![,]) {
2272 let _: Token![,] = input.parse()?;
2273 }
2274 }
2275
2276 Ok(AutoversionArgs {
2277 self_type,
2278 tiers: if tier_names.is_empty() {
2279 None
2280 } else {
2281 Some(tier_names)
2282 },
2283 })
2284 }
2285}
2286
2287struct SimdTokenParamInfo {
2289 index: usize,
2291 #[allow(dead_code)]
2293 ident: Ident,
2294}
2295
2296fn find_simd_token_param(sig: &Signature) -> Option<SimdTokenParamInfo> {
2301 for (i, arg) in sig.inputs.iter().enumerate() {
2302 if let FnArg::Typed(PatType { pat, ty, .. }) = arg
2303 && let Type::Path(type_path) = ty.as_ref()
2304 && let Some(seg) = type_path.path.segments.last()
2305 && seg.ident == "SimdToken"
2306 {
2307 let ident = match pat.as_ref() {
2308 syn::Pat::Ident(pi) => pi.ident.clone(),
2309 syn::Pat::Wild(w) => Ident::new("__autoversion_token", w.underscore_token.span),
2310 _ => continue,
2311 };
2312 return Some(SimdTokenParamInfo { index: i, ident });
2313 }
2314 }
2315 None
2316}
2317
2318fn autoversion_impl(mut input_fn: LightFn, args: AutoversionArgs) -> TokenStream {
2323 let has_self = input_fn
2325 .sig
2326 .inputs
2327 .first()
2328 .is_some_and(|arg| matches!(arg, FnArg::Receiver(_)));
2329
2330 let token_param = match find_simd_token_param(&input_fn.sig) {
2335 Some(p) => p,
2336 None => {
2337 return syn::Error::new_spanned(
2338 &input_fn.sig,
2339 "autoversion requires a `SimdToken` parameter.\n\
2340 Example: fn process(token: SimdToken, data: &[f32]) -> f32 { ... }\n\n\
2341 SimdToken is the dispatch placeholder — autoversion replaces it \
2342 with concrete token types and generates a runtime dispatcher.",
2343 )
2344 .to_compile_error()
2345 .into();
2346 }
2347 };
2348
2349 let tier_names: Vec<String> = match &args.tiers {
2351 Some(names) => names.clone(),
2352 None => DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect(),
2353 };
2354 let tiers = match resolve_tiers(&tier_names, input_fn.sig.ident.span()) {
2355 Ok(t) => t,
2356 Err(e) => return e.to_compile_error().into(),
2357 };
2358
2359 input_fn
2361 .attrs
2362 .retain(|attr| !attr.path().is_ident("arcane") && !attr.path().is_ident("rite"));
2363
2364 let fn_name = &input_fn.sig.ident;
2365 let vis = input_fn.vis.clone();
2366
2367 let fn_attrs: Vec<Attribute> = input_fn.attrs.drain(..).collect();
2369
2370 let mut variants = Vec::new();
2380
2381 for tier in &tiers {
2382 let mut variant_fn = input_fn.clone();
2383
2384 variant_fn.vis = syn::Visibility::Inherited;
2386
2387 variant_fn.sig.ident = format_ident!("{}_{}", fn_name, tier.suffix);
2389
2390 let concrete_type: Type = syn::parse_str(tier.token_path).unwrap();
2392 if let FnArg::Typed(pt) = &mut variant_fn.sig.inputs[token_param.index] {
2393 *pt.ty = concrete_type;
2394 }
2395
2396 if tier.name == "scalar" && has_self && args.self_type.is_some() {
2399 let original_body = variant_fn.body.clone();
2400 variant_fn.body = quote!(let _self = self; #original_body);
2401 }
2402
2403 let cfg_guard = match (tier.target_arch, tier.cargo_feature) {
2405 (Some(arch), Some(feature)) => {
2406 quote! { #[cfg(all(target_arch = #arch, feature = #feature))] }
2407 }
2408 (Some(arch), None) => quote! { #[cfg(target_arch = #arch)] },
2409 (None, Some(feature)) => quote! { #[cfg(feature = #feature)] },
2410 (None, None) => quote! {},
2411 };
2412
2413 if tier.name != "scalar" {
2414 let arcane_attr = if let Some(ref self_type) = args.self_type {
2416 quote! { #[archmage::arcane(_self = #self_type)] }
2417 } else {
2418 quote! { #[archmage::arcane] }
2419 };
2420 variants.push(quote! {
2421 #cfg_guard
2422 #arcane_attr
2423 #variant_fn
2424 });
2425 } else {
2426 variants.push(quote! {
2427 #cfg_guard
2428 #variant_fn
2429 });
2430 }
2431 }
2432
2433 let mut dispatcher_inputs: Vec<FnArg> = input_fn.sig.inputs.iter().cloned().collect();
2439 dispatcher_inputs.remove(token_param.index);
2440
2441 let mut wild_counter = 0u32;
2443 for arg in &mut dispatcher_inputs {
2444 if let FnArg::Typed(pat_type) = arg
2445 && matches!(pat_type.pat.as_ref(), syn::Pat::Wild(_))
2446 {
2447 let ident = format_ident!("__autoversion_wild_{}", wild_counter);
2448 wild_counter += 1;
2449 *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
2450 attrs: vec![],
2451 by_ref: None,
2452 mutability: None,
2453 ident,
2454 subpat: None,
2455 });
2456 }
2457 }
2458
2459 let dispatch_args: Vec<Ident> = dispatcher_inputs
2461 .iter()
2462 .filter_map(|arg| {
2463 if let FnArg::Typed(PatType { pat, .. }) = arg
2464 && let syn::Pat::Ident(pi) = pat.as_ref()
2465 {
2466 return Some(pi.ident.clone());
2467 }
2468 None
2469 })
2470 .collect();
2471
2472 let mut arch_groups: Vec<(Option<&str>, Vec<&&TierDescriptor>)> = Vec::new();
2474 for tier in &tiers {
2475 if tier.name == "scalar" {
2476 continue;
2477 }
2478 if let Some(group) = arch_groups.iter_mut().find(|(a, _)| *a == tier.target_arch) {
2479 group.1.push(tier);
2480 } else {
2481 arch_groups.push((tier.target_arch, vec![tier]));
2482 }
2483 }
2484
2485 let mut dispatch_arms = Vec::new();
2486 for (target_arch, group_tiers) in &arch_groups {
2487 let mut tier_checks = Vec::new();
2488 for tier in group_tiers {
2489 let suffixed = format_ident!("{}_{}", fn_name, tier.suffix);
2490 let token_path: syn::Path = syn::parse_str(tier.token_path).unwrap();
2491
2492 let call = if has_self {
2493 quote! { self.#suffixed(__t, #(#dispatch_args),*) }
2494 } else {
2495 quote! { #suffixed(__t, #(#dispatch_args),*) }
2496 };
2497
2498 let check = quote! {
2499 if let Some(__t) = #token_path::summon() {
2500 break '__dispatch #call;
2501 }
2502 };
2503
2504 if let Some(feat) = tier.cargo_feature {
2505 tier_checks.push(quote! {
2506 #[cfg(feature = #feat)]
2507 { #check }
2508 });
2509 } else {
2510 tier_checks.push(check);
2511 }
2512 }
2513
2514 let inner = quote! { #(#tier_checks)* };
2515
2516 if let Some(arch) = target_arch {
2517 dispatch_arms.push(quote! {
2518 #[cfg(target_arch = #arch)]
2519 { #inner }
2520 });
2521 } else {
2522 dispatch_arms.push(inner);
2523 }
2524 }
2525
2526 let scalar_name = format_ident!("{}_scalar", fn_name);
2528 let scalar_call = if has_self {
2529 quote! { self.#scalar_name(archmage::ScalarToken, #(#dispatch_args),*) }
2530 } else {
2531 quote! { #scalar_name(archmage::ScalarToken, #(#dispatch_args),*) }
2532 };
2533
2534 let dispatcher_inputs_punct: syn::punctuated::Punctuated<FnArg, Token![,]> =
2536 dispatcher_inputs.into_iter().collect();
2537 let output = &input_fn.sig.output;
2538 let generics = &input_fn.sig.generics;
2539 let where_clause = &generics.where_clause;
2540
2541 let dispatcher = quote! {
2542 #(#fn_attrs)*
2543 #vis fn #fn_name #generics (#dispatcher_inputs_punct) #output #where_clause {
2544 '__dispatch: {
2545 use archmage::SimdToken;
2546 #(#dispatch_arms)*
2547 #scalar_call
2548 }
2549 }
2550 };
2551
2552 let expanded = quote! {
2553 #dispatcher
2554 #(#variants)*
2555 };
2556
2557 expanded.into()
2558}
2559
2560#[proc_macro_attribute]
2800pub fn autoversion(attr: TokenStream, item: TokenStream) -> TokenStream {
2801 let args = parse_macro_input!(attr as AutoversionArgs);
2802 let input_fn = parse_macro_input!(item as LightFn);
2803 autoversion_impl(input_fn, args)
2804}
2805
2806#[cfg(test)]
2811mod tests {
2812 use super::*;
2813
2814 use super::generated::{ALL_CONCRETE_TOKENS, ALL_TRAIT_NAMES};
2815 use syn::{ItemFn, ReturnType};
2816
2817 #[test]
2818 fn every_concrete_token_is_in_token_to_features() {
2819 for &name in ALL_CONCRETE_TOKENS {
2820 assert!(
2821 token_to_features(name).is_some(),
2822 "Token `{}` exists in runtime crate but is NOT recognized by \
2823 token_to_features() in the proc macro. Add it!",
2824 name
2825 );
2826 }
2827 }
2828
2829 #[test]
2830 fn every_trait_is_in_trait_to_features() {
2831 for &name in ALL_TRAIT_NAMES {
2832 assert!(
2833 trait_to_features(name).is_some(),
2834 "Trait `{}` exists in runtime crate but is NOT recognized by \
2835 trait_to_features() in the proc macro. Add it!",
2836 name
2837 );
2838 }
2839 }
2840
2841 #[test]
2842 fn token_aliases_map_to_same_features() {
2843 assert_eq!(
2845 token_to_features("Desktop64"),
2846 token_to_features("X64V3Token"),
2847 "Desktop64 and X64V3Token should map to identical features"
2848 );
2849
2850 assert_eq!(
2852 token_to_features("Server64"),
2853 token_to_features("X64V4Token"),
2854 "Server64 and X64V4Token should map to identical features"
2855 );
2856 assert_eq!(
2857 token_to_features("X64V4Token"),
2858 token_to_features("Avx512Token"),
2859 "X64V4Token and Avx512Token should map to identical features"
2860 );
2861
2862 assert_eq!(
2864 token_to_features("Arm64"),
2865 token_to_features("NeonToken"),
2866 "Arm64 and NeonToken should map to identical features"
2867 );
2868 }
2869
2870 #[test]
2871 fn trait_to_features_includes_tokens_as_bounds() {
2872 let tier_tokens = [
2876 "X64V2Token",
2877 "X64CryptoToken",
2878 "X64V3Token",
2879 "Desktop64",
2880 "Avx2FmaToken",
2881 "X64V4Token",
2882 "Avx512Token",
2883 "Server64",
2884 "X64V4xToken",
2885 "Avx512Fp16Token",
2886 "NeonToken",
2887 "Arm64",
2888 "NeonAesToken",
2889 "NeonSha3Token",
2890 "NeonCrcToken",
2891 "Arm64V2Token",
2892 "Arm64V3Token",
2893 ];
2894
2895 for &name in &tier_tokens {
2896 assert!(
2897 trait_to_features(name).is_some(),
2898 "Tier token `{}` should also be recognized in trait_to_features() \
2899 for use as a generic bound. Add it!",
2900 name
2901 );
2902 }
2903 }
2904
2905 #[test]
2906 fn trait_features_are_cumulative() {
2907 let v2_features = trait_to_features("HasX64V2").unwrap();
2909 let v4_features = trait_to_features("HasX64V4").unwrap();
2910
2911 for &f in v2_features {
2912 assert!(
2913 v4_features.contains(&f),
2914 "HasX64V4 should include v2 feature `{}` but doesn't",
2915 f
2916 );
2917 }
2918
2919 assert!(
2921 v4_features.len() > v2_features.len(),
2922 "HasX64V4 should have more features than HasX64V2"
2923 );
2924 }
2925
2926 #[test]
2927 fn x64v3_trait_features_include_v2() {
2928 let v2 = trait_to_features("HasX64V2").unwrap();
2930 let v3 = trait_to_features("X64V3Token").unwrap();
2931
2932 for &f in v2 {
2933 assert!(
2934 v3.contains(&f),
2935 "X64V3Token trait features should include v2 feature `{}` but don't",
2936 f
2937 );
2938 }
2939 }
2940
2941 #[test]
2942 fn has_neon_aes_includes_neon() {
2943 let neon = trait_to_features("HasNeon").unwrap();
2944 let neon_aes = trait_to_features("HasNeonAes").unwrap();
2945
2946 for &f in neon {
2947 assert!(
2948 neon_aes.contains(&f),
2949 "HasNeonAes should include NEON feature `{}`",
2950 f
2951 );
2952 }
2953 }
2954
2955 #[test]
2956 fn no_removed_traits_are_recognized() {
2957 let removed = [
2959 "HasSse",
2960 "HasSse2",
2961 "HasSse41",
2962 "HasSse42",
2963 "HasAvx",
2964 "HasAvx2",
2965 "HasFma",
2966 "HasAvx512f",
2967 "HasAvx512bw",
2968 "HasAvx512vl",
2969 "HasAvx512vbmi2",
2970 "HasSve",
2971 "HasSve2",
2972 ];
2973
2974 for &name in &removed {
2975 assert!(
2976 trait_to_features(name).is_none(),
2977 "Removed trait `{}` should NOT be in trait_to_features(). \
2978 It was removed in 0.3.0 — users should migrate to tier traits.",
2979 name
2980 );
2981 }
2982 }
2983
2984 #[test]
2985 fn no_nonexistent_tokens_are_recognized() {
2986 let fake = [
2988 "SveToken",
2989 "Sve2Token",
2990 "Avx512VnniToken",
2991 "X64V4ModernToken",
2992 "NeonFp16Token",
2993 ];
2994
2995 for &name in &fake {
2996 assert!(
2997 token_to_features(name).is_none(),
2998 "Non-existent token `{}` should NOT be in token_to_features()",
2999 name
3000 );
3001 }
3002 }
3003
3004 #[test]
3005 fn featureless_traits_are_not_in_registries() {
3006 for &name in FEATURELESS_TRAIT_NAMES {
3009 assert!(
3010 token_to_features(name).is_none(),
3011 "`{}` should NOT be in token_to_features() — it has no CPU features",
3012 name
3013 );
3014 assert!(
3015 trait_to_features(name).is_none(),
3016 "`{}` should NOT be in trait_to_features() — it has no CPU features",
3017 name
3018 );
3019 }
3020 }
3021
3022 #[test]
3023 fn find_featureless_trait_detects_simdtoken() {
3024 let names = vec!["SimdToken".to_string()];
3025 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
3026
3027 let names = vec!["IntoConcreteToken".to_string()];
3028 assert_eq!(find_featureless_trait(&names), Some("IntoConcreteToken"));
3029
3030 let names = vec!["HasX64V2".to_string()];
3032 assert_eq!(find_featureless_trait(&names), None);
3033
3034 let names = vec!["HasNeon".to_string()];
3035 assert_eq!(find_featureless_trait(&names), None);
3036
3037 let names = vec!["SimdToken".to_string(), "HasX64V2".to_string()];
3039 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
3040 }
3041
3042 #[test]
3043 fn arm64_v2_v3_traits_are_cumulative() {
3044 let v2_features = trait_to_features("HasArm64V2").unwrap();
3045 let v3_features = trait_to_features("HasArm64V3").unwrap();
3046
3047 for &f in v2_features {
3048 assert!(
3049 v3_features.contains(&f),
3050 "HasArm64V3 should include v2 feature `{}` but doesn't",
3051 f
3052 );
3053 }
3054
3055 assert!(
3056 v3_features.len() > v2_features.len(),
3057 "HasArm64V3 should have more features than HasArm64V2"
3058 );
3059 }
3060
3061 #[test]
3066 fn autoversion_args_empty() {
3067 let args: AutoversionArgs = syn::parse_str("").unwrap();
3068 assert!(args.self_type.is_none());
3069 assert!(args.tiers.is_none());
3070 }
3071
3072 #[test]
3073 fn autoversion_args_single_tier() {
3074 let args: AutoversionArgs = syn::parse_str("v3").unwrap();
3075 assert!(args.self_type.is_none());
3076 assert_eq!(args.tiers.as_ref().unwrap(), &["v3"]);
3077 }
3078
3079 #[test]
3080 fn autoversion_args_tiers_only() {
3081 let args: AutoversionArgs = syn::parse_str("v3, v4, neon").unwrap();
3082 assert!(args.self_type.is_none());
3083 let tiers = args.tiers.unwrap();
3084 assert_eq!(tiers, vec!["v3", "v4", "neon"]);
3085 }
3086
3087 #[test]
3088 fn autoversion_args_many_tiers() {
3089 let args: AutoversionArgs =
3090 syn::parse_str("v1, v2, v3, v4, v4x, neon, arm_v2, wasm128").unwrap();
3091 assert_eq!(
3092 args.tiers.unwrap(),
3093 vec!["v1", "v2", "v3", "v4", "v4x", "neon", "arm_v2", "wasm128"]
3094 );
3095 }
3096
3097 #[test]
3098 fn autoversion_args_trailing_comma() {
3099 let args: AutoversionArgs = syn::parse_str("v3, v4,").unwrap();
3100 assert_eq!(args.tiers.as_ref().unwrap(), &["v3", "v4"]);
3101 }
3102
3103 #[test]
3104 fn autoversion_args_self_only() {
3105 let args: AutoversionArgs = syn::parse_str("_self = MyType").unwrap();
3106 assert!(args.self_type.is_some());
3107 assert!(args.tiers.is_none());
3108 }
3109
3110 #[test]
3111 fn autoversion_args_self_and_tiers() {
3112 let args: AutoversionArgs = syn::parse_str("_self = MyType, v3, neon").unwrap();
3113 assert!(args.self_type.is_some());
3114 let tiers = args.tiers.unwrap();
3115 assert_eq!(tiers, vec!["v3", "neon"]);
3116 }
3117
3118 #[test]
3119 fn autoversion_args_tiers_then_self() {
3120 let args: AutoversionArgs = syn::parse_str("v3, neon, _self = MyType").unwrap();
3122 assert!(args.self_type.is_some());
3123 let tiers = args.tiers.unwrap();
3124 assert_eq!(tiers, vec!["v3", "neon"]);
3125 }
3126
3127 #[test]
3128 fn autoversion_args_self_with_path_type() {
3129 let args: AutoversionArgs = syn::parse_str("_self = crate::MyType").unwrap();
3130 assert!(args.self_type.is_some());
3131 assert!(args.tiers.is_none());
3132 }
3133
3134 #[test]
3135 fn autoversion_args_self_with_generic_type() {
3136 let args: AutoversionArgs = syn::parse_str("_self = Vec<u8>").unwrap();
3137 assert!(args.self_type.is_some());
3138 let ty_str = args.self_type.unwrap().to_token_stream().to_string();
3139 assert!(ty_str.contains("Vec"), "Expected Vec<u8>, got: {}", ty_str);
3140 }
3141
3142 #[test]
3143 fn autoversion_args_self_trailing_comma() {
3144 let args: AutoversionArgs = syn::parse_str("_self = MyType,").unwrap();
3145 assert!(args.self_type.is_some());
3146 assert!(args.tiers.is_none());
3147 }
3148
3149 #[test]
3154 fn find_simd_token_param_first_position() {
3155 let f: ItemFn =
3156 syn::parse_str("fn process(token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
3157 let param = find_simd_token_param(&f.sig).unwrap();
3158 assert_eq!(param.index, 0);
3159 assert_eq!(param.ident, "token");
3160 }
3161
3162 #[test]
3163 fn find_simd_token_param_second_position() {
3164 let f: ItemFn =
3165 syn::parse_str("fn process(data: &[f32], token: SimdToken) -> f32 {}").unwrap();
3166 let param = find_simd_token_param(&f.sig).unwrap();
3167 assert_eq!(param.index, 1);
3168 assert_eq!(param.ident, "token");
3169 }
3170
3171 #[test]
3172 fn find_simd_token_param_underscore_prefix() {
3173 let f: ItemFn =
3174 syn::parse_str("fn process(_token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
3175 let param = find_simd_token_param(&f.sig).unwrap();
3176 assert_eq!(param.index, 0);
3177 assert_eq!(param.ident, "_token");
3178 }
3179
3180 #[test]
3181 fn find_simd_token_param_wildcard() {
3182 let f: ItemFn = syn::parse_str("fn process(_: SimdToken, data: &[f32]) -> f32 {}").unwrap();
3183 let param = find_simd_token_param(&f.sig).unwrap();
3184 assert_eq!(param.index, 0);
3185 assert_eq!(param.ident, "__autoversion_token");
3186 }
3187
3188 #[test]
3189 fn find_simd_token_param_not_found() {
3190 let f: ItemFn = syn::parse_str("fn process(data: &[f32]) -> f32 {}").unwrap();
3191 assert!(find_simd_token_param(&f.sig).is_none());
3192 }
3193
3194 #[test]
3195 fn find_simd_token_param_no_params() {
3196 let f: ItemFn = syn::parse_str("fn process() {}").unwrap();
3197 assert!(find_simd_token_param(&f.sig).is_none());
3198 }
3199
3200 #[test]
3201 fn find_simd_token_param_concrete_token_not_matched() {
3202 let f: ItemFn =
3204 syn::parse_str("fn process(token: X64V3Token, data: &[f32]) -> f32 {}").unwrap();
3205 assert!(find_simd_token_param(&f.sig).is_none());
3206 }
3207
3208 #[test]
3209 fn find_simd_token_param_scalar_token_not_matched() {
3210 let f: ItemFn =
3211 syn::parse_str("fn process(token: ScalarToken, data: &[f32]) -> f32 {}").unwrap();
3212 assert!(find_simd_token_param(&f.sig).is_none());
3213 }
3214
3215 #[test]
3216 fn find_simd_token_param_among_many() {
3217 let f: ItemFn = syn::parse_str(
3218 "fn process(a: i32, b: f64, token: SimdToken, c: &str, d: bool) -> f32 {}",
3219 )
3220 .unwrap();
3221 let param = find_simd_token_param(&f.sig).unwrap();
3222 assert_eq!(param.index, 2);
3223 assert_eq!(param.ident, "token");
3224 }
3225
3226 #[test]
3227 fn find_simd_token_param_with_generics() {
3228 let f: ItemFn =
3229 syn::parse_str("fn process<T: Clone>(token: SimdToken, data: &[T]) -> T {}").unwrap();
3230 let param = find_simd_token_param(&f.sig).unwrap();
3231 assert_eq!(param.index, 0);
3232 assert_eq!(param.ident, "token");
3233 }
3234
3235 #[test]
3236 fn find_simd_token_param_with_where_clause() {
3237 let f: ItemFn = syn::parse_str(
3238 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default {}",
3239 )
3240 .unwrap();
3241 let param = find_simd_token_param(&f.sig).unwrap();
3242 assert_eq!(param.index, 0);
3243 }
3244
3245 #[test]
3246 fn find_simd_token_param_with_lifetime() {
3247 let f: ItemFn =
3248 syn::parse_str("fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a f32 {}")
3249 .unwrap();
3250 let param = find_simd_token_param(&f.sig).unwrap();
3251 assert_eq!(param.index, 0);
3252 }
3253
3254 #[test]
3259 fn autoversion_default_tiers_all_resolve() {
3260 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
3261 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3262 assert!(!tiers.is_empty());
3263 assert!(tiers.iter().any(|t| t.name == "scalar"));
3265 }
3266
3267 #[test]
3268 fn autoversion_scalar_always_appended() {
3269 let names = vec!["v3".to_string(), "neon".to_string()];
3270 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3271 assert!(
3272 tiers.iter().any(|t| t.name == "scalar"),
3273 "scalar must be auto-appended"
3274 );
3275 }
3276
3277 #[test]
3278 fn autoversion_scalar_not_duplicated() {
3279 let names = vec!["v3".to_string(), "scalar".to_string()];
3280 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3281 let scalar_count = tiers.iter().filter(|t| t.name == "scalar").count();
3282 assert_eq!(scalar_count, 1, "scalar must not be duplicated");
3283 }
3284
3285 #[test]
3286 fn autoversion_tiers_sorted_by_priority() {
3287 let names = vec!["neon".to_string(), "v4".to_string(), "v3".to_string()];
3288 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3289 let priorities: Vec<u32> = tiers.iter().map(|t| t.priority).collect();
3291 for window in priorities.windows(2) {
3292 assert!(
3293 window[0] >= window[1],
3294 "Tiers not sorted by priority: {:?}",
3295 priorities
3296 );
3297 }
3298 }
3299
3300 #[test]
3301 fn autoversion_unknown_tier_errors() {
3302 let names = vec!["v3".to_string(), "avx9000".to_string()];
3303 let result = resolve_tiers(&names, proc_macro2::Span::call_site());
3304 match result {
3305 Ok(_) => panic!("Expected error for unknown tier 'avx9000'"),
3306 Err(e) => {
3307 let err_msg = e.to_string();
3308 assert!(
3309 err_msg.contains("avx9000"),
3310 "Error should mention unknown tier: {}",
3311 err_msg
3312 );
3313 }
3314 }
3315 }
3316
3317 #[test]
3318 fn autoversion_all_known_tiers_resolve() {
3319 for tier in ALL_TIERS {
3321 assert!(
3322 find_tier(tier.name).is_some(),
3323 "Tier '{}' should be findable by name",
3324 tier.name
3325 );
3326 }
3327 }
3328
3329 #[test]
3330 fn autoversion_default_tier_list_is_sensible() {
3331 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
3333 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3334
3335 let has_x86 = tiers.iter().any(|t| t.target_arch == Some("x86_64"));
3336 let has_arm = tiers.iter().any(|t| t.target_arch == Some("aarch64"));
3337 let has_wasm = tiers.iter().any(|t| t.target_arch == Some("wasm32"));
3338 let has_scalar = tiers.iter().any(|t| t.name == "scalar");
3339
3340 assert!(has_x86, "Default tiers should include an x86_64 tier");
3341 assert!(has_arm, "Default tiers should include an aarch64 tier");
3342 assert!(has_wasm, "Default tiers should include a wasm32 tier");
3343 assert!(has_scalar, "Default tiers should include scalar");
3344 }
3345
3346 fn do_variant_replacement(func: &str, tier_name: &str, has_self: bool) -> ItemFn {
3354 let mut f: ItemFn = syn::parse_str(func).unwrap();
3355 let fn_name = f.sig.ident.to_string();
3356
3357 let tier = find_tier(tier_name).unwrap();
3358
3359 f.sig.ident = format_ident!("{}_{}", fn_name, tier.suffix);
3361
3362 let token_idx = find_simd_token_param(&f.sig)
3364 .unwrap_or_else(|| panic!("No SimdToken param in: {}", func))
3365 .index;
3366 let concrete_type: Type = syn::parse_str(tier.token_path).unwrap();
3367 if let FnArg::Typed(pt) = &mut f.sig.inputs[token_idx] {
3368 *pt.ty = concrete_type;
3369 }
3370
3371 if tier_name == "scalar" && has_self {
3373 let preamble: syn::Stmt = syn::parse_quote!(let _self = self;);
3374 f.block.stmts.insert(0, preamble);
3375 }
3376
3377 f
3378 }
3379
3380 #[test]
3381 fn variant_replacement_v3_renames_function() {
3382 let f = do_variant_replacement(
3383 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3384 "v3",
3385 false,
3386 );
3387 assert_eq!(f.sig.ident, "process_v3");
3388 }
3389
3390 #[test]
3391 fn variant_replacement_v3_replaces_token_type() {
3392 let f = do_variant_replacement(
3393 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3394 "v3",
3395 false,
3396 );
3397 let first_param_ty = match &f.sig.inputs[0] {
3398 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
3399 _ => panic!("Expected typed param"),
3400 };
3401 assert!(
3402 first_param_ty.contains("X64V3Token"),
3403 "Expected X64V3Token, got: {}",
3404 first_param_ty
3405 );
3406 }
3407
3408 #[test]
3409 fn variant_replacement_neon_produces_valid_fn() {
3410 let f = do_variant_replacement(
3411 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3412 "neon",
3413 false,
3414 );
3415 assert_eq!(f.sig.ident, "compute_neon");
3416 let first_param_ty = match &f.sig.inputs[0] {
3417 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
3418 _ => panic!("Expected typed param"),
3419 };
3420 assert!(
3421 first_param_ty.contains("NeonToken"),
3422 "Expected NeonToken, got: {}",
3423 first_param_ty
3424 );
3425 }
3426
3427 #[test]
3428 fn variant_replacement_wasm128_produces_valid_fn() {
3429 let f = do_variant_replacement(
3430 "fn compute(_t: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3431 "wasm128",
3432 false,
3433 );
3434 assert_eq!(f.sig.ident, "compute_wasm128");
3435 }
3436
3437 #[test]
3438 fn variant_replacement_scalar_produces_valid_fn() {
3439 let f = do_variant_replacement(
3440 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3441 "scalar",
3442 false,
3443 );
3444 assert_eq!(f.sig.ident, "compute_scalar");
3445 let first_param_ty = match &f.sig.inputs[0] {
3446 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
3447 _ => panic!("Expected typed param"),
3448 };
3449 assert!(
3450 first_param_ty.contains("ScalarToken"),
3451 "Expected ScalarToken, got: {}",
3452 first_param_ty
3453 );
3454 }
3455
3456 #[test]
3457 fn variant_replacement_v4_produces_valid_fn() {
3458 let f = do_variant_replacement(
3459 "fn transform(token: SimdToken, data: &mut [f32]) { }",
3460 "v4",
3461 false,
3462 );
3463 assert_eq!(f.sig.ident, "transform_v4");
3464 let first_param_ty = match &f.sig.inputs[0] {
3465 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
3466 _ => panic!("Expected typed param"),
3467 };
3468 assert!(
3469 first_param_ty.contains("X64V4Token"),
3470 "Expected X64V4Token, got: {}",
3471 first_param_ty
3472 );
3473 }
3474
3475 #[test]
3476 fn variant_replacement_v4x_produces_valid_fn() {
3477 let f = do_variant_replacement(
3478 "fn transform(token: SimdToken, data: &mut [f32]) { }",
3479 "v4x",
3480 false,
3481 );
3482 assert_eq!(f.sig.ident, "transform_v4x");
3483 }
3484
3485 #[test]
3486 fn variant_replacement_arm_v2_produces_valid_fn() {
3487 let f = do_variant_replacement(
3488 "fn transform(token: SimdToken, data: &mut [f32]) { }",
3489 "arm_v2",
3490 false,
3491 );
3492 assert_eq!(f.sig.ident, "transform_arm_v2");
3493 }
3494
3495 #[test]
3496 fn variant_replacement_preserves_generics() {
3497 let f = do_variant_replacement(
3498 "fn process<T: Copy + Default>(token: SimdToken, data: &[T]) -> T { T::default() }",
3499 "v3",
3500 false,
3501 );
3502 assert_eq!(f.sig.ident, "process_v3");
3503 assert!(
3505 !f.sig.generics.params.is_empty(),
3506 "Generics should be preserved"
3507 );
3508 }
3509
3510 #[test]
3511 fn variant_replacement_preserves_where_clause() {
3512 let f = do_variant_replacement(
3513 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default { T::default() }",
3514 "v3",
3515 false,
3516 );
3517 assert!(
3518 f.sig.generics.where_clause.is_some(),
3519 "Where clause should be preserved"
3520 );
3521 }
3522
3523 #[test]
3524 fn variant_replacement_preserves_return_type() {
3525 let f = do_variant_replacement(
3526 "fn process(token: SimdToken, data: &[f32]) -> Vec<f32> { vec![] }",
3527 "neon",
3528 false,
3529 );
3530 let ret = f.sig.output.to_token_stream().to_string();
3531 assert!(
3532 ret.contains("Vec"),
3533 "Return type should be preserved, got: {}",
3534 ret
3535 );
3536 }
3537
3538 #[test]
3539 fn variant_replacement_preserves_multiple_params() {
3540 let f = do_variant_replacement(
3541 "fn process(token: SimdToken, a: &[f32], b: &[f32], scale: f32) -> f32 { 0.0 }",
3542 "v3",
3543 false,
3544 );
3545 assert_eq!(f.sig.inputs.len(), 4);
3547 }
3548
3549 #[test]
3550 fn variant_replacement_preserves_no_return_type() {
3551 let f = do_variant_replacement(
3552 "fn transform(token: SimdToken, data: &mut [f32]) { }",
3553 "v3",
3554 false,
3555 );
3556 assert!(
3557 matches!(f.sig.output, ReturnType::Default),
3558 "No return type should remain as Default"
3559 );
3560 }
3561
3562 #[test]
3563 fn variant_replacement_preserves_lifetime_params() {
3564 let f = do_variant_replacement(
3565 "fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a [f32] { data }",
3566 "v3",
3567 false,
3568 );
3569 assert!(!f.sig.generics.params.is_empty());
3570 }
3571
3572 #[test]
3573 fn variant_replacement_scalar_self_injects_preamble() {
3574 let f = do_variant_replacement(
3575 "fn method(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3576 "scalar",
3577 true, );
3579 assert_eq!(f.sig.ident, "method_scalar");
3580
3581 let body_str = f.block.to_token_stream().to_string();
3583 assert!(
3584 body_str.contains("let _self = self"),
3585 "Scalar+self variant should have _self preamble, got: {}",
3586 body_str
3587 );
3588 }
3589
3590 #[test]
3591 fn variant_replacement_all_default_tiers_produce_valid_fns() {
3592 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
3593 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site()).unwrap();
3594
3595 for tier in &tiers {
3596 let f = do_variant_replacement(
3597 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3598 tier.name,
3599 false,
3600 );
3601 let expected_name = format!("process_{}", tier.suffix);
3602 assert_eq!(
3603 f.sig.ident.to_string(),
3604 expected_name,
3605 "Tier '{}' should produce function '{}'",
3606 tier.name,
3607 expected_name
3608 );
3609 }
3610 }
3611
3612 #[test]
3613 fn variant_replacement_all_known_tiers_produce_valid_fns() {
3614 for tier in ALL_TIERS {
3615 let f = do_variant_replacement(
3616 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3617 tier.name,
3618 false,
3619 );
3620 let expected_name = format!("compute_{}", tier.suffix);
3621 assert_eq!(
3622 f.sig.ident.to_string(),
3623 expected_name,
3624 "Tier '{}' should produce function '{}'",
3625 tier.name,
3626 expected_name
3627 );
3628 }
3629 }
3630
3631 #[test]
3632 fn variant_replacement_no_simdtoken_remains() {
3633 for tier in ALL_TIERS {
3634 let f = do_variant_replacement(
3635 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
3636 tier.name,
3637 false,
3638 );
3639 let full_str = f.to_token_stream().to_string();
3640 assert!(
3641 !full_str.contains("SimdToken"),
3642 "Tier '{}' variant still contains 'SimdToken': {}",
3643 tier.name,
3644 full_str
3645 );
3646 }
3647 }
3648
3649 #[test]
3654 fn tier_v3_targets_x86_64() {
3655 let tier = find_tier("v3").unwrap();
3656 assert_eq!(tier.target_arch, Some("x86_64"));
3657 assert_eq!(tier.cargo_feature, None);
3658 }
3659
3660 #[test]
3661 fn tier_v4_requires_avx512_feature() {
3662 let tier = find_tier("v4").unwrap();
3663 assert_eq!(tier.target_arch, Some("x86_64"));
3664 assert_eq!(tier.cargo_feature, Some("avx512"));
3665 }
3666
3667 #[test]
3668 fn tier_v4x_requires_avx512_feature() {
3669 let tier = find_tier("v4x").unwrap();
3670 assert_eq!(tier.cargo_feature, Some("avx512"));
3671 }
3672
3673 #[test]
3674 fn tier_neon_targets_aarch64() {
3675 let tier = find_tier("neon").unwrap();
3676 assert_eq!(tier.target_arch, Some("aarch64"));
3677 assert_eq!(tier.cargo_feature, None);
3678 }
3679
3680 #[test]
3681 fn tier_wasm128_targets_wasm32() {
3682 let tier = find_tier("wasm128").unwrap();
3683 assert_eq!(tier.target_arch, Some("wasm32"));
3684 assert_eq!(tier.cargo_feature, None);
3685 }
3686
3687 #[test]
3688 fn tier_scalar_has_no_guards() {
3689 let tier = find_tier("scalar").unwrap();
3690 assert_eq!(tier.target_arch, None);
3691 assert_eq!(tier.cargo_feature, None);
3692 assert_eq!(tier.priority, 0);
3693 }
3694
3695 #[test]
3696 fn tier_priorities_are_consistent() {
3697 let v2 = find_tier("v2").unwrap();
3699 let v3 = find_tier("v3").unwrap();
3700 let v4 = find_tier("v4").unwrap();
3701 assert!(v4.priority > v3.priority);
3702 assert!(v3.priority > v2.priority);
3703
3704 let neon = find_tier("neon").unwrap();
3705 let arm_v2 = find_tier("arm_v2").unwrap();
3706 let arm_v3 = find_tier("arm_v3").unwrap();
3707 assert!(arm_v3.priority > arm_v2.priority);
3708 assert!(arm_v2.priority > neon.priority);
3709
3710 let scalar = find_tier("scalar").unwrap();
3712 assert!(neon.priority > scalar.priority);
3713 assert!(v2.priority > scalar.priority);
3714 }
3715
3716 #[test]
3721 fn dispatcher_param_removal_free_fn() {
3722 let f: ItemFn =
3724 syn::parse_str("fn process(token: SimdToken, data: &[f32], scale: f32) -> f32 { 0.0 }")
3725 .unwrap();
3726
3727 let token_param = find_simd_token_param(&f.sig).unwrap();
3728 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
3729 dispatcher_inputs.remove(token_param.index);
3730
3731 assert_eq!(dispatcher_inputs.len(), 2);
3733
3734 for arg in &dispatcher_inputs {
3736 if let FnArg::Typed(pt) = arg {
3737 let ty_str = pt.ty.to_token_stream().to_string();
3738 assert!(
3739 !ty_str.contains("SimdToken"),
3740 "SimdToken should be removed from dispatcher, found: {}",
3741 ty_str
3742 );
3743 }
3744 }
3745 }
3746
3747 #[test]
3748 fn dispatcher_param_removal_token_only() {
3749 let f: ItemFn = syn::parse_str("fn process(token: SimdToken) -> f32 { 0.0 }").unwrap();
3750
3751 let token_param = find_simd_token_param(&f.sig).unwrap();
3752 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
3753 dispatcher_inputs.remove(token_param.index);
3754
3755 assert_eq!(dispatcher_inputs.len(), 0);
3757 }
3758
3759 #[test]
3760 fn dispatcher_param_removal_token_last() {
3761 let f: ItemFn =
3762 syn::parse_str("fn process(data: &[f32], scale: f32, token: SimdToken) -> f32 { 0.0 }")
3763 .unwrap();
3764
3765 let token_param = find_simd_token_param(&f.sig).unwrap();
3766 assert_eq!(token_param.index, 2);
3767
3768 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
3769 dispatcher_inputs.remove(token_param.index);
3770
3771 assert_eq!(dispatcher_inputs.len(), 2);
3772 }
3773
3774 #[test]
3775 fn dispatcher_dispatch_args_extraction() {
3776 let f: ItemFn =
3778 syn::parse_str("fn process(data: &[f32], scale: f32) -> f32 { 0.0 }").unwrap();
3779
3780 let dispatch_args: Vec<String> = f
3781 .sig
3782 .inputs
3783 .iter()
3784 .filter_map(|arg| {
3785 if let FnArg::Typed(PatType { pat, .. }) = arg {
3786 if let syn::Pat::Ident(pi) = pat.as_ref() {
3787 return Some(pi.ident.to_string());
3788 }
3789 }
3790 None
3791 })
3792 .collect();
3793
3794 assert_eq!(dispatch_args, vec!["data", "scale"]);
3795 }
3796
3797 #[test]
3798 fn dispatcher_wildcard_params_get_renamed() {
3799 let f: ItemFn = syn::parse_str("fn process(_: &[f32], _: f32) -> f32 { 0.0 }").unwrap();
3800
3801 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
3802
3803 let mut wild_counter = 0u32;
3804 for arg in &mut dispatcher_inputs {
3805 if let FnArg::Typed(pat_type) = arg {
3806 if matches!(pat_type.pat.as_ref(), syn::Pat::Wild(_)) {
3807 let ident = format_ident!("__autoversion_wild_{}", wild_counter);
3808 wild_counter += 1;
3809 *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
3810 attrs: vec![],
3811 by_ref: None,
3812 mutability: None,
3813 ident,
3814 subpat: None,
3815 });
3816 }
3817 }
3818 }
3819
3820 assert_eq!(wild_counter, 2);
3822
3823 let names: Vec<String> = dispatcher_inputs
3824 .iter()
3825 .filter_map(|arg| {
3826 if let FnArg::Typed(PatType { pat, .. }) = arg {
3827 if let syn::Pat::Ident(pi) = pat.as_ref() {
3828 return Some(pi.ident.to_string());
3829 }
3830 }
3831 None
3832 })
3833 .collect();
3834
3835 assert_eq!(names, vec!["__autoversion_wild_0", "__autoversion_wild_1"]);
3836 }
3837
3838 #[test]
3843 fn suffix_path_simple() {
3844 let path: syn::Path = syn::parse_str("process").unwrap();
3845 let suffixed = suffix_path(&path, "v3");
3846 assert_eq!(suffixed.to_token_stream().to_string(), "process_v3");
3847 }
3848
3849 #[test]
3850 fn suffix_path_qualified() {
3851 let path: syn::Path = syn::parse_str("module::process").unwrap();
3852 let suffixed = suffix_path(&path, "neon");
3853 let s = suffixed.to_token_stream().to_string();
3854 assert!(
3855 s.contains("process_neon"),
3856 "Expected process_neon, got: {}",
3857 s
3858 );
3859 }
3860}