archmage_macros/
lib.rs

1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    fold::Fold,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12    ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17    replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21    fn fold_type(&mut self, ty: Type) -> Type {
22        match ty {
23            Type::Path(ref type_path) if type_path.qself.is_none() => {
24                // Check if it's just `Self`
25                if type_path.path.is_ident("Self") {
26                    return self.replacement.clone();
27                }
28                // Otherwise continue folding
29                syn::fold::fold_type(self, ty)
30            }
31            _ => syn::fold::fold_type(self, ty),
32        }
33    }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39    /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40    /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41    inline_always: bool,
42    /// The concrete type to use for `self` receiver.
43    /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44    self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48    fn parse(input: ParseStream) -> syn::Result<Self> {
49        let mut args = ArcaneArgs::default();
50
51        while !input.is_empty() {
52            let ident: Ident = input.parse()?;
53            match ident.to_string().as_str() {
54                "inline_always" => args.inline_always = true,
55                "_self" => {
56                    let _: Token![=] = input.parse()?;
57                    args.self_type = Some(input.parse()?);
58                }
59                other => {
60                    return Err(syn::Error::new(
61                        ident.span(),
62                        format!("unknown arcane argument: `{}`", other),
63                    ))
64                }
65            }
66            // Consume optional comma
67            if input.peek(Token![,]) {
68                let _: Token![,] = input.parse()?;
69            }
70        }
71
72        Ok(args)
73    }
74}
75
76/// Maps a token type name to its required target features.
77///
78/// Based on LLVM x86-64 microarchitecture levels (psABI).
79fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
80    match token_name {
81        // x86_64 feature tokens (kept for backwards compatibility)
82        "Sse41Token" => Some(&["sse4.1"]),
83        "Sse42Token" => Some(&["sse4.2"]),
84        "AvxToken" => Some(&["avx"]),
85        "Avx2Token" => Some(&["avx2"]),
86        "FmaToken" => Some(&["fma"]),
87        "Avx2FmaToken" => Some(&["avx2", "fma"]),
88        "Avx512fToken" => Some(&["avx512f"]),
89        "Avx512bwToken" => Some(&["avx512bw"]),
90
91        // x86_64 tier tokens
92        "X64V2Token" => Some(&["sse4.2", "popcnt"]),
93        "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
94        "X64V4Token" | "Avx512Token" | "Server64" => {
95            Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
96        }
97        "Avx512ModernToken" => Some(&[
98            "avx512f",
99            "avx512bw",
100            "avx512cd",
101            "avx512dq",
102            "avx512vl",
103            "avx512vpopcntdq",
104            "avx512ifma",
105            "avx512vbmi",
106            "avx512vbmi2",
107            "avx512bitalg",
108            "avx512vnni",
109            "avx512bf16",
110            "vpclmulqdq",
111            "gfni",
112            "vaes",
113        ]),
114        "Avx512Fp16Token" => Some(&[
115            "avx512f",
116            "avx512bw",
117            "avx512cd",
118            "avx512dq",
119            "avx512vl",
120            "avx512fp16",
121        ]),
122
123        // AArch64 tokens
124        "NeonToken" | "Arm64" => Some(&["neon"]),
125        "NeonAesToken" => Some(&["neon", "aes"]),
126        "NeonSha3Token" => Some(&["neon", "sha3"]),
127        "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
128        "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
129
130        // WASM tokens
131        "Simd128Token" => Some(&["simd128"]),
132
133        _ => None,
134    }
135}
136
137/// Maps a trait bound name to its required target features.
138///
139/// IMPORTANT: Each entry must include ALL implied features, not just the defining ones.
140/// The compiler needs explicit #[target_feature] for each feature used.
141///
142/// Based on LLVM x86-64 microarchitecture levels (psABI).
143fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
144    match trait_name {
145        // x86 tier traits - each includes all features from lower tiers
146        "HasX64V2" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
147        "HasX64V4" => Some(&[
148            // v2 features
149            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", // v3 features
150            "avx", "avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt", // v4 features
151            "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
152        ]),
153
154        // x86 token types - when used directly as bounds
155        "X64V2Token" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
156        "X64V3Token" | "Desktop64" | "Avx2FmaToken" => Some(&[
157            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
158            "f16c", "lzcnt",
159        ]),
160        "X64V4Token" | "Avx512Token" | "Server64" => Some(&[
161            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
162            "f16c", "lzcnt", "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
163        ]),
164        "Avx512ModernToken" => Some(&[
165            "sse3",
166            "ssse3",
167            "sse4.1",
168            "sse4.2",
169            "popcnt",
170            "avx",
171            "avx2",
172            "fma",
173            "bmi1",
174            "bmi2",
175            "f16c",
176            "lzcnt",
177            "avx512f",
178            "avx512bw",
179            "avx512cd",
180            "avx512dq",
181            "avx512vl",
182            "avx512vpopcntdq",
183            "avx512ifma",
184            "avx512vbmi",
185            "avx512vbmi2",
186            "avx512bitalg",
187            "avx512vnni",
188            "avx512bf16",
189            "vpclmulqdq",
190            "gfni",
191            "vaes",
192        ]),
193        "Avx512Fp16Token" => Some(&[
194            "sse3",
195            "ssse3",
196            "sse4.1",
197            "sse4.2",
198            "popcnt",
199            "avx",
200            "avx2",
201            "fma",
202            "bmi1",
203            "bmi2",
204            "f16c",
205            "lzcnt",
206            "avx512f",
207            "avx512bw",
208            "avx512cd",
209            "avx512dq",
210            "avx512vl",
211            "avx512fp16",
212        ]),
213
214        // Width traits - minimal features to satisfy width
215        "Has128BitSimd" => Some(&["sse2"]),
216        "Has256BitSimd" => Some(&["avx"]),
217        "Has512BitSimd" => Some(&["avx512f"]),
218
219        // AArch64 traits
220        "HasNeon" => Some(&["neon"]),
221        "HasNeonAes" => Some(&["neon", "aes"]),
222        "HasNeonSha3" => Some(&["neon", "sha3"]),
223
224        // AArch64 token types
225        "NeonToken" | "Arm64" => Some(&["neon"]),
226        "NeonAesToken" => Some(&["neon", "aes"]),
227        "NeonSha3Token" => Some(&["neon", "sha3"]),
228        "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
229        "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
230
231        _ => None,
232    }
233}
234
235/// Result of extracting token info from a type.
236enum TokenTypeInfo {
237    /// Concrete token type (e.g., `Avx2Token`)
238    Concrete(String),
239    /// impl Trait with the trait names (e.g., `impl HasAvx2`)
240    ImplTrait(Vec<String>),
241    /// Generic type parameter name (e.g., `T`)
242    Generic(String),
243}
244
245/// Extract token type information from a type.
246fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
247    match ty {
248        Type::Path(type_path) => {
249            // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
250            type_path.path.segments.last().map(|seg| {
251                let name = seg.ident.to_string();
252                // Check if it's a known concrete token type
253                if token_to_features(&name).is_some() {
254                    TokenTypeInfo::Concrete(name)
255                } else {
256                    // Might be a generic type parameter like `T`
257                    TokenTypeInfo::Generic(name)
258                }
259            })
260        }
261        Type::Reference(type_ref) => {
262            // Handle &Token or &mut Token
263            extract_token_type_info(&type_ref.elem)
264        }
265        Type::ImplTrait(impl_trait) => {
266            // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
267            let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
268            if traits.is_empty() {
269                None
270            } else {
271                Some(TokenTypeInfo::ImplTrait(traits))
272            }
273        }
274        _ => None,
275    }
276}
277
278/// Extract trait names from type param bounds.
279fn extract_trait_names_from_bounds(
280    bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
281) -> Vec<String> {
282    bounds
283        .iter()
284        .filter_map(|bound| {
285            if let TypeParamBound::Trait(trait_bound) = bound {
286                trait_bound
287                    .path
288                    .segments
289                    .last()
290                    .map(|seg| seg.ident.to_string())
291            } else {
292                None
293            }
294        })
295        .collect()
296}
297
298/// Look up a generic type parameter in the function's generics.
299fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
300    // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
301    for param in &sig.generics.params {
302        if let GenericParam::Type(type_param) = param {
303            if type_param.ident == type_name {
304                let traits = extract_trait_names_from_bounds(&type_param.bounds);
305                if !traits.is_empty() {
306                    return Some(traits);
307                }
308            }
309        }
310    }
311
312    // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
313    if let Some(where_clause) = &sig.generics.where_clause {
314        for predicate in &where_clause.predicates {
315            if let syn::WherePredicate::Type(pred_type) = predicate {
316                if let Type::Path(type_path) = &pred_type.bounded_ty {
317                    if let Some(seg) = type_path.path.segments.last() {
318                        if seg.ident == type_name {
319                            let traits = extract_trait_names_from_bounds(&pred_type.bounds);
320                            if !traits.is_empty() {
321                                return Some(traits);
322                            }
323                        }
324                    }
325                }
326            }
327        }
328    }
329
330    None
331}
332
333/// Convert trait names to features, collecting all features from all traits.
334fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
335    let mut all_features = Vec::new();
336
337    for trait_name in trait_names {
338        if let Some(features) = trait_to_features(trait_name) {
339            for &feature in features {
340                if !all_features.contains(&feature) {
341                    all_features.push(feature);
342                }
343            }
344        }
345    }
346
347    if all_features.is_empty() {
348        None
349    } else {
350        Some(all_features)
351    }
352}
353
354/// Find the first token parameter and return its name and features.
355fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
356    for arg in &sig.inputs {
357        match arg {
358            FnArg::Receiver(_) => {
359                // Self receivers (self, &self, &mut self) are not yet supported.
360                // The macro creates an inner function, and Rust's inner functions
361                // cannot have `self` parameters. Supporting this would require
362                // AST rewriting to replace `self` with a regular parameter.
363                // See the module docs for the workaround.
364                continue;
365            }
366            FnArg::Typed(PatType { pat, ty, .. }) => {
367                if let Some(info) = extract_token_type_info(ty) {
368                    let features = match info {
369                        TokenTypeInfo::Concrete(name) => {
370                            token_to_features(&name).map(|f| f.to_vec())
371                        }
372                        TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
373                        TokenTypeInfo::Generic(type_name) => {
374                            // Look up the generic parameter's bounds
375                            find_generic_bounds(sig, &type_name)
376                                .and_then(|traits| traits_to_features(&traits))
377                        }
378                    };
379
380                    if let Some(features) = features {
381                        // Extract parameter name
382                        if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
383                            return Some((pat_ident.ident.clone(), features));
384                        }
385                    }
386                }
387            }
388        }
389    }
390    None
391}
392
393/// Represents the kind of self receiver and the transformed parameter.
394enum SelfReceiver {
395    /// `self` (by value/move)
396    Owned,
397    /// `&self` (shared reference)
398    Ref,
399    /// `&mut self` (mutable reference)
400    RefMut,
401}
402
403/// Shared implementation for arcane/simd_fn macros.
404fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
405    // Check for self receiver
406    let has_self_receiver = input_fn
407        .sig
408        .inputs
409        .first()
410        .map(|arg| matches!(arg, FnArg::Receiver(_)))
411        .unwrap_or(false);
412
413    // If there's a self receiver, we need _self = Type
414    if has_self_receiver && args.self_type.is_none() {
415        let msg = format!(
416            "{} with self receiver requires `_self = Type` argument.\n\
417             Example: #[{}(_self = MyType)]\n\
418             Use `_self` (not `self`) in the function body to refer to self.",
419            macro_name, macro_name
420        );
421        return syn::Error::new_spanned(&input_fn.sig, msg)
422            .to_compile_error()
423            .into();
424    }
425
426    // Find the token parameter and its features
427    let (_token_ident, features) = match find_token_param(&input_fn.sig) {
428        Some(result) => result,
429        None => {
430            let msg = format!(
431                "{} requires a token parameter. Supported forms:\n\
432                 - Concrete: `token: Avx2Token`\n\
433                 - impl Trait: `token: impl HasAvx2`\n\
434                 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
435                 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasAvx2, ...)`",
436                macro_name, macro_name
437            );
438            return syn::Error::new_spanned(&input_fn.sig, msg)
439                .to_compile_error()
440                .into();
441        }
442    };
443
444    // Build target_feature attributes
445    let target_feature_attrs: Vec<Attribute> = features
446        .iter()
447        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
448        .collect();
449
450    // Extract function components
451    let vis = &input_fn.vis;
452    let sig = &input_fn.sig;
453    let fn_name = &sig.ident;
454    let generics = &sig.generics;
455    let where_clause = &generics.where_clause;
456    let inputs = &sig.inputs;
457    let output = &sig.output;
458    let body = &input_fn.block;
459    let attrs = &input_fn.attrs;
460
461    // Determine self receiver type if present
462    let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
463        FnArg::Receiver(receiver) => {
464            if receiver.reference.is_none() {
465                Some(SelfReceiver::Owned)
466            } else if receiver.mutability.is_some() {
467                Some(SelfReceiver::RefMut)
468            } else {
469                Some(SelfReceiver::Ref)
470            }
471        }
472        _ => None,
473    });
474
475    // Build inner function parameters, transforming self if needed
476    let inner_params: Vec<proc_macro2::TokenStream> = inputs
477        .iter()
478        .map(|arg| match arg {
479            FnArg::Receiver(_) => {
480                // Transform self receiver to _self parameter
481                let self_ty = args.self_type.as_ref().unwrap();
482                match self_receiver_kind.as_ref().unwrap() {
483                    SelfReceiver::Owned => quote!(_self: #self_ty),
484                    SelfReceiver::Ref => quote!(_self: &#self_ty),
485                    SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
486                }
487            }
488            FnArg::Typed(pat_type) => quote!(#pat_type),
489        })
490        .collect();
491
492    // Build inner function call arguments
493    let inner_args: Vec<proc_macro2::TokenStream> = inputs
494        .iter()
495        .filter_map(|arg| match arg {
496            FnArg::Typed(pat_type) => {
497                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
498                    let ident = &pat_ident.ident;
499                    Some(quote!(#ident))
500                } else {
501                    None
502                }
503            }
504            FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
505        })
506        .collect();
507
508    let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
509
510    // Choose inline attribute based on args
511    // Note: #[inline(always)] + #[target_feature] requires nightly with
512    // #![feature(target_feature_inline_always)]
513    let inline_attr: Attribute = if args.inline_always {
514        parse_quote!(#[inline(always)])
515    } else {
516        parse_quote!(#[inline])
517    };
518
519    // Transform output and body to replace Self with concrete type if needed
520    let (inner_output, inner_body): (ReturnType, syn::Block) =
521        if let Some(ref self_ty) = args.self_type {
522            let mut replacer = ReplaceSelf {
523                replacement: self_ty,
524            };
525            let transformed_output = replacer.fold_return_type(output.clone());
526            let transformed_body = replacer.fold_block((**body).clone());
527            (transformed_output, transformed_body)
528        } else {
529            (output.clone(), (**body).clone())
530        };
531
532    // Generate the expanded function
533    let expanded = quote! {
534        #(#attrs)*
535        #vis #sig {
536            #(#target_feature_attrs)*
537            #inline_attr
538            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
539            #inner_body
540
541            // SAFETY: The token parameter proves the required CPU features are available.
542            // Tokens can only be constructed when features are verified (via try_new()
543            // runtime check or forge_token_dangerously() in a context where features are guaranteed).
544            unsafe { #inner_fn_name(#(#inner_args),*) }
545        }
546    };
547
548    expanded.into()
549}
550
551/// Mark a function as an arcane SIMD function.
552///
553/// This macro enables safe use of SIMD intrinsics by generating an inner function
554/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
555/// the token parameter type. The outer function calls the inner function unsafely,
556/// which is justified because the token parameter proves the features are available.
557///
558/// **The token is passed through to the inner function**, so you can call other
559/// token-taking functions from inside `#[arcane]`.
560///
561/// # Token Parameter Forms
562///
563/// The macro supports four forms of token parameters:
564///
565/// ## Concrete Token Types
566///
567/// ```ignore
568/// #[arcane]
569/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
570///     // AVX2 intrinsics safe here
571/// }
572/// ```
573///
574/// ## impl Trait Bounds
575///
576/// ```ignore
577/// #[arcane]
578/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
579///     // Accepts any token that provides AVX2
580/// }
581/// ```
582///
583/// ## Generic Type Parameters
584///
585/// ```ignore
586/// #[arcane]
587/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
588///     // Generic over any AVX2-capable token
589/// }
590///
591/// // Also works with where clauses:
592/// #[arcane]
593/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
594/// where
595///     T: HasAvx2
596/// {
597///     // ...
598/// }
599/// ```
600///
601/// ## Methods with Self Receivers
602///
603/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
604/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
605///
606/// ```ignore
607/// use archmage::{HasAvx2, arcane};
608/// use wide::f32x8;
609///
610/// trait Avx2Ops {
611///     fn double(&self, token: impl HasAvx2) -> Self;
612///     fn square(self, token: impl HasAvx2) -> Self;
613///     fn scale(&mut self, token: impl HasAvx2, factor: f32);
614/// }
615///
616/// impl Avx2Ops for f32x8 {
617///     #[arcane(_self = f32x8)]
618///     fn double(&self, _token: impl HasAvx2) -> Self {
619///         // Use _self instead of self in the body
620///         *_self + *_self
621///     }
622///
623///     #[arcane(_self = f32x8)]
624///     fn square(self, _token: impl HasAvx2) -> Self {
625///         _self * _self
626///     }
627///
628///     #[arcane(_self = f32x8)]
629///     fn scale(&mut self, _token: impl HasAvx2, factor: f32) {
630///         *_self = *_self * f32x8::splat(factor);
631///     }
632/// }
633/// ```
634///
635/// **Why `_self`?** The macro generates an inner function where `self` becomes
636/// a regular parameter named `_self`. Using `_self` in your code reminds you
637/// that you're not using the normal `self` keyword.
638///
639/// **All receiver types are supported:**
640/// - `self` (by value/move) → `_self: Type`
641/// - `&self` (shared reference) → `_self: &Type`
642/// - `&mut self` (mutable reference) → `_self: &mut Type`
643///
644/// # Multiple Trait Bounds
645///
646/// When using `impl Trait` or generic bounds with multiple traits,
647/// all required features are enabled:
648///
649/// ```ignore
650/// #[arcane]
651/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
652///     // Both AVX2 and FMA intrinsics are safe here
653/// }
654/// ```
655///
656/// # Expansion
657///
658/// The macro expands to approximately:
659///
660/// ```ignore
661/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
662///     #[target_feature(enable = "avx2")]
663///     #[inline]
664///     unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
665///         let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
666///         let doubled = _mm256_add_ps(v, v);
667///         let mut out = [0.0f32; 8];
668///         unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
669///         out
670///     }
671///     // SAFETY: Token proves the required features are available
672///     unsafe { __simd_inner_process(token, data) }
673/// }
674/// ```
675///
676/// # Profile Tokens
677///
678/// Profile tokens automatically enable all required features:
679///
680/// ```ignore
681/// #[arcane]
682/// fn kernel(token: X64V3Token, data: &mut [f32]) {
683///     // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
684/// }
685/// ```
686///
687/// # Supported Tokens
688///
689/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
690///   `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
691/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
692/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
693/// - **WASM**: `Simd128Token`
694///
695/// # Supported Trait Bounds
696///
697/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
698///   `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
699/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
700/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
701///
702/// # Options
703///
704/// ## `inline_always`
705///
706/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
707/// This can improve performance by ensuring aggressive inlining, but requires
708/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
709/// the crate using the macro.
710///
711/// ```ignore
712/// #![feature(target_feature_inline_always)]
713///
714/// #[arcane(inline_always)]
715/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
716///     // Inner function will use #[inline(always)]
717/// }
718/// ```
719#[proc_macro_attribute]
720pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
721    let args = parse_macro_input!(attr as ArcaneArgs);
722    let input_fn = parse_macro_input!(item as ItemFn);
723    arcane_impl(input_fn, "arcane", args)
724}
725
726/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
727///
728/// See [`arcane`] for full documentation.
729#[proc_macro_attribute]
730pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
731    let args = parse_macro_input!(attr as ArcaneArgs);
732    let input_fn = parse_macro_input!(item as ItemFn);
733    arcane_impl(input_fn, "simd_fn", args)
734}
735
736// ============================================================================
737// Multiwidth macro for width-agnostic SIMD code
738// ============================================================================
739
740use syn::ItemMod;
741
742/// Arguments to the `#[multiwidth]` macro.
743struct MultiwidthArgs {
744    /// Include SSE (128-bit) specialization
745    sse: bool,
746    /// Include AVX2 (256-bit) specialization
747    avx2: bool,
748    /// Include AVX-512 (512-bit) specialization
749    avx512: bool,
750}
751
752impl Default for MultiwidthArgs {
753    fn default() -> Self {
754        Self {
755            sse: true,
756            avx2: true,
757            avx512: true,
758        }
759    }
760}
761
762impl Parse for MultiwidthArgs {
763    fn parse(input: ParseStream) -> syn::Result<Self> {
764        let mut args = MultiwidthArgs {
765            sse: false,
766            avx2: false,
767            avx512: false,
768        };
769
770        // If no args provided, enable all
771        if input.is_empty() {
772            return Ok(MultiwidthArgs::default());
773        }
774
775        while !input.is_empty() {
776            let ident: Ident = input.parse()?;
777            match ident.to_string().as_str() {
778                "sse" => args.sse = true,
779                "avx2" => args.avx2 = true,
780                "avx512" => args.avx512 = true,
781                other => {
782                    return Err(syn::Error::new(
783                        ident.span(),
784                        format!(
785                            "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512",
786                            other
787                        ),
788                    ))
789                }
790            }
791            // Consume optional comma
792            if input.peek(Token![,]) {
793                let _: Token![,] = input.parse()?;
794            }
795        }
796
797        Ok(args)
798    }
799}
800
801/// Width configuration for specialization.
802struct WidthConfig {
803    /// Module name suffix (e.g., "sse", "avx2", "avx512")
804    name: &'static str,
805    /// The namespace import path
806    namespace: &'static str,
807    /// Token type name
808    token: &'static str,
809    /// Whether this requires a feature flag
810    feature: Option<&'static str>,
811    /// Target features to enable for this width
812    target_features: &'static [&'static str],
813}
814
815const WIDTH_CONFIGS: &[WidthConfig] = &[
816    WidthConfig {
817        name: "sse",
818        namespace: "magetypes::simd::sse",
819        token: "archmage::Sse41Token",
820        feature: None,
821        target_features: &["sse4.1"],
822    },
823    WidthConfig {
824        name: "avx2",
825        namespace: "magetypes::simd::avx2",
826        token: "archmage::Avx2FmaToken",
827        feature: None,
828        target_features: &["avx2", "fma"],
829    },
830    WidthConfig {
831        name: "avx512",
832        namespace: "magetypes::simd::avx512",
833        token: "archmage::X64V4Token",
834        feature: Some("avx512"),
835        target_features: &["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"],
836    },
837];
838
839/// Generate width-specialized SIMD code.
840///
841/// This macro takes a module containing width-agnostic SIMD code and generates
842/// specialized versions for each target width (SSE, AVX2, AVX-512).
843///
844/// # Usage
845///
846/// ```ignore
847/// use archmage::multiwidth;
848///
849/// #[multiwidth]
850/// mod kernels {
851///     // Inside this module, these types are available:
852///     // - f32xN, i32xN, etc. (width-appropriate SIMD types)
853///     // - Token (the token type: Sse41Token, Avx2FmaToken, or X64V4Token)
854///     // - LANES_F32, LANES_32, etc. (lane count constants)
855///
856///     use archmage::simd::*;
857///
858///     pub fn normalize(token: Token, data: &mut [f32]) {
859///         for chunk in data.chunks_exact_mut(LANES_F32) {
860///             let v = f32xN::load(token, chunk.try_into().unwrap());
861///             let result = v * f32xN::splat(token, 1.0 / 255.0);
862///             result.store(chunk.try_into().unwrap());
863///         }
864///     }
865/// }
866///
867/// // Generated modules:
868/// // - kernels::sse::normalize(token: Sse41Token, data: &mut [f32])
869/// // - kernels::avx2::normalize(token: Avx2FmaToken, data: &mut [f32])
870/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32])  // if avx512 feature
871/// // - kernels::normalize(data: &mut [f32])  // runtime dispatcher
872/// ```
873///
874/// # Selective Targets
875///
876/// You can specify which targets to generate:
877///
878/// ```ignore
879/// #[multiwidth(avx2, avx512)]  // Only AVX2 and AVX-512, no SSE
880/// mod fast_kernels { ... }
881/// ```
882///
883/// # How It Works
884///
885/// 1. The macro duplicates the module content for each width target
886/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
887/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
888/// 4. A dispatcher function is generated that picks the best available at runtime
889///
890/// # Requirements
891///
892/// - Functions should use `Token` as their token parameter type
893/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
894/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
895#[proc_macro_attribute]
896pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
897    let args = parse_macro_input!(attr as MultiwidthArgs);
898    let input_mod = parse_macro_input!(item as ItemMod);
899
900    multiwidth_impl(input_mod, args)
901}
902
903fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
904    let mod_name = &input_mod.ident;
905    let mod_vis = &input_mod.vis;
906    let mod_attrs = &input_mod.attrs;
907
908    // Get module content
909    let content = match &input_mod.content {
910        Some((_, items)) => items,
911        None => {
912            return syn::Error::new_spanned(
913                &input_mod,
914                "multiwidth requires an inline module (mod name { ... }), not a file module",
915            )
916            .to_compile_error()
917            .into();
918        }
919    };
920
921    // Build specialized modules
922    let mut specialized_mods = Vec::new();
923    let mut enabled_configs = Vec::new();
924
925    for config in WIDTH_CONFIGS {
926        let enabled = match config.name {
927            "sse" => args.sse,
928            "avx2" => args.avx2,
929            "avx512" => args.avx512,
930            _ => false,
931        };
932
933        if !enabled {
934            continue;
935        }
936
937        enabled_configs.push(config);
938
939        let width_mod_name = format_ident!("{}", config.name);
940        let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
941
942        // Transform the content: replace `use archmage::simd::*` with width-specific import
943        // and add target_feature attributes for optimization
944        let transformed_items: Vec<syn::Item> = content
945            .iter()
946            .map(|item| transform_item_for_width(item.clone(), &namespace, config))
947            .collect();
948
949        let feature_attr = config.feature.map(|f| {
950            let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
951            quote!(#[cfg(feature = #f_lit)])
952        });
953
954        specialized_mods.push(quote! {
955            #feature_attr
956            pub mod #width_mod_name {
957                #(#transformed_items)*
958            }
959        });
960    }
961
962    // Generate dispatcher functions for each public function in the module
963    let dispatchers = generate_dispatchers(content, &enabled_configs);
964
965    let expanded = quote! {
966        #(#mod_attrs)*
967        #mod_vis mod #mod_name {
968            #(#specialized_mods)*
969
970            #dispatchers
971        }
972    };
973
974    expanded.into()
975}
976
977/// Transform a single item for a specific width namespace.
978fn transform_item_for_width(
979    item: syn::Item,
980    namespace: &syn::Path,
981    config: &WidthConfig,
982) -> syn::Item {
983    match item {
984        syn::Item::Use(mut use_item) => {
985            // Check if this is `use archmage::simd::*` or similar
986            if is_simd_wildcard_use(&use_item) {
987                // Replace with width-specific import
988                use_item.tree = syn::UseTree::Path(syn::UsePath {
989                    ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
990                    colon2_token: Default::default(),
991                    tree: Box::new(build_use_tree_from_path(namespace, 1)),
992                });
993            }
994            syn::Item::Use(use_item)
995        }
996        syn::Item::Fn(func) => {
997            // Transform function to use inner function pattern with target_feature
998            // This is the same pattern as #[arcane], enabling SIMD optimization
999            // without requiring -C target-cpu=native
1000            transform_fn_with_target_feature(func, config)
1001        }
1002        other => other,
1003    }
1004}
1005
1006/// Transform a function to use the inner function pattern with target_feature.
1007/// This generates:
1008/// ```ignore
1009/// pub fn example(token: Token, data: &[f32]) -> f32 {
1010///     #[target_feature(enable = "avx2", enable = "fma")]
1011///     #[inline]
1012///     unsafe fn inner(token: Token, data: &[f32]) -> f32 {
1013///         // original body
1014///     }
1015///     // SAFETY: Token proves CPU support
1016///     unsafe { inner(token, data) }
1017/// }
1018/// ```
1019fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
1020    let vis = &func.vis;
1021    let sig = &func.sig;
1022    let fn_name = &sig.ident;
1023    let generics = &sig.generics;
1024    let where_clause = &generics.where_clause;
1025    let inputs = &sig.inputs;
1026    let output = &sig.output;
1027    let body = &func.block;
1028    let attrs = &func.attrs;
1029
1030    // Build target_feature attributes
1031    let target_feature_attrs: Vec<syn::Attribute> = config
1032        .target_features
1033        .iter()
1034        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
1035        .collect();
1036
1037    // Build parameter list for inner function
1038    let inner_params: Vec<proc_macro2::TokenStream> =
1039        inputs.iter().map(|arg| quote!(#arg)).collect();
1040
1041    // Build argument list for calling inner function
1042    let call_args: Vec<proc_macro2::TokenStream> = inputs
1043        .iter()
1044        .filter_map(|arg| match arg {
1045            syn::FnArg::Typed(pat_type) => {
1046                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1047                    let ident = &pat_ident.ident;
1048                    Some(quote!(#ident))
1049                } else {
1050                    None
1051                }
1052            }
1053            syn::FnArg::Receiver(_) => Some(quote!(self)),
1054        })
1055        .collect();
1056
1057    let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
1058
1059    let expanded = quote! {
1060        #(#attrs)*
1061        #vis #sig {
1062            #(#target_feature_attrs)*
1063            #[inline]
1064            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
1065            #body
1066
1067            // SAFETY: The Token parameter proves the required CPU features are available.
1068            // Tokens can only be constructed via try_new() which checks CPU support.
1069            unsafe { #inner_fn_name(#(#call_args),*) }
1070        }
1071    };
1072
1073    syn::parse2(expanded).expect("Failed to parse transformed function")
1074}
1075
1076/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
1077fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
1078    fn check_tree(tree: &syn::UseTree) -> bool {
1079        match tree {
1080            syn::UseTree::Path(path) => {
1081                let ident = path.ident.to_string();
1082                if ident == "archmage" || ident == "magetypes" || ident == "crate" {
1083                    check_tree_for_simd(&path.tree)
1084                } else {
1085                    false
1086                }
1087            }
1088            _ => false,
1089        }
1090    }
1091
1092    fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
1093        match tree {
1094            syn::UseTree::Path(path) => {
1095                if path.ident == "simd" {
1096                    matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
1097                } else {
1098                    check_tree_for_simd(&path.tree)
1099                }
1100            }
1101            _ => false,
1102        }
1103    }
1104
1105    check_tree(&use_item.tree)
1106}
1107
1108/// Build a UseTree from a path, starting at a given segment index.
1109fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
1110    let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
1111
1112    if segments.is_empty() {
1113        syn::UseTree::Glob(syn::UseGlob {
1114            star_token: Default::default(),
1115        })
1116    } else if segments.len() == 1 {
1117        syn::UseTree::Path(syn::UsePath {
1118            ident: segments[0].ident.clone(),
1119            colon2_token: Default::default(),
1120            tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1121                star_token: Default::default(),
1122            })),
1123        })
1124    } else {
1125        let first = &segments[0];
1126        let rest_path = syn::Path {
1127            leading_colon: None,
1128            segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1129        };
1130        syn::UseTree::Path(syn::UsePath {
1131            ident: first.ident.clone(),
1132            colon2_token: Default::default(),
1133            tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1134        })
1135    }
1136}
1137
1138/// Width-specific type names that can't be used in dispatcher signatures.
1139const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1140    "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1141];
1142
1143/// Check if a type string contains width-specific types.
1144fn contains_width_specific_type(ty_str: &str) -> bool {
1145    WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1146}
1147
1148/// Check if a function signature uses width-specific types (can't have a dispatcher).
1149fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1150    // Check return type
1151    if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1152        let ty_str = quote!(#ty).to_string();
1153        if contains_width_specific_type(&ty_str) {
1154            return true;
1155        }
1156    }
1157
1158    // Check parameters (excluding Token which we filter out anyway)
1159    for arg in &func.sig.inputs {
1160        if let syn::FnArg::Typed(pat_type) = arg {
1161            let ty = &pat_type.ty;
1162            let ty_str = quote!(#ty).to_string();
1163            // Skip Token parameters - they're filtered out for dispatchers
1164            if ty_str.contains("Token") {
1165                continue;
1166            }
1167            if contains_width_specific_type(&ty_str) {
1168                return true;
1169            }
1170        }
1171    }
1172
1173    false
1174}
1175
1176/// Generate runtime dispatcher functions for public functions.
1177///
1178/// Note: Dispatchers are only generated for functions that don't use width-specific
1179/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1180/// width-specific types can only be called via the width-specific submodules.
1181fn generate_dispatchers(
1182    content: &[syn::Item],
1183    configs: &[&WidthConfig],
1184) -> proc_macro2::TokenStream {
1185    let mut dispatchers = Vec::new();
1186
1187    for item in content {
1188        if let syn::Item::Fn(func) = item {
1189            // Only generate dispatchers for public functions
1190            if !matches!(func.vis, syn::Visibility::Public(_)) {
1191                continue;
1192            }
1193
1194            // Skip functions that use width-specific types - they can't have dispatchers
1195            if uses_width_specific_types(func) {
1196                continue;
1197            }
1198
1199            let fn_name = &func.sig.ident;
1200            let fn_generics = &func.sig.generics;
1201            let fn_output = &func.sig.output;
1202            let fn_attrs: Vec<_> = func
1203                .attrs
1204                .iter()
1205                .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1206                .collect();
1207
1208            // Filter out the token parameter from the dispatcher signature
1209            let non_token_params: Vec<_> = func
1210                .sig
1211                .inputs
1212                .iter()
1213                .filter(|arg| {
1214                    match arg {
1215                        syn::FnArg::Typed(pat_type) => {
1216                            // Check if type contains "Token"
1217                            let ty = &pat_type.ty;
1218                            let ty_str = quote!(#ty).to_string();
1219                            !ty_str.contains("Token")
1220                        }
1221                        _ => true,
1222                    }
1223                })
1224                .collect();
1225
1226            // Extract just the parameter names for passing to specialized functions
1227            let param_names: Vec<_> = non_token_params
1228                .iter()
1229                .filter_map(|arg| match arg {
1230                    syn::FnArg::Typed(pat_type) => {
1231                        if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1232                            Some(&pat_ident.ident)
1233                        } else {
1234                            None
1235                        }
1236                    }
1237                    _ => None,
1238                })
1239                .collect();
1240
1241            // Generate dispatch branches (highest capability first)
1242            let mut branches = Vec::new();
1243
1244            for config in configs.iter().rev() {
1245                let mod_name = format_ident!("{}", config.name);
1246                let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1247
1248                let feature_check = config.feature.map(|f| {
1249                    let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1250                    quote!(#[cfg(feature = #f_lit)])
1251                });
1252
1253                branches.push(quote! {
1254                    #feature_check
1255                    if let Some(token) = #token_path::try_new() {
1256                        return #mod_name::#fn_name(token, #(#param_names),*);
1257                    }
1258                });
1259            }
1260
1261            // Generate dispatcher
1262            dispatchers.push(quote! {
1263                #(#fn_attrs)*
1264                /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1265                pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1266                    use archmage::SimdToken;
1267
1268                    #(#branches)*
1269
1270                    // Fallback: panic if no SIMD available
1271                    // TODO: Allow user-provided scalar fallback
1272                    panic!("No SIMD support available for {}", stringify!(#fn_name));
1273                }
1274            });
1275        }
1276    }
1277
1278    quote! { #(#dispatchers)* }
1279}