Skip to main content

archmage_macros/
lib.rs

1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    fold::Fold,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12    ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17    replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21    fn fold_type(&mut self, ty: Type) -> Type {
22        match ty {
23            Type::Path(ref type_path) if type_path.qself.is_none() => {
24                // Check if it's just `Self`
25                if type_path.path.is_ident("Self") {
26                    return self.replacement.clone();
27                }
28                // Otherwise continue folding
29                syn::fold::fold_type(self, ty)
30            }
31            _ => syn::fold::fold_type(self, ty),
32        }
33    }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39    /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40    /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41    inline_always: bool,
42    /// The concrete type to use for `self` receiver.
43    /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44    self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48    fn parse(input: ParseStream) -> syn::Result<Self> {
49        let mut args = ArcaneArgs::default();
50
51        while !input.is_empty() {
52            let ident: Ident = input.parse()?;
53            match ident.to_string().as_str() {
54                "inline_always" => args.inline_always = true,
55                "_self" => {
56                    let _: Token![=] = input.parse()?;
57                    args.self_type = Some(input.parse()?);
58                }
59                other => {
60                    return Err(syn::Error::new(
61                        ident.span(),
62                        format!("unknown arcane argument: `{}`", other),
63                    ))
64                }
65            }
66            // Consume optional comma
67            if input.peek(Token![,]) {
68                let _: Token![,] = input.parse()?;
69            }
70        }
71
72        Ok(args)
73    }
74}
75
76// Token-to-features and trait-to-features mappings are generated from
77// token-registry.toml by xtask. Regenerate with: cargo run -p xtask -- generate
78mod generated;
79use generated::{token_to_features, trait_to_features};
80
81/// Result of extracting token info from a type.
82enum TokenTypeInfo {
83    /// Concrete token type (e.g., `Avx2Token`)
84    Concrete(String),
85    /// impl Trait with the trait names (e.g., `impl Has256BitSimd`)
86    ImplTrait(Vec<String>),
87    /// Generic type parameter name (e.g., `T`)
88    Generic(String),
89}
90
91/// Extract token type information from a type.
92fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
93    match ty {
94        Type::Path(type_path) => {
95            // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
96            type_path.path.segments.last().map(|seg| {
97                let name = seg.ident.to_string();
98                // Check if it's a known concrete token type
99                if token_to_features(&name).is_some() {
100                    TokenTypeInfo::Concrete(name)
101                } else {
102                    // Might be a generic type parameter like `T`
103                    TokenTypeInfo::Generic(name)
104                }
105            })
106        }
107        Type::Reference(type_ref) => {
108            // Handle &Token or &mut Token
109            extract_token_type_info(&type_ref.elem)
110        }
111        Type::ImplTrait(impl_trait) => {
112            // Handle `impl Has256BitSimd` or `impl HasX64V2 + HasNeon`
113            let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
114            if traits.is_empty() {
115                None
116            } else {
117                Some(TokenTypeInfo::ImplTrait(traits))
118            }
119        }
120        _ => None,
121    }
122}
123
124/// Extract trait names from type param bounds.
125fn extract_trait_names_from_bounds(
126    bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
127) -> Vec<String> {
128    bounds
129        .iter()
130        .filter_map(|bound| {
131            if let TypeParamBound::Trait(trait_bound) = bound {
132                trait_bound
133                    .path
134                    .segments
135                    .last()
136                    .map(|seg| seg.ident.to_string())
137            } else {
138                None
139            }
140        })
141        .collect()
142}
143
144/// Look up a generic type parameter in the function's generics.
145fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
146    // Check inline bounds first (e.g., `fn foo<T: HasX64V2>(token: T)`)
147    for param in &sig.generics.params {
148        if let GenericParam::Type(type_param) = param {
149            if type_param.ident == type_name {
150                let traits = extract_trait_names_from_bounds(&type_param.bounds);
151                if !traits.is_empty() {
152                    return Some(traits);
153                }
154            }
155        }
156    }
157
158    // Check where clause (e.g., `fn foo<T>(token: T) where T: HasX64V2`)
159    if let Some(where_clause) = &sig.generics.where_clause {
160        for predicate in &where_clause.predicates {
161            if let syn::WherePredicate::Type(pred_type) = predicate {
162                if let Type::Path(type_path) = &pred_type.bounded_ty {
163                    if let Some(seg) = type_path.path.segments.last() {
164                        if seg.ident == type_name {
165                            let traits = extract_trait_names_from_bounds(&pred_type.bounds);
166                            if !traits.is_empty() {
167                                return Some(traits);
168                            }
169                        }
170                    }
171                }
172            }
173        }
174    }
175
176    None
177}
178
179/// Convert trait names to features, collecting all features from all traits.
180fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
181    let mut all_features = Vec::new();
182
183    for trait_name in trait_names {
184        if let Some(features) = trait_to_features(trait_name) {
185            for &feature in features {
186                if !all_features.contains(&feature) {
187                    all_features.push(feature);
188                }
189            }
190        }
191    }
192
193    if all_features.is_empty() {
194        None
195    } else {
196        Some(all_features)
197    }
198}
199
200/// Find the first token parameter and return its name and features.
201fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
202    for arg in &sig.inputs {
203        match arg {
204            FnArg::Receiver(_) => {
205                // Self receivers (self, &self, &mut self) are not yet supported.
206                // The macro creates an inner function, and Rust's inner functions
207                // cannot have `self` parameters. Supporting this would require
208                // AST rewriting to replace `self` with a regular parameter.
209                // See the module docs for the workaround.
210                continue;
211            }
212            FnArg::Typed(PatType { pat, ty, .. }) => {
213                if let Some(info) = extract_token_type_info(ty) {
214                    let features = match info {
215                        TokenTypeInfo::Concrete(name) => {
216                            token_to_features(&name).map(|f| f.to_vec())
217                        }
218                        TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
219                        TokenTypeInfo::Generic(type_name) => {
220                            // Look up the generic parameter's bounds
221                            find_generic_bounds(sig, &type_name)
222                                .and_then(|traits| traits_to_features(&traits))
223                        }
224                    };
225
226                    if let Some(features) = features {
227                        // Extract parameter name
228                        if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
229                            return Some((pat_ident.ident.clone(), features));
230                        }
231                    }
232                }
233            }
234        }
235    }
236    None
237}
238
239/// Represents the kind of self receiver and the transformed parameter.
240enum SelfReceiver {
241    /// `self` (by value/move)
242    Owned,
243    /// `&self` (shared reference)
244    Ref,
245    /// `&mut self` (mutable reference)
246    RefMut,
247}
248
249/// Shared implementation for arcane/simd_fn macros.
250fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
251    // Check for self receiver
252    let has_self_receiver = input_fn
253        .sig
254        .inputs
255        .first()
256        .map(|arg| matches!(arg, FnArg::Receiver(_)))
257        .unwrap_or(false);
258
259    // If there's a self receiver, we need _self = Type
260    if has_self_receiver && args.self_type.is_none() {
261        let msg = format!(
262            "{} with self receiver requires `_self = Type` argument.\n\
263             Example: #[{}(_self = MyType)]\n\
264             Use `_self` (not `self`) in the function body to refer to self.",
265            macro_name, macro_name
266        );
267        return syn::Error::new_spanned(&input_fn.sig, msg)
268            .to_compile_error()
269            .into();
270    }
271
272    // Find the token parameter and its features
273    let (_token_ident, features) = match find_token_param(&input_fn.sig) {
274        Some(result) => result,
275        None => {
276            let msg = format!(
277                "{} requires a token parameter. Supported forms:\n\
278                 - Concrete: `token: X64V3Token`\n\
279                 - impl Trait: `token: impl Has256BitSimd`\n\
280                 - Generic: `fn foo<T: HasX64V2>(token: T, ...)`\n\
281                 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasNeon, ...)`",
282                macro_name, macro_name
283            );
284            return syn::Error::new_spanned(&input_fn.sig, msg)
285                .to_compile_error()
286                .into();
287        }
288    };
289
290    // Build target_feature attributes
291    let target_feature_attrs: Vec<Attribute> = features
292        .iter()
293        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
294        .collect();
295
296    // Extract function components
297    let vis = &input_fn.vis;
298    let sig = &input_fn.sig;
299    let fn_name = &sig.ident;
300    let generics = &sig.generics;
301    let where_clause = &generics.where_clause;
302    let inputs = &sig.inputs;
303    let output = &sig.output;
304    let body = &input_fn.block;
305    let attrs = &input_fn.attrs;
306
307    // Determine self receiver type if present
308    let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
309        FnArg::Receiver(receiver) => {
310            if receiver.reference.is_none() {
311                Some(SelfReceiver::Owned)
312            } else if receiver.mutability.is_some() {
313                Some(SelfReceiver::RefMut)
314            } else {
315                Some(SelfReceiver::Ref)
316            }
317        }
318        _ => None,
319    });
320
321    // Build inner function parameters, transforming self if needed
322    let inner_params: Vec<proc_macro2::TokenStream> = inputs
323        .iter()
324        .map(|arg| match arg {
325            FnArg::Receiver(_) => {
326                // Transform self receiver to _self parameter
327                let self_ty = args.self_type.as_ref().unwrap();
328                match self_receiver_kind.as_ref().unwrap() {
329                    SelfReceiver::Owned => quote!(_self: #self_ty),
330                    SelfReceiver::Ref => quote!(_self: &#self_ty),
331                    SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
332                }
333            }
334            FnArg::Typed(pat_type) => quote!(#pat_type),
335        })
336        .collect();
337
338    // Build inner function call arguments
339    let inner_args: Vec<proc_macro2::TokenStream> = inputs
340        .iter()
341        .filter_map(|arg| match arg {
342            FnArg::Typed(pat_type) => {
343                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
344                    let ident = &pat_ident.ident;
345                    Some(quote!(#ident))
346                } else {
347                    None
348                }
349            }
350            FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
351        })
352        .collect();
353
354    let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
355
356    // Choose inline attribute based on args
357    // Note: #[inline(always)] + #[target_feature] requires nightly with
358    // #![feature(target_feature_inline_always)]
359    let inline_attr: Attribute = if args.inline_always {
360        parse_quote!(#[inline(always)])
361    } else {
362        parse_quote!(#[inline])
363    };
364
365    // Transform output and body to replace Self with concrete type if needed
366    let (inner_output, inner_body): (ReturnType, syn::Block) =
367        if let Some(ref self_ty) = args.self_type {
368            let mut replacer = ReplaceSelf {
369                replacement: self_ty,
370            };
371            let transformed_output = replacer.fold_return_type(output.clone());
372            let transformed_body = replacer.fold_block((**body).clone());
373            (transformed_output, transformed_body)
374        } else {
375            (output.clone(), (**body).clone())
376        };
377
378    // Generate the expanded function
379    let expanded = quote! {
380        #(#attrs)*
381        #vis #sig {
382            #(#target_feature_attrs)*
383            #inline_attr
384            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
385            #inner_body
386
387            // SAFETY: The token parameter proves the required CPU features are available.
388            // Tokens can only be constructed when features are verified (via try_new()
389            // runtime check or forge_token_dangerously() in a context where features are guaranteed).
390            unsafe { #inner_fn_name(#(#inner_args),*) }
391        }
392    };
393
394    expanded.into()
395}
396
397/// Mark a function as an arcane SIMD function.
398///
399/// This macro enables safe use of SIMD intrinsics by generating an inner function
400/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
401/// the token parameter type. The outer function calls the inner function unsafely,
402/// which is justified because the token parameter proves the features are available.
403///
404/// **The token is passed through to the inner function**, so you can call other
405/// token-taking functions from inside `#[arcane]`.
406///
407/// # Token Parameter Forms
408///
409/// The macro supports four forms of token parameters:
410///
411/// ## Concrete Token Types
412///
413/// ```ignore
414/// #[arcane]
415/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
416///     // AVX2 intrinsics safe here
417/// }
418/// ```
419///
420/// ## impl Trait Bounds
421///
422/// ```ignore
423/// #[arcane]
424/// fn process(token: impl Has256BitSimd, data: &[f32; 8]) -> [f32; 8] {
425///     // Accepts any token that provides 256-bit SIMD
426/// }
427/// ```
428///
429/// ## Generic Type Parameters
430///
431/// ```ignore
432/// #[arcane]
433/// fn process<T: Has256BitSimd>(token: T, data: &[f32; 8]) -> [f32; 8] {
434///     // Generic over any 256-bit-capable token
435/// }
436///
437/// // Also works with where clauses:
438/// #[arcane]
439/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
440/// where
441///     T: Has256BitSimd
442/// {
443///     // ...
444/// }
445/// ```
446///
447/// ## Methods with Self Receivers
448///
449/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
450/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
451///
452/// ```ignore
453/// use archmage::{Has256BitSimd, arcane};
454/// use wide::f32x8;
455///
456/// trait SimdOps {
457///     fn double(&self, token: impl Has256BitSimd) -> Self;
458///     fn square(self, token: impl Has256BitSimd) -> Self;
459///     fn scale(&mut self, token: impl Has256BitSimd, factor: f32);
460/// }
461///
462/// impl SimdOps for f32x8 {
463///     #[arcane(_self = f32x8)]
464///     fn double(&self, _token: impl Has256BitSimd) -> Self {
465///         // Use _self instead of self in the body
466///         *_self + *_self
467///     }
468///
469///     #[arcane(_self = f32x8)]
470///     fn square(self, _token: impl Has256BitSimd) -> Self {
471///         _self * _self
472///     }
473///
474///     #[arcane(_self = f32x8)]
475///     fn scale(&mut self, _token: impl Has256BitSimd, factor: f32) {
476///         *_self = *_self * f32x8::splat(factor);
477///     }
478/// }
479/// ```
480///
481/// **Why `_self`?** The macro generates an inner function where `self` becomes
482/// a regular parameter named `_self`. Using `_self` in your code reminds you
483/// that you're not using the normal `self` keyword.
484///
485/// **All receiver types are supported:**
486/// - `self` (by value/move) → `_self: Type`
487/// - `&self` (shared reference) → `_self: &Type`
488/// - `&mut self` (mutable reference) → `_self: &mut Type`
489///
490/// # Multiple Trait Bounds
491///
492/// When using `impl Trait` or generic bounds with multiple traits,
493/// all required features are enabled:
494///
495/// ```ignore
496/// #[arcane]
497/// fn fma_kernel(token: impl HasX64V2 + Has256BitSimd, data: &[f32; 8]) -> [f32; 8] {
498///     // Both SSE4.2 and AVX features are enabled here
499/// }
500/// ```
501///
502/// # Expansion
503///
504/// The macro expands to approximately:
505///
506/// ```ignore
507/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
508///     #[target_feature(enable = "avx2")]
509///     #[inline]
510///     unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
511///         let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
512///         let doubled = _mm256_add_ps(v, v);
513///         let mut out = [0.0f32; 8];
514///         unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
515///         out
516///     }
517///     // SAFETY: Token proves the required features are available
518///     unsafe { __simd_inner_process(token, data) }
519/// }
520/// ```
521///
522/// # Profile Tokens
523///
524/// Profile tokens automatically enable all required features:
525///
526/// ```ignore
527/// #[arcane]
528/// fn kernel(token: X64V3Token, data: &mut [f32]) {
529///     // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
530/// }
531/// ```
532///
533/// # Supported Tokens
534///
535/// - **x86_64 tiers**: `X64V2Token`, `X64V3Token` / `Desktop64` / `Avx2FmaToken`,
536///   `X64V4Token` / `Avx512Token` / `Server64`, `Avx512ModernToken`, `Avx512Fp16Token`
537/// - **ARM**: `NeonToken` / `Arm64`, `NeonAesToken`, `NeonSha3Token`, `NeonCrcToken`
538/// - **WASM**: `Simd128Token`
539///
540/// # Supported Trait Bounds
541///
542/// - **x86_64 tiers**: `HasX64V2`, `HasX64V4`
543/// - **x86_64 width**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
544/// - **ARM**: `HasNeon`, `HasNeonAes`, `HasNeonSha3`
545///
546/// Concrete token types also work as trait bounds (e.g., `impl X64V3Token`).
547///
548/// # Options
549///
550/// ## `inline_always`
551///
552/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
553/// This can improve performance by ensuring aggressive inlining, but requires
554/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
555/// the crate using the macro.
556///
557/// ```ignore
558/// #![feature(target_feature_inline_always)]
559///
560/// #[arcane(inline_always)]
561/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
562///     // Inner function will use #[inline(always)]
563/// }
564/// ```
565#[proc_macro_attribute]
566pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
567    let args = parse_macro_input!(attr as ArcaneArgs);
568    let input_fn = parse_macro_input!(item as ItemFn);
569    arcane_impl(input_fn, "arcane", args)
570}
571
572/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
573///
574/// See [`arcane`] for full documentation.
575#[proc_macro_attribute]
576pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
577    let args = parse_macro_input!(attr as ArcaneArgs);
578    let input_fn = parse_macro_input!(item as ItemFn);
579    arcane_impl(input_fn, "simd_fn", args)
580}
581
582// ============================================================================
583// Multiwidth macro for width-agnostic SIMD code
584// ============================================================================
585
586use syn::ItemMod;
587
588/// Arguments to the `#[multiwidth]` macro.
589struct MultiwidthArgs {
590    /// Include SSE (128-bit) specialization
591    sse: bool,
592    /// Include AVX2 (256-bit) specialization
593    avx2: bool,
594    /// Include AVX-512 (512-bit) specialization
595    avx512: bool,
596    /// Include WASM SIMD128 (128-bit) specialization
597    wasm: bool,
598    /// Include NEON (128-bit ARM) specialization
599    neon: bool,
600}
601
602impl Default for MultiwidthArgs {
603    fn default() -> Self {
604        Self {
605            sse: true,
606            avx2: true,
607            avx512: true,
608            wasm: true,
609            neon: true,
610        }
611    }
612}
613
614impl Parse for MultiwidthArgs {
615    fn parse(input: ParseStream) -> syn::Result<Self> {
616        let mut args = MultiwidthArgs {
617            sse: false,
618            avx2: false,
619            avx512: false,
620            wasm: false,
621            neon: false,
622        };
623
624        // If no args provided, enable all
625        if input.is_empty() {
626            return Ok(MultiwidthArgs::default());
627        }
628
629        while !input.is_empty() {
630            let ident: Ident = input.parse()?;
631            match ident.to_string().as_str() {
632                "sse" => args.sse = true,
633                "avx2" => args.avx2 = true,
634                "avx512" => args.avx512 = true,
635                "wasm" | "simd128" => args.wasm = true,
636                "neon" | "arm" => args.neon = true,
637                other => {
638                    return Err(syn::Error::new(
639                        ident.span(),
640                        format!(
641                        "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512, wasm, neon",
642                        other
643                    ),
644                    ))
645                }
646            }
647            // Consume optional comma
648            if input.peek(Token![,]) {
649                let _: Token![,] = input.parse()?;
650            }
651        }
652
653        Ok(args)
654    }
655}
656
657use generated::{WidthConfig, ARM_WIDTH_CONFIGS, WASM_WIDTH_CONFIGS, X86_WIDTH_CONFIGS};
658
659/// Generate width-specialized SIMD code.
660///
661/// This macro takes a module containing width-agnostic SIMD code and generates
662/// specialized versions for each target width (SSE, AVX2, AVX-512).
663///
664/// # Usage
665///
666/// ```ignore
667/// use archmage::multiwidth;
668///
669/// #[multiwidth]
670/// mod kernels {
671///     // Inside this module, these types are available:
672///     // - f32xN, i32xN, etc. (width-appropriate SIMD types)
673///     // - Token (the token type: X64V3Token for SSE/AVX2, or X64V4Token for AVX-512)
674///     // - LANES_F32, LANES_32, etc. (lane count constants)
675///
676///     use archmage::simd::*;
677///
678///     pub fn normalize(token: Token, data: &mut [f32]) {
679///         for chunk in data.chunks_exact_mut(LANES_F32) {
680///             let v = f32xN::load(token, chunk.try_into().unwrap());
681///             let result = v * f32xN::splat(token, 1.0 / 255.0);
682///             result.store(chunk.try_into().unwrap());
683///         }
684///     }
685/// }
686///
687/// // Generated modules:
688/// // - kernels::sse::normalize(token: X64V3Token, data: &mut [f32])
689/// // - kernels::avx2::normalize(token: X64V3Token, data: &mut [f32])
690/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32])  // if avx512 feature
691/// // - kernels::normalize(data: &mut [f32])  // runtime dispatcher
692/// ```
693///
694/// # Selective Targets
695///
696/// You can specify which targets to generate:
697///
698/// ```ignore
699/// #[multiwidth(avx2, avx512)]  // Only AVX2 and AVX-512, no SSE
700/// mod fast_kernels { ... }
701/// ```
702///
703/// # How It Works
704///
705/// 1. The macro duplicates the module content for each width target
706/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
707/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
708/// 4. A dispatcher function is generated that picks the best available at runtime
709///
710/// # Requirements
711///
712/// - Functions should use `Token` as their token parameter type
713/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
714/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
715#[proc_macro_attribute]
716pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
717    let args = parse_macro_input!(attr as MultiwidthArgs);
718    let input_mod = parse_macro_input!(item as ItemMod);
719
720    multiwidth_impl(input_mod, args)
721}
722
723/// Configuration with target arch for conditional compilation
724struct ArchConfig<'a> {
725    config: &'a WidthConfig,
726    target_arch: Option<&'static str>,
727}
728
729fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
730    let mod_name = &input_mod.ident;
731    let mod_vis = &input_mod.vis;
732    let mod_attrs = &input_mod.attrs;
733
734    // Get module content
735    let content = match &input_mod.content {
736        Some((_, items)) => items,
737        None => {
738            return syn::Error::new_spanned(
739                &input_mod,
740                "multiwidth requires an inline module (mod name { ... }), not a file module",
741            )
742            .to_compile_error()
743            .into();
744        }
745    };
746
747    // Build list of all enabled configs across architectures
748    let mut all_configs: Vec<ArchConfig> = Vec::new();
749
750    // x86_64 configs
751    for config in X86_WIDTH_CONFIGS {
752        let enabled = match config.name {
753            "sse" => args.sse,
754            "avx2" => args.avx2,
755            "avx512" => args.avx512,
756            _ => false,
757        };
758        if enabled {
759            all_configs.push(ArchConfig {
760                config,
761                target_arch: Some("x86_64"),
762            });
763        }
764    }
765
766    // WASM configs
767    if args.wasm {
768        for config in WASM_WIDTH_CONFIGS {
769            all_configs.push(ArchConfig {
770                config,
771                target_arch: Some("wasm32"),
772            });
773        }
774    }
775
776    // ARM configs
777    if args.neon {
778        for config in ARM_WIDTH_CONFIGS {
779            all_configs.push(ArchConfig {
780                config,
781                target_arch: Some("aarch64"),
782            });
783        }
784    }
785
786    // Build specialized modules
787    let mut specialized_mods = Vec::new();
788    let mut enabled_configs = Vec::new();
789
790    for arch_config in &all_configs {
791        let config = arch_config.config;
792        enabled_configs.push(config);
793
794        let width_mod_name = format_ident!("{}", config.name);
795        let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
796
797        // Transform the content: replace `use archmage::simd::*` with width-specific import
798        // and add target_feature attributes for optimization
799        let transformed_items: Vec<syn::Item> = content
800            .iter()
801            .map(|item| transform_item_for_width(item.clone(), &namespace, config))
802            .collect();
803
804        // Build cfg attributes for target arch and optional feature
805        let arch_attr = arch_config
806            .target_arch
807            .map(|arch| quote!(#[cfg(target_arch = #arch)]));
808
809        let feature_attr = config.feature.map(|f| {
810            let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
811            quote!(#[cfg(feature = #f_lit)])
812        });
813
814        specialized_mods.push(quote! {
815            #arch_attr
816            #feature_attr
817            pub mod #width_mod_name {
818                #(#transformed_items)*
819            }
820        });
821    }
822
823    // Generate dispatcher functions for each public function in the module
824    // The dispatcher is x86_64-specific (runtime feature detection)
825    // For WASM and ARM, features are compile-time only
826    let x86_configs: Vec<_> = all_configs
827        .iter()
828        .filter(|c| c.target_arch == Some("x86_64"))
829        .map(|c| c.config)
830        .collect();
831
832    // Only generate dispatcher section if we have x86 configs
833    let dispatcher_section = if !x86_configs.is_empty() {
834        let dispatchers = generate_dispatchers(content, &x86_configs);
835        quote! {
836            // Runtime dispatcher (x86_64 only - WASM/ARM use compile-time features)
837            #[cfg(target_arch = "x86_64")]
838            mod __dispatchers {
839                use super::*;
840                #dispatchers
841            }
842            #[cfg(target_arch = "x86_64")]
843            pub use __dispatchers::*;
844        }
845    } else {
846        quote! {}
847    };
848
849    let expanded = quote! {
850        #(#mod_attrs)*
851        #mod_vis mod #mod_name {
852            #(#specialized_mods)*
853
854            #dispatcher_section
855        }
856    };
857
858    expanded.into()
859}
860
861/// Transform a single item for a specific width namespace.
862fn transform_item_for_width(
863    item: syn::Item,
864    namespace: &syn::Path,
865    config: &WidthConfig,
866) -> syn::Item {
867    match item {
868        syn::Item::Use(mut use_item) => {
869            // Check if this is `use archmage::simd::*` or similar
870            if is_simd_wildcard_use(&use_item) {
871                // Replace with width-specific import
872                use_item.tree = syn::UseTree::Path(syn::UsePath {
873                    ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
874                    colon2_token: Default::default(),
875                    tree: Box::new(build_use_tree_from_path(namespace, 1)),
876                });
877            }
878            syn::Item::Use(use_item)
879        }
880        syn::Item::Fn(func) => {
881            // Transform function to use inner function pattern with target_feature
882            // This is the same pattern as #[arcane], enabling SIMD optimization
883            // without requiring -C target-cpu=native
884            transform_fn_with_target_feature(func, config)
885        }
886        other => other,
887    }
888}
889
890/// Transform a function to use the inner function pattern with target_feature.
891/// This generates:
892/// ```ignore
893/// pub fn example(token: Token, data: &[f32]) -> f32 {
894///     #[target_feature(enable = "avx2", enable = "fma")]
895///     #[inline]
896///     unsafe fn inner(token: Token, data: &[f32]) -> f32 {
897///         // original body
898///     }
899///     // SAFETY: Token proves CPU support
900///     unsafe { inner(token, data) }
901/// }
902/// ```
903fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
904    let vis = &func.vis;
905    let sig = &func.sig;
906    let fn_name = &sig.ident;
907    let generics = &sig.generics;
908    let where_clause = &generics.where_clause;
909    let inputs = &sig.inputs;
910    let output = &sig.output;
911    let body = &func.block;
912    let attrs = &func.attrs;
913
914    // Build target_feature attributes
915    let target_feature_attrs: Vec<syn::Attribute> = config
916        .target_features
917        .iter()
918        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
919        .collect();
920
921    // Build parameter list for inner function
922    let inner_params: Vec<proc_macro2::TokenStream> =
923        inputs.iter().map(|arg| quote!(#arg)).collect();
924
925    // Build argument list for calling inner function
926    let call_args: Vec<proc_macro2::TokenStream> = inputs
927        .iter()
928        .filter_map(|arg| match arg {
929            syn::FnArg::Typed(pat_type) => {
930                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
931                    let ident = &pat_ident.ident;
932                    Some(quote!(#ident))
933                } else {
934                    None
935                }
936            }
937            syn::FnArg::Receiver(_) => Some(quote!(self)),
938        })
939        .collect();
940
941    let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
942
943    let expanded = quote! {
944        #(#attrs)*
945        #vis #sig {
946            #(#target_feature_attrs)*
947            #[inline]
948            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
949            #body
950
951            // SAFETY: The Token parameter proves the required CPU features are available.
952            // Tokens can only be constructed via try_new() which checks CPU support.
953            unsafe { #inner_fn_name(#(#call_args),*) }
954        }
955    };
956
957    syn::parse2(expanded).expect("Failed to parse transformed function")
958}
959
960/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
961fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
962    fn check_tree(tree: &syn::UseTree) -> bool {
963        match tree {
964            syn::UseTree::Path(path) => {
965                let ident = path.ident.to_string();
966                if ident == "archmage" || ident == "magetypes" || ident == "crate" {
967                    check_tree_for_simd(&path.tree)
968                } else {
969                    false
970                }
971            }
972            _ => false,
973        }
974    }
975
976    fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
977        match tree {
978            syn::UseTree::Path(path) => {
979                if path.ident == "simd" {
980                    matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
981                } else {
982                    check_tree_for_simd(&path.tree)
983                }
984            }
985            _ => false,
986        }
987    }
988
989    check_tree(&use_item.tree)
990}
991
992/// Build a UseTree from a path, starting at a given segment index.
993fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
994    let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
995
996    if segments.is_empty() {
997        syn::UseTree::Glob(syn::UseGlob {
998            star_token: Default::default(),
999        })
1000    } else if segments.len() == 1 {
1001        syn::UseTree::Path(syn::UsePath {
1002            ident: segments[0].ident.clone(),
1003            colon2_token: Default::default(),
1004            tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1005                star_token: Default::default(),
1006            })),
1007        })
1008    } else {
1009        let first = &segments[0];
1010        let rest_path = syn::Path {
1011            leading_colon: None,
1012            segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1013        };
1014        syn::UseTree::Path(syn::UsePath {
1015            ident: first.ident.clone(),
1016            colon2_token: Default::default(),
1017            tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1018        })
1019    }
1020}
1021
1022/// Width-specific type names that can't be used in dispatcher signatures.
1023const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1024    "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1025];
1026
1027/// Check if a type string contains width-specific types.
1028fn contains_width_specific_type(ty_str: &str) -> bool {
1029    WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1030}
1031
1032/// Check if a function signature uses width-specific types (can't have a dispatcher).
1033fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1034    // Check return type
1035    if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1036        let ty_str = quote!(#ty).to_string();
1037        if contains_width_specific_type(&ty_str) {
1038            return true;
1039        }
1040    }
1041
1042    // Check parameters (excluding Token which we filter out anyway)
1043    for arg in &func.sig.inputs {
1044        if let syn::FnArg::Typed(pat_type) = arg {
1045            let ty = &pat_type.ty;
1046            let ty_str = quote!(#ty).to_string();
1047            // Skip Token parameters - they're filtered out for dispatchers
1048            if ty_str.contains("Token") {
1049                continue;
1050            }
1051            if contains_width_specific_type(&ty_str) {
1052                return true;
1053            }
1054        }
1055    }
1056
1057    false
1058}
1059
1060/// Generate runtime dispatcher functions for public functions.
1061///
1062/// Note: Dispatchers are only generated for functions that don't use width-specific
1063/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1064/// width-specific types can only be called via the width-specific submodules.
1065fn generate_dispatchers(
1066    content: &[syn::Item],
1067    configs: &[&WidthConfig],
1068) -> proc_macro2::TokenStream {
1069    let mut dispatchers = Vec::new();
1070
1071    for item in content {
1072        if let syn::Item::Fn(func) = item {
1073            // Only generate dispatchers for public functions
1074            if !matches!(func.vis, syn::Visibility::Public(_)) {
1075                continue;
1076            }
1077
1078            // Skip functions that use width-specific types - they can't have dispatchers
1079            if uses_width_specific_types(func) {
1080                continue;
1081            }
1082
1083            let fn_name = &func.sig.ident;
1084            let fn_generics = &func.sig.generics;
1085            let fn_output = &func.sig.output;
1086            let fn_attrs: Vec<_> = func
1087                .attrs
1088                .iter()
1089                .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1090                .collect();
1091
1092            // Filter out the token parameter from the dispatcher signature
1093            let non_token_params: Vec<_> = func
1094                .sig
1095                .inputs
1096                .iter()
1097                .filter(|arg| {
1098                    match arg {
1099                        syn::FnArg::Typed(pat_type) => {
1100                            // Check if type contains "Token"
1101                            let ty = &pat_type.ty;
1102                            let ty_str = quote!(#ty).to_string();
1103                            !ty_str.contains("Token")
1104                        }
1105                        _ => true,
1106                    }
1107                })
1108                .collect();
1109
1110            // Extract just the parameter names for passing to specialized functions
1111            let param_names: Vec<_> = non_token_params
1112                .iter()
1113                .filter_map(|arg| match arg {
1114                    syn::FnArg::Typed(pat_type) => {
1115                        if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1116                            Some(&pat_ident.ident)
1117                        } else {
1118                            None
1119                        }
1120                    }
1121                    _ => None,
1122                })
1123                .collect();
1124
1125            // Generate dispatch branches (highest capability first)
1126            let mut branches = Vec::new();
1127
1128            for config in configs.iter().rev() {
1129                let mod_name = format_ident!("{}", config.name);
1130                let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1131
1132                let feature_check = config.feature.map(|f| {
1133                    let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1134                    quote!(#[cfg(feature = #f_lit)])
1135                });
1136
1137                branches.push(quote! {
1138                    #feature_check
1139                    if let Some(token) = #token_path::try_new() {
1140                        return #mod_name::#fn_name(token, #(#param_names),*);
1141                    }
1142                });
1143            }
1144
1145            // Generate dispatcher
1146            dispatchers.push(quote! {
1147                #(#fn_attrs)*
1148                /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1149                pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1150                    use archmage::SimdToken;
1151
1152                    #(#branches)*
1153
1154                    // Fallback: panic if no SIMD available
1155                    // TODO: Allow user-provided scalar fallback
1156                    panic!("No SIMD support available for {}", stringify!(#fn_name));
1157                }
1158            });
1159        }
1160    }
1161
1162    quote! { #(#dispatchers)* }
1163}
1164
1165// =============================================================================
1166// Unit tests for token/trait recognition maps
1167// =============================================================================
1168
1169#[cfg(test)]
1170mod tests {
1171    use super::*;
1172
1173    use super::generated::{ALL_CONCRETE_TOKENS, ALL_TRAIT_NAMES};
1174
1175    #[test]
1176    fn every_concrete_token_is_in_token_to_features() {
1177        for &name in ALL_CONCRETE_TOKENS {
1178            assert!(
1179                token_to_features(name).is_some(),
1180                "Token `{}` exists in runtime crate but is NOT recognized by \
1181                 token_to_features() in the proc macro. Add it!",
1182                name
1183            );
1184        }
1185    }
1186
1187    #[test]
1188    fn every_trait_is_in_trait_to_features() {
1189        for &name in ALL_TRAIT_NAMES {
1190            assert!(
1191                trait_to_features(name).is_some(),
1192                "Trait `{}` exists in runtime crate but is NOT recognized by \
1193                 trait_to_features() in the proc macro. Add it!",
1194                name
1195            );
1196        }
1197    }
1198
1199    #[test]
1200    fn token_aliases_map_to_same_features() {
1201        // Desktop64 = X64V3Token
1202        assert_eq!(
1203            token_to_features("Desktop64"),
1204            token_to_features("X64V3Token"),
1205            "Desktop64 and X64V3Token should map to identical features"
1206        );
1207
1208        // Server64 = X64V4Token = Avx512Token
1209        assert_eq!(
1210            token_to_features("Server64"),
1211            token_to_features("X64V4Token"),
1212            "Server64 and X64V4Token should map to identical features"
1213        );
1214        assert_eq!(
1215            token_to_features("X64V4Token"),
1216            token_to_features("Avx512Token"),
1217            "X64V4Token and Avx512Token should map to identical features"
1218        );
1219
1220        // Arm64 = NeonToken
1221        assert_eq!(
1222            token_to_features("Arm64"),
1223            token_to_features("NeonToken"),
1224            "Arm64 and NeonToken should map to identical features"
1225        );
1226    }
1227
1228    #[test]
1229    fn trait_to_features_includes_tokens_as_bounds() {
1230        // Tier tokens should also work as trait bounds
1231        // (for `impl X64V3Token` patterns, even though Rust won't allow it,
1232        // the macro processes AST before type checking)
1233        let tier_tokens = [
1234            "X64V2Token",
1235            "X64V3Token",
1236            "Desktop64",
1237            "Avx2FmaToken",
1238            "X64V4Token",
1239            "Avx512Token",
1240            "Server64",
1241            "Avx512ModernToken",
1242            "Avx512Fp16Token",
1243            "NeonToken",
1244            "Arm64",
1245            "NeonAesToken",
1246            "NeonSha3Token",
1247            "NeonCrcToken",
1248        ];
1249
1250        for &name in &tier_tokens {
1251            assert!(
1252                trait_to_features(name).is_some(),
1253                "Tier token `{}` should also be recognized in trait_to_features() \
1254                 for use as a generic bound. Add it!",
1255                name
1256            );
1257        }
1258    }
1259
1260    #[test]
1261    fn trait_features_are_cumulative() {
1262        // HasX64V4 should include all HasX64V2 features plus more
1263        let v2_features = trait_to_features("HasX64V2").unwrap();
1264        let v4_features = trait_to_features("HasX64V4").unwrap();
1265
1266        for &f in v2_features {
1267            assert!(
1268                v4_features.contains(&f),
1269                "HasX64V4 should include v2 feature `{}` but doesn't",
1270                f
1271            );
1272        }
1273
1274        // v4 should have more features than v2
1275        assert!(
1276            v4_features.len() > v2_features.len(),
1277            "HasX64V4 should have more features than HasX64V2"
1278        );
1279    }
1280
1281    #[test]
1282    fn x64v3_trait_features_include_v2() {
1283        // X64V3Token as trait bound should include v2 features
1284        let v2 = trait_to_features("HasX64V2").unwrap();
1285        let v3 = trait_to_features("X64V3Token").unwrap();
1286
1287        for &f in v2 {
1288            assert!(
1289                v3.contains(&f),
1290                "X64V3Token trait features should include v2 feature `{}` but don't",
1291                f
1292            );
1293        }
1294    }
1295
1296    #[test]
1297    fn has_neon_aes_includes_neon() {
1298        let neon = trait_to_features("HasNeon").unwrap();
1299        let neon_aes = trait_to_features("HasNeonAes").unwrap();
1300
1301        for &f in neon {
1302            assert!(
1303                neon_aes.contains(&f),
1304                "HasNeonAes should include NEON feature `{}`",
1305                f
1306            );
1307        }
1308    }
1309
1310    #[test]
1311    fn no_removed_traits_are_recognized() {
1312        // These traits were removed in 0.3.0 and should NOT be recognized
1313        let removed = [
1314            "HasSse",
1315            "HasSse2",
1316            "HasSse41",
1317            "HasSse42",
1318            "HasAvx",
1319            "HasAvx2",
1320            "HasFma",
1321            "HasAvx512f",
1322            "HasAvx512bw",
1323            "HasAvx512vl",
1324            "HasAvx512vbmi2",
1325            "HasSve",
1326            "HasSve2",
1327        ];
1328
1329        for &name in &removed {
1330            assert!(
1331                trait_to_features(name).is_none(),
1332                "Removed trait `{}` should NOT be in trait_to_features(). \
1333                 It was removed in 0.3.0 — users should migrate to tier traits.",
1334                name
1335            );
1336        }
1337    }
1338
1339    #[test]
1340    fn no_nonexistent_tokens_are_recognized() {
1341        // These tokens don't exist and should NOT be recognized
1342        let fake = [
1343            "Sse2Token",
1344            "SveToken",
1345            "Sve2Token",
1346            "Avx512VnniToken",
1347            "X64V4ModernToken",
1348            "NeonFp16Token",
1349        ];
1350
1351        for &name in &fake {
1352            assert!(
1353                token_to_features(name).is_none(),
1354                "Non-existent token `{}` should NOT be in token_to_features()",
1355                name
1356            );
1357        }
1358    }
1359}