1mod arcane;
6mod autoversion;
7mod common;
8mod generated;
9mod incant;
10mod magetypes;
11mod rewrite;
12mod rite;
13mod tiers;
14mod token_discovery;
15
16use proc_macro::TokenStream;
17use syn::parse_macro_input;
18
19use arcane::*;
20use autoversion::*;
21use common::*;
22use incant::*;
23use magetypes::*;
24use rite::*;
25use tiers::*;
26
27#[cfg(test)]
29use generated::{token_to_features, trait_to_features};
30#[cfg(test)]
31use quote::{ToTokens, format_ident};
32#[cfg(test)]
33use syn::{FnArg, PatType, Type};
34#[cfg(test)]
35use token_discovery::*;
36
37#[proc_macro_attribute]
202pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
203 let args = parse_macro_input!(attr as ArcaneArgs);
204 let input_fn = parse_macro_input!(item as LightFn);
205 arcane_impl(input_fn, "arcane", args)
206}
207
208#[proc_macro_attribute]
212#[doc(hidden)]
213pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
214 let args = parse_macro_input!(attr as ArcaneArgs);
215 let input_fn = parse_macro_input!(item as LightFn);
216 arcane_impl(input_fn, "simd_fn", args)
217}
218
219#[proc_macro_attribute]
232pub fn token_target_features_boundary(attr: TokenStream, item: TokenStream) -> TokenStream {
233 let args = parse_macro_input!(attr as ArcaneArgs);
234 let input_fn = parse_macro_input!(item as LightFn);
235 arcane_impl(input_fn, "token_target_features_boundary", args)
236}
237
238#[proc_macro_attribute]
317pub fn rite(attr: TokenStream, item: TokenStream) -> TokenStream {
318 let args = parse_macro_input!(attr as RiteArgs);
319 let input_fn = parse_macro_input!(item as LightFn);
320 rite_impl(input_fn, args)
321}
322
323#[proc_macro_attribute]
334pub fn token_target_features(attr: TokenStream, item: TokenStream) -> TokenStream {
335 let args = parse_macro_input!(attr as RiteArgs);
336 let input_fn = parse_macro_input!(item as LightFn);
337 rite_impl(input_fn, args)
338}
339
340#[proc_macro_attribute]
401pub fn magetypes(attr: TokenStream, item: TokenStream) -> TokenStream {
402 let input_fn = parse_macro_input!(item as LightFn);
403
404 let (rite_flag, defines, tier_names) =
421 match syn::parse::Parser::parse(parse_magetypes_attr, attr) {
422 Ok(parsed) => parsed,
423 Err(e) => return e.to_compile_error().into(),
424 };
425
426 let tier_names = if tier_names.is_empty() {
427 DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect()
428 } else {
429 tier_names
430 };
431
432 let tiers = match resolve_tiers(
434 &tier_names,
435 input_fn.sig.ident.span(),
436 true, ) {
438 Ok(t) => t,
439 Err(e) => return e.to_compile_error().into(),
440 };
441
442 magetypes_impl(input_fn, &tiers, rite_flag, &defines)
443}
444
445fn parse_magetypes_attr(
450 input: syn::parse::ParseStream,
451) -> syn::Result<(bool, Vec<String>, Vec<String>)> {
452 use syn::Token;
453 let mut rite_flag = false;
454 let mut defines = Vec::new();
455 let mut tier_names = Vec::new();
456
457 while !input.is_empty() {
458 let peek_rite = input.peek(syn::Ident) && {
461 let fork = input.fork();
462 fork.parse::<syn::Ident>()
463 .is_ok_and(|i| i == "rite" && !fork.peek(syn::token::Paren))
464 };
465 let peek_define = input.peek(syn::Ident) && {
466 let fork = input.fork();
467 fork.parse::<syn::Ident>()
468 .is_ok_and(|i| i == "define" && fork.peek(syn::token::Paren))
469 };
470
471 if peek_rite {
472 let _: syn::Ident = input.parse()?;
473 rite_flag = true;
474 } else if peek_define {
475 let _: syn::Ident = input.parse()?;
476 let content;
477 syn::parenthesized!(content in input);
478 while !content.is_empty() {
479 let ty: syn::Ident = content.parse()?;
480 defines.push(ty.to_string());
481 if content.peek(Token![,]) {
482 let _: Token![,] = content.parse()?;
483 }
484 }
485 } else {
486 tier_names.push(parse_one_tier(input)?);
488 }
489
490 if input.peek(Token![,]) {
491 let _: Token![,] = input.parse()?;
492 }
493 }
494
495 Ok((rite_flag, defines, tier_names))
496}
497
498#[proc_macro]
608pub fn incant(input: TokenStream) -> TokenStream {
609 let input = parse_macro_input!(input as IncantInput);
610 incant_impl(input)
611}
612
613#[proc_macro]
615pub fn simd_route(input: TokenStream) -> TokenStream {
616 let input = parse_macro_input!(input as IncantInput);
617 incant_impl(input)
618}
619
620#[proc_macro]
628pub fn dispatch_variant(input: TokenStream) -> TokenStream {
629 let input = parse_macro_input!(input as IncantInput);
630 incant_impl(input)
631}
632
633#[proc_macro_attribute]
764pub fn autoversion(attr: TokenStream, item: TokenStream) -> TokenStream {
765 let args = parse_macro_input!(attr as AutoversionArgs);
766 let input_fn = parse_macro_input!(item as LightFn);
767 autoversion_impl(input_fn, args)
768}
769
770#[cfg(test)]
775mod tests {
776 use super::*;
777
778 use super::generated::{ALL_CONCRETE_TOKENS, ALL_TRAIT_NAMES};
779 use syn::{ItemFn, ReturnType};
780
781 #[test]
782 fn every_concrete_token_is_in_token_to_features() {
783 for &name in ALL_CONCRETE_TOKENS {
784 assert!(
785 token_to_features(name).is_some(),
786 "Token `{}` exists in runtime crate but is NOT recognized by \
787 token_to_features() in the proc macro. Add it!",
788 name
789 );
790 }
791 }
792
793 #[test]
794 fn every_trait_is_in_trait_to_features() {
795 for &name in ALL_TRAIT_NAMES {
796 assert!(
797 trait_to_features(name).is_some(),
798 "Trait `{}` exists in runtime crate but is NOT recognized by \
799 trait_to_features() in the proc macro. Add it!",
800 name
801 );
802 }
803 }
804
805 #[test]
806 fn token_aliases_map_to_same_features() {
807 assert_eq!(
809 token_to_features("Desktop64"),
810 token_to_features("X64V3Token"),
811 "Desktop64 and X64V3Token should map to identical features"
812 );
813
814 assert_eq!(
816 token_to_features("Server64"),
817 token_to_features("X64V4Token"),
818 "Server64 and X64V4Token should map to identical features"
819 );
820 assert_eq!(
821 token_to_features("X64V4Token"),
822 token_to_features("Avx512Token"),
823 "X64V4Token and Avx512Token should map to identical features"
824 );
825
826 assert_eq!(
828 token_to_features("Arm64"),
829 token_to_features("NeonToken"),
830 "Arm64 and NeonToken should map to identical features"
831 );
832 }
833
834 #[test]
835 fn trait_to_features_includes_tokens_as_bounds() {
836 let tier_tokens = [
840 "X64V2Token",
841 "X64CryptoToken",
842 "X64V3Token",
843 "Desktop64",
844 "Avx2FmaToken",
845 "X64V4Token",
846 "Avx512Token",
847 "Server64",
848 "X64V4xToken",
849 "Avx512Fp16Token",
850 "NeonToken",
851 "Arm64",
852 "NeonAesToken",
853 "NeonSha3Token",
854 "NeonCrcToken",
855 "Arm64V2Token",
856 "Arm64V3Token",
857 ];
858
859 for &name in &tier_tokens {
860 assert!(
861 trait_to_features(name).is_some(),
862 "Tier token `{}` should also be recognized in trait_to_features() \
863 for use as a generic bound. Add it!",
864 name
865 );
866 }
867 }
868
869 #[test]
870 fn trait_features_are_cumulative() {
871 let v2_features = trait_to_features("HasX64V2").unwrap();
873 let v4_features = trait_to_features("HasX64V4").unwrap();
874
875 for &f in v2_features {
876 assert!(
877 v4_features.contains(&f),
878 "HasX64V4 should include v2 feature `{}` but doesn't",
879 f
880 );
881 }
882
883 assert!(
885 v4_features.len() > v2_features.len(),
886 "HasX64V4 should have more features than HasX64V2"
887 );
888 }
889
890 #[test]
891 fn x64v3_trait_features_include_v2() {
892 let v2 = trait_to_features("HasX64V2").unwrap();
894 let v3 = trait_to_features("X64V3Token").unwrap();
895
896 for &f in v2 {
897 assert!(
898 v3.contains(&f),
899 "X64V3Token trait features should include v2 feature `{}` but don't",
900 f
901 );
902 }
903 }
904
905 #[test]
906 fn has_neon_aes_includes_neon() {
907 let neon = trait_to_features("HasNeon").unwrap();
908 let neon_aes = trait_to_features("HasNeonAes").unwrap();
909
910 for &f in neon {
911 assert!(
912 neon_aes.contains(&f),
913 "HasNeonAes should include NEON feature `{}`",
914 f
915 );
916 }
917 }
918
919 #[test]
920 fn no_removed_traits_are_recognized() {
921 let removed = [
923 "HasSse",
924 "HasSse2",
925 "HasSse41",
926 "HasSse42",
927 "HasAvx",
928 "HasAvx2",
929 "HasFma",
930 "HasAvx512f",
931 "HasAvx512bw",
932 "HasAvx512vl",
933 "HasAvx512vbmi2",
934 "HasSve",
935 "HasSve2",
936 ];
937
938 for &name in &removed {
939 assert!(
940 trait_to_features(name).is_none(),
941 "Removed trait `{}` should NOT be in trait_to_features(). \
942 It was removed in 0.3.0 — users should migrate to tier traits.",
943 name
944 );
945 }
946 }
947
948 #[test]
949 fn no_nonexistent_tokens_are_recognized() {
950 let fake = [
952 "SveToken",
953 "Sve2Token",
954 "Avx512VnniToken",
955 "X64V4ModernToken",
956 "NeonFp16Token",
957 ];
958
959 for &name in &fake {
960 assert!(
961 token_to_features(name).is_none(),
962 "Non-existent token `{}` should NOT be in token_to_features()",
963 name
964 );
965 }
966 }
967
968 #[test]
969 fn featureless_traits_are_not_in_registries() {
970 for &name in FEATURELESS_TRAIT_NAMES {
973 assert!(
974 token_to_features(name).is_none(),
975 "`{}` should NOT be in token_to_features() — it has no CPU features",
976 name
977 );
978 assert!(
979 trait_to_features(name).is_none(),
980 "`{}` should NOT be in trait_to_features() — it has no CPU features",
981 name
982 );
983 }
984 }
985
986 #[test]
987 fn find_featureless_trait_detects_simdtoken() {
988 let names = vec!["SimdToken".to_string()];
989 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
990
991 let names = vec!["IntoConcreteToken".to_string()];
992 assert_eq!(find_featureless_trait(&names), Some("IntoConcreteToken"));
993
994 let names = vec!["HasX64V2".to_string()];
996 assert_eq!(find_featureless_trait(&names), None);
997
998 let names = vec!["HasNeon".to_string()];
999 assert_eq!(find_featureless_trait(&names), None);
1000
1001 let names = vec!["SimdToken".to_string(), "HasX64V2".to_string()];
1003 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
1004 }
1005
1006 #[test]
1007 fn arm64_v2_v3_traits_are_cumulative() {
1008 let v2_features = trait_to_features("HasArm64V2").unwrap();
1009 let v3_features = trait_to_features("HasArm64V3").unwrap();
1010
1011 for &f in v2_features {
1012 assert!(
1013 v3_features.contains(&f),
1014 "HasArm64V3 should include v2 feature `{}` but doesn't",
1015 f
1016 );
1017 }
1018
1019 assert!(
1020 v3_features.len() > v2_features.len(),
1021 "HasArm64V3 should have more features than HasArm64V2"
1022 );
1023 }
1024
1025 fn resolve_tier_names(names: &[&str], default_gates: bool) -> Vec<String> {
1030 let names: Vec<String> = names.iter().map(|s| s.to_string()).collect();
1031 resolve_tiers(&names, proc_macro2::Span::call_site(), default_gates)
1032 .unwrap()
1033 .iter()
1034 .map(|rt| {
1035 if let Some(ref gate) = rt.feature_gate {
1036 format!("{}({})", rt.name, gate)
1037 } else {
1038 rt.name.to_string()
1039 }
1040 })
1041 .collect()
1042 }
1043
1044 #[test]
1045 fn resolve_defaults() {
1046 let tiers = resolve_tier_names(&["v4", "v3", "neon", "wasm128", "scalar"], true);
1047 assert!(tiers.contains(&"v3".to_string()));
1048 assert!(tiers.contains(&"scalar".to_string()));
1049 assert!(tiers.contains(&"v4(avx512)".to_string()));
1051 }
1052
1053 #[test]
1054 fn resolve_additive_appends() {
1055 let tiers = resolve_tier_names(&["+v1"], true);
1056 assert!(tiers.contains(&"v1".to_string()));
1057 assert!(tiers.contains(&"v3".to_string())); assert!(tiers.contains(&"scalar".to_string())); }
1060
1061 #[test]
1062 fn resolve_additive_v4_overrides_gate() {
1063 let tiers = resolve_tier_names(&["+v4"], true);
1065 assert!(tiers.contains(&"v4".to_string())); assert!(!tiers.iter().any(|t| t == "v4(avx512)")); }
1068
1069 #[test]
1070 fn resolve_additive_default_replaces_scalar() {
1071 let tiers = resolve_tier_names(&["+default"], true);
1072 assert!(tiers.contains(&"default".to_string()));
1073 assert!(!tiers.iter().any(|t| t == "scalar")); }
1075
1076 #[test]
1077 fn resolve_subtractive_removes() {
1078 let tiers = resolve_tier_names(&["-neon", "-wasm128"], true);
1079 assert!(!tiers.iter().any(|t| t == "neon"));
1080 assert!(!tiers.iter().any(|t| t == "wasm128"));
1081 assert!(tiers.contains(&"v3".to_string())); }
1083
1084 #[test]
1085 fn resolve_subtractive_removes_scalar() {
1086 let tiers = resolve_tier_names(&["-scalar"], true);
1088 assert!(
1089 !tiers.iter().any(|t| t == "scalar"),
1090 "expected scalar to be removed, got {tiers:?}"
1091 );
1092 }
1093
1094 #[test]
1095 fn resolve_subtractive_removes_scalar_with_other_modifiers() {
1096 let tiers = resolve_tier_names(&["-scalar", "+arm_v2"], true);
1098 assert!(
1099 !tiers.iter().any(|t| t == "scalar"),
1100 "expected scalar to be removed, got {tiers:?}"
1101 );
1102 assert!(tiers.contains(&"arm_v2".to_string()));
1103 }
1104
1105 #[test]
1106 fn resolve_mixed_add_remove() {
1107 let tiers = resolve_tier_names(&["-neon", "-wasm128", "+v1"], true);
1108 assert!(tiers.contains(&"v1".to_string()));
1109 assert!(!tiers.iter().any(|t| t == "neon"));
1110 assert!(!tiers.iter().any(|t| t == "wasm128"));
1111 assert!(tiers.contains(&"v3".to_string()));
1112 assert!(tiers.contains(&"scalar".to_string()));
1113 }
1114
1115 #[test]
1116 fn resolve_additive_duplicate_is_noop() {
1117 let tiers = resolve_tier_names(&["+v3"], true);
1119 let v3_count = tiers.iter().filter(|t| t.as_str() == "v3").count();
1120 assert_eq!(v3_count, 1);
1121 }
1122
1123 #[test]
1124 fn resolve_mixing_plus_and_plain_is_error() {
1125 let names: Vec<String> = vec!["+v1".into(), "v3".into()];
1126 let result = resolve_tiers(&names, proc_macro2::Span::call_site(), true);
1127 assert!(result.is_err());
1128 }
1129
1130 #[test]
1131 fn resolve_underscore_tier_name() {
1132 let tiers = resolve_tier_names(&["_v3", "_neon", "_scalar"], false);
1133 assert!(tiers.contains(&"v3".to_string()));
1134 assert!(tiers.contains(&"neon".to_string()));
1135 assert!(tiers.contains(&"scalar".to_string()));
1136 }
1137
1138 #[test]
1143 fn autoversion_args_empty() {
1144 let args: AutoversionArgs = syn::parse_str("").unwrap();
1145 assert!(args.self_type.is_none());
1146 assert!(args.tiers.is_none());
1147 }
1148
1149 #[test]
1150 fn autoversion_args_single_tier() {
1151 let args: AutoversionArgs = syn::parse_str("v3").unwrap();
1152 assert!(args.self_type.is_none());
1153 assert_eq!(args.tiers.as_ref().unwrap(), &["v3"]);
1154 }
1155
1156 #[test]
1157 fn autoversion_args_tiers_only() {
1158 let args: AutoversionArgs = syn::parse_str("v3, v4, neon").unwrap();
1159 assert!(args.self_type.is_none());
1160 let tiers = args.tiers.unwrap();
1161 assert_eq!(tiers, vec!["v3", "v4", "neon"]);
1162 }
1163
1164 #[test]
1165 fn autoversion_args_many_tiers() {
1166 let args: AutoversionArgs =
1167 syn::parse_str("v1, v2, v3, v4, v4x, neon, arm_v2, wasm128").unwrap();
1168 assert_eq!(
1169 args.tiers.unwrap(),
1170 vec!["v1", "v2", "v3", "v4", "v4x", "neon", "arm_v2", "wasm128"]
1171 );
1172 }
1173
1174 #[test]
1175 fn autoversion_args_trailing_comma() {
1176 let args: AutoversionArgs = syn::parse_str("v3, v4,").unwrap();
1177 assert_eq!(args.tiers.as_ref().unwrap(), &["v3", "v4"]);
1178 }
1179
1180 #[test]
1181 fn autoversion_args_self_only() {
1182 let args: AutoversionArgs = syn::parse_str("_self = MyType").unwrap();
1183 assert!(args.self_type.is_some());
1184 assert!(args.tiers.is_none());
1185 }
1186
1187 #[test]
1188 fn autoversion_args_self_and_tiers() {
1189 let args: AutoversionArgs = syn::parse_str("_self = MyType, v3, neon").unwrap();
1190 assert!(args.self_type.is_some());
1191 let tiers = args.tiers.unwrap();
1192 assert_eq!(tiers, vec!["v3", "neon"]);
1193 }
1194
1195 #[test]
1196 fn autoversion_args_tiers_then_self() {
1197 let args: AutoversionArgs = syn::parse_str("v3, neon, _self = MyType").unwrap();
1199 assert!(args.self_type.is_some());
1200 let tiers = args.tiers.unwrap();
1201 assert_eq!(tiers, vec!["v3", "neon"]);
1202 }
1203
1204 #[test]
1205 fn autoversion_args_self_with_path_type() {
1206 let args: AutoversionArgs = syn::parse_str("_self = crate::MyType").unwrap();
1207 assert!(args.self_type.is_some());
1208 assert!(args.tiers.is_none());
1209 }
1210
1211 #[test]
1212 fn autoversion_args_self_with_generic_type() {
1213 let args: AutoversionArgs = syn::parse_str("_self = Vec<u8>").unwrap();
1214 assert!(args.self_type.is_some());
1215 let ty_str = args.self_type.unwrap().to_token_stream().to_string();
1216 assert!(ty_str.contains("Vec"), "Expected Vec<u8>, got: {}", ty_str);
1217 }
1218
1219 #[test]
1220 fn autoversion_args_self_trailing_comma() {
1221 let args: AutoversionArgs = syn::parse_str("_self = MyType,").unwrap();
1222 assert!(args.self_type.is_some());
1223 assert!(args.tiers.is_none());
1224 }
1225
1226 #[test]
1231 fn find_autoversion_token_param_simdtoken_first() {
1232 let f: ItemFn =
1233 syn::parse_str("fn process(token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1234 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1235 assert_eq!(param.index, 0);
1236 assert_eq!(param.ident, "token");
1237 assert_eq!(param.kind, AutoversionTokenKind::SimdToken);
1238 }
1239
1240 #[test]
1241 fn find_autoversion_token_param_simdtoken_second() {
1242 let f: ItemFn =
1243 syn::parse_str("fn process(data: &[f32], token: SimdToken) -> f32 {}").unwrap();
1244 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1245 assert_eq!(param.index, 1);
1246 assert_eq!(param.kind, AutoversionTokenKind::SimdToken);
1247 }
1248
1249 #[test]
1250 fn find_autoversion_token_param_underscore_prefix() {
1251 let f: ItemFn =
1252 syn::parse_str("fn process(_token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1253 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1254 assert_eq!(param.index, 0);
1255 assert_eq!(param.ident, "_token");
1256 }
1257
1258 #[test]
1259 fn find_autoversion_token_param_wildcard() {
1260 let f: ItemFn = syn::parse_str("fn process(_: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1261 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1262 assert_eq!(param.index, 0);
1263 assert_eq!(param.ident, "__autoversion_token");
1264 }
1265
1266 #[test]
1267 fn find_autoversion_token_param_scalar_token() {
1268 let f: ItemFn =
1269 syn::parse_str("fn process_scalar(_: ScalarToken, data: &[f32]) -> f32 {}").unwrap();
1270 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1271 assert_eq!(param.index, 0);
1272 assert_eq!(param.kind, AutoversionTokenKind::ScalarToken);
1273 }
1274
1275 #[test]
1276 fn find_autoversion_token_param_not_found() {
1277 let f: ItemFn = syn::parse_str("fn process(data: &[f32]) -> f32 {}").unwrap();
1278 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1279 }
1280
1281 #[test]
1282 fn find_autoversion_token_param_no_params() {
1283 let f: ItemFn = syn::parse_str("fn process() {}").unwrap();
1284 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1285 }
1286
1287 #[test]
1288 fn find_autoversion_token_param_concrete_token_errors() {
1289 let f: ItemFn =
1290 syn::parse_str("fn process(token: X64V3Token, data: &[f32]) -> f32 {}").unwrap();
1291 let err = find_autoversion_token_param(&f.sig).unwrap_err();
1292 let msg = err.to_string();
1293 assert!(
1294 msg.contains("concrete token"),
1295 "error should mention concrete token: {msg}"
1296 );
1297 assert!(
1298 msg.contains("#[arcane]"),
1299 "error should suggest #[arcane]: {msg}"
1300 );
1301 }
1302
1303 #[test]
1304 fn find_autoversion_token_param_neon_token_errors() {
1305 let f: ItemFn =
1306 syn::parse_str("fn process(token: NeonToken, data: &[f32]) -> f32 {}").unwrap();
1307 assert!(find_autoversion_token_param(&f.sig).is_err());
1308 }
1309
1310 #[test]
1311 fn find_autoversion_token_param_unknown_type_ignored() {
1312 let f: ItemFn = syn::parse_str("fn process(data: &[f32], scale: f32) -> f32 {}").unwrap();
1314 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1315 }
1316
1317 #[test]
1318 fn find_autoversion_token_param_among_many() {
1319 let f: ItemFn = syn::parse_str(
1320 "fn process(a: i32, b: f64, token: SimdToken, c: &str, d: bool) -> f32 {}",
1321 )
1322 .unwrap();
1323 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1324 assert_eq!(param.index, 2);
1325 assert_eq!(param.ident, "token");
1326 }
1327
1328 #[test]
1329 fn find_autoversion_token_param_with_generics() {
1330 let f: ItemFn =
1331 syn::parse_str("fn process<T: Clone>(token: SimdToken, data: &[T]) -> T {}").unwrap();
1332 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1333 assert_eq!(param.index, 0);
1334 }
1335
1336 #[test]
1337 fn find_autoversion_token_param_with_where_clause() {
1338 let f: ItemFn = syn::parse_str(
1339 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default {}",
1340 )
1341 .unwrap();
1342 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1343 assert_eq!(param.index, 0);
1344 }
1345
1346 #[test]
1347 fn find_autoversion_token_param_with_lifetime() {
1348 let f: ItemFn =
1349 syn::parse_str("fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a f32 {}")
1350 .unwrap();
1351 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1352 assert_eq!(param.index, 0);
1353 }
1354
1355 #[test]
1360 fn autoversion_default_tiers_all_resolve() {
1361 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1362 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1363 assert!(!tiers.is_empty());
1364 assert!(tiers.iter().any(|t| t.name == "scalar"));
1366 }
1367
1368 #[test]
1369 fn autoversion_scalar_always_appended() {
1370 let names = vec!["v3".to_string(), "neon".to_string()];
1371 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1372 assert!(
1373 tiers.iter().any(|t| t.name == "scalar"),
1374 "scalar must be auto-appended"
1375 );
1376 }
1377
1378 #[test]
1379 fn autoversion_scalar_not_duplicated() {
1380 let names = vec!["v3".to_string(), "scalar".to_string()];
1381 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1382 let scalar_count = tiers.iter().filter(|t| t.name == "scalar").count();
1383 assert_eq!(scalar_count, 1, "scalar must not be duplicated");
1384 }
1385
1386 #[test]
1387 fn autoversion_tiers_sorted_by_priority() {
1388 let names = vec!["neon".to_string(), "v4".to_string(), "v3".to_string()];
1389 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1390 let priorities: Vec<u32> = tiers.iter().map(|t| t.priority).collect();
1392 for window in priorities.windows(2) {
1393 assert!(
1394 window[0] >= window[1],
1395 "Tiers not sorted by priority: {:?}",
1396 priorities
1397 );
1398 }
1399 }
1400
1401 #[test]
1402 fn autoversion_unknown_tier_errors() {
1403 let names = vec!["v3".to_string(), "avx9000".to_string()];
1404 let result = resolve_tiers(&names, proc_macro2::Span::call_site(), false);
1405 match result {
1406 Ok(_) => panic!("Expected error for unknown tier 'avx9000'"),
1407 Err(e) => {
1408 let err_msg = e.to_string();
1409 assert!(
1410 err_msg.contains("avx9000"),
1411 "Error should mention unknown tier: {}",
1412 err_msg
1413 );
1414 }
1415 }
1416 }
1417
1418 #[test]
1419 fn autoversion_all_known_tiers_resolve() {
1420 for tier in ALL_TIERS {
1422 assert!(
1423 find_tier(tier.name).is_some(),
1424 "Tier '{}' should be findable by name",
1425 tier.name
1426 );
1427 }
1428 }
1429
1430 #[test]
1431 fn autoversion_default_tier_list_is_sensible() {
1432 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1434 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1435
1436 let has_x86 = tiers.iter().any(|t| t.target_arch == Some("x86_64"));
1437 let has_arm = tiers.iter().any(|t| t.target_arch == Some("aarch64"));
1438 let has_wasm = tiers.iter().any(|t| t.target_arch == Some("wasm32"));
1439 let has_scalar = tiers.iter().any(|t| t.name == "scalar");
1440
1441 assert!(has_x86, "Default tiers should include an x86_64 tier");
1442 assert!(has_arm, "Default tiers should include an aarch64 tier");
1443 assert!(has_wasm, "Default tiers should include a wasm32 tier");
1444 assert!(has_scalar, "Default tiers should include scalar");
1445 }
1446
1447 fn do_variant_replacement(func: &str, tier_name: &str, has_self: bool) -> ItemFn {
1455 let mut f: ItemFn = syn::parse_str(func).unwrap();
1456 let fn_name = f.sig.ident.to_string();
1457
1458 let tier = find_tier(tier_name).unwrap();
1459
1460 f.sig.ident = format_ident!("{}_{}", fn_name, tier.suffix);
1462
1463 let token_idx = find_autoversion_token_param(&f.sig)
1465 .expect("should not error on SimdToken")
1466 .unwrap_or_else(|| panic!("No SimdToken param in: {}", func))
1467 .index;
1468 if tier_name == "default" {
1469 let stmts = f.block.stmts.clone();
1471 let mut inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1472 inputs.remove(token_idx);
1473 f.sig.inputs = inputs.into_iter().collect();
1474 f.block.stmts = stmts;
1475 } else {
1476 let concrete_type: Type = syn::parse_str(tier.token_path).unwrap();
1477 if let FnArg::Typed(pt) = &mut f.sig.inputs[token_idx] {
1478 *pt.ty = concrete_type;
1479 }
1480 }
1481
1482 if (tier_name == "scalar" || tier_name == "default") && has_self {
1484 let preamble: syn::Stmt = syn::parse_quote!(let _self = self;);
1485 f.block.stmts.insert(0, preamble);
1486 }
1487
1488 f
1489 }
1490
1491 #[test]
1492 fn variant_replacement_v3_renames_function() {
1493 let f = do_variant_replacement(
1494 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1495 "v3",
1496 false,
1497 );
1498 assert_eq!(f.sig.ident, "process_v3");
1499 }
1500
1501 #[test]
1502 fn variant_replacement_v3_replaces_token_type() {
1503 let f = do_variant_replacement(
1504 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1505 "v3",
1506 false,
1507 );
1508 let first_param_ty = match &f.sig.inputs[0] {
1509 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1510 _ => panic!("Expected typed param"),
1511 };
1512 assert!(
1513 first_param_ty.contains("X64V3Token"),
1514 "Expected X64V3Token, got: {}",
1515 first_param_ty
1516 );
1517 }
1518
1519 #[test]
1520 fn variant_replacement_neon_produces_valid_fn() {
1521 let f = do_variant_replacement(
1522 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1523 "neon",
1524 false,
1525 );
1526 assert_eq!(f.sig.ident, "compute_neon");
1527 let first_param_ty = match &f.sig.inputs[0] {
1528 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1529 _ => panic!("Expected typed param"),
1530 };
1531 assert!(
1532 first_param_ty.contains("NeonToken"),
1533 "Expected NeonToken, got: {}",
1534 first_param_ty
1535 );
1536 }
1537
1538 #[test]
1539 fn variant_replacement_wasm128_produces_valid_fn() {
1540 let f = do_variant_replacement(
1541 "fn compute(_t: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1542 "wasm128",
1543 false,
1544 );
1545 assert_eq!(f.sig.ident, "compute_wasm128");
1546 }
1547
1548 #[test]
1549 fn variant_replacement_scalar_produces_valid_fn() {
1550 let f = do_variant_replacement(
1551 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1552 "scalar",
1553 false,
1554 );
1555 assert_eq!(f.sig.ident, "compute_scalar");
1556 let first_param_ty = match &f.sig.inputs[0] {
1557 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1558 _ => panic!("Expected typed param"),
1559 };
1560 assert!(
1561 first_param_ty.contains("ScalarToken"),
1562 "Expected ScalarToken, got: {}",
1563 first_param_ty
1564 );
1565 }
1566
1567 #[test]
1568 fn variant_replacement_v4_produces_valid_fn() {
1569 let f = do_variant_replacement(
1570 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1571 "v4",
1572 false,
1573 );
1574 assert_eq!(f.sig.ident, "transform_v4");
1575 let first_param_ty = match &f.sig.inputs[0] {
1576 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1577 _ => panic!("Expected typed param"),
1578 };
1579 assert!(
1580 first_param_ty.contains("X64V4Token"),
1581 "Expected X64V4Token, got: {}",
1582 first_param_ty
1583 );
1584 }
1585
1586 #[test]
1587 fn variant_replacement_v4x_produces_valid_fn() {
1588 let f = do_variant_replacement(
1589 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1590 "v4x",
1591 false,
1592 );
1593 assert_eq!(f.sig.ident, "transform_v4x");
1594 }
1595
1596 #[test]
1597 fn variant_replacement_arm_v2_produces_valid_fn() {
1598 let f = do_variant_replacement(
1599 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1600 "arm_v2",
1601 false,
1602 );
1603 assert_eq!(f.sig.ident, "transform_arm_v2");
1604 }
1605
1606 #[test]
1607 fn variant_replacement_preserves_generics() {
1608 let f = do_variant_replacement(
1609 "fn process<T: Copy + Default>(token: SimdToken, data: &[T]) -> T { T::default() }",
1610 "v3",
1611 false,
1612 );
1613 assert_eq!(f.sig.ident, "process_v3");
1614 assert!(
1616 !f.sig.generics.params.is_empty(),
1617 "Generics should be preserved"
1618 );
1619 }
1620
1621 #[test]
1622 fn variant_replacement_preserves_where_clause() {
1623 let f = do_variant_replacement(
1624 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default { T::default() }",
1625 "v3",
1626 false,
1627 );
1628 assert!(
1629 f.sig.generics.where_clause.is_some(),
1630 "Where clause should be preserved"
1631 );
1632 }
1633
1634 #[test]
1635 fn variant_replacement_preserves_return_type() {
1636 let f = do_variant_replacement(
1637 "fn process(token: SimdToken, data: &[f32]) -> Vec<f32> { vec![] }",
1638 "neon",
1639 false,
1640 );
1641 let ret = f.sig.output.to_token_stream().to_string();
1642 assert!(
1643 ret.contains("Vec"),
1644 "Return type should be preserved, got: {}",
1645 ret
1646 );
1647 }
1648
1649 #[test]
1650 fn variant_replacement_preserves_multiple_params() {
1651 let f = do_variant_replacement(
1652 "fn process(token: SimdToken, a: &[f32], b: &[f32], scale: f32) -> f32 { 0.0 }",
1653 "v3",
1654 false,
1655 );
1656 assert_eq!(f.sig.inputs.len(), 4);
1658 }
1659
1660 #[test]
1661 fn variant_replacement_preserves_no_return_type() {
1662 let f = do_variant_replacement(
1663 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1664 "v3",
1665 false,
1666 );
1667 assert!(
1668 matches!(f.sig.output, ReturnType::Default),
1669 "No return type should remain as Default"
1670 );
1671 }
1672
1673 #[test]
1674 fn variant_replacement_preserves_lifetime_params() {
1675 let f = do_variant_replacement(
1676 "fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a [f32] { data }",
1677 "v3",
1678 false,
1679 );
1680 assert!(!f.sig.generics.params.is_empty());
1681 }
1682
1683 #[test]
1684 fn variant_replacement_scalar_self_injects_preamble() {
1685 let f = do_variant_replacement(
1686 "fn method(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1687 "scalar",
1688 true, );
1690 assert_eq!(f.sig.ident, "method_scalar");
1691
1692 let body_str = f.block.to_token_stream().to_string();
1694 assert!(
1695 body_str.contains("let _self = self"),
1696 "Scalar+self variant should have _self preamble, got: {}",
1697 body_str
1698 );
1699 }
1700
1701 #[test]
1702 fn variant_replacement_all_default_tiers_produce_valid_fns() {
1703 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1704 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1705
1706 for tier in &tiers {
1707 let f = do_variant_replacement(
1708 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1709 tier.name,
1710 false,
1711 );
1712 let expected_name = format!("process_{}", tier.suffix);
1713 assert_eq!(
1714 f.sig.ident.to_string(),
1715 expected_name,
1716 "Tier '{}' should produce function '{}'",
1717 tier.name,
1718 expected_name
1719 );
1720 }
1721 }
1722
1723 #[test]
1724 fn variant_replacement_all_known_tiers_produce_valid_fns() {
1725 for tier in ALL_TIERS {
1726 let f = do_variant_replacement(
1727 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1728 tier.name,
1729 false,
1730 );
1731 let expected_name = format!("compute_{}", tier.suffix);
1732 assert_eq!(
1733 f.sig.ident.to_string(),
1734 expected_name,
1735 "Tier '{}' should produce function '{}'",
1736 tier.name,
1737 expected_name
1738 );
1739 }
1740 }
1741
1742 #[test]
1743 fn variant_replacement_no_simdtoken_remains() {
1744 for tier in ALL_TIERS {
1745 let f = do_variant_replacement(
1746 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1747 tier.name,
1748 false,
1749 );
1750 let full_str = f.to_token_stream().to_string();
1751 assert!(
1752 !full_str.contains("SimdToken"),
1753 "Tier '{}' variant still contains 'SimdToken': {}",
1754 tier.name,
1755 full_str
1756 );
1757 }
1758 }
1759
1760 #[test]
1765 fn tier_v3_targets_x86_64() {
1766 let tier = find_tier("v3").unwrap();
1767 assert_eq!(tier.target_arch, Some("x86_64"));
1768 }
1769
1770 #[test]
1771 fn tier_v4_targets_x86_64() {
1772 let tier = find_tier("v4").unwrap();
1773 assert_eq!(tier.target_arch, Some("x86_64"));
1774 }
1775
1776 #[test]
1777 fn tier_v4x_targets_x86_64() {
1778 let tier = find_tier("v4x").unwrap();
1779 assert_eq!(tier.target_arch, Some("x86_64"));
1780 }
1781
1782 #[test]
1783 fn tier_neon_targets_aarch64() {
1784 let tier = find_tier("neon").unwrap();
1785 assert_eq!(tier.target_arch, Some("aarch64"));
1786 }
1787
1788 #[test]
1789 fn tier_wasm128_targets_wasm32() {
1790 let tier = find_tier("wasm128").unwrap();
1791 assert_eq!(tier.target_arch, Some("wasm32"));
1792 }
1793
1794 #[test]
1795 fn tier_scalar_has_no_guards() {
1796 let tier = find_tier("scalar").unwrap();
1797 assert_eq!(tier.target_arch, None);
1798 assert_eq!(tier.priority, 0);
1799 }
1800
1801 #[test]
1802 fn tier_priorities_are_consistent() {
1803 let v2 = find_tier("v2").unwrap();
1805 let v3 = find_tier("v3").unwrap();
1806 let v4 = find_tier("v4").unwrap();
1807 assert!(v4.priority > v3.priority);
1808 assert!(v3.priority > v2.priority);
1809
1810 let neon = find_tier("neon").unwrap();
1811 let arm_v2 = find_tier("arm_v2").unwrap();
1812 let arm_v3 = find_tier("arm_v3").unwrap();
1813 assert!(arm_v3.priority > arm_v2.priority);
1814 assert!(arm_v2.priority > neon.priority);
1815
1816 let scalar = find_tier("scalar").unwrap();
1818 assert!(neon.priority > scalar.priority);
1819 assert!(v2.priority > scalar.priority);
1820 }
1821
1822 #[test]
1827 fn dispatcher_param_removal_free_fn() {
1828 let f: ItemFn =
1830 syn::parse_str("fn process(token: SimdToken, data: &[f32], scale: f32) -> f32 { 0.0 }")
1831 .unwrap();
1832
1833 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1834 assert_eq!(token_param.kind, AutoversionTokenKind::SimdToken);
1836 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1837 dispatcher_inputs.remove(token_param.index);
1838 assert_eq!(dispatcher_inputs.len(), 2);
1839 }
1840
1841 #[test]
1842 fn dispatcher_param_removal_token_only() {
1843 let f: ItemFn = syn::parse_str("fn process(token: SimdToken) -> f32 { 0.0 }").unwrap();
1844 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1845 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1846 dispatcher_inputs.remove(token_param.index);
1847 assert_eq!(dispatcher_inputs.len(), 0);
1848 }
1849
1850 #[test]
1851 fn dispatcher_param_removal_token_last() {
1852 let f: ItemFn =
1853 syn::parse_str("fn process(data: &[f32], scale: f32, token: SimdToken) -> f32 { 0.0 }")
1854 .unwrap();
1855 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1856 assert_eq!(token_param.index, 2);
1857 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1858 dispatcher_inputs.remove(token_param.index);
1859 assert_eq!(dispatcher_inputs.len(), 2);
1860 }
1861
1862 #[test]
1863 fn dispatcher_scalar_token_kept() {
1864 let f: ItemFn =
1866 syn::parse_str("fn process_scalar(_: ScalarToken, data: &[f32]) -> f32 { 0.0 }")
1867 .unwrap();
1868 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1869 assert_eq!(token_param.kind, AutoversionTokenKind::ScalarToken);
1870 assert_eq!(f.sig.inputs.len(), 2);
1872 }
1873
1874 #[test]
1875 fn dispatcher_dispatch_args_extraction() {
1876 let f: ItemFn =
1878 syn::parse_str("fn process(data: &[f32], scale: f32) -> f32 { 0.0 }").unwrap();
1879
1880 let dispatch_args: Vec<String> = f
1881 .sig
1882 .inputs
1883 .iter()
1884 .filter_map(|arg| {
1885 if let FnArg::Typed(PatType { pat, .. }) = arg {
1886 if let syn::Pat::Ident(pi) = pat.as_ref() {
1887 return Some(pi.ident.to_string());
1888 }
1889 }
1890 None
1891 })
1892 .collect();
1893
1894 assert_eq!(dispatch_args, vec!["data", "scale"]);
1895 }
1896
1897 #[test]
1898 fn dispatcher_wildcard_params_get_renamed() {
1899 let f: ItemFn = syn::parse_str("fn process(_: &[f32], _: f32) -> f32 { 0.0 }").unwrap();
1900
1901 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1902
1903 let mut wild_counter = 0u32;
1904 for arg in &mut dispatcher_inputs {
1905 if let FnArg::Typed(pat_type) = arg {
1906 if matches!(pat_type.pat.as_ref(), syn::Pat::Wild(_)) {
1907 let ident = format_ident!("__autoversion_wild_{}", wild_counter);
1908 wild_counter += 1;
1909 *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
1910 attrs: vec![],
1911 by_ref: None,
1912 mutability: None,
1913 ident,
1914 subpat: None,
1915 });
1916 }
1917 }
1918 }
1919
1920 assert_eq!(wild_counter, 2);
1922
1923 let names: Vec<String> = dispatcher_inputs
1924 .iter()
1925 .filter_map(|arg| {
1926 if let FnArg::Typed(PatType { pat, .. }) = arg {
1927 if let syn::Pat::Ident(pi) = pat.as_ref() {
1928 return Some(pi.ident.to_string());
1929 }
1930 }
1931 None
1932 })
1933 .collect();
1934
1935 assert_eq!(names, vec!["__autoversion_wild_0", "__autoversion_wild_1"]);
1936 }
1937
1938 #[test]
1943 fn suffix_path_simple() {
1944 let path: syn::Path = syn::parse_str("process").unwrap();
1945 let suffixed = suffix_path(&path, "v3");
1946 assert_eq!(suffixed.to_token_stream().to_string(), "process_v3");
1947 }
1948
1949 #[test]
1950 fn suffix_path_qualified() {
1951 let path: syn::Path = syn::parse_str("module::process").unwrap();
1952 let suffixed = suffix_path(&path, "neon");
1953 let s = suffixed.to_token_stream().to_string();
1954 assert!(
1955 s.contains("process_neon"),
1956 "Expected process_neon, got: {}",
1957 s
1958 );
1959 }
1960}