archmage_macros/lib.rs
1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 fold::Fold,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
12 ReturnType, Signature, Token, Type, TypeParamBound,
13};
14
15/// A Fold implementation that replaces `Self` with a concrete type.
16struct ReplaceSelf<'a> {
17 replacement: &'a Type,
18}
19
20impl Fold for ReplaceSelf<'_> {
21 fn fold_type(&mut self, ty: Type) -> Type {
22 match ty {
23 Type::Path(ref type_path) if type_path.qself.is_none() => {
24 // Check if it's just `Self`
25 if type_path.path.is_ident("Self") {
26 return self.replacement.clone();
27 }
28 // Otherwise continue folding
29 syn::fold::fold_type(self, ty)
30 }
31 _ => syn::fold::fold_type(self, ty),
32 }
33 }
34}
35
36/// Arguments to the `#[arcane]` macro.
37#[derive(Default)]
38struct ArcaneArgs {
39 /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
40 /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
41 inline_always: bool,
42 /// The concrete type to use for `self` receiver.
43 /// When specified, `self`/`&self`/`&mut self` is transformed to `_self: Type`/`&Type`/`&mut Type`.
44 self_type: Option<Type>,
45}
46
47impl Parse for ArcaneArgs {
48 fn parse(input: ParseStream) -> syn::Result<Self> {
49 let mut args = ArcaneArgs::default();
50
51 while !input.is_empty() {
52 let ident: Ident = input.parse()?;
53 match ident.to_string().as_str() {
54 "inline_always" => args.inline_always = true,
55 "_self" => {
56 let _: Token![=] = input.parse()?;
57 args.self_type = Some(input.parse()?);
58 }
59 other => {
60 return Err(syn::Error::new(
61 ident.span(),
62 format!("unknown arcane argument: `{}`", other),
63 ))
64 }
65 }
66 // Consume optional comma
67 if input.peek(Token![,]) {
68 let _: Token![,] = input.parse()?;
69 }
70 }
71
72 Ok(args)
73 }
74}
75
76/// Maps a token type name to its required target features.
77///
78/// Based on LLVM x86-64 microarchitecture levels (psABI).
79fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
80 match token_name {
81 // x86_64 feature tokens (kept for backwards compatibility)
82 "Sse41Token" => Some(&["sse4.1"]),
83 "Sse42Token" => Some(&["sse4.2"]),
84 "AvxToken" => Some(&["avx"]),
85 "Avx2Token" => Some(&["avx2"]),
86 "FmaToken" => Some(&["fma"]),
87 "Avx2FmaToken" => Some(&["avx2", "fma"]),
88 "Avx512fToken" => Some(&["avx512f"]),
89 "Avx512bwToken" => Some(&["avx512bw"]),
90
91 // x86_64 tier tokens
92 "X64V2Token" => Some(&["sse4.2", "popcnt"]),
93 "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
94 "X64V4Token" | "Avx512Token" | "Server64" => {
95 Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
96 }
97 "Avx512ModernToken" => Some(&[
98 "avx512f",
99 "avx512bw",
100 "avx512cd",
101 "avx512dq",
102 "avx512vl",
103 "avx512vpopcntdq",
104 "avx512ifma",
105 "avx512vbmi",
106 "avx512vbmi2",
107 "avx512bitalg",
108 "avx512vnni",
109 "avx512bf16",
110 "vpclmulqdq",
111 "gfni",
112 "vaes",
113 ]),
114 "Avx512Fp16Token" => Some(&[
115 "avx512f",
116 "avx512bw",
117 "avx512cd",
118 "avx512dq",
119 "avx512vl",
120 "avx512fp16",
121 ]),
122
123 // AArch64 tokens
124 "NeonToken" | "Arm64" => Some(&["neon"]),
125 "NeonAesToken" => Some(&["neon", "aes"]),
126 "NeonSha3Token" => Some(&["neon", "sha3"]),
127 "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
128 "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
129
130 // WASM tokens
131 "Simd128Token" => Some(&["simd128"]),
132
133 _ => None,
134 }
135}
136
137/// Maps a trait bound name to its required target features.
138///
139/// IMPORTANT: Each entry must include ALL implied features, not just the defining ones.
140/// The compiler needs explicit #[target_feature] for each feature used.
141///
142/// Based on LLVM x86-64 microarchitecture levels (psABI).
143fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
144 match trait_name {
145 // x86 tier traits - each includes all features from lower tiers
146 "HasX64V2" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
147 "HasX64V4" => Some(&[
148 // v2 features
149 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", // v3 features
150 "avx", "avx2", "fma", "bmi1", "bmi2", "f16c", "lzcnt", // v4 features
151 "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
152 ]),
153
154 // x86 token types - when used directly as bounds
155 "X64V2Token" => Some(&["sse3", "ssse3", "sse4.1", "sse4.2", "popcnt"]),
156 "X64V3Token" | "Desktop64" | "Avx2FmaToken" => Some(&[
157 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
158 "f16c", "lzcnt",
159 ]),
160 "X64V4Token" | "Avx512Token" | "Server64" => Some(&[
161 "sse3", "ssse3", "sse4.1", "sse4.2", "popcnt", "avx", "avx2", "fma", "bmi1", "bmi2",
162 "f16c", "lzcnt", "avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl",
163 ]),
164 "Avx512ModernToken" => Some(&[
165 "sse3",
166 "ssse3",
167 "sse4.1",
168 "sse4.2",
169 "popcnt",
170 "avx",
171 "avx2",
172 "fma",
173 "bmi1",
174 "bmi2",
175 "f16c",
176 "lzcnt",
177 "avx512f",
178 "avx512bw",
179 "avx512cd",
180 "avx512dq",
181 "avx512vl",
182 "avx512vpopcntdq",
183 "avx512ifma",
184 "avx512vbmi",
185 "avx512vbmi2",
186 "avx512bitalg",
187 "avx512vnni",
188 "avx512bf16",
189 "vpclmulqdq",
190 "gfni",
191 "vaes",
192 ]),
193 "Avx512Fp16Token" => Some(&[
194 "sse3",
195 "ssse3",
196 "sse4.1",
197 "sse4.2",
198 "popcnt",
199 "avx",
200 "avx2",
201 "fma",
202 "bmi1",
203 "bmi2",
204 "f16c",
205 "lzcnt",
206 "avx512f",
207 "avx512bw",
208 "avx512cd",
209 "avx512dq",
210 "avx512vl",
211 "avx512fp16",
212 ]),
213
214 // Width traits - minimal features to satisfy width
215 "Has128BitSimd" => Some(&["sse2"]),
216 "Has256BitSimd" => Some(&["avx"]),
217 "Has512BitSimd" => Some(&["avx512f"]),
218
219 // AArch64 traits
220 "HasNeon" => Some(&["neon"]),
221 "HasNeonAes" => Some(&["neon", "aes"]),
222 "HasNeonSha3" => Some(&["neon", "sha3"]),
223
224 // AArch64 token types
225 "NeonToken" | "Arm64" => Some(&["neon"]),
226 "NeonAesToken" => Some(&["neon", "aes"]),
227 "NeonSha3Token" => Some(&["neon", "sha3"]),
228 "ArmCryptoToken" => Some(&["neon", "aes", "sha2"]),
229 "ArmCrypto3Token" => Some(&["neon", "aes", "sha2", "sha3"]),
230
231 _ => None,
232 }
233}
234
235/// Result of extracting token info from a type.
236enum TokenTypeInfo {
237 /// Concrete token type (e.g., `Avx2Token`)
238 Concrete(String),
239 /// impl Trait with the trait names (e.g., `impl HasAvx2`)
240 ImplTrait(Vec<String>),
241 /// Generic type parameter name (e.g., `T`)
242 Generic(String),
243}
244
245/// Extract token type information from a type.
246fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
247 match ty {
248 Type::Path(type_path) => {
249 // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
250 type_path.path.segments.last().map(|seg| {
251 let name = seg.ident.to_string();
252 // Check if it's a known concrete token type
253 if token_to_features(&name).is_some() {
254 TokenTypeInfo::Concrete(name)
255 } else {
256 // Might be a generic type parameter like `T`
257 TokenTypeInfo::Generic(name)
258 }
259 })
260 }
261 Type::Reference(type_ref) => {
262 // Handle &Token or &mut Token
263 extract_token_type_info(&type_ref.elem)
264 }
265 Type::ImplTrait(impl_trait) => {
266 // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
267 let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
268 if traits.is_empty() {
269 None
270 } else {
271 Some(TokenTypeInfo::ImplTrait(traits))
272 }
273 }
274 _ => None,
275 }
276}
277
278/// Extract trait names from type param bounds.
279fn extract_trait_names_from_bounds(
280 bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
281) -> Vec<String> {
282 bounds
283 .iter()
284 .filter_map(|bound| {
285 if let TypeParamBound::Trait(trait_bound) = bound {
286 trait_bound
287 .path
288 .segments
289 .last()
290 .map(|seg| seg.ident.to_string())
291 } else {
292 None
293 }
294 })
295 .collect()
296}
297
298/// Look up a generic type parameter in the function's generics.
299fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
300 // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
301 for param in &sig.generics.params {
302 if let GenericParam::Type(type_param) = param {
303 if type_param.ident == type_name {
304 let traits = extract_trait_names_from_bounds(&type_param.bounds);
305 if !traits.is_empty() {
306 return Some(traits);
307 }
308 }
309 }
310 }
311
312 // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
313 if let Some(where_clause) = &sig.generics.where_clause {
314 for predicate in &where_clause.predicates {
315 if let syn::WherePredicate::Type(pred_type) = predicate {
316 if let Type::Path(type_path) = &pred_type.bounded_ty {
317 if let Some(seg) = type_path.path.segments.last() {
318 if seg.ident == type_name {
319 let traits = extract_trait_names_from_bounds(&pred_type.bounds);
320 if !traits.is_empty() {
321 return Some(traits);
322 }
323 }
324 }
325 }
326 }
327 }
328 }
329
330 None
331}
332
333/// Convert trait names to features, collecting all features from all traits.
334fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
335 let mut all_features = Vec::new();
336
337 for trait_name in trait_names {
338 if let Some(features) = trait_to_features(trait_name) {
339 for &feature in features {
340 if !all_features.contains(&feature) {
341 all_features.push(feature);
342 }
343 }
344 }
345 }
346
347 if all_features.is_empty() {
348 None
349 } else {
350 Some(all_features)
351 }
352}
353
354/// Find the first token parameter and return its name and features.
355fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
356 for arg in &sig.inputs {
357 match arg {
358 FnArg::Receiver(_) => {
359 // Self receivers (self, &self, &mut self) are not yet supported.
360 // The macro creates an inner function, and Rust's inner functions
361 // cannot have `self` parameters. Supporting this would require
362 // AST rewriting to replace `self` with a regular parameter.
363 // See the module docs for the workaround.
364 continue;
365 }
366 FnArg::Typed(PatType { pat, ty, .. }) => {
367 if let Some(info) = extract_token_type_info(ty) {
368 let features = match info {
369 TokenTypeInfo::Concrete(name) => {
370 token_to_features(&name).map(|f| f.to_vec())
371 }
372 TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
373 TokenTypeInfo::Generic(type_name) => {
374 // Look up the generic parameter's bounds
375 find_generic_bounds(sig, &type_name)
376 .and_then(|traits| traits_to_features(&traits))
377 }
378 };
379
380 if let Some(features) = features {
381 // Extract parameter name
382 if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
383 return Some((pat_ident.ident.clone(), features));
384 }
385 }
386 }
387 }
388 }
389 }
390 None
391}
392
393/// Represents the kind of self receiver and the transformed parameter.
394enum SelfReceiver {
395 /// `self` (by value/move)
396 Owned,
397 /// `&self` (shared reference)
398 Ref,
399 /// `&mut self` (mutable reference)
400 RefMut,
401}
402
403/// Shared implementation for arcane/simd_fn macros.
404fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
405 // Check for self receiver
406 let has_self_receiver = input_fn
407 .sig
408 .inputs
409 .first()
410 .map(|arg| matches!(arg, FnArg::Receiver(_)))
411 .unwrap_or(false);
412
413 // If there's a self receiver, we need _self = Type
414 if has_self_receiver && args.self_type.is_none() {
415 let msg = format!(
416 "{} with self receiver requires `_self = Type` argument.\n\
417 Example: #[{}(_self = MyType)]\n\
418 Use `_self` (not `self`) in the function body to refer to self.",
419 macro_name, macro_name
420 );
421 return syn::Error::new_spanned(&input_fn.sig, msg)
422 .to_compile_error()
423 .into();
424 }
425
426 // Find the token parameter and its features
427 let (_token_ident, features) = match find_token_param(&input_fn.sig) {
428 Some(result) => result,
429 None => {
430 let msg = format!(
431 "{} requires a token parameter. Supported forms:\n\
432 - Concrete: `token: Avx2Token`\n\
433 - impl Trait: `token: impl HasAvx2`\n\
434 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
435 - With self: `#[{}(_self = Type)] fn method(&self, token: impl HasAvx2, ...)`",
436 macro_name, macro_name
437 );
438 return syn::Error::new_spanned(&input_fn.sig, msg)
439 .to_compile_error()
440 .into();
441 }
442 };
443
444 // Build target_feature attributes
445 let target_feature_attrs: Vec<Attribute> = features
446 .iter()
447 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
448 .collect();
449
450 // Extract function components
451 let vis = &input_fn.vis;
452 let sig = &input_fn.sig;
453 let fn_name = &sig.ident;
454 let generics = &sig.generics;
455 let where_clause = &generics.where_clause;
456 let inputs = &sig.inputs;
457 let output = &sig.output;
458 let body = &input_fn.block;
459 let attrs = &input_fn.attrs;
460
461 // Determine self receiver type if present
462 let self_receiver_kind: Option<SelfReceiver> = inputs.first().and_then(|arg| match arg {
463 FnArg::Receiver(receiver) => {
464 if receiver.reference.is_none() {
465 Some(SelfReceiver::Owned)
466 } else if receiver.mutability.is_some() {
467 Some(SelfReceiver::RefMut)
468 } else {
469 Some(SelfReceiver::Ref)
470 }
471 }
472 _ => None,
473 });
474
475 // Build inner function parameters, transforming self if needed
476 let inner_params: Vec<proc_macro2::TokenStream> = inputs
477 .iter()
478 .map(|arg| match arg {
479 FnArg::Receiver(_) => {
480 // Transform self receiver to _self parameter
481 let self_ty = args.self_type.as_ref().unwrap();
482 match self_receiver_kind.as_ref().unwrap() {
483 SelfReceiver::Owned => quote!(_self: #self_ty),
484 SelfReceiver::Ref => quote!(_self: &#self_ty),
485 SelfReceiver::RefMut => quote!(_self: &mut #self_ty),
486 }
487 }
488 FnArg::Typed(pat_type) => quote!(#pat_type),
489 })
490 .collect();
491
492 // Build inner function call arguments
493 let inner_args: Vec<proc_macro2::TokenStream> = inputs
494 .iter()
495 .filter_map(|arg| match arg {
496 FnArg::Typed(pat_type) => {
497 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
498 let ident = &pat_ident.ident;
499 Some(quote!(#ident))
500 } else {
501 None
502 }
503 }
504 FnArg::Receiver(_) => Some(quote!(self)), // Pass self to inner as _self
505 })
506 .collect();
507
508 let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
509
510 // Choose inline attribute based on args
511 // Note: #[inline(always)] + #[target_feature] requires nightly with
512 // #![feature(target_feature_inline_always)]
513 let inline_attr: Attribute = if args.inline_always {
514 parse_quote!(#[inline(always)])
515 } else {
516 parse_quote!(#[inline])
517 };
518
519 // Transform output and body to replace Self with concrete type if needed
520 let (inner_output, inner_body): (ReturnType, syn::Block) =
521 if let Some(ref self_ty) = args.self_type {
522 let mut replacer = ReplaceSelf {
523 replacement: self_ty,
524 };
525 let transformed_output = replacer.fold_return_type(output.clone());
526 let transformed_body = replacer.fold_block((**body).clone());
527 (transformed_output, transformed_body)
528 } else {
529 (output.clone(), (**body).clone())
530 };
531
532 // Generate the expanded function
533 let expanded = quote! {
534 #(#attrs)*
535 #vis #sig {
536 #(#target_feature_attrs)*
537 #inline_attr
538 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #inner_output #where_clause
539 #inner_body
540
541 // SAFETY: The token parameter proves the required CPU features are available.
542 // Tokens can only be constructed when features are verified (via try_new()
543 // runtime check or forge_token_dangerously() in a context where features are guaranteed).
544 unsafe { #inner_fn_name(#(#inner_args),*) }
545 }
546 };
547
548 expanded.into()
549}
550
551/// Mark a function as an arcane SIMD function.
552///
553/// This macro enables safe use of SIMD intrinsics by generating an inner function
554/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
555/// the token parameter type. The outer function calls the inner function unsafely,
556/// which is justified because the token parameter proves the features are available.
557///
558/// **The token is passed through to the inner function**, so you can call other
559/// token-taking functions from inside `#[arcane]`.
560///
561/// # Token Parameter Forms
562///
563/// The macro supports four forms of token parameters:
564///
565/// ## Concrete Token Types
566///
567/// ```ignore
568/// #[arcane]
569/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
570/// // AVX2 intrinsics safe here
571/// }
572/// ```
573///
574/// ## impl Trait Bounds
575///
576/// ```ignore
577/// #[arcane]
578/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
579/// // Accepts any token that provides AVX2
580/// }
581/// ```
582///
583/// ## Generic Type Parameters
584///
585/// ```ignore
586/// #[arcane]
587/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
588/// // Generic over any AVX2-capable token
589/// }
590///
591/// // Also works with where clauses:
592/// #[arcane]
593/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
594/// where
595/// T: HasAvx2
596/// {
597/// // ...
598/// }
599/// ```
600///
601/// ## Methods with Self Receivers
602///
603/// Methods with `self`, `&self`, `&mut self` receivers are supported via the
604/// `_self = Type` argument. Use `_self` in the function body instead of `self`:
605///
606/// ```ignore
607/// use archmage::{HasAvx2, arcane};
608/// use wide::f32x8;
609///
610/// trait Avx2Ops {
611/// fn double(&self, token: impl HasAvx2) -> Self;
612/// fn square(self, token: impl HasAvx2) -> Self;
613/// fn scale(&mut self, token: impl HasAvx2, factor: f32);
614/// }
615///
616/// impl Avx2Ops for f32x8 {
617/// #[arcane(_self = f32x8)]
618/// fn double(&self, _token: impl HasAvx2) -> Self {
619/// // Use _self instead of self in the body
620/// *_self + *_self
621/// }
622///
623/// #[arcane(_self = f32x8)]
624/// fn square(self, _token: impl HasAvx2) -> Self {
625/// _self * _self
626/// }
627///
628/// #[arcane(_self = f32x8)]
629/// fn scale(&mut self, _token: impl HasAvx2, factor: f32) {
630/// *_self = *_self * f32x8::splat(factor);
631/// }
632/// }
633/// ```
634///
635/// **Why `_self`?** The macro generates an inner function where `self` becomes
636/// a regular parameter named `_self`. Using `_self` in your code reminds you
637/// that you're not using the normal `self` keyword.
638///
639/// **All receiver types are supported:**
640/// - `self` (by value/move) → `_self: Type`
641/// - `&self` (shared reference) → `_self: &Type`
642/// - `&mut self` (mutable reference) → `_self: &mut Type`
643///
644/// # Multiple Trait Bounds
645///
646/// When using `impl Trait` or generic bounds with multiple traits,
647/// all required features are enabled:
648///
649/// ```ignore
650/// #[arcane]
651/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
652/// // Both AVX2 and FMA intrinsics are safe here
653/// }
654/// ```
655///
656/// # Expansion
657///
658/// The macro expands to approximately:
659///
660/// ```ignore
661/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
662/// #[target_feature(enable = "avx2")]
663/// #[inline]
664/// unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
665/// let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
666/// let doubled = _mm256_add_ps(v, v);
667/// let mut out = [0.0f32; 8];
668/// unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
669/// out
670/// }
671/// // SAFETY: Token proves the required features are available
672/// unsafe { __simd_inner_process(token, data) }
673/// }
674/// ```
675///
676/// # Profile Tokens
677///
678/// Profile tokens automatically enable all required features:
679///
680/// ```ignore
681/// #[arcane]
682/// fn kernel(token: X64V3Token, data: &mut [f32]) {
683/// // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
684/// }
685/// ```
686///
687/// # Supported Tokens
688///
689/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
690/// `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
691/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
692/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
693/// - **WASM**: `Simd128Token`
694///
695/// # Supported Trait Bounds
696///
697/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
698/// `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
699/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
700/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
701///
702/// # Options
703///
704/// ## `inline_always`
705///
706/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
707/// This can improve performance by ensuring aggressive inlining, but requires
708/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
709/// the crate using the macro.
710///
711/// ```ignore
712/// #![feature(target_feature_inline_always)]
713///
714/// #[arcane(inline_always)]
715/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
716/// // Inner function will use #[inline(always)]
717/// }
718/// ```
719#[proc_macro_attribute]
720pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
721 let args = parse_macro_input!(attr as ArcaneArgs);
722 let input_fn = parse_macro_input!(item as ItemFn);
723 arcane_impl(input_fn, "arcane", args)
724}
725
726/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
727///
728/// See [`arcane`] for full documentation.
729#[proc_macro_attribute]
730pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
731 let args = parse_macro_input!(attr as ArcaneArgs);
732 let input_fn = parse_macro_input!(item as ItemFn);
733 arcane_impl(input_fn, "simd_fn", args)
734}
735
736// ============================================================================
737// Multiwidth macro for width-agnostic SIMD code
738// ============================================================================
739
740use syn::ItemMod;
741
742/// Arguments to the `#[multiwidth]` macro.
743struct MultiwidthArgs {
744 /// Include SSE (128-bit) specialization
745 sse: bool,
746 /// Include AVX2 (256-bit) specialization
747 avx2: bool,
748 /// Include AVX-512 (512-bit) specialization
749 avx512: bool,
750}
751
752impl Default for MultiwidthArgs {
753 fn default() -> Self {
754 Self {
755 sse: true,
756 avx2: true,
757 avx512: true,
758 }
759 }
760}
761
762impl Parse for MultiwidthArgs {
763 fn parse(input: ParseStream) -> syn::Result<Self> {
764 let mut args = MultiwidthArgs {
765 sse: false,
766 avx2: false,
767 avx512: false,
768 };
769
770 // If no args provided, enable all
771 if input.is_empty() {
772 return Ok(MultiwidthArgs::default());
773 }
774
775 while !input.is_empty() {
776 let ident: Ident = input.parse()?;
777 match ident.to_string().as_str() {
778 "sse" => args.sse = true,
779 "avx2" => args.avx2 = true,
780 "avx512" => args.avx512 = true,
781 other => {
782 return Err(syn::Error::new(
783 ident.span(),
784 format!(
785 "unknown multiwidth target: `{}`. Expected: sse, avx2, avx512",
786 other
787 ),
788 ))
789 }
790 }
791 // Consume optional comma
792 if input.peek(Token![,]) {
793 let _: Token![,] = input.parse()?;
794 }
795 }
796
797 Ok(args)
798 }
799}
800
801/// Width configuration for specialization.
802struct WidthConfig {
803 /// Module name suffix (e.g., "sse", "avx2", "avx512")
804 name: &'static str,
805 /// The namespace import path
806 namespace: &'static str,
807 /// Token type name
808 token: &'static str,
809 /// Whether this requires a feature flag
810 feature: Option<&'static str>,
811 /// Target features to enable for this width
812 target_features: &'static [&'static str],
813}
814
815const WIDTH_CONFIGS: &[WidthConfig] = &[
816 WidthConfig {
817 name: "sse",
818 namespace: "magetypes::simd::sse",
819 token: "archmage::Sse41Token",
820 feature: None,
821 target_features: &["sse4.1"],
822 },
823 WidthConfig {
824 name: "avx2",
825 namespace: "magetypes::simd::avx2",
826 token: "archmage::Avx2FmaToken",
827 feature: None,
828 target_features: &["avx2", "fma"],
829 },
830 WidthConfig {
831 name: "avx512",
832 namespace: "magetypes::simd::avx512",
833 token: "archmage::X64V4Token",
834 feature: Some("avx512"),
835 target_features: &["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"],
836 },
837];
838
839/// Generate width-specialized SIMD code.
840///
841/// This macro takes a module containing width-agnostic SIMD code and generates
842/// specialized versions for each target width (SSE, AVX2, AVX-512).
843///
844/// # Usage
845///
846/// ```ignore
847/// use archmage::multiwidth;
848///
849/// #[multiwidth]
850/// mod kernels {
851/// // Inside this module, these types are available:
852/// // - f32xN, i32xN, etc. (width-appropriate SIMD types)
853/// // - Token (the token type: Sse41Token, Avx2FmaToken, or X64V4Token)
854/// // - LANES_F32, LANES_32, etc. (lane count constants)
855///
856/// use archmage::simd::*;
857///
858/// pub fn normalize(token: Token, data: &mut [f32]) {
859/// for chunk in data.chunks_exact_mut(LANES_F32) {
860/// let v = f32xN::load(token, chunk.try_into().unwrap());
861/// let result = v * f32xN::splat(token, 1.0 / 255.0);
862/// result.store(chunk.try_into().unwrap());
863/// }
864/// }
865/// }
866///
867/// // Generated modules:
868/// // - kernels::sse::normalize(token: Sse41Token, data: &mut [f32])
869/// // - kernels::avx2::normalize(token: Avx2FmaToken, data: &mut [f32])
870/// // - kernels::avx512::normalize(token: X64V4Token, data: &mut [f32]) // if avx512 feature
871/// // - kernels::normalize(data: &mut [f32]) // runtime dispatcher
872/// ```
873///
874/// # Selective Targets
875///
876/// You can specify which targets to generate:
877///
878/// ```ignore
879/// #[multiwidth(avx2, avx512)] // Only AVX2 and AVX-512, no SSE
880/// mod fast_kernels { ... }
881/// ```
882///
883/// # How It Works
884///
885/// 1. The macro duplicates the module content for each width target
886/// 2. Each copy imports from the appropriate namespace (`archmage::simd::sse`, etc.)
887/// 3. The `use archmage::simd::*` statement is rewritten to the width-specific import
888/// 4. A dispatcher function is generated that picks the best available at runtime
889///
890/// # Requirements
891///
892/// - Functions should use `Token` as their token parameter type
893/// - Use `f32xN`, `i32xN`, etc. for SIMD types (not concrete types like `f32x8`)
894/// - Use `LANES_F32`, `LANES_32`, etc. for lane counts
895#[proc_macro_attribute]
896pub fn multiwidth(attr: TokenStream, item: TokenStream) -> TokenStream {
897 let args = parse_macro_input!(attr as MultiwidthArgs);
898 let input_mod = parse_macro_input!(item as ItemMod);
899
900 multiwidth_impl(input_mod, args)
901}
902
903fn multiwidth_impl(input_mod: ItemMod, args: MultiwidthArgs) -> TokenStream {
904 let mod_name = &input_mod.ident;
905 let mod_vis = &input_mod.vis;
906 let mod_attrs = &input_mod.attrs;
907
908 // Get module content
909 let content = match &input_mod.content {
910 Some((_, items)) => items,
911 None => {
912 return syn::Error::new_spanned(
913 &input_mod,
914 "multiwidth requires an inline module (mod name { ... }), not a file module",
915 )
916 .to_compile_error()
917 .into();
918 }
919 };
920
921 // Build specialized modules
922 let mut specialized_mods = Vec::new();
923 let mut enabled_configs = Vec::new();
924
925 for config in WIDTH_CONFIGS {
926 let enabled = match config.name {
927 "sse" => args.sse,
928 "avx2" => args.avx2,
929 "avx512" => args.avx512,
930 _ => false,
931 };
932
933 if !enabled {
934 continue;
935 }
936
937 enabled_configs.push(config);
938
939 let width_mod_name = format_ident!("{}", config.name);
940 let namespace: syn::Path = syn::parse_str(config.namespace).unwrap();
941
942 // Transform the content: replace `use archmage::simd::*` with width-specific import
943 // and add target_feature attributes for optimization
944 let transformed_items: Vec<syn::Item> = content
945 .iter()
946 .map(|item| transform_item_for_width(item.clone(), &namespace, config))
947 .collect();
948
949 let feature_attr = config.feature.map(|f| {
950 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
951 quote!(#[cfg(feature = #f_lit)])
952 });
953
954 specialized_mods.push(quote! {
955 #feature_attr
956 pub mod #width_mod_name {
957 #(#transformed_items)*
958 }
959 });
960 }
961
962 // Generate dispatcher functions for each public function in the module
963 let dispatchers = generate_dispatchers(content, &enabled_configs);
964
965 let expanded = quote! {
966 #(#mod_attrs)*
967 #mod_vis mod #mod_name {
968 #(#specialized_mods)*
969
970 #dispatchers
971 }
972 };
973
974 expanded.into()
975}
976
977/// Transform a single item for a specific width namespace.
978fn transform_item_for_width(
979 item: syn::Item,
980 namespace: &syn::Path,
981 config: &WidthConfig,
982) -> syn::Item {
983 match item {
984 syn::Item::Use(mut use_item) => {
985 // Check if this is `use archmage::simd::*` or similar
986 if is_simd_wildcard_use(&use_item) {
987 // Replace with width-specific import
988 use_item.tree = syn::UseTree::Path(syn::UsePath {
989 ident: format_ident!("{}", namespace.segments.first().unwrap().ident),
990 colon2_token: Default::default(),
991 tree: Box::new(build_use_tree_from_path(namespace, 1)),
992 });
993 }
994 syn::Item::Use(use_item)
995 }
996 syn::Item::Fn(func) => {
997 // Transform function to use inner function pattern with target_feature
998 // This is the same pattern as #[arcane], enabling SIMD optimization
999 // without requiring -C target-cpu=native
1000 transform_fn_with_target_feature(func, config)
1001 }
1002 other => other,
1003 }
1004}
1005
1006/// Transform a function to use the inner function pattern with target_feature.
1007/// This generates:
1008/// ```ignore
1009/// pub fn example(token: Token, data: &[f32]) -> f32 {
1010/// #[target_feature(enable = "avx2", enable = "fma")]
1011/// #[inline]
1012/// unsafe fn inner(token: Token, data: &[f32]) -> f32 {
1013/// // original body
1014/// }
1015/// // SAFETY: Token proves CPU support
1016/// unsafe { inner(token, data) }
1017/// }
1018/// ```
1019fn transform_fn_with_target_feature(func: syn::ItemFn, config: &WidthConfig) -> syn::Item {
1020 let vis = &func.vis;
1021 let sig = &func.sig;
1022 let fn_name = &sig.ident;
1023 let generics = &sig.generics;
1024 let where_clause = &generics.where_clause;
1025 let inputs = &sig.inputs;
1026 let output = &sig.output;
1027 let body = &func.block;
1028 let attrs = &func.attrs;
1029
1030 // Build target_feature attributes
1031 let target_feature_attrs: Vec<syn::Attribute> = config
1032 .target_features
1033 .iter()
1034 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
1035 .collect();
1036
1037 // Build parameter list for inner function
1038 let inner_params: Vec<proc_macro2::TokenStream> =
1039 inputs.iter().map(|arg| quote!(#arg)).collect();
1040
1041 // Build argument list for calling inner function
1042 let call_args: Vec<proc_macro2::TokenStream> = inputs
1043 .iter()
1044 .filter_map(|arg| match arg {
1045 syn::FnArg::Typed(pat_type) => {
1046 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1047 let ident = &pat_ident.ident;
1048 Some(quote!(#ident))
1049 } else {
1050 None
1051 }
1052 }
1053 syn::FnArg::Receiver(_) => Some(quote!(self)),
1054 })
1055 .collect();
1056
1057 let inner_fn_name = format_ident!("__multiwidth_inner_{}", fn_name);
1058
1059 let expanded = quote! {
1060 #(#attrs)*
1061 #vis #sig {
1062 #(#target_feature_attrs)*
1063 #[inline]
1064 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
1065 #body
1066
1067 // SAFETY: The Token parameter proves the required CPU features are available.
1068 // Tokens can only be constructed via try_new() which checks CPU support.
1069 unsafe { #inner_fn_name(#(#call_args),*) }
1070 }
1071 };
1072
1073 syn::parse2(expanded).expect("Failed to parse transformed function")
1074}
1075
1076/// Check if a use item is `use archmage::simd::*`, `use magetypes::simd::*`, or `use crate::simd::*`.
1077fn is_simd_wildcard_use(use_item: &syn::ItemUse) -> bool {
1078 fn check_tree(tree: &syn::UseTree) -> bool {
1079 match tree {
1080 syn::UseTree::Path(path) => {
1081 let ident = path.ident.to_string();
1082 if ident == "archmage" || ident == "magetypes" || ident == "crate" {
1083 check_tree_for_simd(&path.tree)
1084 } else {
1085 false
1086 }
1087 }
1088 _ => false,
1089 }
1090 }
1091
1092 fn check_tree_for_simd(tree: &syn::UseTree) -> bool {
1093 match tree {
1094 syn::UseTree::Path(path) => {
1095 if path.ident == "simd" {
1096 matches!(path.tree.as_ref(), syn::UseTree::Glob(_))
1097 } else {
1098 check_tree_for_simd(&path.tree)
1099 }
1100 }
1101 _ => false,
1102 }
1103 }
1104
1105 check_tree(&use_item.tree)
1106}
1107
1108/// Build a UseTree from a path, starting at a given segment index.
1109fn build_use_tree_from_path(path: &syn::Path, start_idx: usize) -> syn::UseTree {
1110 let segments: Vec<_> = path.segments.iter().skip(start_idx).collect();
1111
1112 if segments.is_empty() {
1113 syn::UseTree::Glob(syn::UseGlob {
1114 star_token: Default::default(),
1115 })
1116 } else if segments.len() == 1 {
1117 syn::UseTree::Path(syn::UsePath {
1118 ident: segments[0].ident.clone(),
1119 colon2_token: Default::default(),
1120 tree: Box::new(syn::UseTree::Glob(syn::UseGlob {
1121 star_token: Default::default(),
1122 })),
1123 })
1124 } else {
1125 let first = &segments[0];
1126 let rest_path = syn::Path {
1127 leading_colon: None,
1128 segments: path.segments.iter().skip(start_idx + 1).cloned().collect(),
1129 };
1130 syn::UseTree::Path(syn::UsePath {
1131 ident: first.ident.clone(),
1132 colon2_token: Default::default(),
1133 tree: Box::new(build_use_tree_from_path(&rest_path, 0)),
1134 })
1135 }
1136}
1137
1138/// Width-specific type names that can't be used in dispatcher signatures.
1139const WIDTH_SPECIFIC_TYPES: &[&str] = &[
1140 "f32xN", "f64xN", "i8xN", "i16xN", "i32xN", "i64xN", "u8xN", "u16xN", "u32xN", "u64xN", "Token",
1141];
1142
1143/// Check if a type string contains width-specific types.
1144fn contains_width_specific_type(ty_str: &str) -> bool {
1145 WIDTH_SPECIFIC_TYPES.iter().any(|t| ty_str.contains(t))
1146}
1147
1148/// Check if a function signature uses width-specific types (can't have a dispatcher).
1149fn uses_width_specific_types(func: &syn::ItemFn) -> bool {
1150 // Check return type
1151 if let syn::ReturnType::Type(_, ty) = &func.sig.output {
1152 let ty_str = quote!(#ty).to_string();
1153 if contains_width_specific_type(&ty_str) {
1154 return true;
1155 }
1156 }
1157
1158 // Check parameters (excluding Token which we filter out anyway)
1159 for arg in &func.sig.inputs {
1160 if let syn::FnArg::Typed(pat_type) = arg {
1161 let ty = &pat_type.ty;
1162 let ty_str = quote!(#ty).to_string();
1163 // Skip Token parameters - they're filtered out for dispatchers
1164 if ty_str.contains("Token") {
1165 continue;
1166 }
1167 if contains_width_specific_type(&ty_str) {
1168 return true;
1169 }
1170 }
1171 }
1172
1173 false
1174}
1175
1176/// Generate runtime dispatcher functions for public functions.
1177///
1178/// Note: Dispatchers are only generated for functions that don't use width-specific
1179/// types (f32xN, Token, etc.) in their signature. Functions that take/return
1180/// width-specific types can only be called via the width-specific submodules.
1181fn generate_dispatchers(
1182 content: &[syn::Item],
1183 configs: &[&WidthConfig],
1184) -> proc_macro2::TokenStream {
1185 let mut dispatchers = Vec::new();
1186
1187 for item in content {
1188 if let syn::Item::Fn(func) = item {
1189 // Only generate dispatchers for public functions
1190 if !matches!(func.vis, syn::Visibility::Public(_)) {
1191 continue;
1192 }
1193
1194 // Skip functions that use width-specific types - they can't have dispatchers
1195 if uses_width_specific_types(func) {
1196 continue;
1197 }
1198
1199 let fn_name = &func.sig.ident;
1200 let fn_generics = &func.sig.generics;
1201 let fn_output = &func.sig.output;
1202 let fn_attrs: Vec<_> = func
1203 .attrs
1204 .iter()
1205 .filter(|a| !a.path().is_ident("arcane") && !a.path().is_ident("simd_fn"))
1206 .collect();
1207
1208 // Filter out the token parameter from the dispatcher signature
1209 let non_token_params: Vec<_> = func
1210 .sig
1211 .inputs
1212 .iter()
1213 .filter(|arg| {
1214 match arg {
1215 syn::FnArg::Typed(pat_type) => {
1216 // Check if type contains "Token"
1217 let ty = &pat_type.ty;
1218 let ty_str = quote!(#ty).to_string();
1219 !ty_str.contains("Token")
1220 }
1221 _ => true,
1222 }
1223 })
1224 .collect();
1225
1226 // Extract just the parameter names for passing to specialized functions
1227 let param_names: Vec<_> = non_token_params
1228 .iter()
1229 .filter_map(|arg| match arg {
1230 syn::FnArg::Typed(pat_type) => {
1231 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
1232 Some(&pat_ident.ident)
1233 } else {
1234 None
1235 }
1236 }
1237 _ => None,
1238 })
1239 .collect();
1240
1241 // Generate dispatch branches (highest capability first)
1242 let mut branches = Vec::new();
1243
1244 for config in configs.iter().rev() {
1245 let mod_name = format_ident!("{}", config.name);
1246 let token_path: syn::Path = syn::parse_str(config.token).unwrap();
1247
1248 let feature_check = config.feature.map(|f| {
1249 let f_lit = syn::LitStr::new(f, proc_macro2::Span::call_site());
1250 quote!(#[cfg(feature = #f_lit)])
1251 });
1252
1253 branches.push(quote! {
1254 #feature_check
1255 if let Some(token) = #token_path::try_new() {
1256 return #mod_name::#fn_name(token, #(#param_names),*);
1257 }
1258 });
1259 }
1260
1261 // Generate dispatcher
1262 dispatchers.push(quote! {
1263 #(#fn_attrs)*
1264 /// Runtime dispatcher - automatically selects the best available SIMD implementation.
1265 pub fn #fn_name #fn_generics (#(#non_token_params),*) #fn_output {
1266 use archmage::SimdToken;
1267
1268 #(#branches)*
1269
1270 // Fallback: panic if no SIMD available
1271 // TODO: Allow user-provided scalar fallback
1272 panic!("No SIMD support available for {}", stringify!(#fn_name));
1273 }
1274 });
1275 }
1276 }
1277
1278 quote! { #(#dispatchers)* }
1279}