archmage_macros/
lib.rs

1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
11    Signature, Token, Type, TypeParamBound,
12};
13
14/// Arguments to the `#[arcane]` macro.
15#[derive(Default)]
16struct ArcaneArgs {
17    /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
18    /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
19    inline_always: bool,
20}
21
22impl Parse for ArcaneArgs {
23    fn parse(input: ParseStream) -> syn::Result<Self> {
24        let mut args = ArcaneArgs::default();
25
26        while !input.is_empty() {
27            let ident: Ident = input.parse()?;
28            match ident.to_string().as_str() {
29                "inline_always" => args.inline_always = true,
30                other => {
31                    return Err(syn::Error::new(
32                        ident.span(),
33                        format!("unknown arcane argument: `{}`", other),
34                    ))
35                }
36            }
37            // Consume optional comma
38            if input.peek(Token![,]) {
39                let _: Token![,] = input.parse()?;
40            }
41        }
42
43        Ok(args)
44    }
45}
46
47/// Maps a token type name to its required target features.
48fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
49    match token_name {
50        // x86_64 granular tokens
51        "Sse2Token" => Some(&["sse2"]),
52        "Sse41Token" => Some(&["sse4.1"]),
53        "Sse42Token" => Some(&["sse4.2"]),
54        "AvxToken" => Some(&["avx"]),
55        "Avx2Token" => Some(&["avx2"]),
56        "FmaToken" => Some(&["fma"]),
57        "Avx2FmaToken" => Some(&["avx2", "fma"]),
58        "Avx512fToken" => Some(&["avx512f"]),
59        "Avx512bwToken" => Some(&["avx512bw"]),
60
61        // x86_64 profile tokens
62        "X64V2Token" => Some(&["sse4.2", "popcnt"]),
63        "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
64        "X64V4Token" | "Server64" => {
65            Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
66        }
67
68        // ARM tokens
69        "NeonToken" | "Arm64" => Some(&["neon"]),
70        "SveToken" => Some(&["sve"]),
71        "Sve2Token" => Some(&["sve2"]),
72
73        // WASM tokens
74        "Simd128Token" => Some(&["simd128"]),
75
76        _ => None,
77    }
78}
79
80/// Maps a trait bound name to its required target features.
81/// Used for `impl HasAvx2` and `T: HasAvx2` style parameters.
82fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
83    match trait_name {
84        // x86 feature marker traits
85        "HasSse" => Some(&["sse"]),
86        "HasSse2" => Some(&["sse2"]),
87        "HasSse41" => Some(&["sse4.1"]),
88        "HasSse42" => Some(&["sse4.2"]),
89        "HasAvx" => Some(&["avx"]),
90        "HasAvx2" => Some(&["avx2"]),
91        "HasAvx512f" => Some(&["avx512f"]),
92        "HasAvx512vl" => Some(&["avx512f", "avx512vl"]),
93        "HasAvx512bw" => Some(&["avx512bw"]),
94        "HasAvx512vbmi2" => Some(&["avx512vbmi2"]),
95
96        // Capability marker traits - use most specific features that satisfy them
97        "HasFma" => Some(&["fma"]),
98        "Has128BitSimd" => Some(&["sse2"]),
99        "Has256BitSimd" => Some(&["avx"]),
100        "Has512BitSimd" => Some(&["avx512f"]),
101
102        // ARM feature marker traits
103        "HasNeon" => Some(&["neon"]),
104        "HasSve" => Some(&["sve"]),
105        "HasSve2" => Some(&["sve2"]),
106
107        _ => None,
108    }
109}
110
111/// Result of extracting token info from a type.
112enum TokenTypeInfo {
113    /// Concrete token type (e.g., `Avx2Token`)
114    Concrete(String),
115    /// impl Trait with the trait names (e.g., `impl HasAvx2`)
116    ImplTrait(Vec<String>),
117    /// Generic type parameter name (e.g., `T`)
118    Generic(String),
119}
120
121/// Extract token type information from a type.
122fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
123    match ty {
124        Type::Path(type_path) => {
125            // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
126            type_path.path.segments.last().map(|seg| {
127                let name = seg.ident.to_string();
128                // Check if it's a known concrete token type
129                if token_to_features(&name).is_some() {
130                    TokenTypeInfo::Concrete(name)
131                } else {
132                    // Might be a generic type parameter like `T`
133                    TokenTypeInfo::Generic(name)
134                }
135            })
136        }
137        Type::Reference(type_ref) => {
138            // Handle &Token or &mut Token
139            extract_token_type_info(&type_ref.elem)
140        }
141        Type::ImplTrait(impl_trait) => {
142            // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
143            let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
144            if traits.is_empty() {
145                None
146            } else {
147                Some(TokenTypeInfo::ImplTrait(traits))
148            }
149        }
150        _ => None,
151    }
152}
153
154/// Extract trait names from type param bounds.
155fn extract_trait_names_from_bounds(
156    bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
157) -> Vec<String> {
158    bounds
159        .iter()
160        .filter_map(|bound| {
161            if let TypeParamBound::Trait(trait_bound) = bound {
162                trait_bound
163                    .path
164                    .segments
165                    .last()
166                    .map(|seg| seg.ident.to_string())
167            } else {
168                None
169            }
170        })
171        .collect()
172}
173
174/// Look up a generic type parameter in the function's generics.
175fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
176    // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
177    for param in &sig.generics.params {
178        if let GenericParam::Type(type_param) = param {
179            if type_param.ident == type_name {
180                let traits = extract_trait_names_from_bounds(&type_param.bounds);
181                if !traits.is_empty() {
182                    return Some(traits);
183                }
184            }
185        }
186    }
187
188    // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
189    if let Some(where_clause) = &sig.generics.where_clause {
190        for predicate in &where_clause.predicates {
191            if let syn::WherePredicate::Type(pred_type) = predicate {
192                if let Type::Path(type_path) = &pred_type.bounded_ty {
193                    if let Some(seg) = type_path.path.segments.last() {
194                        if seg.ident == type_name {
195                            let traits = extract_trait_names_from_bounds(&pred_type.bounds);
196                            if !traits.is_empty() {
197                                return Some(traits);
198                            }
199                        }
200                    }
201                }
202            }
203        }
204    }
205
206    None
207}
208
209/// Convert trait names to features, collecting all features from all traits.
210fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
211    let mut all_features = Vec::new();
212
213    for trait_name in trait_names {
214        if let Some(features) = trait_to_features(trait_name) {
215            for &feature in features {
216                if !all_features.contains(&feature) {
217                    all_features.push(feature);
218                }
219            }
220        }
221    }
222
223    if all_features.is_empty() {
224        None
225    } else {
226        Some(all_features)
227    }
228}
229
230/// Find the first token parameter and return its name and features.
231fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
232    for arg in &sig.inputs {
233        match arg {
234            FnArg::Receiver(_) => {
235                // Self receivers (self, &self, &mut self) are not yet supported.
236                // The macro creates an inner function, and Rust's inner functions
237                // cannot have `self` parameters. Supporting this would require
238                // AST rewriting to replace `self` with a regular parameter.
239                // See the module docs for the workaround.
240                continue;
241            }
242            FnArg::Typed(PatType { pat, ty, .. }) => {
243                if let Some(info) = extract_token_type_info(ty) {
244                    let features = match info {
245                        TokenTypeInfo::Concrete(name) => {
246                            token_to_features(&name).map(|f| f.to_vec())
247                        }
248                        TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
249                        TokenTypeInfo::Generic(type_name) => {
250                            // Look up the generic parameter's bounds
251                            find_generic_bounds(sig, &type_name)
252                                .and_then(|traits| traits_to_features(&traits))
253                        }
254                    };
255
256                    if let Some(features) = features {
257                        // Extract parameter name
258                        if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
259                            return Some((pat_ident.ident.clone(), features));
260                        }
261                    }
262                }
263            }
264        }
265    }
266    None
267}
268
269/// Shared implementation for arcane/simd_fn macros.
270fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
271    // Find the token parameter and its features
272    let (_token_ident, features) = match find_token_param(&input_fn.sig) {
273        Some(result) => result,
274        None => {
275            let msg = format!(
276                "{} requires a token parameter. Supported forms:\n\
277                 - Concrete: `token: Avx2Token`\n\
278                 - impl Trait: `token: impl HasAvx2`\n\
279                 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
280                 Note: self receivers (&self, &mut self) are not yet supported.",
281                macro_name
282            );
283            return syn::Error::new_spanned(&input_fn.sig, msg)
284                .to_compile_error()
285                .into();
286        }
287    };
288
289    // Build target_feature attributes
290    let target_feature_attrs: Vec<Attribute> = features
291        .iter()
292        .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
293        .collect();
294
295    // Extract function components
296    let vis = &input_fn.vis;
297    let sig = &input_fn.sig;
298    let fn_name = &sig.ident;
299    let generics = &sig.generics;
300    let where_clause = &generics.where_clause;
301    let inputs = &sig.inputs;
302    let output = &sig.output;
303    let body = &input_fn.block;
304    let attrs = &input_fn.attrs;
305
306    // Build inner function parameters (ALL parameters including token)
307    let inner_params: Vec<_> = inputs.iter().cloned().collect();
308
309    // Build inner function call arguments (ALL arguments including token)
310    let inner_args: Vec<_> = inputs
311        .iter()
312        .filter_map(|arg| match arg {
313            FnArg::Typed(pat_type) => {
314                if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
315                    let ident = &pat_ident.ident;
316                    Some(quote!(#ident))
317                } else {
318                    None
319                }
320            }
321            FnArg::Receiver(_) => Some(quote!(self)),
322        })
323        .collect();
324
325    let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
326
327    // Choose inline attribute based on args
328    // Note: #[inline(always)] + #[target_feature] requires nightly with
329    // #![feature(target_feature_inline_always)]
330    let inline_attr: Attribute = if args.inline_always {
331        parse_quote!(#[inline(always)])
332    } else {
333        parse_quote!(#[inline])
334    };
335
336    // Generate the expanded function
337    let expanded = quote! {
338        #(#attrs)*
339        #vis #sig {
340            #(#target_feature_attrs)*
341            #inline_attr
342            unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
343            #body
344
345            // SAFETY: The token parameter proves the required CPU features are available.
346            // Tokens can only be constructed when features are verified (via try_new()
347            // runtime check or forge_token_dangerously() in a context where features are guaranteed).
348            unsafe { #inner_fn_name(#(#inner_args),*) }
349        }
350    };
351
352    expanded.into()
353}
354
355/// Mark a function as an arcane SIMD function.
356///
357/// This macro enables safe use of SIMD intrinsics by generating an inner function
358/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
359/// the token parameter type. The outer function calls the inner function unsafely,
360/// which is justified because the token parameter proves the features are available.
361///
362/// **The token is passed through to the inner function**, so you can call other
363/// token-taking functions from inside `#[arcane]`.
364///
365/// # Token Parameter Forms
366///
367/// The macro supports four forms of token parameters:
368///
369/// ## Concrete Token Types
370///
371/// ```ignore
372/// #[arcane]
373/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
374///     // AVX2 intrinsics safe here
375/// }
376/// ```
377///
378/// ## impl Trait Bounds
379///
380/// ```ignore
381/// #[arcane]
382/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
383///     // Accepts any token that provides AVX2
384/// }
385/// ```
386///
387/// ## Generic Type Parameters
388///
389/// ```ignore
390/// #[arcane]
391/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
392///     // Generic over any AVX2-capable token
393/// }
394///
395/// // Also works with where clauses:
396/// #[arcane]
397/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
398/// where
399///     T: HasAvx2
400/// {
401///     // ...
402/// }
403/// ```
404///
405/// ## Methods with Self Receivers (NOT YET SUPPORTED)
406///
407/// Methods with `self`, `&self`, `&mut self` receivers are **not currently supported**.
408///
409/// **Why:** The macro works by creating an inner function with `#[target_feature]`.
410/// Rust's inner functions cannot have `self` parameters—`self` only works in
411/// associated functions. Supporting this would require rewriting the function body
412/// to replace `self` with a regular parameter, which adds significant complexity.
413///
414/// **Workaround:** Use a free function with the token as an explicit parameter:
415///
416/// ```ignore
417/// impl MyProcessor {
418///     fn process(&mut self, data: &[f32; 8]) -> [f32; 8] {
419///         // Delegate to a free function
420///         process_impl(self.token, data)
421///     }
422/// }
423///
424/// #[arcane]
425/// fn process_impl(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
426///     // SIMD intrinsics safe here
427/// }
428/// ```
429///
430/// **Future work:** Supporting `self` receivers would require:
431/// 1. Adding a type parameter `__Self` to the inner function
432/// 2. Converting the receiver to a regular parameter (`&self` → `__self: &__Self`)
433/// 3. Walking the AST to replace all `self` with `__self` and `Self` with `__Self`
434/// 4. Copying where clauses with the type substitution
435///
436/// # Multiple Trait Bounds
437///
438/// When using `impl Trait` or generic bounds with multiple traits,
439/// all required features are enabled:
440///
441/// ```ignore
442/// #[arcane]
443/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
444///     // Both AVX2 and FMA intrinsics are safe here
445/// }
446/// ```
447///
448/// # Expansion
449///
450/// The macro expands to approximately:
451///
452/// ```ignore
453/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
454///     #[target_feature(enable = "avx2")]
455///     #[inline]
456///     unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
457///         let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
458///         let doubled = _mm256_add_ps(v, v);
459///         let mut out = [0.0f32; 8];
460///         unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
461///         out
462///     }
463///     // SAFETY: Token proves the required features are available
464///     unsafe { __simd_inner_process(token, data) }
465/// }
466/// ```
467///
468/// # Profile Tokens
469///
470/// Profile tokens automatically enable all required features:
471///
472/// ```ignore
473/// #[arcane]
474/// fn kernel(token: X64V3Token, data: &mut [f32]) {
475///     // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
476/// }
477/// ```
478///
479/// # Supported Tokens
480///
481/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
482///   `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
483/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
484/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
485/// - **WASM**: `Simd128Token`
486///
487/// # Supported Trait Bounds
488///
489/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
490///   `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
491/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
492/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
493///
494/// # Options
495///
496/// ## `inline_always`
497///
498/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
499/// This can improve performance by ensuring aggressive inlining, but requires
500/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
501/// the crate using the macro.
502///
503/// ```ignore
504/// #![feature(target_feature_inline_always)]
505///
506/// #[arcane(inline_always)]
507/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
508///     // Inner function will use #[inline(always)]
509/// }
510/// ```
511#[proc_macro_attribute]
512pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
513    let args = parse_macro_input!(attr as ArcaneArgs);
514    let input_fn = parse_macro_input!(item as ItemFn);
515    arcane_impl(input_fn, "arcane", args)
516}
517
518/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
519///
520/// See [`arcane`] for full documentation.
521#[proc_macro_attribute]
522pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
523    let args = parse_macro_input!(attr as ArcaneArgs);
524    let input_fn = parse_macro_input!(item as ItemFn);
525    arcane_impl(input_fn, "simd_fn", args)
526}