archmage_macros/lib.rs
1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 fold::Fold,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12 ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17 replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21 fn fold_type(&mut self, ty: Type) -> Type {
22 match ty {
23 Type::Path(ref type_path) if type_path.qself.is_none() => {
24 // Check if it's just `Self`
25 if type_path.path.is_ident("Self") {
26 return self.replacement.clone();
27 }
28 // Otherwise continue folding
29 syn::fold::fold_type(self, ty)
30 }
31 _ => syn::fold::fold_type(self, ty),
32 }
33 }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39 /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40 /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41 inline_always: bool,
42 /// The concrete type to use for `self` receiver.
43 /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44 self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48 fn parse(input: ParseStream) -> syn::Result<Self> {
49 let mut args = ArcaneArgs::default();
50
51 while !input.is_empty() {
52 let ident: Ident = input.parse()?;
53 match ident.to_string().as_str() {
54 "inline_always" => args.inline_always = true,
55 "_self" => {
56 let _: Token![=] = input.parse()?;
57 args.self_type = Some(input.parse()?);
58 }
59 other => {
60 return Err(syn::Error::new(
61 ident.span(),
62 format!("unknown arcane argument: `{}`", other),
63 ))
64 }
65 }
66 // Consume optional comma
67 if input.peek(Token![,]) {
68 let _: Token![,] = input.parse()?;
69 }
70 }
71
72 Ok(args)
73 }
74}
75
76// Token-to-features and trait-to-features mappings are generated from
77// token-registry.toml by xtask. Regenerate with: cargo run -p xtask -- generate
78mod generated;
79use generated::{token_to_features, trait_to_features};
80
81/// Result of extracting token info from a type.
82enum TokenTypeInfo {
83 /// Concrete token type (e.g., `Avx2Token`)
84 Concrete(String),
85 /// impl Trait with the trait names (e.g., `impl Has256BitSimd`)
86 ImplTrait(Vec<String>),
87 /// Generic type parameter name (e.g., `T`)
88 Generic(String),
89}
90
91/// Extract token type information from a type.
92fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
93 match ty {
94 Type::Path(type_path) => {
95 // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
96 type_path.path.segments.last().map(|seg| {
97 let name = seg.ident.to_string();
98 // Check if it's a known concrete token type
99 if token_to_features(&name).is_some() {
100 TokenTypeInfo::Concrete(name)
101 } else {
102 // Might be a generic type parameter like `T`
103 TokenTypeInfo::Generic(name)
104 }
105 })
106 }
107 Type::Reference(type_ref) => {
108 // Handle &Token or &mut Token
109 extract_token_type_info(&type_ref.elem)
110 }
111 Type::ImplTrait(impl_trait) => {
112 // Handle `impl Has256BitSimd` or `impl HasX64V2 + HasNeon`
113 let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
114 if traits.is_empty() {
115 None
116 } else {
117 Some(TokenTypeInfo::ImplTrait(traits))
118 }
119 }
120 _ => None,
121 }
122}
123
124/// Extract trait names from type param bounds.
125fn extract_trait_names_from_bounds(
126 bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
127) -> Vec<String> {
128 bounds
129 .iter()
130 .filter_map(|bound| {
131 if let TypeParamBound::Trait(trait_bound) = bound {
132 trait_bound
133 .path
134 .segments
135 .last()
136 .map(|seg| seg.ident.to_string())
137 } else {
138 None
139 }
140 })
141 .collect()
142}
143
144/// Look up a generic type parameter in the function's generics.
145fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
146 // Check inline bounds first (e.g., `fn foo<T: HasX64V2>(token: T)`)
147 for param in &sig.generics.params {
148 if let GenericParam::Type(type_param) = param {
149 if type_param.ident == type_name {
150 let traits = extract_trait_names_from_bounds(&type_param.bounds);
151 if !traits.is_empty() {
152 return Some(traits);
153 }
154 }
155 }
156 }
157
158 // Check where clause (e.g., `fn foo<T>(token: T) where T: HasX64V2`)
159 if let Some(where_clause) = &sig.generics.where_clause {
160 for predicate in &where_clause.predicates {
161 if let syn::WherePredicate::Type(pred_type) = predicate {
162 if let Type::Path(type_path) = &pred_type.bounded_ty {
163 if let Some(seg) = type_path.path.segments.last() {
164 if seg.ident == type_name {
165 let traits = extract_trait_names_from_bounds(&pred_type.bounds);
166 if !traits.is_empty() {
167 return Some(traits);
168 }
169 }
170 }
171 }
172 }
173 }
174 }
175
176 None
177}
178
179/// Convert trait names to features, collecting all features from all traits.
180fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
181 let mut all_features = Vec::new();
182
183 for trait_name in trait_names {
184 if let Some(features) = trait_to_features(trait_name) {
185 for &feature in features {
186 if !all_features.contains(&feature) {
187 all_features.push(feature);
188 }
189 }
190 }
191 }
192
193 if all_features.is_empty() {
194 None
195 } else {
196 Some(all_features)
197 }
198}
199
200/// Find the first token parameter and return its name and features.
201fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
202 for arg in &sig.inputs {
203 match arg {
204 FnArg::Receiver(_) => {
205 // Self receivers (self, &self, &mut self) are not yet supported.
206 // The macro creates an inner function, and Rust's inner functions
207 // cannot have `self` parameters. Supporting this would require
208 // AST rewriting to replace `self` with a regular parameter.
209 // See the module docs for the workaround.
210 continue;
211 }
212 FnArg::Typed(PatType { pat, ty, .. }) => {
213 if let Some(info) = extract_token_type_info(ty) {
214 let features = match info {
215 TokenTypeInfo::Concrete(name) => {
216 token_to_features(&name).map(|f| f.to_vec())
217 }
218 TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
219 TokenTypeInfo::Generic(type_name) => {
220 // Look up the generic parameter's bounds
221 find_generic_bounds(sig, &type_name)
222 .and_then(|traits| traits_to_features(&traits))
223 }
224 };
225
226 if let Some(features) = features {
227 // Extract parameter name
228 if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
229 return Some((pat_ident.ident.clone(), features));
230 }
231 }
232 }
233 }
234 }
235 }
236 None
237}
238
239/// Represents the kind of self receiver and the transformed parameter.
240enum SelfReceiver {
241 /// `self` (by value/move)
242 Owned,
243 /// `&self` (shared reference)
244 Ref,
245 /// `&mut self` (mutable reference)
246 RefMut,
247}
248
249/// Shared implementation for arcane/simd_fn macros.
250fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
251 // Check for self receiver
252 let has_self_receiver = input_fn
253 .sig
254 .inputs
255 .first()
256 .map(|arg| matches!(arg, FnArg::Receiver(_)))
257 .unwrap_or(false);
258
259 // If there's a self receiver, we need _self = Type
260 if has_self_receiver && args.self_type.is_none() {
261 let msg = format!(
262 "{} with self receiver requires `_self = Type` argument.\n\
263 Example: #[{}(_self = MyType)]\n\
264 Use `_self` (not `self`) in the function body to refer to self.",
265 macro_name, macro_name
266 );
267 return syn::Error::new_spanned(&input_fn.sig, msg)
268 .to_compile_error()
269 .into();
270 }
271
272 // Find the token parameter and its features
273 let (_token_ident, features) = match find_token_param(&input_fn.sig) {
274 Some(result) => result,
275 None => {
276 let msg = format!(
277 "{} requires a token parameter. Supported forms:\n\
278 - Concrete: `token: X64V3Token`\n\
279 - impl Trait: `token: impl Has256BitSimd`\n\
280 - Generic: `fn foo<T: HasX64V2>(token: T, ...)`\n\
281 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasNeon, ...)`",
282 macro_name, macro_name
283 );
284 return syn::Error::new_spanned(&input_fn.sig, msg)
285 .to_compile_error()
286 .into();
287 }
288 };
289
290 // Build target_feature attributes
291 let target_feature_attrs: Vec<Attribute> = features
292 .iter()
293 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
294 .collect();
295
296 // Extract function components
297 let vis = &input_fn.vis;
298 let sig = &input_fn.sig;
299 let fn_name = &sig.ident;
300 let generics = &sig.generics;
301 let where_clause = &generics.where_clause;
302 let inputs = &sig.inputs;
303 let output = &sig.output;
304 let body = &input_fn.block;
305 let attrs = &input_fn.attrs;
306
307 // Determine self receiver type if present
308 let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
309 FnArg::Receiver(receiver) => {
310 if receiver.reference.is_none() {
311 Some(SelfReceiver::Owned)
312 } else if receiver.mutability.is_some() {
313 Some(SelfReceiver::RefMut)
314 } else {
315 Some(SelfReceiver::Ref)
316 }
317 }
318 _ => None,
319 });
320
321 // Build inner function parameters, transforming self if needed
322 let inner_params: Vec<proc_macro2::TokenStream> = inputs
323 .iter()
324 .map(|arg| match arg {
325 FnArg::Receiver(_) => {
326 // Transform self receiver to _self parameter
327 let self_ty = args.self_type.as_ref().unwrap();
328 match self_receiver_kind.as_ref().unwrap() {
329 SelfReceiver::Owned => quote!(_self: #self_ty),
330 SelfReceiver::Ref => quote!(_self: &#self_ty),
331 SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
332 }
333 }
334 FnArg::Typed(pat_type) => quote!(#pat_type),
335 })
336 .collect();
337
338 // Build inner function call arguments
339 let inner_args: Vec<proc_macro2::TokenStream> = inputs
340 .iter()
341 .filter_map(|arg| match arg {
342 FnArg::Typed(pat_type) => {
343 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
344 let ident = &pat_ident.ident;
345 Some(quote!(#ident))
346 } else {
347 None
348 }
349 }
350 FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
351 })
352 .collect();
353
354 let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
355
356 // Choose inline attribute based on args
357 // Note: #[inline(always)] + #[target_feature] requires nightly with
358 // #![feature(target_feature_inline_always)]
359 let inline_attr: Attribute = if args.inline_always {
360 parse_quote!(#[inline(always)])
361 } else {
362 parse_quote!(#[inline])
363 };
364
365 // Transform output and body to replace Self with concrete type if needed
366 let (inner_output, inner_body): (ReturnType, syn::Block) =
367 if let Some(ref self_ty) = args.self_type {
368 let mut replacer = ReplaceSelf {
369 replacement: self_ty,
370 };
371 let transformed_output = replacer.fold_return_type(output.clone());
372 let transformed_body = replacer.fold_block((**body).clone());
373 (transformed_output, transformed_body)
374 } else {
375 (output.clone(), (**body).clone())
376 };
377
378 // Generate the expanded function
379 let expanded = quote! {
380 #(#attrs)*
381 #vis #sig {
382 #(#target_feature_attrs)*
383 #inline_attr
384 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
385 #inner_body
386
387 // SAFETY: The token parameter proves the required CPU features are available.
388 // Tokens can only be constructed when features are verified (via try_new()
389 // runtime check or forge_token_dangerously() in a context where features are guaranteed).
390 unsafe { #inner_fn_name(#(#inner_args),*) }
391 }
392 };
393
394 expanded.into()
395}
396
397/// Mark a function as an arcane SIMD function.
398///
399/// This macro enables safe use of SIMD intrinsics by generating an inner function
400/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
401/// the token parameter type. The outer function calls the inner function unsafely,
402/// which is justified because the token parameter proves the features are available.
403///
404/// **The token is passed through to the inner function**, so you can call other
405/// token-taking functions from inside `#[arcane]`.
406///
407/// # Token Parameter Forms
408///
409/// The macro supports four forms of token parameters:
410///
411/// ## Concrete Token Types
412///
413/// ```ignore
414/// #[arcane]
415/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
416/// // AVX2 intrinsics safe here
417/// }
418/// ```
419///
420/// ## impl Trait Bounds
421///
422/// ```ignore
423/// #[arcane]
424/// fn process(token: impl Has256BitSimd, data: &[f32; 8]) -> [f32; 8] {
425/// // Accepts any token that provides 256-bit SIMD
426/// }
427/// ```
428///
429/// ## Generic Type Parameters
430///
431/// ```ignore
432/// #[arcane]
433/// fn process<T: Has256BitSimd>(token: T, data: &[f32; 8]) -> [f32; 8] {
434/// // Generic over any 256-bit-capable token
435/// }
436///
437/// // Also works with where clauses:
438/// #[arcane]
439/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
440/// where
441/// T: Has256BitSimd
442/// {
443/// // ...
444/// }
445/// ```
446///
447/// ## Methods with Self Receivers
448///
449/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
450/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
451///
452/// ```ignore
453/// use archmage::{Has256BitSimd, arcane};
454/// use wide::f32x8;
455///
456/// trait SimdOps {
457/// fn double(&self, token: impl Has256BitSimd) -> Self;
458/// fn square(self, token: impl Has256BitSimd) -> Self;
459/// fn scale(&mut self, token: impl Has256BitSimd, factor: f32);
460/// }
461///
462/// impl SimdOps for f32x8 {
463/// #[arcane(_self = f32x8)]
464/// fn double(&self, _token: impl Has256BitSimd) -> Self {
465/// // Use _self instead of self in the body
466/// *_self + *_self
467/// }
468///
469/// #[arcane(_self = f32x8)]
470/// fn square(self, _token: impl Has256BitSimd) -> Self {
471/// _self * _self
472/// }
473///
474/// #[arcane(_self = f32x8)]
475/// fn scale(&mut self, _token: impl Has256BitSimd, factor: f32) {
476/// *_self = *_self * f32x8::splat(factor);
477/// }
478/// }
479/// ```
480///
481/// **Why `_self`?** The macro generates an inner function where `self` becomes
482/// a regular parameter named `_self`. Using `_self` in your code reminds you
483/// that you're not using the normal `self` keyword.
484///
485/// **All receiver types are supported:**
486/// - `self` (by value/move) → `_self: Type`
487/// - `&self` (shared reference) → `_self: &Type`
488/// - `&mut self` (mutable reference) → `_self: &mut Type`
489///
490/// # Multiple Trait Bounds
491///
492/// When using `impl Trait` or generic bounds with multiple traits,
493/// all required features are enabled:
494///
495/// ```ignore
496/// #[arcane]
497/// fn fma_kernel(token: impl HasX64V2 + Has256BitSimd, data: &[f32; 8]) -> [f32; 8] {
498/// // Both SSE4.2 and AVX features are enabled here
499/// }
500/// ```
501///
502/// # Expansion
503///
504/// The macro expands to approximately:
505///
506/// ```ignore
507/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
508/// #[target_feature(enable = "avx2")]
509/// #[inline]
510/// unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
511/// let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
512/// let doubled = _mm256_add_ps(v, v);
513/// let mut out = [0.0f32; 8];
514/// unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
515/// out
516/// }
517/// // SAFETY: Token proves the required features are available
518/// unsafe { __simd_inner_process(token, data) }
519/// }
520/// ```
521///
522/// # Profile Tokens
523///
524/// Profile tokens automatically enable all required features:
525///
526/// ```ignore
527/// #[arcane]
528/// fn kernel(token: X64V3Token, data: &mut [f32]) {
529/// // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
530/// }
531/// ```
532///
533/// # Supported Tokens
534///
535/// - **x86_64 tiers**: `X64V2Token`, `X64V3Token` / `Desktop64` / `Avx2FmaToken`,
536/// `X64V4Token` / `Avx512Token` / `Server64`, `Avx512ModernToken`, `Avx512Fp16Token`
537/// - **ARM**: `NeonToken` / `Arm64`, `NeonAesToken`, `NeonSha3Token`, `NeonCrcToken`
538/// - **WASM**: `Simd128Token`
539///
540/// # Supported Trait Bounds
541///
542/// - **x86_64 tiers**: `HasX64V2`, `HasX64V4`
543/// - **x86_64 width**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
544/// - **ARM**: `HasNeon`, `HasNeonAes`, `HasNeonSha3`
545///
546/// Concrete token types also work as trait bounds (e.g., `impl X64V3Token`).
547///
548/// # Options
549///
550/// ## `inline_always`
551///
552/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
553/// This can improve performance by ensuring aggressive inlining, but requires
554/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
555/// the crate using the macro.
556///
557/// ```ignore
558/// #![feature(target_feature_inline_always)]
559///
560/// #[arcane(inline_always)]
561/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
562/// // Inner function will use #[inline(always)]
563/// }
564/// ```
565#[proc_macro_attribute]
566pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
567 let args = parse_macro_input!(attr as ArcaneArgs);
568 let input_fn = parse_macro_input!(item as ItemFn);
569 arcane_impl(input_fn, "arcane", args)
570}
571
572/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
573///
574/// See [`arcane`] for full documentation.
575#[proc_macro_attribute]
576pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
577 let args = parse_macro_input!(attr as ArcaneArgs);
578 let input_fn = parse_macro_input!(item as ItemFn);
579 arcane_impl(input_fn, "simd_fn", args)
580}
581
582// ============================================================================
583// Multiwidth macro for width-agnostic SIMD code
584// ============================================================================
585
586use syn::ItemMod;
587
588/// Arguments to the `#[multiwidth]` macro.
589struct MultiwidthArgs {
590 /// Include SSE (128-bit) specialization
591 sse: bool,
592 /// Include AVX2 (256-bit) specialization
593 avx2: bool,
594 /// Include AVX-512 (512-bit) specialization
595 avx512: bool,
596 /// Include WASM SIMD128 (128-bit) specialization
597 wasm: bool,
598 /// Include NEON (128-bit ARM) specialization
599 neon: bool,
600}
601
602impl Default for MultiwidthArgs {
603 fn default() -> Self {
604 Self {
605 sse: true,
606 avx2: true,
607 avx512: true,
608 wasm: true,
609 neon: true,
610 }
611 }
612}
613
614impl Parse for MultiwidthArgs {
615 fn parse(input: ParseStream) -> syn::Result<Self> {
616 let mut args = MultiwidthArgs {
617 sse: false,
618 avx2: false,
619 avx512: false,
620 wasm: false,
621 neon: false,
622 };
623
624 // If no args provided, enable all
625 if input.is_empty() {
626 return Ok(MultiwidthArgs::default());
627 }
628
629 while !input.is_empty() {
630 let ident: Ident = input.parse()?;
631 match ident.to_string().as_str() {
632 "sse" => args.sse = true,
633 "avx2" => args.avx2 = true,
634 "avx512" => args.avx512 = true,
635 "wasm" | "simd128" => args.wasm = true,
636 "neon" | "arm" => args.neon = true,
637 other => {
638 return Err(syn::Error::new(
639 ident.span(),
640 format!(
641 "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512, wasm, neon",
642 other
643 ),
644 ))
645 }
646 }
647 // Consume optional comma
648 if input.peek(Token![,]) {
649 let _: Token![,] = input.parse()?;
650 }
651 }
652
653 Ok(args)
654 }
655}
656
657use generated::{WidthConfig, ARM_WIDTH_CONFIGS, WASM_WIDTH_CONFIGS, X86_WIDTH_CONFIGS};
658
659/// Generate width-specialized SIMD code.
660///
661/// This macro takes a module containing width-agnostic SIMD code and generates
662/// specialized versions for each target width (SSE, AVX2, AVX-512).
663///
664/// # Usage
665///
666/// ```ignore
667/// use archmage::multiwidth;
668///
669/// #[multiwidth]
670/// mod kernels {
671/// // Inside this module, these types are available:
672/// // - f32xN, i32xN, etc. (width-appropriate SIMD types)
673/// // - Token (the token type: X64V3Token for SSE/AVX2, or X64V4Token for AVX-512)
674/// // - LANES_F32, LANES_32, etc. (lane count constants)
675///
676/// use archmage::simd::*;
677///
678/// pub fn normalize(token: Token, data: &mut [f32]) {
679/// for chunk in data.chunks_exact_mut(LANES_F32) {
680/// let v = f32xN::load(token, chunk.try_into().unwrap());
681/// let result = v * f32xN::splat(token, 1.0 / 255.0);
682/// result.store(chunk.try_into().unwrap());
683/// }
684/// }
685/// }
686///
687/// // Generated modules:
688/// // - kernels::sse::normalize(token: X64V3Token, data: &mut [f32])
689/// // - kernels::avx2::normalize(token: X64V3Token, data: &mut [f32])
690/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32]) // if avx512 feature
691/// // - kernels::normalize(data: &mut [f32]) // runtime dispatcher
692/// ```
693///
694/// # Selective Targets
695///
696/// You can specify which targets to generate:
697///
698/// ```ignore
699/// #[multiwidth(avx2, avx512)] // Only AVX2 and AVX-512, no SSE
700/// mod fast_kernels { ... }
701/// ```
702///
703/// # How It Works
704///
705/// 1. The macro duplicates the module content for each width target
706/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
707/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
708/// 4. A dispatcher function is generated that picks the best available at runtime
709///
710/// # Requirements
711///
712/// - Functions should use `Token` as their token parameter type
713/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
714/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
715#[proc_macro_attribute]
716pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
717 let args = parse_macro_input!(attr as MultiwidthArgs);
718 let input_mod = parse_macro_input!(item as ItemMod);
719
720 multiwidth_impl(input_mod, args)
721}
722
723/// Configuration with target arch for conditional compilation
724struct ArchConfig<'a> {
725 config: &'a WidthConfig,
726 target_arch: Option<&'static str>,
727}
728
729fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
730 let mod_name = &input_mod.ident;
731 let mod_vis = &input_mod.vis;
732 let mod_attrs = &input_mod.attrs;
733
734 // Get module content
735 let content = match &input_mod.content {
736 Some((_, items)) => items,
737 None => {
738 return syn::Error::new_spanned(
739 &input_mod,
740 "multiwidth requires an inline module (mod name { ... }), not a file module",
741 )
742 .to_compile_error()
743 .into();
744 }
745 };
746
747 // Build list of all enabled configs across architectures
748 let mut all_configs: Vec<ArchConfig> = Vec::new();
749
750 // x86_64 configs
751 for config in X86_WIDTH_CONFIGS {
752 let enabled = match config.name {
753 "sse" => args.sse,
754 "avx2" => args.avx2,
755 "avx512" => args.avx512,
756 _ => false,
757 };
758 if enabled {
759 all_configs.push(ArchConfig {
760 config,
761 target_arch: Some("x86_64"),
762 });
763 }
764 }
765
766 // WASM configs
767 if args.wasm {
768 for config in WASM_WIDTH_CONFIGS {
769 all_configs.push(ArchConfig {
770 config,
771 target_arch: Some("wasm32"),
772 });
773 }
774 }
775
776 // ARM configs
777 if args.neon {
778 for config in ARM_WIDTH_CONFIGS {
779 all_configs.push(ArchConfig {
780 config,
781 target_arch: Some("aarch64"),
782 });
783 }
784 }
785
786 // Build specialized modules
787 let mut specialized_mods = Vec::new();
788 let mut enabled_configs = Vec::new();
789
790 for arch_config in &all_configs {
791 let config = arch_config.config;
792 enabled_configs.push(config);
793
794 let width_mod_name = format_ident!("{}", config.name);
795 let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
796
797 // Transform the content: replace `use archmage::simd::*` with width-specific import
798 // and add target_feature attributes for optimization
799 let transformed_items: Vec<syn::Item> = content
800 .iter()
801 .map(|item| transform_item_for_width(item.clone(), &namespace, config))
802 .collect();
803
804 // Build cfg attributes for target arch and optional feature
805 let arch_attr = arch_config
806 .target_arch
807 .map(|arch| quote!(#[cfg(target_arch = #arch)]));
808
809 let feature_attr = config.feature.map(|f| {
810 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
811 quote!(#[cfg(feature = #f_lit)])
812 });
813
814 specialized_mods.push(quote! {
815 #arch_attr
816 #feature_attr
817 pub mod #width_mod_name {
818 #(#transformed_items)*
819 }
820 });
821 }
822
823 // Generate dispatcher functions for each public function in the module
824 // The dispatcher is x86_64-specific (runtime feature detection)
825 // For WASM and ARM, features are compile-time only
826 let x86_configs: Vec<_> = all_configs
827 .iter()
828 .filter(|c| c.target_arch == Some("x86_64"))
829 .map(|c| c.config)
830 .collect();
831
832 // Only generate dispatcher section if we have x86 configs
833 let dispatcher_section = if !x86_configs.is_empty() {
834 let dispatchers = generate_dispatchers(content, &x86_configs);
835 quote! {
836 // Runtime dispatcher (x86_64 only - WASM/ARM use compile-time features)
837 #[cfg(target_arch = "x86_64")]
838 mod __dispatchers {
839 use super::*;
840 #dispatchers
841 }
842 #[cfg(target_arch = "x86_64")]
843 pub use __dispatchers::*;
844 }
845 } else {
846 quote! {}
847 };
848
849 let expanded = quote! {
850 #(#mod_attrs)*
851 #mod_vis mod #mod_name {
852 #(#specialized_mods)*
853
854 #dispatcher_section
855 }
856 };
857
858 expanded.into()
859}
860
861/// Transform a single item for a specific width namespace.
862fn transform_item_for_width(
863 item: syn::Item,
864 namespace: &syn::Path,
865 config: &WidthConfig,
866) -> syn::Item {
867 match item {
868 syn::Item::Use(mut use_item) => {
869 // Check if this is `use archmage::simd::*` or similar
870 if is_simd_wildcard_use(&use_item) {
871 // Replace with width-specific import
872 use_item.tree = syn::UseTree::Path(syn::UsePath {
873 ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
874 colon2_token: Default::default(),
875 tree: Box::new(build_use_tree_from_path(namespace, 1)),
876 });
877 }
878 syn::Item::Use(use_item)
879 }
880 syn::Item::Fn(func) => {
881 // Transform function to use inner function pattern with target_feature
882 // This is the same pattern as #[arcane], enabling SIMD optimization
883 // without requiring -C target-cpu=native
884 transform_fn_with_target_feature(func, config)
885 }
886 other => other,
887 }
888}
889
890/// Transform a function to use the inner function pattern with target_feature.
891/// This generates:
892/// ```ignore
893/// pub fn example(token: Token, data: &[f32]) -> f32 {
894/// #[target_feature(enable = "avx2", enable = "fma")]
895/// #[inline]
896/// unsafe fn inner(token: Token, data: &[f32]) -> f32 {
897/// // original body
898/// }
899/// // SAFETY: Token proves CPU support
900/// unsafe { inner(token, data) }
901/// }
902/// ```
903fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
904 let vis = &func.vis;
905 let sig = &func.sig;
906 let fn_name = &sig.ident;
907 let generics = &sig.generics;
908 let where_clause = &generics.where_clause;
909 let inputs = &sig.inputs;
910 let output = &sig.output;
911 let body = &func.block;
912 let attrs = &func.attrs;
913
914 // Build target_feature attributes
915 let target_feature_attrs: Vec<syn::Attribute> = config
916 .target_features
917 .iter()
918 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
919 .collect();
920
921 // Build parameter list for inner function
922 let inner_params: Vec<proc_macro2::TokenStream> =
923 inputs.iter().map(|arg| quote!(#arg)).collect();
924
925 // Build argument list for calling inner function
926 let call_args: Vec<proc_macro2::TokenStream> = inputs
927 .iter()
928 .filter_map(|arg| match arg {
929 syn::FnArg::Typed(pat_type) => {
930 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
931 let ident = &pat_ident.ident;
932 Some(quote!(#ident))
933 } else {
934 None
935 }
936 }
937 syn::FnArg::Receiver(_) => Some(quote!(self)),
938 })
939 .collect();
940
941 let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
942
943 let expanded = quote! {
944 #(#attrs)*
945 #vis #sig {
946 #(#target_feature_attrs)*
947 #[inline]
948 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
949 #body
950
951 // SAFETY: The Token parameter proves the required CPU features are available.
952 // Tokens can only be constructed via try_new() which checks CPU support.
953 unsafe { #inner_fn_name(#(#call_args),*) }
954 }
955 };
956
957 syn::parse2(expanded).expect("Failed to parse transformed function")
958}
959
960/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
961fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
962 fn check_tree(tree: &syn::UseTree) -> bool {
963 match tree {
964 syn::UseTree::Path(path) => {
965 let ident = path.ident.to_string();
966 if ident == "archmage" || ident == "magetypes" || ident == "crate" {
967 check_tree_for_simd(&path.tree)
968 } else {
969 false
970 }
971 }
972 _ => false,
973 }
974 }
975
976 fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
977 match tree {
978 syn::UseTree::Path(path) => {
979 if path.ident == "simd" {
980 matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
981 } else {
982 check_tree_for_simd(&path.tree)
983 }
984 }
985 _ => false,
986 }
987 }
988
989 check_tree(&use_item.tree)
990}
991
992/// Build a UseTree from a path, starting at a given segment index.
993fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
994 let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
995
996 if segments.is_empty() {
997 syn::UseTree::Glob(syn::UseGlob {
998 star_token: Default::default(),
999 })
1000 } else if segments.len() == 1 {
1001 syn::UseTree::Path(syn::UsePath {
1002 ident: segments[0].ident.clone(),
1003 colon2_token: Default::default(),
1004 tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1005 star_token: Default::default(),
1006 })),
1007 })
1008 } else {
1009 let first = &segments[0];
1010 let rest_path = syn::Path {
1011 leading_colon: None,
1012 segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1013 };
1014 syn::UseTree::Path(syn::UsePath {
1015 ident: first.ident.clone(),
1016 colon2_token: Default::default(),
1017 tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1018 })
1019 }
1020}
1021
1022/// Width-specific type names that can't be used in dispatcher signatures.
1023const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1024 "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1025];
1026
1027/// Check if a type string contains width-specific types.
1028fn contains_width_specific_type(ty_str: &str) -> bool {
1029 WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1030}
1031
1032/// Check if a function signature uses width-specific types (can't have a dispatcher).
1033fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1034 // Check return type
1035 if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1036 let ty_str = quote!(#ty).to_string();
1037 if contains_width_specific_type(&ty_str) {
1038 return true;
1039 }
1040 }
1041
1042 // Check parameters (excluding Token which we filter out anyway)
1043 for arg in &func.sig.inputs {
1044 if let syn::FnArg::Typed(pat_type) = arg {
1045 let ty = &pat_type.ty;
1046 let ty_str = quote!(#ty).to_string();
1047 // Skip Token parameters - they're filtered out for dispatchers
1048 if ty_str.contains("Token") {
1049 continue;
1050 }
1051 if contains_width_specific_type(&ty_str) {
1052 return true;
1053 }
1054 }
1055 }
1056
1057 false
1058}
1059
1060/// Generate runtime dispatcher functions for public functions.
1061///
1062/// Note: Dispatchers are only generated for functions that don't use width-specific
1063/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1064/// width-specific types can only be called via the width-specific submodules.
1065fn generate_dispatchers(
1066 content: &[syn::Item],
1067 configs: &[&WidthConfig],
1068) -> proc_macro2::TokenStream {
1069 let mut dispatchers = Vec::new();
1070
1071 for item in content {
1072 if let syn::Item::Fn(func) = item {
1073 // Only generate dispatchers for public functions
1074 if !matches!(func.vis, syn::Visibility::Public(_)) {
1075 continue;
1076 }
1077
1078 // Skip functions that use width-specific types - they can't have dispatchers
1079 if uses_width_specific_types(func) {
1080 continue;
1081 }
1082
1083 let fn_name = &func.sig.ident;
1084 let fn_generics = &func.sig.generics;
1085 let fn_output = &func.sig.output;
1086 let fn_attrs: Vec<_> = func
1087 .attrs
1088 .iter()
1089 .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1090 .collect();
1091
1092 // Filter out the token parameter from the dispatcher signature
1093 let non_token_params: Vec<_> = func
1094 .sig
1095 .inputs
1096 .iter()
1097 .filter(|arg| {
1098 match arg {
1099 syn::FnArg::Typed(pat_type) => {
1100 // Check if type contains "Token"
1101 let ty = &pat_type.ty;
1102 let ty_str = quote!(#ty).to_string();
1103 !ty_str.contains("Token")
1104 }
1105 _ => true,
1106 }
1107 })
1108 .collect();
1109
1110 // Extract just the parameter names for passing to specialized functions
1111 let param_names: Vec<_> = non_token_params
1112 .iter()
1113 .filter_map(|arg| match arg {
1114 syn::FnArg::Typed(pat_type) => {
1115 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1116 Some(&pat_ident.ident)
1117 } else {
1118 None
1119 }
1120 }
1121 _ => None,
1122 })
1123 .collect();
1124
1125 // Generate dispatch branches (highest capability first)
1126 let mut branches = Vec::new();
1127
1128 for config in configs.iter().rev() {
1129 let mod_name = format_ident!("{}", config.name);
1130 let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1131
1132 let feature_check = config.feature.map(|f| {
1133 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1134 quote!(#[cfg(feature = #f_lit)])
1135 });
1136
1137 branches.push(quote! {
1138 #feature_check
1139 if let Some(token) = #token_path::try_new() {
1140 return #mod_name::#fn_name(token, #(#param_names),*);
1141 }
1142 });
1143 }
1144
1145 // Generate dispatcher
1146 dispatchers.push(quote! {
1147 #(#fn_attrs)*
1148 /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1149 pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1150 use archmage::SimdToken;
1151
1152 #(#branches)*
1153
1154 // Fallback: panic if no SIMD available
1155 // TODO: Allow user-provided scalar fallback
1156 panic!("No SIMD support available for {}", stringify!(#fn_name));
1157 }
1158 });
1159 }
1160 }
1161
1162 quote! { #(#dispatchers)* }
1163}
1164
1165// =============================================================================
1166// Unit tests for token/trait recognition maps
1167// =============================================================================
1168
1169#[cfg(test)]
1170mod tests {
1171 use super::*;
1172
1173 use super::generated::{ALL_CONCRETE_TOKENS, ALL_TRAIT_NAMES};
1174
1175 #[test]
1176 fn every_concrete_token_is_in_token_to_features() {
1177 for &name in ALL_CONCRETE_TOKENS {
1178 assert!(
1179 token_to_features(name).is_some(),
1180 "Token `{}` exists in runtime crate but is NOT recognized by \
1181 token_to_features() in the proc macro. Add it!",
1182 name
1183 );
1184 }
1185 }
1186
1187 #[test]
1188 fn every_trait_is_in_trait_to_features() {
1189 for &name in ALL_TRAIT_NAMES {
1190 assert!(
1191 trait_to_features(name).is_some(),
1192 "Trait `{}` exists in runtime crate but is NOT recognized by \
1193 trait_to_features() in the proc macro. Add it!",
1194 name
1195 );
1196 }
1197 }
1198
1199 #[test]
1200 fn token_aliases_map_to_same_features() {
1201 // Desktop64 = X64V3Token
1202 assert_eq!(
1203 token_to_features("Desktop64"),
1204 token_to_features("X64V3Token"),
1205 "Desktop64 and X64V3Token should map to identical features"
1206 );
1207
1208 // Server64 = X64V4Token = Avx512Token
1209 assert_eq!(
1210 token_to_features("Server64"),
1211 token_to_features("X64V4Token"),
1212 "Server64 and X64V4Token should map to identical features"
1213 );
1214 assert_eq!(
1215 token_to_features("X64V4Token"),
1216 token_to_features("Avx512Token"),
1217 "X64V4Token and Avx512Token should map to identical features"
1218 );
1219
1220 // Arm64 = NeonToken
1221 assert_eq!(
1222 token_to_features("Arm64"),
1223 token_to_features("NeonToken"),
1224 "Arm64 and NeonToken should map to identical features"
1225 );
1226 }
1227
1228 #[test]
1229 fn trait_to_features_includes_tokens_as_bounds() {
1230 // Tier tokens should also work as trait bounds
1231 // (for `impl X64V3Token` patterns, even though Rust won't allow it,
1232 // the macro processes AST before type checking)
1233 let tier_tokens = [
1234 "X64V2Token",
1235 "X64V3Token",
1236 "Desktop64",
1237 "Avx2FmaToken",
1238 "X64V4Token",
1239 "Avx512Token",
1240 "Server64",
1241 "Avx512ModernToken",
1242 "Avx512Fp16Token",
1243 "NeonToken",
1244 "Arm64",
1245 "NeonAesToken",
1246 "NeonSha3Token",
1247 "NeonCrcToken",
1248 ];
1249
1250 for &name in &tier_tokens {
1251 assert!(
1252 trait_to_features(name).is_some(),
1253 "Tier token `{}` should also be recognized in trait_to_features() \
1254 for use as a generic bound. Add it!",
1255 name
1256 );
1257 }
1258 }
1259
1260 #[test]
1261 fn trait_features_are_cumulative() {
1262 // HasX64V4 should include all HasX64V2 features plus more
1263 let v2_features = trait_to_features("HasX64V2").unwrap();
1264 let v4_features = trait_to_features("HasX64V4").unwrap();
1265
1266 for &f in v2_features {
1267 assert!(
1268 v4_features.contains(&f),
1269 "HasX64V4 should include v2 feature `{}` but doesn't",
1270 f
1271 );
1272 }
1273
1274 // v4 should have more features than v2
1275 assert!(
1276 v4_features.len() > v2_features.len(),
1277 "HasX64V4 should have more features than HasX64V2"
1278 );
1279 }
1280
1281 #[test]
1282 fn x64v3_trait_features_include_v2() {
1283 // X64V3Token as trait bound should include v2 features
1284 let v2 = trait_to_features("HasX64V2").unwrap();
1285 let v3 = trait_to_features("X64V3Token").unwrap();
1286
1287 for &f in v2 {
1288 assert!(
1289 v3.contains(&f),
1290 "X64V3Token trait features should include v2 feature `{}` but don't",
1291 f
1292 );
1293 }
1294 }
1295
1296 #[test]
1297 fn has_neon_aes_includes_neon() {
1298 let neon = trait_to_features("HasNeon").unwrap();
1299 let neon_aes = trait_to_features("HasNeonAes").unwrap();
1300
1301 for &f in neon {
1302 assert!(
1303 neon_aes.contains(&f),
1304 "HasNeonAes should include NEON feature `{}`",
1305 f
1306 );
1307 }
1308 }
1309
1310 #[test]
1311 fn no_removed_traits_are_recognized() {
1312 // These traits were removed in 0.3.0 and should NOT be recognized
1313 let removed = [
1314 "HasSse",
1315 "HasSse2",
1316 "HasSse41",
1317 "HasSse42",
1318 "HasAvx",
1319 "HasAvx2",
1320 "HasFma",
1321 "HasAvx512f",
1322 "HasAvx512bw",
1323 "HasAvx512vl",
1324 "HasAvx512vbmi2",
1325 "HasSve",
1326 "HasSve2",
1327 ];
1328
1329 for &name in &removed {
1330 assert!(
1331 trait_to_features(name).is_none(),
1332 "Removed trait `{}` should NOT be in trait_to_features(). \
1333 It was removed in 0.3.0 — users should migrate to tier traits.",
1334 name
1335 );
1336 }
1337 }
1338
1339 #[test]
1340 fn no_nonexistent_tokens_are_recognized() {
1341 // These tokens don't exist and should NOT be recognized
1342 let fake = [
1343 "Sse2Token",
1344 "SveToken",
1345 "Sve2Token",
1346 "Avx512VnniToken",
1347 "X64V4ModernToken",
1348 "NeonFp16Token",
1349 ];
1350
1351 for &name in &fake {
1352 assert!(
1353 token_to_features(name).is_none(),
1354 "Non-existent token `{}` should NOT be in token_to_features()",
1355 name
1356 );
1357 }
1358 }
1359}