archmage_macros/lib.rs
1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 fold::Fold,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12 ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17 replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21 fn fold_type(&mut self, ty: Type) -> Type {
22 match ty {
23 Type::Path(ref type_path) if type_path.qself.is_none() => {
24 // Check if it's just `Self`
25 if type_path.path.is_ident("Self") {
26 return self.replacement.clone();
27 }
28 // Otherwise continue folding
29 syn::fold::fold_type(self, ty)
30 }
31 _ => syn::fold::fold_type(self, ty),
32 }
33 }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39 /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40 /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41 inline_always: bool,
42 /// The concrete type to use for `self` receiver.
43 /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44 self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48 fn parse(input: ParseStream) -> syn::Result<Self> {
49 let mut args = ArcaneArgs::default();
50
51 while !input.is_empty() {
52 let ident: Ident = input.parse()?;
53 match ident.to_string().as_str() {
54 "inline_always" => args.inline_always = true,
55 "_self" => {
56 let _: Token![=] = input.parse()?;
57 args.self_type = Some(input.parse()?);
58 }
59 other => {
60 return Err(syn::Error::new(
61 ident.span(),
62 format!("unknown arcane argument: `{}`", other),
63 ))
64 }
65 }
66 // Consume optional comma
67 if input.peek(Token![,]) {
68 let _: Token![,] = input.parse()?;
69 }
70 }
71
72 Ok(args)
73 }
74}
75
76/// Maps a token type name to its required target features.
77///
78/// Based on LLVM x86-64 microarchitecture levels (psABI).
79fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
80 match token_name {
81 // x86_64 feature tokens (kept for backwards compatibility)
82 "Sse41Token" => Some(&["sse4.1"]),
83 "Sse42Token" => Some(&["sse4.2"]),
84 "AvxToken" => Some(&["avx"]),
85 "Avx2Token" => Some(&["avx2"]),
86 "FmaToken" => Some(&["fma"]),
87 "Avx2FmaToken" => Some(&["avx2", "fma"]),
88 "Avx512fToken" => Some(&["avx512f"]),
89 "Avx512bwToken" => Some(&["avx512bw"]),
90
91 // x86_64 tier tokens
92 "X64V2Token" => Some(&["sse4.2", "popcnt"]),
93 "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
94 "X64V4Token" | "Avx512Token" | "Server64" => {
95 Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
96 }
97 "Avx512ModernToken" => Some(&[
98 "avx512f",
99 "avx512bw",
100 "avx512cd",
101 "avx512dq",
102 "avx512vl",
103 "avx512vpopcntdq",
104 "avx512ifma",
105 "avx512vbmi",
106 "avx512vbmi2",
107 "avx512bitalg",
108 "avx512vnni",
109 "avx512bf16",
110 "vpclmulqdq",
111 "gfni",
112 "vaes",
113 ]),
114 "Avx512Fp16Token" => Some(&[
115 "avx512f",
116 "avx512bw",
117 "avx512cd",
118 "avx512dq",
119 "avx512vl",
120 "avx512fp16",
121 ]),
122
123 // AArch64 tokens
124 "NeonToken" | "Arm64" => Some(&["neon"]),
125 "NeonAesToken" => Some(&["neon", "aes"]),
126 "NeonSha3Token" => Some(&["neon", "sha3"]),
127 "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
128 "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
129
130 // WASM tokens
131 "Simd128Token" => Some(&["simd128"]),
132
133 _ => None,
134 }
135}
136
137/// Maps a trait bound name to its required target features.
138///
139/// IMPORTANT: Each entry must include ALL implied features, not just the defining ones.
140/// The compiler needs explicit #[target_feature] for each feature used.
141///
142/// Based on LLVM x86-64 microarchitecture levels (psABI).
143fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
144 match trait_name {
145 // x86 tier traits - each includes all features from lower tiers
146 "HasX64V2" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
147 "HasX64V4" => Some(&[
148 // v2 features
149 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", // v3 features
150 "avx", "avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt", // v4 features
151 "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
152 ]),
153
154 // x86 token types - when used directly as bounds
155 "X64V2Token" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
156 "X64V3Token" | "Desktop64" | "Avx2FmaToken" => Some(&[
157 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
158 "f16c", "lzcnt",
159 ]),
160 "X64V4Token" | "Avx512Token" | "Server64" => Some(&[
161 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
162 "f16c", "lzcnt", "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
163 ]),
164 "Avx512ModernToken" => Some(&[
165 "sse3",
166 "ssse3",
167 "sse4.1",
168 "sse4.2",
169 "popcnt",
170 "avx",
171 "avx2",
172 "fma",
173 "bmi1",
174 "bmi2",
175 "f16c",
176 "lzcnt",
177 "avx512f",
178 "avx512bw",
179 "avx512cd",
180 "avx512dq",
181 "avx512vl",
182 "avx512vpopcntdq",
183 "avx512ifma",
184 "avx512vbmi",
185 "avx512vbmi2",
186 "avx512bitalg",
187 "avx512vnni",
188 "avx512bf16",
189 "vpclmulqdq",
190 "gfni",
191 "vaes",
192 ]),
193 "Avx512Fp16Token" => Some(&[
194 "sse3",
195 "ssse3",
196 "sse4.1",
197 "sse4.2",
198 "popcnt",
199 "avx",
200 "avx2",
201 "fma",
202 "bmi1",
203 "bmi2",
204 "f16c",
205 "lzcnt",
206 "avx512f",
207 "avx512bw",
208 "avx512cd",
209 "avx512dq",
210 "avx512vl",
211 "avx512fp16",
212 ]),
213
214 // Width traits - minimal features to satisfy width
215 "Has128BitSimd" => Some(&["sse2"]),
216 "Has256BitSimd" => Some(&["avx"]),
217 "Has512BitSimd" => Some(&["avx512f"]),
218
219 // AArch64 traits
220 "HasNeon" => Some(&["neon"]),
221 "HasNeonAes" => Some(&["neon", "aes"]),
222 "HasNeonSha3" => Some(&["neon", "sha3"]),
223
224 // AArch64 token types
225 "NeonToken" | "Arm64" => Some(&["neon"]),
226 "NeonAesToken" => Some(&["neon", "aes"]),
227 "NeonSha3Token" => Some(&["neon", "sha3"]),
228 "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
229 "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
230
231 _ => None,
232 }
233}
234
235/// Result of extracting token info from a type.
236enum TokenTypeInfo {
237 /// Concrete token type (e.g., `Avx2Token`)
238 Concrete(String),
239 /// impl Trait with the trait names (e.g., `impl HasAvx2`)
240 ImplTrait(Vec<String>),
241 /// Generic type parameter name (e.g., `T`)
242 Generic(String),
243}
244
245/// Extract token type information from a type.
246fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
247 match ty {
248 Type::Path(type_path) => {
249 // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
250 type_path.path.segments.last().map(|seg| {
251 let name = seg.ident.to_string();
252 // Check if it's a known concrete token type
253 if token_to_features(&name).is_some() {
254 TokenTypeInfo::Concrete(name)
255 } else {
256 // Might be a generic type parameter like `T`
257 TokenTypeInfo::Generic(name)
258 }
259 })
260 }
261 Type::Reference(type_ref) => {
262 // Handle &Token or &mut Token
263 extract_token_type_info(&type_ref.elem)
264 }
265 Type::ImplTrait(impl_trait) => {
266 // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
267 let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
268 if traits.is_empty() {
269 None
270 } else {
271 Some(TokenTypeInfo::ImplTrait(traits))
272 }
273 }
274 _ => None,
275 }
276}
277
278/// Extract trait names from type param bounds.
279fn extract_trait_names_from_bounds(
280 bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
281) -> Vec<String> {
282 bounds
283 .iter()
284 .filter_map(|bound| {
285 if let TypeParamBound::Trait(trait_bound) = bound {
286 trait_bound
287 .path
288 .segments
289 .last()
290 .map(|seg| seg.ident.to_string())
291 } else {
292 None
293 }
294 })
295 .collect()
296}
297
298/// Look up a generic type parameter in the function's generics.
299fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
300 // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
301 for param in &sig.generics.params {
302 if let GenericParam::Type(type_param) = param {
303 if type_param.ident == type_name {
304 let traits = extract_trait_names_from_bounds(&type_param.bounds);
305 if !traits.is_empty() {
306 return Some(traits);
307 }
308 }
309 }
310 }
311
312 // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
313 if let Some(where_clause) = &sig.generics.where_clause {
314 for predicate in &where_clause.predicates {
315 if let syn::WherePredicate::Type(pred_type) = predicate {
316 if let Type::Path(type_path) = &pred_type.bounded_ty {
317 if let Some(seg) = type_path.path.segments.last() {
318 if seg.ident == type_name {
319 let traits = extract_trait_names_from_bounds(&pred_type.bounds);
320 if !traits.is_empty() {
321 return Some(traits);
322 }
323 }
324 }
325 }
326 }
327 }
328 }
329
330 None
331}
332
333/// Convert trait names to features, collecting all features from all traits.
334fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
335 let mut all_features = Vec::new();
336
337 for trait_name in trait_names {
338 if let Some(features) = trait_to_features(trait_name) {
339 for &feature in features {
340 if !all_features.contains(&feature) {
341 all_features.push(feature);
342 }
343 }
344 }
345 }
346
347 if all_features.is_empty() {
348 None
349 } else {
350 Some(all_features)
351 }
352}
353
354/// Find the first token parameter and return its name and features.
355fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
356 for arg in &sig.inputs {
357 match arg {
358 FnArg::Receiver(_) => {
359 // Self receivers (self, &self, &mut self) are not yet supported.
360 // The macro creates an inner function, and Rust's inner functions
361 // cannot have `self` parameters. Supporting this would require
362 // AST rewriting to replace `self` with a regular parameter.
363 // See the module docs for the workaround.
364 continue;
365 }
366 FnArg::Typed(PatType { pat, ty, .. }) => {
367 if let Some(info) = extract_token_type_info(ty) {
368 let features = match info {
369 TokenTypeInfo::Concrete(name) => {
370 token_to_features(&name).map(|f| f.to_vec())
371 }
372 TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
373 TokenTypeInfo::Generic(type_name) => {
374 // Look up the generic parameter's bounds
375 find_generic_bounds(sig, &type_name)
376 .and_then(|traits| traits_to_features(&traits))
377 }
378 };
379
380 if let Some(features) = features {
381 // Extract parameter name
382 if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
383 return Some((pat_ident.ident.clone(), features));
384 }
385 }
386 }
387 }
388 }
389 }
390 None
391}
392
393/// Represents the kind of self receiver and the transformed parameter.
394enum SelfReceiver {
395 /// `self` (by value/move)
396 Owned,
397 /// `&self` (shared reference)
398 Ref,
399 /// `&mut self` (mutable reference)
400 RefMut,
401}
402
403/// Shared implementation for arcane/simd_fn macros.
404fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
405 // Check for self receiver
406 let has_self_receiver = input_fn
407 .sig
408 .inputs
409 .first()
410 .map(|arg| matches!(arg, FnArg::Receiver(_)))
411 .unwrap_or(false);
412
413 // If there's a self receiver, we need _self = Type
414 if has_self_receiver && args.self_type.is_none() {
415 let msg = format!(
416 "{} with self receiver requires `_self = Type` argument.\n\
417 Example: #[{}(_self = MyType)]\n\
418 Use `_self` (not `self`) in the function body to refer to self.",
419 macro_name, macro_name
420 );
421 return syn::Error::new_spanned(&input_fn.sig, msg)
422 .to_compile_error()
423 .into();
424 }
425
426 // Find the token parameter and its features
427 let (_token_ident, features) = match find_token_param(&input_fn.sig) {
428 Some(result) => result,
429 None => {
430 let msg = format!(
431 "{} requires a token parameter. Supported forms:\n\
432 - Concrete: `token: Avx2Token`\n\
433 - impl Trait: `token: impl HasAvx2`\n\
434 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
435 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasAvx2, ...)`",
436 macro_name, macro_name
437 );
438 return syn::Error::new_spanned(&input_fn.sig, msg)
439 .to_compile_error()
440 .into();
441 }
442 };
443
444 // Build target_feature attributes
445 let target_feature_attrs: Vec<Attribute> = features
446 .iter()
447 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
448 .collect();
449
450 // Extract function components
451 let vis = &input_fn.vis;
452 let sig = &input_fn.sig;
453 let fn_name = &sig.ident;
454 let generics = &sig.generics;
455 let where_clause = &generics.where_clause;
456 let inputs = &sig.inputs;
457 let output = &sig.output;
458 let body = &input_fn.block;
459 let attrs = &input_fn.attrs;
460
461 // Determine self receiver type if present
462 let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
463 FnArg::Receiver(receiver) => {
464 if receiver.reference.is_none() {
465 Some(SelfReceiver::Owned)
466 } else if receiver.mutability.is_some() {
467 Some(SelfReceiver::RefMut)
468 } else {
469 Some(SelfReceiver::Ref)
470 }
471 }
472 _ => None,
473 });
474
475 // Build inner function parameters, transforming self if needed
476 let inner_params: Vec<proc_macro2::TokenStream> = inputs
477 .iter()
478 .map(|arg| match arg {
479 FnArg::Receiver(_) => {
480 // Transform self receiver to _self parameter
481 let self_ty = args.self_type.as_ref().unwrap();
482 match self_receiver_kind.as_ref().unwrap() {
483 SelfReceiver::Owned => quote!(_self: #self_ty),
484 SelfReceiver::Ref => quote!(_self: &#self_ty),
485 SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
486 }
487 }
488 FnArg::Typed(pat_type) => quote!(#pat_type),
489 })
490 .collect();
491
492 // Build inner function call arguments
493 let inner_args: Vec<proc_macro2::TokenStream> = inputs
494 .iter()
495 .filter_map(|arg| match arg {
496 FnArg::Typed(pat_type) => {
497 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
498 let ident = &pat_ident.ident;
499 Some(quote!(#ident))
500 } else {
501 None
502 }
503 }
504 FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
505 })
506 .collect();
507
508 let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
509
510 // Choose inline attribute based on args
511 // Note: #[inline(always)] + #[target_feature] requires nightly with
512 // #![feature(target_feature_inline_always)]
513 let inline_attr: Attribute = if args.inline_always {
514 parse_quote!(#[inline(always)])
515 } else {
516 parse_quote!(#[inline])
517 };
518
519 // Transform output and body to replace Self with concrete type if needed
520 let (inner_output, inner_body): (ReturnType, syn::Block) =
521 if let Some(ref self_ty) = args.self_type {
522 let mut replacer = ReplaceSelf {
523 replacement: self_ty,
524 };
525 let transformed_output = replacer.fold_return_type(output.clone());
526 let transformed_body = replacer.fold_block((**body).clone());
527 (transformed_output, transformed_body)
528 } else {
529 (output.clone(), (**body).clone())
530 };
531
532 // Generate the expanded function
533 let expanded = quote! {
534 #(#attrs)*
535 #vis #sig {
536 #(#target_feature_attrs)*
537 #inline_attr
538 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
539 #inner_body
540
541 // SAFETY: The token parameter proves the required CPU features are available.
542 // Tokens can only be constructed when features are verified (via try_new()
543 // runtime check or forge_token_dangerously() in a context where features are guaranteed).
544 unsafe { #inner_fn_name(#(#inner_args),*) }
545 }
546 };
547
548 expanded.into()
549}
550
551/// Mark a function as an arcane SIMD function.
552///
553/// This macro enables safe use of SIMD intrinsics by generating an inner function
554/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
555/// the token parameter type. The outer function calls the inner function unsafely,
556/// which is justified because the token parameter proves the features are available.
557///
558/// **The token is passed through to the inner function**, so you can call other
559/// token-taking functions from inside `#[arcane]`.
560///
561/// # Token Parameter Forms
562///
563/// The macro supports four forms of token parameters:
564///
565/// ## Concrete Token Types
566///
567/// ```ignore
568/// #[arcane]
569/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
570/// // AVX2 intrinsics safe here
571/// }
572/// ```
573///
574/// ## impl Trait Bounds
575///
576/// ```ignore
577/// #[arcane]
578/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
579/// // Accepts any token that provides AVX2
580/// }
581/// ```
582///
583/// ## Generic Type Parameters
584///
585/// ```ignore
586/// #[arcane]
587/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
588/// // Generic over any AVX2-capable token
589/// }
590///
591/// // Also works with where clauses:
592/// #[arcane]
593/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
594/// where
595/// T: HasAvx2
596/// {
597/// // ...
598/// }
599/// ```
600///
601/// ## Methods with Self Receivers
602///
603/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
604/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
605///
606/// ```ignore
607/// use archmage::{HasAvx2, arcane};
608/// use wide::f32x8;
609///
610/// trait Avx2Ops {
611/// fn double(&self, token: impl HasAvx2) -> Self;
612/// fn square(self, token: impl HasAvx2) -> Self;
613/// fn scale(&mut self, token: impl HasAvx2, factor: f32);
614/// }
615///
616/// impl Avx2Ops for f32x8 {
617/// #[arcane(_self = f32x8)]
618/// fn double(&self, _token: impl HasAvx2) -> Self {
619/// // Use _self instead of self in the body
620/// *_self + *_self
621/// }
622///
623/// #[arcane(_self = f32x8)]
624/// fn square(self, _token: impl HasAvx2) -> Self {
625/// _self * _self
626/// }
627///
628/// #[arcane(_self = f32x8)]
629/// fn scale(&mut self, _token: impl HasAvx2, factor: f32) {
630/// *_self = *_self * f32x8::splat(factor);
631/// }
632/// }
633/// ```
634///
635/// **Why `_self`?** The macro generates an inner function where `self` becomes
636/// a regular parameter named `_self`. Using `_self` in your code reminds you
637/// that you're not using the normal `self` keyword.
638///
639/// **All receiver types are supported:**
640/// - `self` (by value/move) → `_self: Type`
641/// - `&self` (shared reference) → `_self: &Type`
642/// - `&mut self` (mutable reference) → `_self: &mut Type`
643///
644/// # Multiple Trait Bounds
645///
646/// When using `impl Trait` or generic bounds with multiple traits,
647/// all required features are enabled:
648///
649/// ```ignore
650/// #[arcane]
651/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
652/// // Both AVX2 and FMA intrinsics are safe here
653/// }
654/// ```
655///
656/// # Expansion
657///
658/// The macro expands to approximately:
659///
660/// ```ignore
661/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
662/// #[target_feature(enable = "avx2")]
663/// #[inline]
664/// unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
665/// let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
666/// let doubled = _mm256_add_ps(v, v);
667/// let mut out = [0.0f32; 8];
668/// unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
669/// out
670/// }
671/// // SAFETY: Token proves the required features are available
672/// unsafe { __simd_inner_process(token, data) }
673/// }
674/// ```
675///
676/// # Profile Tokens
677///
678/// Profile tokens automatically enable all required features:
679///
680/// ```ignore
681/// #[arcane]
682/// fn kernel(token: X64V3Token, data: &mut [f32]) {
683/// // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
684/// }
685/// ```
686///
687/// # Supported Tokens
688///
689/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
690/// `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
691/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
692/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
693/// - **WASM**: `Simd128Token`
694///
695/// # Supported Trait Bounds
696///
697/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
698/// `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
699/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
700/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
701///
702/// # Options
703///
704/// ## `inline_always`
705///
706/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
707/// This can improve performance by ensuring aggressive inlining, but requires
708/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
709/// the crate using the macro.
710///
711/// ```ignore
712/// #![feature(target_feature_inline_always)]
713///
714/// #[arcane(inline_always)]
715/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
716/// // Inner function will use #[inline(always)]
717/// }
718/// ```
719#[proc_macro_attribute]
720pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
721 let args = parse_macro_input!(attr as ArcaneArgs);
722 let input_fn = parse_macro_input!(item as ItemFn);
723 arcane_impl(input_fn, "arcane", args)
724}
725
726/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
727///
728/// See [`arcane`] for full documentation.
729#[proc_macro_attribute]
730pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
731 let args = parse_macro_input!(attr as ArcaneArgs);
732 let input_fn = parse_macro_input!(item as ItemFn);
733 arcane_impl(input_fn, "simd_fn", args)
734}
735
736// ============================================================================
737// Multiwidth macro for width-agnostic SIMD code
738// ============================================================================
739
740use syn::ItemMod;
741
742/// Arguments to the `#[multiwidth]` macro.
743struct MultiwidthArgs {
744 /// Include SSE (128-bit) specialization
745 sse: bool,
746 /// Include AVX2 (256-bit) specialization
747 avx2: bool,
748 /// Include AVX-512 (512-bit) specialization
749 avx512: bool,
750 /// Include WASM SIMD128 (128-bit) specialization
751 wasm: bool,
752 /// Include NEON (128-bit ARM) specialization
753 neon: bool,
754}
755
756impl Default for MultiwidthArgs {
757 fn default() -> Self {
758 Self {
759 sse: true,
760 avx2: true,
761 avx512: true,
762 wasm: true,
763 neon: true,
764 }
765 }
766}
767
768impl Parse for MultiwidthArgs {
769 fn parse(input: ParseStream) -> syn::Result<Self> {
770 let mut args = MultiwidthArgs {
771 sse: false,
772 avx2: false,
773 avx512: false,
774 wasm: false,
775 neon: false,
776 };
777
778 // If no args provided, enable all
779 if input.is_empty() {
780 return Ok(MultiwidthArgs::default());
781 }
782
783 while !input.is_empty() {
784 let ident: Ident = input.parse()?;
785 match ident.to_string().as_str() {
786 "sse" => args.sse = true,
787 "avx2" => args.avx2 = true,
788 "avx512" => args.avx512 = true,
789 "wasm" | "simd128" => args.wasm = true,
790 "neon" | "arm" => args.neon = true,
791 other => {
792 return Err(syn::Error::new(
793 ident.span(),
794 format!(
795 "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512, wasm, neon",
796 other
797 ),
798 ))
799 }
800 }
801 // Consume optional comma
802 if input.peek(Token![,]) {
803 let _: Token![,] = input.parse()?;
804 }
805 }
806
807 Ok(args)
808 }
809}
810
811/// Width configuration for specialization.
812struct WidthConfig {
813 /// Module name suffix (e.g., "sse", "avx2", "avx512")
814 name: &'static str,
815 /// The namespace import path
816 namespace: &'static str,
817 /// Token type name
818 token: &'static str,
819 /// Whether this requires a feature flag
820 feature: Option<&'static str>,
821 /// Target features to enable for this width
822 target_features: &'static [&'static str],
823}
824
825/// Width configuration for x86_64 targets
826const X86_WIDTH_CONFIGS: &[WidthConfig] = &[
827 WidthConfig {
828 name: "sse",
829 namespace: "magetypes::simd::sse",
830 token: "archmage::X64V3Token",
831 feature: None,
832 target_features: &["avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt"],
833 },
834 WidthConfig {
835 name: "avx2",
836 namespace: "magetypes::simd::avx2",
837 token: "archmage::X64V3Token",
838 feature: None,
839 target_features: &["avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt"],
840 },
841 WidthConfig {
842 name: "avx512",
843 namespace: "magetypes::simd::avx512",
844 token: "archmage::X64V4Token",
845 feature: Some("avx512"),
846 target_features: &["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"],
847 },
848];
849
850/// Width configuration for wasm32 targets
851const WASM_WIDTH_CONFIGS: &[WidthConfig] = &[WidthConfig {
852 name: "simd128",
853 namespace: "magetypes::simd::simd128",
854 token: "archmage::Simd128Token",
855 feature: None,
856 target_features: &["simd128"],
857}];
858
859/// Width configuration for aarch64 targets
860const ARM_WIDTH_CONFIGS: &[WidthConfig] = &[WidthConfig {
861 name: "neon",
862 namespace: "magetypes::simd::neon",
863 token: "archmage::NeonToken",
864 feature: None,
865 target_features: &["neon"],
866}];
867
868/// Generate width-specialized SIMD code.
869///
870/// This macro takes a module containing width-agnostic SIMD code and generates
871/// specialized versions for each target width (SSE, AVX2, AVX-512).
872///
873/// # Usage
874///
875/// ```ignore
876/// use archmage::multiwidth;
877///
878/// #[multiwidth]
879/// mod kernels {
880/// // Inside this module, these types are available:
881/// // - f32xN, i32xN, etc. (width-appropriate SIMD types)
882/// // - Token (the token type: X64V3Token for SSE/AVX2, or X64V4Token for AVX-512)
883/// // - LANES_F32, LANES_32, etc. (lane count constants)
884///
885/// use archmage::simd::*;
886///
887/// pub fn normalize(token: Token, data: &mut [f32]) {
888/// for chunk in data.chunks_exact_mut(LANES_F32) {
889/// let v = f32xN::load(token, chunk.try_into().unwrap());
890/// let result = v * f32xN::splat(token, 1.0 / 255.0);
891/// result.store(chunk.try_into().unwrap());
892/// }
893/// }
894/// }
895///
896/// // Generated modules:
897/// // - kernels::sse::normalize(token: X64V3Token, data: &mut [f32])
898/// // - kernels::avx2::normalize(token: X64V3Token, data: &mut [f32])
899/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32]) // if avx512 feature
900/// // - kernels::normalize(data: &mut [f32]) // runtime dispatcher
901/// ```
902///
903/// # Selective Targets
904///
905/// You can specify which targets to generate:
906///
907/// ```ignore
908/// #[multiwidth(avx2, avx512)] // Only AVX2 and AVX-512, no SSE
909/// mod fast_kernels { ... }
910/// ```
911///
912/// # How It Works
913///
914/// 1. The macro duplicates the module content for each width target
915/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
916/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
917/// 4. A dispatcher function is generated that picks the best available at runtime
918///
919/// # Requirements
920///
921/// - Functions should use `Token` as their token parameter type
922/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
923/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
924#[proc_macro_attribute]
925pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
926 let args = parse_macro_input!(attr as MultiwidthArgs);
927 let input_mod = parse_macro_input!(item as ItemMod);
928
929 multiwidth_impl(input_mod, args)
930}
931
932/// Configuration with target arch for conditional compilation
933struct ArchConfig<'a> {
934 config: &'a WidthConfig,
935 target_arch: Option<&'static str>,
936}
937
938fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
939 let mod_name = &input_mod.ident;
940 let mod_vis = &input_mod.vis;
941 let mod_attrs = &input_mod.attrs;
942
943 // Get module content
944 let content = match &input_mod.content {
945 Some((_, items)) => items,
946 None => {
947 return syn::Error::new_spanned(
948 &input_mod,
949 "multiwidth requires an inline module (mod name { ... }), not a file module",
950 )
951 .to_compile_error()
952 .into();
953 }
954 };
955
956 // Build list of all enabled configs across architectures
957 let mut all_configs: Vec<ArchConfig> = Vec::new();
958
959 // x86_64 configs
960 for config in X86_WIDTH_CONFIGS {
961 let enabled = match config.name {
962 "sse" => args.sse,
963 "avx2" => args.avx2,
964 "avx512" => args.avx512,
965 _ => false,
966 };
967 if enabled {
968 all_configs.push(ArchConfig {
969 config,
970 target_arch: Some("x86_64"),
971 });
972 }
973 }
974
975 // WASM configs
976 if args.wasm {
977 for config in WASM_WIDTH_CONFIGS {
978 all_configs.push(ArchConfig {
979 config,
980 target_arch: Some("wasm32"),
981 });
982 }
983 }
984
985 // ARM configs
986 if args.neon {
987 for config in ARM_WIDTH_CONFIGS {
988 all_configs.push(ArchConfig {
989 config,
990 target_arch: Some("aarch64"),
991 });
992 }
993 }
994
995 // Build specialized modules
996 let mut specialized_mods = Vec::new();
997 let mut enabled_configs = Vec::new();
998
999 for arch_config in &all_configs {
1000 let config = arch_config.config;
1001 enabled_configs.push(config);
1002
1003 let width_mod_name = format_ident!("{}", config.name);
1004 let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
1005
1006 // Transform the content: replace `use archmage::simd::*` with width-specific import
1007 // and add target_feature attributes for optimization
1008 let transformed_items: Vec<syn::Item> = content
1009 .iter()
1010 .map(|item| transform_item_for_width(item.clone(), &namespace, config))
1011 .collect();
1012
1013 // Build cfg attributes for target arch and optional feature
1014 let arch_attr = arch_config
1015 .target_arch
1016 .map(|arch| quote!(#[cfg(target_arch = #arch)]));
1017
1018 let feature_attr = config.feature.map(|f| {
1019 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1020 quote!(#[cfg(feature = #f_lit)])
1021 });
1022
1023 specialized_mods.push(quote! {
1024 #arch_attr
1025 #feature_attr
1026 pub mod #width_mod_name {
1027 #(#transformed_items)*
1028 }
1029 });
1030 }
1031
1032 // Generate dispatcher functions for each public function in the module
1033 // The dispatcher is x86_64-specific (runtime feature detection)
1034 // For WASM and ARM, features are compile-time only
1035 let x86_configs: Vec<_> = all_configs
1036 .iter()
1037 .filter(|c| c.target_arch == Some("x86_64"))
1038 .map(|c| c.config)
1039 .collect();
1040
1041 // Only generate dispatcher section if we have x86 configs
1042 let dispatcher_section = if !x86_configs.is_empty() {
1043 let dispatchers = generate_dispatchers(content, &x86_configs);
1044 quote! {
1045 // Runtime dispatcher (x86_64 only - WASM/ARM use compile-time features)
1046 #[cfg(target_arch = "x86_64")]
1047 mod __dispatchers {
1048 use super::*;
1049 #dispatchers
1050 }
1051 #[cfg(target_arch = "x86_64")]
1052 pub use __dispatchers::*;
1053 }
1054 } else {
1055 quote! {}
1056 };
1057
1058 let expanded = quote! {
1059 #(#mod_attrs)*
1060 #mod_vis mod #mod_name {
1061 #(#specialized_mods)*
1062
1063 #dispatcher_section
1064 }
1065 };
1066
1067 expanded.into()
1068}
1069
1070/// Transform a single item for a specific width namespace.
1071fn transform_item_for_width(
1072 item: syn::Item,
1073 namespace: &syn::Path,
1074 config: &WidthConfig,
1075) -> syn::Item {
1076 match item {
1077 syn::Item::Use(mut use_item) => {
1078 // Check if this is `use archmage::simd::*` or similar
1079 if is_simd_wildcard_use(&use_item) {
1080 // Replace with width-specific import
1081 use_item.tree = syn::UseTree::Path(syn::UsePath {
1082 ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
1083 colon2_token: Default::default(),
1084 tree: Box::new(build_use_tree_from_path(namespace, 1)),
1085 });
1086 }
1087 syn::Item::Use(use_item)
1088 }
1089 syn::Item::Fn(func) => {
1090 // Transform function to use inner function pattern with target_feature
1091 // This is the same pattern as #[arcane], enabling SIMD optimization
1092 // without requiring -C target-cpu=native
1093 transform_fn_with_target_feature(func, config)
1094 }
1095 other => other,
1096 }
1097}
1098
1099/// Transform a function to use the inner function pattern with target_feature.
1100/// This generates:
1101/// ```ignore
1102/// pub fn example(token: Token, data: &[f32]) -> f32 {
1103/// #[target_feature(enable = "avx2", enable = "fma")]
1104/// #[inline]
1105/// unsafe fn inner(token: Token, data: &[f32]) -> f32 {
1106/// // original body
1107/// }
1108/// // SAFETY: Token proves CPU support
1109/// unsafe { inner(token, data) }
1110/// }
1111/// ```
1112fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
1113 let vis = &func.vis;
1114 let sig = &func.sig;
1115 let fn_name = &sig.ident;
1116 let generics = &sig.generics;
1117 let where_clause = &generics.where_clause;
1118 let inputs = &sig.inputs;
1119 let output = &sig.output;
1120 let body = &func.block;
1121 let attrs = &func.attrs;
1122
1123 // Build target_feature attributes
1124 let target_feature_attrs: Vec<syn::Attribute> = config
1125 .target_features
1126 .iter()
1127 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
1128 .collect();
1129
1130 // Build parameter list for inner function
1131 let inner_params: Vec<proc_macro2::TokenStream> =
1132 inputs.iter().map(|arg| quote!(#arg)).collect();
1133
1134 // Build argument list for calling inner function
1135 let call_args: Vec<proc_macro2::TokenStream> = inputs
1136 .iter()
1137 .filter_map(|arg| match arg {
1138 syn::FnArg::Typed(pat_type) => {
1139 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1140 let ident = &pat_ident.ident;
1141 Some(quote!(#ident))
1142 } else {
1143 None
1144 }
1145 }
1146 syn::FnArg::Receiver(_) => Some(quote!(self)),
1147 })
1148 .collect();
1149
1150 let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
1151
1152 let expanded = quote! {
1153 #(#attrs)*
1154 #vis #sig {
1155 #(#target_feature_attrs)*
1156 #[inline]
1157 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
1158 #body
1159
1160 // SAFETY: The Token parameter proves the required CPU features are available.
1161 // Tokens can only be constructed via try_new() which checks CPU support.
1162 unsafe { #inner_fn_name(#(#call_args),*) }
1163 }
1164 };
1165
1166 syn::parse2(expanded).expect("Failed to parse transformed function")
1167}
1168
1169/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
1170fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
1171 fn check_tree(tree: &syn::UseTree) -> bool {
1172 match tree {
1173 syn::UseTree::Path(path) => {
1174 let ident = path.ident.to_string();
1175 if ident == "archmage" || ident == "magetypes" || ident == "crate" {
1176 check_tree_for_simd(&path.tree)
1177 } else {
1178 false
1179 }
1180 }
1181 _ => false,
1182 }
1183 }
1184
1185 fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
1186 match tree {
1187 syn::UseTree::Path(path) => {
1188 if path.ident == "simd" {
1189 matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
1190 } else {
1191 check_tree_for_simd(&path.tree)
1192 }
1193 }
1194 _ => false,
1195 }
1196 }
1197
1198 check_tree(&use_item.tree)
1199}
1200
1201/// Build a UseTree from a path, starting at a given segment index.
1202fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
1203 let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
1204
1205 if segments.is_empty() {
1206 syn::UseTree::Glob(syn::UseGlob {
1207 star_token: Default::default(),
1208 })
1209 } else if segments.len() == 1 {
1210 syn::UseTree::Path(syn::UsePath {
1211 ident: segments[0].ident.clone(),
1212 colon2_token: Default::default(),
1213 tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1214 star_token: Default::default(),
1215 })),
1216 })
1217 } else {
1218 let first = &segments[0];
1219 let rest_path = syn::Path {
1220 leading_colon: None,
1221 segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1222 };
1223 syn::UseTree::Path(syn::UsePath {
1224 ident: first.ident.clone(),
1225 colon2_token: Default::default(),
1226 tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1227 })
1228 }
1229}
1230
1231/// Width-specific type names that can't be used in dispatcher signatures.
1232const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1233 "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1234];
1235
1236/// Check if a type string contains width-specific types.
1237fn contains_width_specific_type(ty_str: &str) -> bool {
1238 WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1239}
1240
1241/// Check if a function signature uses width-specific types (can't have a dispatcher).
1242fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1243 // Check return type
1244 if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1245 let ty_str = quote!(#ty).to_string();
1246 if contains_width_specific_type(&ty_str) {
1247 return true;
1248 }
1249 }
1250
1251 // Check parameters (excluding Token which we filter out anyway)
1252 for arg in &func.sig.inputs {
1253 if let syn::FnArg::Typed(pat_type) = arg {
1254 let ty = &pat_type.ty;
1255 let ty_str = quote!(#ty).to_string();
1256 // Skip Token parameters - they're filtered out for dispatchers
1257 if ty_str.contains("Token") {
1258 continue;
1259 }
1260 if contains_width_specific_type(&ty_str) {
1261 return true;
1262 }
1263 }
1264 }
1265
1266 false
1267}
1268
1269/// Generate runtime dispatcher functions for public functions.
1270///
1271/// Note: Dispatchers are only generated for functions that don't use width-specific
1272/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1273/// width-specific types can only be called via the width-specific submodules.
1274fn generate_dispatchers(
1275 content: &[syn::Item],
1276 configs: &[&WidthConfig],
1277) -> proc_macro2::TokenStream {
1278 let mut dispatchers = Vec::new();
1279
1280 for item in content {
1281 if let syn::Item::Fn(func) = item {
1282 // Only generate dispatchers for public functions
1283 if !matches!(func.vis, syn::Visibility::Public(_)) {
1284 continue;
1285 }
1286
1287 // Skip functions that use width-specific types - they can't have dispatchers
1288 if uses_width_specific_types(func) {
1289 continue;
1290 }
1291
1292 let fn_name = &func.sig.ident;
1293 let fn_generics = &func.sig.generics;
1294 let fn_output = &func.sig.output;
1295 let fn_attrs: Vec<_> = func
1296 .attrs
1297 .iter()
1298 .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1299 .collect();
1300
1301 // Filter out the token parameter from the dispatcher signature
1302 let non_token_params: Vec<_> = func
1303 .sig
1304 .inputs
1305 .iter()
1306 .filter(|arg| {
1307 match arg {
1308 syn::FnArg::Typed(pat_type) => {
1309 // Check if type contains "Token"
1310 let ty = &pat_type.ty;
1311 let ty_str = quote!(#ty).to_string();
1312 !ty_str.contains("Token")
1313 }
1314 _ => true,
1315 }
1316 })
1317 .collect();
1318
1319 // Extract just the parameter names for passing to specialized functions
1320 let param_names: Vec<_> = non_token_params
1321 .iter()
1322 .filter_map(|arg| match arg {
1323 syn::FnArg::Typed(pat_type) => {
1324 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1325 Some(&pat_ident.ident)
1326 } else {
1327 None
1328 }
1329 }
1330 _ => None,
1331 })
1332 .collect();
1333
1334 // Generate dispatch branches (highest capability first)
1335 let mut branches = Vec::new();
1336
1337 for config in configs.iter().rev() {
1338 let mod_name = format_ident!("{}", config.name);
1339 let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1340
1341 let feature_check = config.feature.map(|f| {
1342 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1343 quote!(#[cfg(feature = #f_lit)])
1344 });
1345
1346 branches.push(quote! {
1347 #feature_check
1348 if let Some(token) = #token_path::try_new() {
1349 return #mod_name::#fn_name(token, #(#param_names),*);
1350 }
1351 });
1352 }
1353
1354 // Generate dispatcher
1355 dispatchers.push(quote! {
1356 #(#fn_attrs)*
1357 /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1358 pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1359 use archmage::SimdToken;
1360
1361 #(#branches)*
1362
1363 // Fallback: panic if no SIMD available
1364 // TODO: Allow user-provided scalar fallback
1365 panic!("No SIMD support available for {}", stringify!(#fn_name));
1366 }
1367 });
1368 }
1369 }
1370
1371 quote! { #(#dispatchers)* }
1372}