Skip to main content

archmage_macros/
lib.rs

1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    fold::Fold,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12    ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17    replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21    fn fold_type(&mut self, ty: Type) -> Type {
22        match ty {
23            Type::Path(ref type_path) if type_path.qself.is_none() => {
24                // Check if it's just `Self`
25                if type_path.path.is_ident("Self") {
26                    return self.replacement.clone();
27                }
28                // Otherwise continue folding
29                syn::fold::fold_type(self, ty)
30            }
31            _ => syn::fold::fold_type(self, ty),
32        }
33    }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39    /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40    /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41    inline_always: bool,
42    /// The concrete type to use for `self` receiver.
43    /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44    self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48    fn parse(input: ParseStream) -> syn::Result<Self> {
49        let mut args = ArcaneArgs::default();
50
51        while !input.is_empty() {
52            let ident: Ident = input.parse()?;
53            match ident.to_string().as_str() {
54                "inline_always" => args.inline_always = true,
55                "_self" => {
56                    let _: Token![=] = input.parse()?;
57                    args.self_type = Some(input.parse()?);
58                }
59                other => {
60                    return Err(syn::Error::new(
61                        ident.span(),
62                        format!("unknown arcane argument: `{}`", other),
63                    ))
64                }
65            }
66            // Consume optional comma
67            if input.peek(Token![,]) {
68                let _: Token![,] = input.parse()?;
69            }
70        }
71
72        Ok(args)
73    }
74}
75
76/// Maps a token type name to its required target features.
77///
78/// Based on LLVM x86-64 microarchitecture levels (psABI).
79fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
80    match token_name {
81        // x86_64 feature tokens (kept for backwards compatibility)
82        "Sse41Token" => Some(&["sse4.1"]),
83        "Sse42Token" => Some(&["sse4.2"]),
84        "AvxToken" => Some(&["avx"]),
85        "Avx2Token" => Some(&["avx2"]),
86        "FmaToken" => Some(&["fma"]),
87        "Avx2FmaToken" => Some(&["avx2", "fma"]),
88        "Avx512fToken" => Some(&["avx512f"]),
89        "Avx512bwToken" => Some(&["avx512bw"]),
90
91        // x86_64 tier tokens
92        "X64V2Token" => Some(&["sse4.2", "popcnt"]),
93        "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
94        "X64V4Token" | "Avx512Token" | "Server64" => {
95            Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
96        }
97        "Avx512ModernToken" => Some(&[
98            "avx512f",
99            "avx512bw",
100            "avx512cd",
101            "avx512dq",
102            "avx512vl",
103            "avx512vpopcntdq",
104            "avx512ifma",
105            "avx512vbmi",
106            "avx512vbmi2",
107            "avx512bitalg",
108            "avx512vnni",
109            "avx512bf16",
110            "vpclmulqdq",
111            "gfni",
112            "vaes",
113        ]),
114        "Avx512Fp16Token" => Some(&[
115            "avx512f",
116            "avx512bw",
117            "avx512cd",
118            "avx512dq",
119            "avx512vl",
120            "avx512fp16",
121        ]),
122
123        // AArch64 tokens
124        "NeonToken" | "Arm64" => Some(&["neon"]),
125        "NeonAesToken" => Some(&["neon", "aes"]),
126        "NeonSha3Token" => Some(&["neon", "sha3"]),
127        "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
128        "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
129
130        // WASM tokens
131        "Simd128Token" => Some(&["simd128"]),
132
133        _ => None,
134    }
135}
136
137/// Maps a trait bound name to its required target features.
138///
139/// IMPORTANT: Each entry must include ALL implied features, not just the defining ones.
140/// The compiler needs explicit #[target_feature] for each feature used.
141///
142/// Based on LLVM x86-64 microarchitecture levels (psABI).
143fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
144    match trait_name {
145        // x86 tier traits - each includes all features from lower tiers
146        "HasX64V2" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
147        "HasX64V4" => Some(&[
148            // v2 features
149            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", // v3 features
150            "avx", "avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt", // v4 features
151            "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
152        ]),
153
154        // x86 token types - when used directly as bounds
155        "X64V2Token" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
156        "X64V3Token" | "Desktop64" | "Avx2FmaToken" => Some(&[
157            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
158            "f16c", "lzcnt",
159        ]),
160        "X64V4Token" | "Avx512Token" | "Server64" => Some(&[
161            "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
162            "f16c", "lzcnt", "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
163        ]),
164        "Avx512ModernToken" => Some(&[
165            "sse3",
166            "ssse3",
167            "sse4.1",
168            "sse4.2",
169            "popcnt",
170            "avx",
171            "avx2",
172            "fma",
173            "bmi1",
174            "bmi2",
175            "f16c",
176            "lzcnt",
177            "avx512f",
178            "avx512bw",
179            "avx512cd",
180            "avx512dq",
181            "avx512vl",
182            "avx512vpopcntdq",
183            "avx512ifma",
184            "avx512vbmi",
185            "avx512vbmi2",
186            "avx512bitalg",
187            "avx512vnni",
188            "avx512bf16",
189            "vpclmulqdq",
190            "gfni",
191            "vaes",
192        ]),
193        "Avx512Fp16Token" => Some(&[
194            "sse3",
195            "ssse3",
196            "sse4.1",
197            "sse4.2",
198            "popcnt",
199            "avx",
200            "avx2",
201            "fma",
202            "bmi1",
203            "bmi2",
204            "f16c",
205            "lzcnt",
206            "avx512f",
207            "avx512bw",
208            "avx512cd",
209            "avx512dq",
210            "avx512vl",
211            "avx512fp16",
212        ]),
213
214        // Width traits - minimal features to satisfy width
215        "Has128BitSimd" => Some(&["sse2"]),
216        "Has256BitSimd" => Some(&["avx"]),
217        "Has512BitSimd" => Some(&["avx512f"]),
218
219        // AArch64 traits
220        "HasNeon" => Some(&["neon"]),
221        "HasNeonAes" => Some(&["neon", "aes"]),
222        "HasNeonSha3" => Some(&["neon", "sha3"]),
223
224        // AArch64 token types
225        "NeonToken" | "Arm64" => Some(&["neon"]),
226        "NeonAesToken" => Some(&["neon", "aes"]),
227        "NeonSha3Token" => Some(&["neon", "sha3"]),
228        "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
229        "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
230
231        _ => None,
232    }
233}
234
235/// Result of extracting token info from a type.
236enum TokenTypeInfo {
237    /// Concrete token type (e.g., `Avx2Token`)
238    Concrete(String),
239    /// impl Trait with the trait names (e.g., `impl HasAvx2`)
240    ImplTrait(Vec<String>),
241    /// Generic type parameter name (e.g., `T`)
242    Generic(String),
243}
244
245/// Extract token type information from a type.
246fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
247    match ty {
248        Type::Path(type_path) => {
249            // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
250            type_path.path.segments.last().map(|seg| {
251                let name = seg.ident.to_string();
252                // Check if it's a known concrete token type
253                if token_to_features(&name).is_some() {
254                    TokenTypeInfo::Concrete(name)
255                } else {
256                    // Might be a generic type parameter like `T`
257                    TokenTypeInfo::Generic(name)
258                }
259            })
260        }
261        Type::Reference(type_ref) => {
262            // Handle &Token or &mut Token
263            extract_token_type_info(&type_ref.elem)
264        }
265        Type::ImplTrait(impl_trait) => {
266            // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
267            let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
268            if traits.is_empty() {
269                None
270            } else {
271                Some(TokenTypeInfo::ImplTrait(traits))
272            }
273        }
274        _ => None,
275    }
276}
277
278/// Extract trait names from type param bounds.
279fn extract_trait_names_from_bounds(
280    bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
281) -> Vec<String> {
282    bounds
283        .iter()
284        .filter_map(|bound| {
285            if let TypeParamBound::Trait(trait_bound) = bound {
286                trait_bound
287                    .path
288                    .segments
289                    .last()
290                    .map(|seg| seg.ident.to_string())
291            } else {
292                None
293            }
294        })
295        .collect()
296}
297
298/// Look up a generic type parameter in the function's generics.
299fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
300    // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
301    for param in &sig.generics.params {
302        if let GenericParam::Type(type_param) = param {
303            if type_param.ident == type_name {
304                let traits = extract_trait_names_from_bounds(&type_param.bounds);
305                if !traits.is_empty() {
306                    return Some(traits);
307                }
308            }
309        }
310    }
311
312    // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
313    if let Some(where_clause) = &sig.generics.where_clause {
314        for predicate in &where_clause.predicates {
315            if let syn::WherePredicate::Type(pred_type) = predicate {
316                if let Type::Path(type_path) = &pred_type.bounded_ty {
317                    if let Some(seg) = type_path.path.segments.last() {
318                        if seg.ident == type_name {
319                            let traits = extract_trait_names_from_bounds(&pred_type.bounds);
320                            if !traits.is_empty() {
321                                return Some(traits);
322                            }
323                        }
324                    }
325                }
326            }
327        }
328    }
329
330    None
331}
332
333/// Convert trait names to features, collecting all features from all traits.
334fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
335    let mut all_features = Vec::new();
336
337    for trait_name in trait_names {
338        if let Some(features) = trait_to_features(trait_name) {
339            for &feature in features {
340                if !all_features.contains(&feature) {
341                    all_features.push(feature);
342                }
343            }
344        }
345    }
346
347    if all_features.is_empty() {
348        None
349    } else {
350        Some(all_features)
351    }
352}
353
354/// Find the first token parameter and return its name and features.
355fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
356    for arg in &sig.inputs {
357        match arg {
358            FnArg::Receiver(_) => {
359                // Self receivers (self, &self, &mut self) are not yet supported.
360                // The macro creates an inner function, and Rust's inner functions
361                // cannot have `self` parameters. Supporting this would require
362                // AST rewriting to replace `self` with a regular parameter.
363                // See the module docs for the workaround.
364                continue;
365            }
366            FnArg::Typed(PatType { pat, ty, .. }) => {
367                if let Some(info) = extract_token_type_info(ty) {
368                    let features = match info {
369                        TokenTypeInfo::Concrete(name) => {
370                            token_to_features(&name).map(|f| f.to_vec())
371                        }
372                        TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
373                        TokenTypeInfo::Generic(type_name) => {
374                            // Look up the generic parameter's bounds
375                            find_generic_bounds(sig, &type_name)
376                                .and_then(|traits| traits_to_features(&traits))
377                        }
378                    };
379
380                    if let Some(features) = features {
381                        // Extract parameter name
382                        if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
383                            return Some((pat_ident.ident.clone(), features));
384                        }
385                    }
386                }
387            }
388        }
389    }
390    None
391}
392
393/// Represents the kind of self receiver and the transformed parameter.
394enum SelfReceiver {
395    /// `self` (by value/move)
396    Owned,
397    /// `&self` (shared reference)
398    Ref,
399    /// `&mut self` (mutable reference)
400    RefMut,
401}
402
403/// Shared implementation for arcane/simd_fn macros.
404fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
405    // Check for self receiver
406    let has_self_receiver = input_fn
407        .sig
408        .inputs
409        .first()
410        .map(|arg| matches!(arg, FnArg::Receiver(_)))
411        .unwrap_or(false);
412
413    // If there's a self receiver, we need _self = Type
414    if has_self_receiver && args.self_type.is_none() {
415        let msg = format!(
416            "{} with self receiver requires `_self = Type` argument.\n\
417             Example: #[{}(_self = MyType)]\n\
418             Use `_self` (not `self`) in the function body to refer to self.",
419            macro_name, macro_name
420        );
421        return syn::Error::new_spanned(&input_fn.sig, msg)
422            .to_compile_error()
423            .into();
424    }
425
426    // Find the token parameter and its features
427    let (_token_ident, features) = match find_token_param(&input_fn.sig) {
428        Some(result) => result,
429        None => {
430            let msg = format!(
431                "{} requires a token parameter. Supported forms:\n\
432                 - Concrete: `token: Avx2Token`\n\
433                 - impl Trait: `token: impl HasAvx2`\n\
434                 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
435                 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasAvx2, ...)`",
436                macro_name, macro_name
437            );
438            return syn::Error::new_spanned(&input_fn.sig, msg)
439                .to_compile_error()
440                .into();
441        }
442    };
443
444    // Build target_feature attributes
445    let target_feature_attrs: Vec<Attribute> = features
446        .iter()
447        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
448        .collect();
449
450    // Extract function components
451    let vis = &input_fn.vis;
452    let sig = &input_fn.sig;
453    let fn_name = &sig.ident;
454    let generics = &sig.generics;
455    let where_clause = &generics.where_clause;
456    let inputs = &sig.inputs;
457    let output = &sig.output;
458    let body = &input_fn.block;
459    let attrs = &input_fn.attrs;
460
461    // Determine self receiver type if present
462    let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
463        FnArg::Receiver(receiver) => {
464            if receiver.reference.is_none() {
465                Some(SelfReceiver::Owned)
466            } else if receiver.mutability.is_some() {
467                Some(SelfReceiver::RefMut)
468            } else {
469                Some(SelfReceiver::Ref)
470            }
471        }
472        _ => None,
473    });
474
475    // Build inner function parameters, transforming self if needed
476    let inner_params: Vec<proc_macro2::TokenStream> = inputs
477        .iter()
478        .map(|arg| match arg {
479            FnArg::Receiver(_) => {
480                // Transform self receiver to _self parameter
481                let self_ty = args.self_type.as_ref().unwrap();
482                match self_receiver_kind.as_ref().unwrap() {
483                    SelfReceiver::Owned => quote!(_self: #self_ty),
484                    SelfReceiver::Ref => quote!(_self: &#self_ty),
485                    SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
486                }
487            }
488            FnArg::Typed(pat_type) => quote!(#pat_type),
489        })
490        .collect();
491
492    // Build inner function call arguments
493    let inner_args: Vec<proc_macro2::TokenStream> = inputs
494        .iter()
495        .filter_map(|arg| match arg {
496            FnArg::Typed(pat_type) => {
497                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
498                    let ident = &pat_ident.ident;
499                    Some(quote!(#ident))
500                } else {
501                    None
502                }
503            }
504            FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
505        })
506        .collect();
507
508    let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
509
510    // Choose inline attribute based on args
511    // Note: #[inline(always)] + #[target_feature] requires nightly with
512    // #![feature(target_feature_inline_always)]
513    let inline_attr: Attribute = if args.inline_always {
514        parse_quote!(#[inline(always)])
515    } else {
516        parse_quote!(#[inline])
517    };
518
519    // Transform output and body to replace Self with concrete type if needed
520    let (inner_output, inner_body): (ReturnType, syn::Block) =
521        if let Some(ref self_ty) = args.self_type {
522            let mut replacer = ReplaceSelf {
523                replacement: self_ty,
524            };
525            let transformed_output = replacer.fold_return_type(output.clone());
526            let transformed_body = replacer.fold_block((**body).clone());
527            (transformed_output, transformed_body)
528        } else {
529            (output.clone(), (**body).clone())
530        };
531
532    // Generate the expanded function
533    let expanded = quote! {
534        #(#attrs)*
535        #vis #sig {
536            #(#target_feature_attrs)*
537            #inline_attr
538            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
539            #inner_body
540
541            // SAFETY: The token parameter proves the required CPU features are available.
542            // Tokens can only be constructed when features are verified (via try_new()
543            // runtime check or forge_token_dangerously() in a context where features are guaranteed).
544            unsafe { #inner_fn_name(#(#inner_args),*) }
545        }
546    };
547
548    expanded.into()
549}
550
551/// Mark a function as an arcane SIMD function.
552///
553/// This macro enables safe use of SIMD intrinsics by generating an inner function
554/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
555/// the token parameter type. The outer function calls the inner function unsafely,
556/// which is justified because the token parameter proves the features are available.
557///
558/// **The token is passed through to the inner function**, so you can call other
559/// token-taking functions from inside `#[arcane]`.
560///
561/// # Token Parameter Forms
562///
563/// The macro supports four forms of token parameters:
564///
565/// ## Concrete Token Types
566///
567/// ```ignore
568/// #[arcane]
569/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
570///     // AVX2 intrinsics safe here
571/// }
572/// ```
573///
574/// ## impl Trait Bounds
575///
576/// ```ignore
577/// #[arcane]
578/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
579///     // Accepts any token that provides AVX2
580/// }
581/// ```
582///
583/// ## Generic Type Parameters
584///
585/// ```ignore
586/// #[arcane]
587/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
588///     // Generic over any AVX2-capable token
589/// }
590///
591/// // Also works with where clauses:
592/// #[arcane]
593/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
594/// where
595///     T: HasAvx2
596/// {
597///     // ...
598/// }
599/// ```
600///
601/// ## Methods with Self Receivers
602///
603/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
604/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
605///
606/// ```ignore
607/// use archmage::{HasAvx2, arcane};
608/// use wide::f32x8;
609///
610/// trait Avx2Ops {
611///     fn double(&self, token: impl HasAvx2) -> Self;
612///     fn square(self, token: impl HasAvx2) -> Self;
613///     fn scale(&mut self, token: impl HasAvx2, factor: f32);
614/// }
615///
616/// impl Avx2Ops for f32x8 {
617///     #[arcane(_self = f32x8)]
618///     fn double(&self, _token: impl HasAvx2) -> Self {
619///         // Use _self instead of self in the body
620///         *_self + *_self
621///     }
622///
623///     #[arcane(_self = f32x8)]
624///     fn square(self, _token: impl HasAvx2) -> Self {
625///         _self * _self
626///     }
627///
628///     #[arcane(_self = f32x8)]
629///     fn scale(&mut self, _token: impl HasAvx2, factor: f32) {
630///         *_self = *_self * f32x8::splat(factor);
631///     }
632/// }
633/// ```
634///
635/// **Why `_self`?** The macro generates an inner function where `self` becomes
636/// a regular parameter named `_self`. Using `_self` in your code reminds you
637/// that you're not using the normal `self` keyword.
638///
639/// **All receiver types are supported:**
640/// - `self` (by value/move) → `_self: Type`
641/// - `&self` (shared reference) → `_self: &Type`
642/// - `&mut self` (mutable reference) → `_self: &mut Type`
643///
644/// # Multiple Trait Bounds
645///
646/// When using `impl Trait` or generic bounds with multiple traits,
647/// all required features are enabled:
648///
649/// ```ignore
650/// #[arcane]
651/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
652///     // Both AVX2 and FMA intrinsics are safe here
653/// }
654/// ```
655///
656/// # Expansion
657///
658/// The macro expands to approximately:
659///
660/// ```ignore
661/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
662///     #[target_feature(enable = "avx2")]
663///     #[inline]
664///     unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
665///         let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
666///         let doubled = _mm256_add_ps(v, v);
667///         let mut out = [0.0f32; 8];
668///         unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
669///         out
670///     }
671///     // SAFETY: Token proves the required features are available
672///     unsafe { __simd_inner_process(token, data) }
673/// }
674/// ```
675///
676/// # Profile Tokens
677///
678/// Profile tokens automatically enable all required features:
679///
680/// ```ignore
681/// #[arcane]
682/// fn kernel(token: X64V3Token, data: &mut [f32]) {
683///     // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
684/// }
685/// ```
686///
687/// # Supported Tokens
688///
689/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
690///   `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
691/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
692/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
693/// - **WASM**: `Simd128Token`
694///
695/// # Supported Trait Bounds
696///
697/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
698///   `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
699/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
700/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
701///
702/// # Options
703///
704/// ## `inline_always`
705///
706/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
707/// This can improve performance by ensuring aggressive inlining, but requires
708/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
709/// the crate using the macro.
710///
711/// ```ignore
712/// #![feature(target_feature_inline_always)]
713///
714/// #[arcane(inline_always)]
715/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
716///     // Inner function will use #[inline(always)]
717/// }
718/// ```
719#[proc_macro_attribute]
720pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
721    let args = parse_macro_input!(attr as ArcaneArgs);
722    let input_fn = parse_macro_input!(item as ItemFn);
723    arcane_impl(input_fn, "arcane", args)
724}
725
726/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
727///
728/// See [`arcane`] for full documentation.
729#[proc_macro_attribute]
730pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
731    let args = parse_macro_input!(attr as ArcaneArgs);
732    let input_fn = parse_macro_input!(item as ItemFn);
733    arcane_impl(input_fn, "simd_fn", args)
734}
735
736// ============================================================================
737// Multiwidth macro for width-agnostic SIMD code
738// ============================================================================
739
740use syn::ItemMod;
741
742/// Arguments to the `#[multiwidth]` macro.
743struct MultiwidthArgs {
744    /// Include SSE (128-bit) specialization
745    sse: bool,
746    /// Include AVX2 (256-bit) specialization
747    avx2: bool,
748    /// Include AVX-512 (512-bit) specialization
749    avx512: bool,
750    /// Include WASM SIMD128 (128-bit) specialization
751    wasm: bool,
752    /// Include NEON (128-bit ARM) specialization
753    neon: bool,
754}
755
756impl Default for MultiwidthArgs {
757    fn default() -> Self {
758        Self {
759            sse: true,
760            avx2: true,
761            avx512: true,
762            wasm: true,
763            neon: true,
764        }
765    }
766}
767
768impl Parse for MultiwidthArgs {
769    fn parse(input: ParseStream) -> syn::Result<Self> {
770        let mut args = MultiwidthArgs {
771            sse: false,
772            avx2: false,
773            avx512: false,
774            wasm: false,
775            neon: false,
776        };
777
778        // If no args provided, enable all
779        if input.is_empty() {
780            return Ok(MultiwidthArgs::default());
781        }
782
783        while !input.is_empty() {
784            let ident: Ident = input.parse()?;
785            match ident.to_string().as_str() {
786                "sse" => args.sse = true,
787                "avx2" => args.avx2 = true,
788                "avx512" => args.avx512 = true,
789                "wasm" | "simd128" => args.wasm = true,
790                "neon" | "arm" => args.neon = true,
791                other => {
792                    return Err(syn::Error::new(
793                        ident.span(),
794                        format!(
795                        "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512, wasm, neon",
796                        other
797                    ),
798                    ))
799                }
800            }
801            // Consume optional comma
802            if input.peek(Token![,]) {
803                let _: Token![,] = input.parse()?;
804            }
805        }
806
807        Ok(args)
808    }
809}
810
811/// Width configuration for specialization.
812struct WidthConfig {
813    /// Module name suffix (e.g., "sse", "avx2", "avx512")
814    name: &'static str,
815    /// The namespace import path
816    namespace: &'static str,
817    /// Token type name
818    token: &'static str,
819    /// Whether this requires a feature flag
820    feature: Option<&'static str>,
821    /// Target features to enable for this width
822    target_features: &'static [&'static str],
823}
824
825/// Width configuration for x86_64 targets
826const X86_WIDTH_CONFIGS: &[WidthConfig] = &[
827    WidthConfig {
828        name: "sse",
829        namespace: "magetypes::simd::sse",
830        token: "archmage::X64V3Token",
831        feature: None,
832        target_features: &["avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt"],
833    },
834    WidthConfig {
835        name: "avx2",
836        namespace: "magetypes::simd::avx2",
837        token: "archmage::X64V3Token",
838        feature: None,
839        target_features: &["avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt"],
840    },
841    WidthConfig {
842        name: "avx512",
843        namespace: "magetypes::simd::avx512",
844        token: "archmage::X64V4Token",
845        feature: Some("avx512"),
846        target_features: &["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"],
847    },
848];
849
850/// Width configuration for wasm32 targets
851const WASM_WIDTH_CONFIGS: &[WidthConfig] = &[WidthConfig {
852    name: "simd128",
853    namespace: "magetypes::simd::simd128",
854    token: "archmage::Simd128Token",
855    feature: None,
856    target_features: &["simd128"],
857}];
858
859/// Width configuration for aarch64 targets
860const ARM_WIDTH_CONFIGS: &[WidthConfig] = &[WidthConfig {
861    name: "neon",
862    namespace: "magetypes::simd::neon",
863    token: "archmage::NeonToken",
864    feature: None,
865    target_features: &["neon"],
866}];
867
868/// Generate width-specialized SIMD code.
869///
870/// This macro takes a module containing width-agnostic SIMD code and generates
871/// specialized versions for each target width (SSE, AVX2, AVX-512).
872///
873/// # Usage
874///
875/// ```ignore
876/// use archmage::multiwidth;
877///
878/// #[multiwidth]
879/// mod kernels {
880///     // Inside this module, these types are available:
881///     // - f32xN, i32xN, etc. (width-appropriate SIMD types)
882///     // - Token (the token type: X64V3Token for SSE/AVX2, or X64V4Token for AVX-512)
883///     // - LANES_F32, LANES_32, etc. (lane count constants)
884///
885///     use archmage::simd::*;
886///
887///     pub fn normalize(token: Token, data: &mut [f32]) {
888///         for chunk in data.chunks_exact_mut(LANES_F32) {
889///             let v = f32xN::load(token, chunk.try_into().unwrap());
890///             let result = v * f32xN::splat(token, 1.0 / 255.0);
891///             result.store(chunk.try_into().unwrap());
892///         }
893///     }
894/// }
895///
896/// // Generated modules:
897/// // - kernels::sse::normalize(token: X64V3Token, data: &mut [f32])
898/// // - kernels::avx2::normalize(token: X64V3Token, data: &mut [f32])
899/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32])  // if avx512 feature
900/// // - kernels::normalize(data: &mut [f32])  // runtime dispatcher
901/// ```
902///
903/// # Selective Targets
904///
905/// You can specify which targets to generate:
906///
907/// ```ignore
908/// #[multiwidth(avx2, avx512)]  // Only AVX2 and AVX-512, no SSE
909/// mod fast_kernels { ... }
910/// ```
911///
912/// # How It Works
913///
914/// 1. The macro duplicates the module content for each width target
915/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
916/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
917/// 4. A dispatcher function is generated that picks the best available at runtime
918///
919/// # Requirements
920///
921/// - Functions should use `Token` as their token parameter type
922/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
923/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
924#[proc_macro_attribute]
925pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
926    let args = parse_macro_input!(attr as MultiwidthArgs);
927    let input_mod = parse_macro_input!(item as ItemMod);
928
929    multiwidth_impl(input_mod, args)
930}
931
932/// Configuration with target arch for conditional compilation
933struct ArchConfig<'a> {
934    config: &'a WidthConfig,
935    target_arch: Option<&'static str>,
936}
937
938fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
939    let mod_name = &input_mod.ident;
940    let mod_vis = &input_mod.vis;
941    let mod_attrs = &input_mod.attrs;
942
943    // Get module content
944    let content = match &input_mod.content {
945        Some((_, items)) => items,
946        None => {
947            return syn::Error::new_spanned(
948                &input_mod,
949                "multiwidth requires an inline module (mod name { ... }), not a file module",
950            )
951            .to_compile_error()
952            .into();
953        }
954    };
955
956    // Build list of all enabled configs across architectures
957    let mut all_configs: Vec<ArchConfig> = Vec::new();
958
959    // x86_64 configs
960    for config in X86_WIDTH_CONFIGS {
961        let enabled = match config.name {
962            "sse" => args.sse,
963            "avx2" => args.avx2,
964            "avx512" => args.avx512,
965            _ => false,
966        };
967        if enabled {
968            all_configs.push(ArchConfig {
969                config,
970                target_arch: Some("x86_64"),
971            });
972        }
973    }
974
975    // WASM configs
976    if args.wasm {
977        for config in WASM_WIDTH_CONFIGS {
978            all_configs.push(ArchConfig {
979                config,
980                target_arch: Some("wasm32"),
981            });
982        }
983    }
984
985    // ARM configs
986    if args.neon {
987        for config in ARM_WIDTH_CONFIGS {
988            all_configs.push(ArchConfig {
989                config,
990                target_arch: Some("aarch64"),
991            });
992        }
993    }
994
995    // Build specialized modules
996    let mut specialized_mods = Vec::new();
997    let mut enabled_configs = Vec::new();
998
999    for arch_config in &all_configs {
1000        let config = arch_config.config;
1001        enabled_configs.push(config);
1002
1003        let width_mod_name = format_ident!("{}", config.name);
1004        let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
1005
1006        // Transform the content: replace `use archmage::simd::*` with width-specific import
1007        // and add target_feature attributes for optimization
1008        let transformed_items: Vec<syn::Item> = content
1009            .iter()
1010            .map(|item| transform_item_for_width(item.clone(), &namespace, config))
1011            .collect();
1012
1013        // Build cfg attributes for target arch and optional feature
1014        let arch_attr = arch_config
1015            .target_arch
1016            .map(|arch| quote!(#[cfg(target_arch = #arch)]));
1017
1018        let feature_attr = config.feature.map(|f| {
1019            let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1020            quote!(#[cfg(feature = #f_lit)])
1021        });
1022
1023        specialized_mods.push(quote! {
1024            #arch_attr
1025            #feature_attr
1026            pub mod #width_mod_name {
1027                #(#transformed_items)*
1028            }
1029        });
1030    }
1031
1032    // Generate dispatcher functions for each public function in the module
1033    // The dispatcher is x86_64-specific (runtime feature detection)
1034    // For WASM and ARM, features are compile-time only
1035    let x86_configs: Vec<_> = all_configs
1036        .iter()
1037        .filter(|c| c.target_arch == Some("x86_64"))
1038        .map(|c| c.config)
1039        .collect();
1040
1041    // Only generate dispatcher section if we have x86 configs
1042    let dispatcher_section = if !x86_configs.is_empty() {
1043        let dispatchers = generate_dispatchers(content, &x86_configs);
1044        quote! {
1045            // Runtime dispatcher (x86_64 only - WASM/ARM use compile-time features)
1046            #[cfg(target_arch = "x86_64")]
1047            mod __dispatchers {
1048                use super::*;
1049                #dispatchers
1050            }
1051            #[cfg(target_arch = "x86_64")]
1052            pub use __dispatchers::*;
1053        }
1054    } else {
1055        quote! {}
1056    };
1057
1058    let expanded = quote! {
1059        #(#mod_attrs)*
1060        #mod_vis mod #mod_name {
1061            #(#specialized_mods)*
1062
1063            #dispatcher_section
1064        }
1065    };
1066
1067    expanded.into()
1068}
1069
1070/// Transform a single item for a specific width namespace.
1071fn transform_item_for_width(
1072    item: syn::Item,
1073    namespace: &syn::Path,
1074    config: &WidthConfig,
1075) -> syn::Item {
1076    match item {
1077        syn::Item::Use(mut use_item) => {
1078            // Check if this is `use archmage::simd::*` or similar
1079            if is_simd_wildcard_use(&use_item) {
1080                // Replace with width-specific import
1081                use_item.tree = syn::UseTree::Path(syn::UsePath {
1082                    ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
1083                    colon2_token: Default::default(),
1084                    tree: Box::new(build_use_tree_from_path(namespace, 1)),
1085                });
1086            }
1087            syn::Item::Use(use_item)
1088        }
1089        syn::Item::Fn(func) => {
1090            // Transform function to use inner function pattern with target_feature
1091            // This is the same pattern as #[arcane], enabling SIMD optimization
1092            // without requiring -C target-cpu=native
1093            transform_fn_with_target_feature(func, config)
1094        }
1095        other => other,
1096    }
1097}
1098
1099/// Transform a function to use the inner function pattern with target_feature.
1100/// This generates:
1101/// ```ignore
1102/// pub fn example(token: Token, data: &[f32]) -> f32 {
1103///     #[target_feature(enable = "avx2", enable = "fma")]
1104///     #[inline]
1105///     unsafe fn inner(token: Token, data: &[f32]) -> f32 {
1106///         // original body
1107///     }
1108///     // SAFETY: Token proves CPU support
1109///     unsafe { inner(token, data) }
1110/// }
1111/// ```
1112fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
1113    let vis = &func.vis;
1114    let sig = &func.sig;
1115    let fn_name = &sig.ident;
1116    let generics = &sig.generics;
1117    let where_clause = &generics.where_clause;
1118    let inputs = &sig.inputs;
1119    let output = &sig.output;
1120    let body = &func.block;
1121    let attrs = &func.attrs;
1122
1123    // Build target_feature attributes
1124    let target_feature_attrs: Vec<syn::Attribute> = config
1125        .target_features
1126        .iter()
1127        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
1128        .collect();
1129
1130    // Build parameter list for inner function
1131    let inner_params: Vec<proc_macro2::TokenStream> =
1132        inputs.iter().map(|arg| quote!(#arg)).collect();
1133
1134    // Build argument list for calling inner function
1135    let call_args: Vec<proc_macro2::TokenStream> = inputs
1136        .iter()
1137        .filter_map(|arg| match arg {
1138            syn::FnArg::Typed(pat_type) => {
1139                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1140                    let ident = &pat_ident.ident;
1141                    Some(quote!(#ident))
1142                } else {
1143                    None
1144                }
1145            }
1146            syn::FnArg::Receiver(_) => Some(quote!(self)),
1147        })
1148        .collect();
1149
1150    let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
1151
1152    let expanded = quote! {
1153        #(#attrs)*
1154        #vis #sig {
1155            #(#target_feature_attrs)*
1156            #[inline]
1157            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
1158            #body
1159
1160            // SAFETY: The Token parameter proves the required CPU features are available.
1161            // Tokens can only be constructed via try_new() which checks CPU support.
1162            unsafe { #inner_fn_name(#(#call_args),*) }
1163        }
1164    };
1165
1166    syn::parse2(expanded).expect("Failed to parse transformed function")
1167}
1168
1169/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
1170fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
1171    fn check_tree(tree: &syn::UseTree) -> bool {
1172        match tree {
1173            syn::UseTree::Path(path) => {
1174                let ident = path.ident.to_string();
1175                if ident == "archmage" || ident == "magetypes" || ident == "crate" {
1176                    check_tree_for_simd(&path.tree)
1177                } else {
1178                    false
1179                }
1180            }
1181            _ => false,
1182        }
1183    }
1184
1185    fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
1186        match tree {
1187            syn::UseTree::Path(path) => {
1188                if path.ident == "simd" {
1189                    matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
1190                } else {
1191                    check_tree_for_simd(&path.tree)
1192                }
1193            }
1194            _ => false,
1195        }
1196    }
1197
1198    check_tree(&use_item.tree)
1199}
1200
1201/// Build a UseTree from a path, starting at a given segment index.
1202fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
1203    let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
1204
1205    if segments.is_empty() {
1206        syn::UseTree::Glob(syn::UseGlob {
1207            star_token: Default::default(),
1208        })
1209    } else if segments.len() == 1 {
1210        syn::UseTree::Path(syn::UsePath {
1211            ident: segments[0].ident.clone(),
1212            colon2_token: Default::default(),
1213            tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1214                star_token: Default::default(),
1215            })),
1216        })
1217    } else {
1218        let first = &segments[0];
1219        let rest_path = syn::Path {
1220            leading_colon: None,
1221            segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1222        };
1223        syn::UseTree::Path(syn::UsePath {
1224            ident: first.ident.clone(),
1225            colon2_token: Default::default(),
1226            tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1227        })
1228    }
1229}
1230
1231/// Width-specific type names that can't be used in dispatcher signatures.
1232const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1233    "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1234];
1235
1236/// Check if a type string contains width-specific types.
1237fn contains_width_specific_type(ty_str: &str) -> bool {
1238    WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1239}
1240
1241/// Check if a function signature uses width-specific types (can't have a dispatcher).
1242fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1243    // Check return type
1244    if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1245        let ty_str = quote!(#ty).to_string();
1246        if contains_width_specific_type(&ty_str) {
1247            return true;
1248        }
1249    }
1250
1251    // Check parameters (excluding Token which we filter out anyway)
1252    for arg in &func.sig.inputs {
1253        if let syn::FnArg::Typed(pat_type) = arg {
1254            let ty = &pat_type.ty;
1255            let ty_str = quote!(#ty).to_string();
1256            // Skip Token parameters - they're filtered out for dispatchers
1257            if ty_str.contains("Token") {
1258                continue;
1259            }
1260            if contains_width_specific_type(&ty_str) {
1261                return true;
1262            }
1263        }
1264    }
1265
1266    false
1267}
1268
1269/// Generate runtime dispatcher functions for public functions.
1270///
1271/// Note: Dispatchers are only generated for functions that don't use width-specific
1272/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1273/// width-specific types can only be called via the width-specific submodules.
1274fn generate_dispatchers(
1275    content: &[syn::Item],
1276    configs: &[&WidthConfig],
1277) -> proc_macro2::TokenStream {
1278    let mut dispatchers = Vec::new();
1279
1280    for item in content {
1281        if let syn::Item::Fn(func) = item {
1282            // Only generate dispatchers for public functions
1283            if !matches!(func.vis, syn::Visibility::Public(_)) {
1284                continue;
1285            }
1286
1287            // Skip functions that use width-specific types - they can't have dispatchers
1288            if uses_width_specific_types(func) {
1289                continue;
1290            }
1291
1292            let fn_name = &func.sig.ident;
1293            let fn_generics = &func.sig.generics;
1294            let fn_output = &func.sig.output;
1295            let fn_attrs: Vec<_> = func
1296                .attrs
1297                .iter()
1298                .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1299                .collect();
1300
1301            // Filter out the token parameter from the dispatcher signature
1302            let non_token_params: Vec<_> = func
1303                .sig
1304                .inputs
1305                .iter()
1306                .filter(|arg| {
1307                    match arg {
1308                        syn::FnArg::Typed(pat_type) => {
1309                            // Check if type contains "Token"
1310                            let ty = &pat_type.ty;
1311                            let ty_str = quote!(#ty).to_string();
1312                            !ty_str.contains("Token")
1313                        }
1314                        _ => true,
1315                    }
1316                })
1317                .collect();
1318
1319            // Extract just the parameter names for passing to specialized functions
1320            let param_names: Vec<_> = non_token_params
1321                .iter()
1322                .filter_map(|arg| match arg {
1323                    syn::FnArg::Typed(pat_type) => {
1324                        if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1325                            Some(&pat_ident.ident)
1326                        } else {
1327                            None
1328                        }
1329                    }
1330                    _ => None,
1331                })
1332                .collect();
1333
1334            // Generate dispatch branches (highest capability first)
1335            let mut branches = Vec::new();
1336
1337            for config in configs.iter().rev() {
1338                let mod_name = format_ident!("{}", config.name);
1339                let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1340
1341                let feature_check = config.feature.map(|f| {
1342                    let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1343                    quote!(#[cfg(feature = #f_lit)])
1344                });
1345
1346                branches.push(quote! {
1347                    #feature_check
1348                    if let Some(token) = #token_path::try_new() {
1349                        return #mod_name::#fn_name(token, #(#param_names),*);
1350                    }
1351                });
1352            }
1353
1354            // Generate dispatcher
1355            dispatchers.push(quote! {
1356                #(#fn_attrs)*
1357                /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1358                pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1359                    use archmage::SimdToken;
1360
1361                    #(#branches)*
1362
1363                    // Fallback: panic if no SIMD available
1364                    // TODO: Allow user-provided scalar fallback
1365                    panic!("No SIMD support available for {}", stringify!(#fn_name));
1366                }
1367            });
1368        }
1369    }
1370
1371    quote! { #(#dispatchers)* }
1372}