1mod arcane;
6mod autoversion;
7mod common;
8mod generated;
9mod incant;
10mod magetypes;
11mod rewrite;
12mod rite;
13mod tiers;
14mod token_discovery;
15
16use proc_macro::TokenStream;
17use syn::parse_macro_input;
18
19use arcane::*;
20use autoversion::*;
21use common::*;
22use incant::*;
23use magetypes::*;
24use rite::*;
25use tiers::*;
26
27#[cfg(test)]
29use generated::{token_to_features, trait_to_features};
30#[cfg(test)]
31use quote::{ToTokens, format_ident};
32#[cfg(test)]
33use syn::{FnArg, PatType, Type};
34#[cfg(test)]
35use token_discovery::*;
36
37#[proc_macro_attribute]
202pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
203 let args = parse_macro_input!(attr as ArcaneArgs);
204 let input_fn = parse_macro_input!(item as LightFn);
205 arcane_impl(input_fn, "arcane", args)
206}
207
208#[proc_macro_attribute]
212#[doc(hidden)]
213pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
214 let args = parse_macro_input!(attr as ArcaneArgs);
215 let input_fn = parse_macro_input!(item as LightFn);
216 arcane_impl(input_fn, "simd_fn", args)
217}
218
219#[proc_macro_attribute]
232pub fn token_target_features_boundary(attr: TokenStream, item: TokenStream) -> TokenStream {
233 let args = parse_macro_input!(attr as ArcaneArgs);
234 let input_fn = parse_macro_input!(item as LightFn);
235 arcane_impl(input_fn, "token_target_features_boundary", args)
236}
237
238#[proc_macro_attribute]
317pub fn rite(attr: TokenStream, item: TokenStream) -> TokenStream {
318 let args = parse_macro_input!(attr as RiteArgs);
319 let input_fn = parse_macro_input!(item as LightFn);
320 rite_impl(input_fn, args)
321}
322
323#[proc_macro_attribute]
334pub fn token_target_features(attr: TokenStream, item: TokenStream) -> TokenStream {
335 let args = parse_macro_input!(attr as RiteArgs);
336 let input_fn = parse_macro_input!(item as LightFn);
337 rite_impl(input_fn, args)
338}
339
340#[proc_macro_attribute]
401pub fn magetypes(attr: TokenStream, item: TokenStream) -> TokenStream {
402 let input_fn = parse_macro_input!(item as LightFn);
403
404 let (rite_flag, defines, tier_names) =
421 match syn::parse::Parser::parse(parse_magetypes_attr, attr) {
422 Ok(parsed) => parsed,
423 Err(e) => return e.to_compile_error().into(),
424 };
425
426 let tier_names = if tier_names.is_empty() {
427 DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect()
428 } else {
429 tier_names
430 };
431
432 let tiers = match resolve_tiers(
434 &tier_names,
435 input_fn.sig.ident.span(),
436 true, ) {
438 Ok(t) => t,
439 Err(e) => return e.to_compile_error().into(),
440 };
441
442 magetypes_impl(input_fn, &tiers, rite_flag, &defines)
443}
444
445fn parse_magetypes_attr(
450 input: syn::parse::ParseStream,
451) -> syn::Result<(bool, Vec<String>, Vec<String>)> {
452 use syn::Token;
453 let mut rite_flag = false;
454 let mut defines = Vec::new();
455 let mut tier_names = Vec::new();
456
457 while !input.is_empty() {
458 let peek_rite = input.peek(syn::Ident) && {
461 let fork = input.fork();
462 fork.parse::<syn::Ident>()
463 .is_ok_and(|i| i == "rite" && !fork.peek(syn::token::Paren))
464 };
465 let peek_define = input.peek(syn::Ident) && {
466 let fork = input.fork();
467 fork.parse::<syn::Ident>()
468 .is_ok_and(|i| i == "define" && fork.peek(syn::token::Paren))
469 };
470
471 if peek_rite {
472 let _: syn::Ident = input.parse()?;
473 rite_flag = true;
474 } else if peek_define {
475 let _: syn::Ident = input.parse()?;
476 let content;
477 syn::parenthesized!(content in input);
478 while !content.is_empty() {
479 let ty: syn::Ident = content.parse()?;
480 defines.push(ty.to_string());
481 if content.peek(Token![,]) {
482 let _: Token![,] = content.parse()?;
483 }
484 }
485 } else {
486 tier_names.push(parse_one_tier(input)?);
488 }
489
490 if input.peek(Token![,]) {
491 let _: Token![,] = input.parse()?;
492 }
493 }
494
495 Ok((rite_flag, defines, tier_names))
496}
497
498#[proc_macro]
608pub fn incant(input: TokenStream) -> TokenStream {
609 let input = parse_macro_input!(input as IncantInput);
610 incant_impl(input)
611}
612
613#[proc_macro]
615pub fn simd_route(input: TokenStream) -> TokenStream {
616 let input = parse_macro_input!(input as IncantInput);
617 incant_impl(input)
618}
619
620#[proc_macro]
628pub fn dispatch_variant(input: TokenStream) -> TokenStream {
629 let input = parse_macro_input!(input as IncantInput);
630 incant_impl(input)
631}
632
633#[proc_macro_attribute]
764pub fn autoversion(attr: TokenStream, item: TokenStream) -> TokenStream {
765 let args = parse_macro_input!(attr as AutoversionArgs);
766 let input_fn = parse_macro_input!(item as LightFn);
767 autoversion_impl(input_fn, args)
768}
769
770#[cfg(test)]
775mod tests {
776 use super::*;
777
778 use super::generated::{ALL_CONCRETE_TOKENS, ALL_TRAIT_NAMES};
779 use syn::{ItemFn, ReturnType};
780
781 #[test]
782 fn every_concrete_token_is_in_token_to_features() {
783 for &name in ALL_CONCRETE_TOKENS {
784 assert!(
785 token_to_features(name).is_some(),
786 "Token `{}` exists in runtime crate but is NOT recognized by \
787 token_to_features() in the proc macro. Add it!",
788 name
789 );
790 }
791 }
792
793 #[test]
794 fn every_trait_is_in_trait_to_features() {
795 for &name in ALL_TRAIT_NAMES {
796 assert!(
797 trait_to_features(name).is_some(),
798 "Trait `{}` exists in runtime crate but is NOT recognized by \
799 trait_to_features() in the proc macro. Add it!",
800 name
801 );
802 }
803 }
804
805 #[test]
806 fn token_aliases_map_to_same_features() {
807 assert_eq!(
809 token_to_features("Desktop64"),
810 token_to_features("X64V3Token"),
811 "Desktop64 and X64V3Token should map to identical features"
812 );
813
814 assert_eq!(
816 token_to_features("Server64"),
817 token_to_features("X64V4Token"),
818 "Server64 and X64V4Token should map to identical features"
819 );
820 assert_eq!(
821 token_to_features("X64V4Token"),
822 token_to_features("Avx512Token"),
823 "X64V4Token and Avx512Token should map to identical features"
824 );
825
826 assert_eq!(
828 token_to_features("Arm64"),
829 token_to_features("NeonToken"),
830 "Arm64 and NeonToken should map to identical features"
831 );
832 }
833
834 #[test]
835 fn trait_to_features_includes_tokens_as_bounds() {
836 let tier_tokens = [
840 "X64V2Token",
841 "X64CryptoToken",
842 "X64V3Token",
843 "Desktop64",
844 "Avx2FmaToken",
845 "X64V4Token",
846 "Avx512Token",
847 "Server64",
848 "X64V4xToken",
849 "Avx512Fp16Token",
850 "NeonToken",
851 "Arm64",
852 "NeonAesToken",
853 "NeonSha3Token",
854 "NeonCrcToken",
855 "Arm64V2Token",
856 "Arm64V3Token",
857 ];
858
859 for &name in &tier_tokens {
860 assert!(
861 trait_to_features(name).is_some(),
862 "Tier token `{}` should also be recognized in trait_to_features() \
863 for use as a generic bound. Add it!",
864 name
865 );
866 }
867 }
868
869 #[test]
870 fn trait_features_are_cumulative() {
871 let v2_features = trait_to_features("HasX64V2").unwrap();
873 let v4_features = trait_to_features("HasX64V4").unwrap();
874
875 for &f in v2_features {
876 assert!(
877 v4_features.contains(&f),
878 "HasX64V4 should include v2 feature `{}` but doesn't",
879 f
880 );
881 }
882
883 assert!(
885 v4_features.len() > v2_features.len(),
886 "HasX64V4 should have more features than HasX64V2"
887 );
888 }
889
890 #[test]
891 fn x64v3_trait_features_include_v2() {
892 let v2 = trait_to_features("HasX64V2").unwrap();
894 let v3 = trait_to_features("X64V3Token").unwrap();
895
896 for &f in v2 {
897 assert!(
898 v3.contains(&f),
899 "X64V3Token trait features should include v2 feature `{}` but don't",
900 f
901 );
902 }
903 }
904
905 #[test]
906 fn has_neon_aes_includes_neon() {
907 let neon = trait_to_features("HasNeon").unwrap();
908 let neon_aes = trait_to_features("HasNeonAes").unwrap();
909
910 for &f in neon {
911 assert!(
912 neon_aes.contains(&f),
913 "HasNeonAes should include NEON feature `{}`",
914 f
915 );
916 }
917 }
918
919 #[test]
920 fn no_removed_traits_are_recognized() {
921 let removed = [
923 "HasSse",
924 "HasSse2",
925 "HasSse41",
926 "HasSse42",
927 "HasAvx",
928 "HasAvx2",
929 "HasFma",
930 "HasAvx512f",
931 "HasAvx512bw",
932 "HasAvx512vl",
933 "HasAvx512vbmi2",
934 "HasSve",
935 "HasSve2",
936 ];
937
938 for &name in &removed {
939 assert!(
940 trait_to_features(name).is_none(),
941 "Removed trait `{}` should NOT be in trait_to_features(). \
942 It was removed in 0.3.0 — users should migrate to tier traits.",
943 name
944 );
945 }
946 }
947
948 #[test]
949 fn no_nonexistent_tokens_are_recognized() {
950 let fake = [
952 "SveToken",
953 "Sve2Token",
954 "Avx512VnniToken",
955 "X64V4ModernToken",
956 "NeonFp16Token",
957 ];
958
959 for &name in &fake {
960 assert!(
961 token_to_features(name).is_none(),
962 "Non-existent token `{}` should NOT be in token_to_features()",
963 name
964 );
965 }
966 }
967
968 #[test]
969 fn featureless_traits_are_not_in_registries() {
970 for &name in FEATURELESS_TRAIT_NAMES {
973 assert!(
974 token_to_features(name).is_none(),
975 "`{}` should NOT be in token_to_features() — it has no CPU features",
976 name
977 );
978 assert!(
979 trait_to_features(name).is_none(),
980 "`{}` should NOT be in trait_to_features() — it has no CPU features",
981 name
982 );
983 }
984 }
985
986 #[test]
987 fn find_featureless_trait_detects_simdtoken() {
988 let names = vec!["SimdToken".to_string()];
989 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
990
991 let names = vec!["IntoConcreteToken".to_string()];
992 assert_eq!(find_featureless_trait(&names), Some("IntoConcreteToken"));
993
994 let names = vec!["HasX64V2".to_string()];
996 assert_eq!(find_featureless_trait(&names), None);
997
998 let names = vec!["HasNeon".to_string()];
999 assert_eq!(find_featureless_trait(&names), None);
1000
1001 let names = vec!["SimdToken".to_string(), "HasX64V2".to_string()];
1003 assert_eq!(find_featureless_trait(&names), Some("SimdToken"));
1004 }
1005
1006 #[test]
1007 fn arm64_v2_v3_traits_are_cumulative() {
1008 let v2_features = trait_to_features("HasArm64V2").unwrap();
1009 let v3_features = trait_to_features("HasArm64V3").unwrap();
1010
1011 for &f in v2_features {
1012 assert!(
1013 v3_features.contains(&f),
1014 "HasArm64V3 should include v2 feature `{}` but doesn't",
1015 f
1016 );
1017 }
1018
1019 assert!(
1020 v3_features.len() > v2_features.len(),
1021 "HasArm64V3 should have more features than HasArm64V2"
1022 );
1023 }
1024
1025 fn resolve_tier_names(names: &[&str], default_gates: bool) -> Vec<String> {
1030 let names: Vec<String> = names.iter().map(|s| s.to_string()).collect();
1031 resolve_tiers(&names, proc_macro2::Span::call_site(), default_gates)
1032 .unwrap()
1033 .iter()
1034 .map(|rt| {
1035 if let Some(ref gate) = rt.feature_gate {
1036 format!("{}({})", rt.name, gate)
1037 } else {
1038 rt.name.to_string()
1039 }
1040 })
1041 .collect()
1042 }
1043
1044 #[test]
1045 fn resolve_defaults() {
1046 let tiers = resolve_tier_names(&["v4", "v3", "neon", "wasm128", "scalar"], true);
1047 assert!(tiers.contains(&"v3".to_string()));
1048 assert!(tiers.contains(&"scalar".to_string()));
1049 assert!(tiers.contains(&"v4(avx512)".to_string()));
1051 }
1052
1053 #[test]
1054 fn resolve_additive_appends() {
1055 let tiers = resolve_tier_names(&["+v1"], true);
1056 assert!(tiers.contains(&"v1".to_string()));
1057 assert!(tiers.contains(&"v3".to_string())); assert!(tiers.contains(&"scalar".to_string())); }
1060
1061 #[test]
1062 fn resolve_additive_v4_overrides_gate() {
1063 let tiers = resolve_tier_names(&["+v4"], true);
1065 assert!(tiers.contains(&"v4".to_string())); assert!(!tiers.iter().any(|t| t == "v4(avx512)")); }
1068
1069 #[test]
1070 fn resolve_additive_default_replaces_scalar() {
1071 let tiers = resolve_tier_names(&["+default"], true);
1072 assert!(tiers.contains(&"default".to_string()));
1073 assert!(!tiers.iter().any(|t| t == "scalar")); }
1075
1076 #[test]
1077 fn resolve_subtractive_removes() {
1078 let tiers = resolve_tier_names(&["-neon", "-wasm128"], true);
1079 assert!(!tiers.iter().any(|t| t == "neon"));
1080 assert!(!tiers.iter().any(|t| t == "wasm128"));
1081 assert!(tiers.contains(&"v3".to_string())); }
1083
1084 #[test]
1085 fn resolve_subtractive_removes_scalar() {
1086 let tiers = resolve_tier_names(&["-scalar"], true);
1088 assert!(
1089 !tiers.iter().any(|t| t == "scalar"),
1090 "expected scalar to be removed, got {tiers:?}"
1091 );
1092 }
1093
1094 #[test]
1095 fn resolve_subtractive_removes_scalar_with_other_modifiers() {
1096 let tiers = resolve_tier_names(&["-scalar", "+arm_v2"], true);
1098 assert!(
1099 !tiers.iter().any(|t| t == "scalar"),
1100 "expected scalar to be removed, got {tiers:?}"
1101 );
1102 assert!(tiers.contains(&"arm_v2".to_string()));
1103 }
1104
1105 #[test]
1106 fn resolve_mixed_add_remove() {
1107 let tiers = resolve_tier_names(&["-neon", "-wasm128", "+v1"], true);
1108 assert!(tiers.contains(&"v1".to_string()));
1109 assert!(!tiers.iter().any(|t| t == "neon"));
1110 assert!(!tiers.iter().any(|t| t == "wasm128"));
1111 assert!(tiers.contains(&"v3".to_string()));
1112 assert!(tiers.contains(&"scalar".to_string()));
1113 }
1114
1115 #[test]
1116 fn resolve_additive_duplicate_is_noop() {
1117 let tiers = resolve_tier_names(&["+v3"], true);
1119 let v3_count = tiers.iter().filter(|t| t.as_str() == "v3").count();
1120 assert_eq!(v3_count, 1);
1121 }
1122
1123 #[test]
1124 fn resolve_mixing_plus_and_plain_is_additive() {
1125 let names: Vec<String> = vec!["+v1".into(), "v3".into()];
1129 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), true).unwrap();
1130 let suffixes: Vec<&str> = tiers.iter().map(|t| t.tier.suffix).collect();
1131 assert!(suffixes.contains(&"v1"));
1132 assert!(suffixes.contains(&"v3"));
1133 }
1134
1135 #[test]
1136 fn resolve_underscore_tier_name() {
1137 let tiers = resolve_tier_names(&["_v3", "_neon", "_scalar"], false);
1138 assert!(tiers.contains(&"v3".to_string()));
1139 assert!(tiers.contains(&"neon".to_string()));
1140 assert!(tiers.contains(&"scalar".to_string()));
1141 }
1142
1143 #[test]
1148 fn autoversion_args_empty() {
1149 let args: AutoversionArgs = syn::parse_str("").unwrap();
1150 assert!(args.self_type.is_none());
1151 assert!(args.tiers.is_none());
1152 }
1153
1154 #[test]
1155 fn autoversion_args_single_tier() {
1156 let args: AutoversionArgs = syn::parse_str("v3").unwrap();
1157 assert!(args.self_type.is_none());
1158 assert_eq!(args.tiers.as_ref().unwrap(), &["v3"]);
1159 }
1160
1161 #[test]
1162 fn autoversion_args_tiers_only() {
1163 let args: AutoversionArgs = syn::parse_str("v3, v4, neon").unwrap();
1164 assert!(args.self_type.is_none());
1165 let tiers = args.tiers.unwrap();
1166 assert_eq!(tiers, vec!["v3", "v4", "neon"]);
1167 }
1168
1169 #[test]
1170 fn autoversion_args_many_tiers() {
1171 let args: AutoversionArgs =
1172 syn::parse_str("v1, v2, v3, v4, v4x, neon, arm_v2, wasm128").unwrap();
1173 assert_eq!(
1174 args.tiers.unwrap(),
1175 vec!["v1", "v2", "v3", "v4", "v4x", "neon", "arm_v2", "wasm128"]
1176 );
1177 }
1178
1179 #[test]
1180 fn autoversion_args_trailing_comma() {
1181 let args: AutoversionArgs = syn::parse_str("v3, v4,").unwrap();
1182 assert_eq!(args.tiers.as_ref().unwrap(), &["v3", "v4"]);
1183 }
1184
1185 #[test]
1186 fn autoversion_args_self_only() {
1187 let args: AutoversionArgs = syn::parse_str("_self = MyType").unwrap();
1188 assert!(args.self_type.is_some());
1189 assert!(args.tiers.is_none());
1190 }
1191
1192 #[test]
1193 fn autoversion_args_self_and_tiers() {
1194 let args: AutoversionArgs = syn::parse_str("_self = MyType, v3, neon").unwrap();
1195 assert!(args.self_type.is_some());
1196 let tiers = args.tiers.unwrap();
1197 assert_eq!(tiers, vec!["v3", "neon"]);
1198 }
1199
1200 #[test]
1201 fn autoversion_args_tiers_then_self() {
1202 let args: AutoversionArgs = syn::parse_str("v3, neon, _self = MyType").unwrap();
1204 assert!(args.self_type.is_some());
1205 let tiers = args.tiers.unwrap();
1206 assert_eq!(tiers, vec!["v3", "neon"]);
1207 }
1208
1209 #[test]
1210 fn autoversion_args_self_with_path_type() {
1211 let args: AutoversionArgs = syn::parse_str("_self = crate::MyType").unwrap();
1212 assert!(args.self_type.is_some());
1213 assert!(args.tiers.is_none());
1214 }
1215
1216 #[test]
1217 fn autoversion_args_self_with_generic_type() {
1218 let args: AutoversionArgs = syn::parse_str("_self = Vec<u8>").unwrap();
1219 assert!(args.self_type.is_some());
1220 let ty_str = args.self_type.unwrap().to_token_stream().to_string();
1221 assert!(ty_str.contains("Vec"), "Expected Vec<u8>, got: {}", ty_str);
1222 }
1223
1224 #[test]
1225 fn autoversion_args_self_trailing_comma() {
1226 let args: AutoversionArgs = syn::parse_str("_self = MyType,").unwrap();
1227 assert!(args.self_type.is_some());
1228 assert!(args.tiers.is_none());
1229 }
1230
1231 #[test]
1236 fn find_autoversion_token_param_simdtoken_first() {
1237 let f: ItemFn =
1238 syn::parse_str("fn process(token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1239 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1240 assert_eq!(param.index, 0);
1241 assert_eq!(param.ident, "token");
1242 assert_eq!(param.kind, AutoversionTokenKind::SimdToken);
1243 }
1244
1245 #[test]
1246 fn find_autoversion_token_param_simdtoken_second() {
1247 let f: ItemFn =
1248 syn::parse_str("fn process(data: &[f32], token: SimdToken) -> f32 {}").unwrap();
1249 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1250 assert_eq!(param.index, 1);
1251 assert_eq!(param.kind, AutoversionTokenKind::SimdToken);
1252 }
1253
1254 #[test]
1255 fn find_autoversion_token_param_underscore_prefix() {
1256 let f: ItemFn =
1257 syn::parse_str("fn process(_token: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1258 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1259 assert_eq!(param.index, 0);
1260 assert_eq!(param.ident, "_token");
1261 }
1262
1263 #[test]
1264 fn find_autoversion_token_param_wildcard() {
1265 let f: ItemFn = syn::parse_str("fn process(_: SimdToken, data: &[f32]) -> f32 {}").unwrap();
1266 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1267 assert_eq!(param.index, 0);
1268 assert_eq!(param.ident, "__autoversion_token");
1269 }
1270
1271 #[test]
1272 fn find_autoversion_token_param_scalar_token() {
1273 let f: ItemFn =
1274 syn::parse_str("fn process_scalar(_: ScalarToken, data: &[f32]) -> f32 {}").unwrap();
1275 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1276 assert_eq!(param.index, 0);
1277 assert_eq!(param.kind, AutoversionTokenKind::ScalarToken);
1278 }
1279
1280 #[test]
1281 fn find_autoversion_token_param_not_found() {
1282 let f: ItemFn = syn::parse_str("fn process(data: &[f32]) -> f32 {}").unwrap();
1283 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1284 }
1285
1286 #[test]
1287 fn find_autoversion_token_param_no_params() {
1288 let f: ItemFn = syn::parse_str("fn process() {}").unwrap();
1289 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1290 }
1291
1292 #[test]
1293 fn find_autoversion_token_param_concrete_token_errors() {
1294 let f: ItemFn =
1295 syn::parse_str("fn process(token: X64V3Token, data: &[f32]) -> f32 {}").unwrap();
1296 let err = find_autoversion_token_param(&f.sig).unwrap_err();
1297 let msg = err.to_string();
1298 assert!(
1299 msg.contains("concrete token"),
1300 "error should mention concrete token: {msg}"
1301 );
1302 assert!(
1303 msg.contains("#[arcane]"),
1304 "error should suggest #[arcane]: {msg}"
1305 );
1306 }
1307
1308 #[test]
1309 fn find_autoversion_token_param_neon_token_errors() {
1310 let f: ItemFn =
1311 syn::parse_str("fn process(token: NeonToken, data: &[f32]) -> f32 {}").unwrap();
1312 assert!(find_autoversion_token_param(&f.sig).is_err());
1313 }
1314
1315 #[test]
1316 fn find_autoversion_token_param_unknown_type_ignored() {
1317 let f: ItemFn = syn::parse_str("fn process(data: &[f32], scale: f32) -> f32 {}").unwrap();
1319 assert!(find_autoversion_token_param(&f.sig).unwrap().is_none());
1320 }
1321
1322 #[test]
1323 fn find_autoversion_token_param_among_many() {
1324 let f: ItemFn = syn::parse_str(
1325 "fn process(a: i32, b: f64, token: SimdToken, c: &str, d: bool) -> f32 {}",
1326 )
1327 .unwrap();
1328 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1329 assert_eq!(param.index, 2);
1330 assert_eq!(param.ident, "token");
1331 }
1332
1333 #[test]
1334 fn find_autoversion_token_param_with_generics() {
1335 let f: ItemFn =
1336 syn::parse_str("fn process<T: Clone>(token: SimdToken, data: &[T]) -> T {}").unwrap();
1337 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1338 assert_eq!(param.index, 0);
1339 }
1340
1341 #[test]
1342 fn find_autoversion_token_param_with_where_clause() {
1343 let f: ItemFn = syn::parse_str(
1344 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default {}",
1345 )
1346 .unwrap();
1347 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1348 assert_eq!(param.index, 0);
1349 }
1350
1351 #[test]
1352 fn find_autoversion_token_param_with_lifetime() {
1353 let f: ItemFn =
1354 syn::parse_str("fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a f32 {}")
1355 .unwrap();
1356 let param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1357 assert_eq!(param.index, 0);
1358 }
1359
1360 #[test]
1365 fn autoversion_default_tiers_all_resolve() {
1366 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1367 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1368 assert!(!tiers.is_empty());
1369 assert!(tiers.iter().any(|t| t.name == "scalar"));
1371 }
1372
1373 #[test]
1374 fn autoversion_scalar_always_appended() {
1375 let names = vec!["v3".to_string(), "neon".to_string()];
1376 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1377 assert!(
1378 tiers.iter().any(|t| t.name == "scalar"),
1379 "scalar must be auto-appended"
1380 );
1381 }
1382
1383 #[test]
1384 fn autoversion_scalar_not_duplicated() {
1385 let names = vec!["v3".to_string(), "scalar".to_string()];
1386 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1387 let scalar_count = tiers.iter().filter(|t| t.name == "scalar").count();
1388 assert_eq!(scalar_count, 1, "scalar must not be duplicated");
1389 }
1390
1391 #[test]
1392 fn autoversion_tiers_sorted_by_priority() {
1393 let names = vec!["neon".to_string(), "v4".to_string(), "v3".to_string()];
1394 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1395 let priorities: Vec<u32> = tiers.iter().map(|t| t.priority).collect();
1397 for window in priorities.windows(2) {
1398 assert!(
1399 window[0] >= window[1],
1400 "Tiers not sorted by priority: {:?}",
1401 priorities
1402 );
1403 }
1404 }
1405
1406 #[test]
1407 fn autoversion_unknown_tier_errors() {
1408 let names = vec!["v3".to_string(), "avx9000".to_string()];
1409 let result = resolve_tiers(&names, proc_macro2::Span::call_site(), false);
1410 match result {
1411 Ok(_) => panic!("Expected error for unknown tier 'avx9000'"),
1412 Err(e) => {
1413 let err_msg = e.to_string();
1414 assert!(
1415 err_msg.contains("avx9000"),
1416 "Error should mention unknown tier: {}",
1417 err_msg
1418 );
1419 }
1420 }
1421 }
1422
1423 #[test]
1424 fn autoversion_all_known_tiers_resolve() {
1425 for tier in ALL_TIERS {
1427 assert!(
1428 find_tier(tier.name).is_some(),
1429 "Tier '{}' should be findable by name",
1430 tier.name
1431 );
1432 }
1433 }
1434
1435 #[test]
1436 fn autoversion_default_tier_list_is_sensible() {
1437 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1439 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1440
1441 let has_x86 = tiers.iter().any(|t| t.target_arch == Some("x86_64"));
1442 let has_arm = tiers.iter().any(|t| t.target_arch == Some("aarch64"));
1443 let has_wasm = tiers.iter().any(|t| t.target_arch == Some("wasm32"));
1444 let has_scalar = tiers.iter().any(|t| t.name == "scalar");
1445
1446 assert!(has_x86, "Default tiers should include an x86_64 tier");
1447 assert!(has_arm, "Default tiers should include an aarch64 tier");
1448 assert!(has_wasm, "Default tiers should include a wasm32 tier");
1449 assert!(has_scalar, "Default tiers should include scalar");
1450 }
1451
1452 fn do_variant_replacement(func: &str, tier_name: &str, has_self: bool) -> ItemFn {
1460 let mut f: ItemFn = syn::parse_str(func).unwrap();
1461 let fn_name = f.sig.ident.to_string();
1462
1463 let tier = find_tier(tier_name).unwrap();
1464
1465 f.sig.ident = format_ident!("{}_{}", fn_name, tier.suffix);
1467
1468 let token_idx = find_autoversion_token_param(&f.sig)
1470 .expect("should not error on SimdToken")
1471 .unwrap_or_else(|| panic!("No SimdToken param in: {}", func))
1472 .index;
1473 if tier_name == "default" {
1474 let stmts = f.block.stmts.clone();
1476 let mut inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1477 inputs.remove(token_idx);
1478 f.sig.inputs = inputs.into_iter().collect();
1479 f.block.stmts = stmts;
1480 } else {
1481 let concrete_type: Type = syn::parse_str(tier.token_path).unwrap();
1482 if let FnArg::Typed(pt) = &mut f.sig.inputs[token_idx] {
1483 *pt.ty = concrete_type;
1484 }
1485 }
1486
1487 if (tier_name == "scalar" || tier_name == "default") && has_self {
1489 let preamble: syn::Stmt = syn::parse_quote!(let _self = self;);
1490 f.block.stmts.insert(0, preamble);
1491 }
1492
1493 f
1494 }
1495
1496 #[test]
1497 fn variant_replacement_v3_renames_function() {
1498 let f = do_variant_replacement(
1499 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1500 "v3",
1501 false,
1502 );
1503 assert_eq!(f.sig.ident, "process_v3");
1504 }
1505
1506 #[test]
1507 fn variant_replacement_v3_replaces_token_type() {
1508 let f = do_variant_replacement(
1509 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1510 "v3",
1511 false,
1512 );
1513 let first_param_ty = match &f.sig.inputs[0] {
1514 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1515 _ => panic!("Expected typed param"),
1516 };
1517 assert!(
1518 first_param_ty.contains("X64V3Token"),
1519 "Expected X64V3Token, got: {}",
1520 first_param_ty
1521 );
1522 }
1523
1524 #[test]
1525 fn variant_replacement_neon_produces_valid_fn() {
1526 let f = do_variant_replacement(
1527 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1528 "neon",
1529 false,
1530 );
1531 assert_eq!(f.sig.ident, "compute_neon");
1532 let first_param_ty = match &f.sig.inputs[0] {
1533 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1534 _ => panic!("Expected typed param"),
1535 };
1536 assert!(
1537 first_param_ty.contains("NeonToken"),
1538 "Expected NeonToken, got: {}",
1539 first_param_ty
1540 );
1541 }
1542
1543 #[test]
1544 fn variant_replacement_wasm128_produces_valid_fn() {
1545 let f = do_variant_replacement(
1546 "fn compute(_t: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1547 "wasm128",
1548 false,
1549 );
1550 assert_eq!(f.sig.ident, "compute_wasm128");
1551 }
1552
1553 #[test]
1554 fn variant_replacement_scalar_produces_valid_fn() {
1555 let f = do_variant_replacement(
1556 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1557 "scalar",
1558 false,
1559 );
1560 assert_eq!(f.sig.ident, "compute_scalar");
1561 let first_param_ty = match &f.sig.inputs[0] {
1562 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1563 _ => panic!("Expected typed param"),
1564 };
1565 assert!(
1566 first_param_ty.contains("ScalarToken"),
1567 "Expected ScalarToken, got: {}",
1568 first_param_ty
1569 );
1570 }
1571
1572 #[test]
1573 fn variant_replacement_v4_produces_valid_fn() {
1574 let f = do_variant_replacement(
1575 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1576 "v4",
1577 false,
1578 );
1579 assert_eq!(f.sig.ident, "transform_v4");
1580 let first_param_ty = match &f.sig.inputs[0] {
1581 FnArg::Typed(pt) => pt.ty.to_token_stream().to_string(),
1582 _ => panic!("Expected typed param"),
1583 };
1584 assert!(
1585 first_param_ty.contains("X64V4Token"),
1586 "Expected X64V4Token, got: {}",
1587 first_param_ty
1588 );
1589 }
1590
1591 #[test]
1592 fn variant_replacement_v4x_produces_valid_fn() {
1593 let f = do_variant_replacement(
1594 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1595 "v4x",
1596 false,
1597 );
1598 assert_eq!(f.sig.ident, "transform_v4x");
1599 }
1600
1601 #[test]
1602 fn variant_replacement_arm_v2_produces_valid_fn() {
1603 let f = do_variant_replacement(
1604 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1605 "arm_v2",
1606 false,
1607 );
1608 assert_eq!(f.sig.ident, "transform_arm_v2");
1609 }
1610
1611 #[test]
1612 fn variant_replacement_preserves_generics() {
1613 let f = do_variant_replacement(
1614 "fn process<T: Copy + Default>(token: SimdToken, data: &[T]) -> T { T::default() }",
1615 "v3",
1616 false,
1617 );
1618 assert_eq!(f.sig.ident, "process_v3");
1619 assert!(
1621 !f.sig.generics.params.is_empty(),
1622 "Generics should be preserved"
1623 );
1624 }
1625
1626 #[test]
1627 fn variant_replacement_preserves_where_clause() {
1628 let f = do_variant_replacement(
1629 "fn process<T>(token: SimdToken, data: &[T]) -> T where T: Copy + Default { T::default() }",
1630 "v3",
1631 false,
1632 );
1633 assert!(
1634 f.sig.generics.where_clause.is_some(),
1635 "Where clause should be preserved"
1636 );
1637 }
1638
1639 #[test]
1640 fn variant_replacement_preserves_return_type() {
1641 let f = do_variant_replacement(
1642 "fn process(token: SimdToken, data: &[f32]) -> Vec<f32> { vec![] }",
1643 "neon",
1644 false,
1645 );
1646 let ret = f.sig.output.to_token_stream().to_string();
1647 assert!(
1648 ret.contains("Vec"),
1649 "Return type should be preserved, got: {}",
1650 ret
1651 );
1652 }
1653
1654 #[test]
1655 fn variant_replacement_preserves_multiple_params() {
1656 let f = do_variant_replacement(
1657 "fn process(token: SimdToken, a: &[f32], b: &[f32], scale: f32) -> f32 { 0.0 }",
1658 "v3",
1659 false,
1660 );
1661 assert_eq!(f.sig.inputs.len(), 4);
1663 }
1664
1665 #[test]
1666 fn variant_replacement_preserves_no_return_type() {
1667 let f = do_variant_replacement(
1668 "fn transform(token: SimdToken, data: &mut [f32]) { }",
1669 "v3",
1670 false,
1671 );
1672 assert!(
1673 matches!(f.sig.output, ReturnType::Default),
1674 "No return type should remain as Default"
1675 );
1676 }
1677
1678 #[test]
1679 fn variant_replacement_preserves_lifetime_params() {
1680 let f = do_variant_replacement(
1681 "fn process<'a>(token: SimdToken, data: &'a [f32]) -> &'a [f32] { data }",
1682 "v3",
1683 false,
1684 );
1685 assert!(!f.sig.generics.params.is_empty());
1686 }
1687
1688 #[test]
1689 fn variant_replacement_scalar_self_injects_preamble() {
1690 let f = do_variant_replacement(
1691 "fn method(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1692 "scalar",
1693 true, );
1695 assert_eq!(f.sig.ident, "method_scalar");
1696
1697 let body_str = f.block.to_token_stream().to_string();
1699 assert!(
1700 body_str.contains("let _self = self"),
1701 "Scalar+self variant should have _self preamble, got: {}",
1702 body_str
1703 );
1704 }
1705
1706 #[test]
1707 fn variant_replacement_all_default_tiers_produce_valid_fns() {
1708 let names: Vec<String> = DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect();
1709 let tiers = resolve_tiers(&names, proc_macro2::Span::call_site(), false).unwrap();
1710
1711 for tier in &tiers {
1712 let f = do_variant_replacement(
1713 "fn process(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1714 tier.name,
1715 false,
1716 );
1717 let expected_name = format!("process_{}", tier.suffix);
1718 assert_eq!(
1719 f.sig.ident.to_string(),
1720 expected_name,
1721 "Tier '{}' should produce function '{}'",
1722 tier.name,
1723 expected_name
1724 );
1725 }
1726 }
1727
1728 #[test]
1729 fn variant_replacement_all_known_tiers_produce_valid_fns() {
1730 for tier in ALL_TIERS {
1731 let f = do_variant_replacement(
1732 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1733 tier.name,
1734 false,
1735 );
1736 let expected_name = format!("compute_{}", tier.suffix);
1737 assert_eq!(
1738 f.sig.ident.to_string(),
1739 expected_name,
1740 "Tier '{}' should produce function '{}'",
1741 tier.name,
1742 expected_name
1743 );
1744 }
1745 }
1746
1747 #[test]
1748 fn variant_replacement_no_simdtoken_remains() {
1749 for tier in ALL_TIERS {
1750 let f = do_variant_replacement(
1751 "fn compute(token: SimdToken, data: &[f32]) -> f32 { 0.0 }",
1752 tier.name,
1753 false,
1754 );
1755 let full_str = f.to_token_stream().to_string();
1756 assert!(
1757 !full_str.contains("SimdToken"),
1758 "Tier '{}' variant still contains 'SimdToken': {}",
1759 tier.name,
1760 full_str
1761 );
1762 }
1763 }
1764
1765 #[test]
1770 fn tier_v3_targets_x86_64() {
1771 let tier = find_tier("v3").unwrap();
1772 assert_eq!(tier.target_arch, Some("x86_64"));
1773 }
1774
1775 #[test]
1776 fn tier_v4_targets_x86_64() {
1777 let tier = find_tier("v4").unwrap();
1778 assert_eq!(tier.target_arch, Some("x86_64"));
1779 }
1780
1781 #[test]
1782 fn tier_v4x_targets_x86_64() {
1783 let tier = find_tier("v4x").unwrap();
1784 assert_eq!(tier.target_arch, Some("x86_64"));
1785 }
1786
1787 #[test]
1788 fn tier_neon_targets_aarch64() {
1789 let tier = find_tier("neon").unwrap();
1790 assert_eq!(tier.target_arch, Some("aarch64"));
1791 }
1792
1793 #[test]
1794 fn tier_wasm128_targets_wasm32() {
1795 let tier = find_tier("wasm128").unwrap();
1796 assert_eq!(tier.target_arch, Some("wasm32"));
1797 }
1798
1799 #[test]
1800 fn tier_scalar_has_no_guards() {
1801 let tier = find_tier("scalar").unwrap();
1802 assert_eq!(tier.target_arch, None);
1803 assert_eq!(tier.priority, 0);
1804 }
1805
1806 #[test]
1807 fn tier_priorities_are_consistent() {
1808 let v2 = find_tier("v2").unwrap();
1810 let v3 = find_tier("v3").unwrap();
1811 let v4 = find_tier("v4").unwrap();
1812 assert!(v4.priority > v3.priority);
1813 assert!(v3.priority > v2.priority);
1814
1815 let neon = find_tier("neon").unwrap();
1816 let arm_v2 = find_tier("arm_v2").unwrap();
1817 let arm_v3 = find_tier("arm_v3").unwrap();
1818 assert!(arm_v3.priority > arm_v2.priority);
1819 assert!(arm_v2.priority > neon.priority);
1820
1821 let scalar = find_tier("scalar").unwrap();
1823 assert!(neon.priority > scalar.priority);
1824 assert!(v2.priority > scalar.priority);
1825 }
1826
1827 #[test]
1832 fn dispatcher_param_removal_free_fn() {
1833 let f: ItemFn =
1835 syn::parse_str("fn process(token: SimdToken, data: &[f32], scale: f32) -> f32 { 0.0 }")
1836 .unwrap();
1837
1838 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1839 assert_eq!(token_param.kind, AutoversionTokenKind::SimdToken);
1841 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1842 dispatcher_inputs.remove(token_param.index);
1843 assert_eq!(dispatcher_inputs.len(), 2);
1844 }
1845
1846 #[test]
1847 fn dispatcher_param_removal_token_only() {
1848 let f: ItemFn = syn::parse_str("fn process(token: SimdToken) -> f32 { 0.0 }").unwrap();
1849 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1850 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1851 dispatcher_inputs.remove(token_param.index);
1852 assert_eq!(dispatcher_inputs.len(), 0);
1853 }
1854
1855 #[test]
1856 fn dispatcher_param_removal_token_last() {
1857 let f: ItemFn =
1858 syn::parse_str("fn process(data: &[f32], scale: f32, token: SimdToken) -> f32 { 0.0 }")
1859 .unwrap();
1860 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1861 assert_eq!(token_param.index, 2);
1862 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1863 dispatcher_inputs.remove(token_param.index);
1864 assert_eq!(dispatcher_inputs.len(), 2);
1865 }
1866
1867 #[test]
1868 fn dispatcher_scalar_token_kept() {
1869 let f: ItemFn =
1871 syn::parse_str("fn process_scalar(_: ScalarToken, data: &[f32]) -> f32 { 0.0 }")
1872 .unwrap();
1873 let token_param = find_autoversion_token_param(&f.sig).unwrap().unwrap();
1874 assert_eq!(token_param.kind, AutoversionTokenKind::ScalarToken);
1875 assert_eq!(f.sig.inputs.len(), 2);
1877 }
1878
1879 #[test]
1880 fn dispatcher_dispatch_args_extraction() {
1881 let f: ItemFn =
1883 syn::parse_str("fn process(data: &[f32], scale: f32) -> f32 { 0.0 }").unwrap();
1884
1885 let dispatch_args: Vec<String> = f
1886 .sig
1887 .inputs
1888 .iter()
1889 .filter_map(|arg| {
1890 if let FnArg::Typed(PatType { pat, .. }) = arg {
1891 if let syn::Pat::Ident(pi) = pat.as_ref() {
1892 return Some(pi.ident.to_string());
1893 }
1894 }
1895 None
1896 })
1897 .collect();
1898
1899 assert_eq!(dispatch_args, vec!["data", "scale"]);
1900 }
1901
1902 #[test]
1903 fn dispatcher_wildcard_params_get_renamed() {
1904 let f: ItemFn = syn::parse_str("fn process(_: &[f32], _: f32) -> f32 { 0.0 }").unwrap();
1905
1906 let mut dispatcher_inputs: Vec<FnArg> = f.sig.inputs.iter().cloned().collect();
1907
1908 let mut wild_counter = 0u32;
1909 for arg in &mut dispatcher_inputs {
1910 if let FnArg::Typed(pat_type) = arg {
1911 if matches!(pat_type.pat.as_ref(), syn::Pat::Wild(_)) {
1912 let ident = format_ident!("__autoversion_wild_{}", wild_counter);
1913 wild_counter += 1;
1914 *pat_type.pat = syn::Pat::Ident(syn::PatIdent {
1915 attrs: vec![],
1916 by_ref: None,
1917 mutability: None,
1918 ident,
1919 subpat: None,
1920 });
1921 }
1922 }
1923 }
1924
1925 assert_eq!(wild_counter, 2);
1927
1928 let names: Vec<String> = dispatcher_inputs
1929 .iter()
1930 .filter_map(|arg| {
1931 if let FnArg::Typed(PatType { pat, .. }) = arg {
1932 if let syn::Pat::Ident(pi) = pat.as_ref() {
1933 return Some(pi.ident.to_string());
1934 }
1935 }
1936 None
1937 })
1938 .collect();
1939
1940 assert_eq!(names, vec!["__autoversion_wild_0", "__autoversion_wild_1"]);
1941 }
1942
1943 #[test]
1948 fn suffix_path_simple() {
1949 let path: syn::Path = syn::parse_str("process").unwrap();
1950 let suffixed = suffix_path(&path, "v3");
1951 assert_eq!(suffixed.to_token_stream().to_string(), "process_v3");
1952 }
1953
1954 #[test]
1955 fn suffix_path_qualified() {
1956 let path: syn::Path = syn::parse_str("module::process").unwrap();
1957 let suffixed = suffix_path(&path, "neon");
1958 let s = suffixed.to_token_stream().to_string();
1959 assert!(
1960 s.contains("process_neon"),
1961 "Expected process_neon, got: {}",
1962 s
1963 );
1964 }
1965}