archmage_macros/lib.rs
1//! Proc-macros for archmage SIMD capability tokens.
2//!
3//! Provides `#[arcane]` attribute (with `#[simd_fn]` alias) to make raw intrinsics
4//! safe via token proof.
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input, parse_quote, Attribute, FnArg, GenericParam, Ident, ItemFn, PatType,
11 Signature, Token, Type, TypeParamBound,
12};
13
14/// Arguments to the `#[arcane]` macro.
15#[derive(Default)]
16struct ArcaneArgs {
17 /// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
18 /// Requires nightly Rust with `#![feature(target_feature_inline_always)]`.
19 inline_always: bool,
20}
21
22impl Parse for ArcaneArgs {
23 fn parse(input: ParseStream) -> syn::Result<Self> {
24 let mut args = ArcaneArgs::default();
25
26 while !input.is_empty() {
27 let ident: Ident = input.parse()?;
28 match ident.to_string().as_str() {
29 "inline_always" => args.inline_always = true,
30 other => {
31 return Err(syn::Error::new(
32 ident.span(),
33 format!("unknown arcane argument: `{}`", other),
34 ))
35 }
36 }
37 // Consume optional comma
38 if input.peek(Token![,]) {
39 let _: Token![,] = input.parse()?;
40 }
41 }
42
43 Ok(args)
44 }
45}
46
47/// Maps a token type name to its required target features.
48fn token_to_features(token_name: &str) -> Option<&'static [&'static str]> {
49 match token_name {
50 // x86_64 granular tokens
51 "Sse2Token" => Some(&["sse2"]),
52 "Sse41Token" => Some(&["sse4.1"]),
53 "Sse42Token" => Some(&["sse4.2"]),
54 "AvxToken" => Some(&["avx"]),
55 "Avx2Token" => Some(&["avx2"]),
56 "FmaToken" => Some(&["fma"]),
57 "Avx2FmaToken" => Some(&["avx2", "fma"]),
58 "Avx512fToken" => Some(&["avx512f"]),
59 "Avx512bwToken" => Some(&["avx512bw"]),
60
61 // x86_64 profile tokens
62 "X64V2Token" => Some(&["sse4.2", "popcnt"]),
63 "X64V3Token" | "Desktop64" => Some(&["avx2", "fma", "bmi1", "bmi2"]),
64 "X64V4Token" | "Server64" => {
65 Some(&["avx512f", "avx512bw", "avx512cd", "avx512dq", "avx512vl"])
66 }
67
68 // ARM tokens
69 "NeonToken" | "Arm64" => Some(&["neon"]),
70 "SveToken" => Some(&["sve"]),
71 "Sve2Token" => Some(&["sve2"]),
72
73 // WASM tokens
74 "Simd128Token" => Some(&["simd128"]),
75
76 _ => None,
77 }
78}
79
80/// Maps a trait bound name to its required target features.
81/// Used for `impl HasAvx2` and `T: HasAvx2` style parameters.
82fn trait_to_features(trait_name: &str) -> Option<&'static [&'static str]> {
83 match trait_name {
84 // x86 feature marker traits
85 "HasSse" => Some(&["sse"]),
86 "HasSse2" => Some(&["sse2"]),
87 "HasSse41" => Some(&["sse4.1"]),
88 "HasSse42" => Some(&["sse4.2"]),
89 "HasAvx" => Some(&["avx"]),
90 "HasAvx2" => Some(&["avx2"]),
91 "HasAvx512f" => Some(&["avx512f"]),
92 "HasAvx512vl" => Some(&["avx512f", "avx512vl"]),
93 "HasAvx512bw" => Some(&["avx512bw"]),
94 "HasAvx512vbmi2" => Some(&["avx512vbmi2"]),
95
96 // Capability marker traits - use most specific features that satisfy them
97 "HasFma" => Some(&["fma"]),
98 "Has128BitSimd" => Some(&["sse2"]),
99 "Has256BitSimd" => Some(&["avx"]),
100 "Has512BitSimd" => Some(&["avx512f"]),
101
102 // ARM feature marker traits
103 "HasNeon" => Some(&["neon"]),
104 "HasSve" => Some(&["sve"]),
105 "HasSve2" => Some(&["sve2"]),
106
107 _ => None,
108 }
109}
110
111/// Result of extracting token info from a type.
112enum TokenTypeInfo {
113 /// Concrete token type (e.g., `Avx2Token`)
114 Concrete(String),
115 /// impl Trait with the trait names (e.g., `impl HasAvx2`)
116 ImplTrait(Vec<String>),
117 /// Generic type parameter name (e.g., `T`)
118 Generic(String),
119}
120
121/// Extract token type information from a type.
122fn extract_token_type_info(ty: &Type) -> Option<TokenTypeInfo> {
123 match ty {
124 Type::Path(type_path) => {
125 // Get the last segment of the path (e.g., "Avx2Token" from "archmage::Avx2Token")
126 type_path.path.segments.last().map(|seg| {
127 let name = seg.ident.to_string();
128 // Check if it's a known concrete token type
129 if token_to_features(&name).is_some() {
130 TokenTypeInfo::Concrete(name)
131 } else {
132 // Might be a generic type parameter like `T`
133 TokenTypeInfo::Generic(name)
134 }
135 })
136 }
137 Type::Reference(type_ref) => {
138 // Handle &Token or &mut Token
139 extract_token_type_info(&type_ref.elem)
140 }
141 Type::ImplTrait(impl_trait) => {
142 // Handle `impl HasAvx2` or `impl HasAvx2 + HasFma`
143 let traits: Vec<String> = extract_trait_names_from_bounds(&impl_trait.bounds);
144 if traits.is_empty() {
145 None
146 } else {
147 Some(TokenTypeInfo::ImplTrait(traits))
148 }
149 }
150 _ => None,
151 }
152}
153
154/// Extract trait names from type param bounds.
155fn extract_trait_names_from_bounds(
156 bounds: &syn::punctuated::Punctuated<TypeParamBound, Token![+]>,
157) -> Vec<String> {
158 bounds
159 .iter()
160 .filter_map(|bound| {
161 if let TypeParamBound::Trait(trait_bound) = bound {
162 trait_bound
163 .path
164 .segments
165 .last()
166 .map(|seg| seg.ident.to_string())
167 } else {
168 None
169 }
170 })
171 .collect()
172}
173
174/// Look up a generic type parameter in the function's generics.
175fn find_generic_bounds(sig: &Signature, type_name: &str) -> Option<Vec<String>> {
176 // Check inline bounds first (e.g., `fn foo<T: HasAvx2>(token: T)`)
177 for param in &sig.generics.params {
178 if let GenericParam::Type(type_param) = param {
179 if type_param.ident == type_name {
180 let traits = extract_trait_names_from_bounds(&type_param.bounds);
181 if !traits.is_empty() {
182 return Some(traits);
183 }
184 }
185 }
186 }
187
188 // Check where clause (e.g., `fn foo<T>(token: T) where T: HasAvx2`)
189 if let Some(where_clause) = &sig.generics.where_clause {
190 for predicate in &where_clause.predicates {
191 if let syn::WherePredicate::Type(pred_type) = predicate {
192 if let Type::Path(type_path) = &pred_type.bounded_ty {
193 if let Some(seg) = type_path.path.segments.last() {
194 if seg.ident == type_name {
195 let traits = extract_trait_names_from_bounds(&pred_type.bounds);
196 if !traits.is_empty() {
197 return Some(traits);
198 }
199 }
200 }
201 }
202 }
203 }
204 }
205
206 None
207}
208
209/// Convert trait names to features, collecting all features from all traits.
210fn traits_to_features(trait_names: &[String]) -> Option<Vec<&'static str>> {
211 let mut all_features = Vec::new();
212
213 for trait_name in trait_names {
214 if let Some(features) = trait_to_features(trait_name) {
215 for &feature in features {
216 if !all_features.contains(&feature) {
217 all_features.push(feature);
218 }
219 }
220 }
221 }
222
223 if all_features.is_empty() {
224 None
225 } else {
226 Some(all_features)
227 }
228}
229
230/// Find the first token parameter and return its name and features.
231fn find_token_param(sig: &Signature) -> Option<(Ident, Vec<&'static str>)> {
232 for arg in &sig.inputs {
233 match arg {
234 FnArg::Receiver(_) => {
235 // Self receivers (self, &self, &mut self) are not yet supported.
236 // The macro creates an inner function, and Rust's inner functions
237 // cannot have `self` parameters. Supporting this would require
238 // AST rewriting to replace `self` with a regular parameter.
239 // See the module docs for the workaround.
240 continue;
241 }
242 FnArg::Typed(PatType { pat, ty, .. }) => {
243 if let Some(info) = extract_token_type_info(ty) {
244 let features = match info {
245 TokenTypeInfo::Concrete(name) => {
246 token_to_features(&name).map(|f| f.to_vec())
247 }
248 TokenTypeInfo::ImplTrait(trait_names) => traits_to_features(&trait_names),
249 TokenTypeInfo::Generic(type_name) => {
250 // Look up the generic parameter's bounds
251 find_generic_bounds(sig, &type_name)
252 .and_then(|traits| traits_to_features(&traits))
253 }
254 };
255
256 if let Some(features) = features {
257 // Extract parameter name
258 if let syn::Pat::Ident(pat_ident) = pat.as_ref() {
259 return Some((pat_ident.ident.clone(), features));
260 }
261 }
262 }
263 }
264 }
265 }
266 None
267}
268
269/// Shared implementation for arcane/simd_fn macros.
270fn arcane_impl(input_fn: ItemFn, macro_name: &str, args: ArcaneArgs) -> TokenStream {
271 // Find the token parameter and its features
272 let (_token_ident, features) = match find_token_param(&input_fn.sig) {
273 Some(result) => result,
274 None => {
275 let msg = format!(
276 "{} requires a token parameter. Supported forms:\n\
277 - Concrete: `token: Avx2Token`\n\
278 - impl Trait: `token: impl HasAvx2`\n\
279 - Generic: `fn foo<T: HasAvx2>(token: T, ...)`\n\
280 Note: self receivers (&self, &mut self) are not yet supported.",
281 macro_name
282 );
283 return syn::Error::new_spanned(&input_fn.sig, msg)
284 .to_compile_error()
285 .into();
286 }
287 };
288
289 // Build target_feature attributes
290 let target_feature_attrs: Vec<Attribute> = features
291 .iter()
292 .map(|feature| parse_quote!(#[target_feature(enable = #feature)]))
293 .collect();
294
295 // Extract function components
296 let vis = &input_fn.vis;
297 let sig = &input_fn.sig;
298 let fn_name = &sig.ident;
299 let generics = &sig.generics;
300 let where_clause = &generics.where_clause;
301 let inputs = &sig.inputs;
302 let output = &sig.output;
303 let body = &input_fn.block;
304 let attrs = &input_fn.attrs;
305
306 // Build inner function parameters (ALL parameters including token)
307 let inner_params: Vec<_> = inputs.iter().cloned().collect();
308
309 // Build inner function call arguments (ALL arguments including token)
310 let inner_args: Vec<_> = inputs
311 .iter()
312 .filter_map(|arg| match arg {
313 FnArg::Typed(pat_type) => {
314 if let syn::Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
315 let ident = &pat_ident.ident;
316 Some(quote!(#ident))
317 } else {
318 None
319 }
320 }
321 FnArg::Receiver(_) => Some(quote!(self)),
322 })
323 .collect();
324
325 let inner_fn_name = format_ident!("__simd_inner_{}", fn_name);
326
327 // Choose inline attribute based on args
328 // Note: #[inline(always)] + #[target_feature] requires nightly with
329 // #![feature(target_feature_inline_always)]
330 let inline_attr: Attribute = if args.inline_always {
331 parse_quote!(#[inline(always)])
332 } else {
333 parse_quote!(#[inline])
334 };
335
336 // Generate the expanded function
337 let expanded = quote! {
338 #(#attrs)*
339 #vis #sig {
340 #(#target_feature_attrs)*
341 #inline_attr
342 unsafe fn #inner_fn_name #generics (#(#inner_params),*) #output #where_clause
343 #body
344
345 // SAFETY: The token parameter proves the required CPU features are available.
346 // Tokens can only be constructed when features are verified (via try_new()
347 // runtime check or forge_token_dangerously() in a context where features are guaranteed).
348 unsafe { #inner_fn_name(#(#inner_args),*) }
349 }
350 };
351
352 expanded.into()
353}
354
355/// Mark a function as an arcane SIMD function.
356///
357/// This macro enables safe use of SIMD intrinsics by generating an inner function
358/// with the appropriate `#[target_feature(enable = "...")]` attributes based on
359/// the token parameter type. The outer function calls the inner function unsafely,
360/// which is justified because the token parameter proves the features are available.
361///
362/// **The token is passed through to the inner function**, so you can call other
363/// token-taking functions from inside `#[arcane]`.
364///
365/// # Token Parameter Forms
366///
367/// The macro supports four forms of token parameters:
368///
369/// ## Concrete Token Types
370///
371/// ```ignore
372/// #[arcane]
373/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
374/// // AVX2 intrinsics safe here
375/// }
376/// ```
377///
378/// ## impl Trait Bounds
379///
380/// ```ignore
381/// #[arcane]
382/// fn process(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
383/// // Accepts any token that provides AVX2
384/// }
385/// ```
386///
387/// ## Generic Type Parameters
388///
389/// ```ignore
390/// #[arcane]
391/// fn process<T: HasAvx2>(token: T, data: &[f32; 8]) -> [f32; 8] {
392/// // Generic over any AVX2-capable token
393/// }
394///
395/// // Also works with where clauses:
396/// #[arcane]
397/// fn process<T>(token: T, data: &[f32; 8]) -> [f32; 8]
398/// where
399/// T: HasAvx2
400/// {
401/// // ...
402/// }
403/// ```
404///
405/// ## Methods with Self Receivers (NOT YET SUPPORTED)
406///
407/// Methods with `self`, `&self`, `&mut self` receivers are **not currently supported**.
408///
409/// **Why:** The macro works by creating an inner function with `#[target_feature]`.
410/// Rust's inner functions cannot have `self` parameters—`self` only works in
411/// associated functions. Supporting this would require rewriting the function body
412/// to replace `self` with a regular parameter, which adds significant complexity.
413///
414/// **Workaround:** Use a free function with the token as an explicit parameter:
415///
416/// ```ignore
417/// impl MyProcessor {
418/// fn process(&mut self, data: &[f32; 8]) -> [f32; 8] {
419/// // Delegate to a free function
420/// process_impl(self.token, data)
421/// }
422/// }
423///
424/// #[arcane]
425/// fn process_impl(token: impl HasAvx2, data: &[f32; 8]) -> [f32; 8] {
426/// // SIMD intrinsics safe here
427/// }
428/// ```
429///
430/// **Future work:** Supporting `self` receivers would require:
431/// 1. Adding a type parameter `__Self` to the inner function
432/// 2. Converting the receiver to a regular parameter (`&self` → `__self: &__Self`)
433/// 3. Walking the AST to replace all `self` with `__self` and `Self` with `__Self`
434/// 4. Copying where clauses with the type substitution
435///
436/// # Multiple Trait Bounds
437///
438/// When using `impl Trait` or generic bounds with multiple traits,
439/// all required features are enabled:
440///
441/// ```ignore
442/// #[arcane]
443/// fn fma_kernel(token: impl HasAvx2 + HasFma, data: &[f32; 8]) -> [f32; 8] {
444/// // Both AVX2 and FMA intrinsics are safe here
445/// }
446/// ```
447///
448/// # Expansion
449///
450/// The macro expands to approximately:
451///
452/// ```ignore
453/// fn process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
454/// #[target_feature(enable = "avx2")]
455/// #[inline]
456/// unsafe fn __simd_inner_process(token: Avx2Token, data: &[f32; 8]) -> [f32; 8] {
457/// let v = unsafe { _mm256_loadu_ps(data.as_ptr()) };
458/// let doubled = _mm256_add_ps(v, v);
459/// let mut out = [0.0f32; 8];
460/// unsafe { _mm256_storeu_ps(out.as_mut_ptr(), doubled) };
461/// out
462/// }
463/// // SAFETY: Token proves the required features are available
464/// unsafe { __simd_inner_process(token, data) }
465/// }
466/// ```
467///
468/// # Profile Tokens
469///
470/// Profile tokens automatically enable all required features:
471///
472/// ```ignore
473/// #[arcane]
474/// fn kernel(token: X64V3Token, data: &mut [f32]) {
475/// // AVX2 + FMA + BMI1 + BMI2 intrinsics all safe here!
476/// }
477/// ```
478///
479/// # Supported Tokens
480///
481/// - **x86_64**: `Sse2Token`, `Sse41Token`, `Sse42Token`, `AvxToken`, `Avx2Token`,
482/// `FmaToken`, `Avx2FmaToken`, `Avx512fToken`, `Avx512bwToken`
483/// - **x86_64 profiles**: `X64V2Token`, `X64V3Token`, `X64V4Token`
484/// - **ARM**: `NeonToken`, `SveToken`, `Sve2Token`
485/// - **WASM**: `Simd128Token`
486///
487/// # Supported Trait Bounds
488///
489/// - **x86_64**: `HasSse`, `HasSse2`, `HasSse41`, `HasSse42`, `HasAvx`, `HasAvx2`,
490/// `HasAvx512f`, `HasAvx512vl`, `HasAvx512bw`, `HasAvx512vbmi2`, `HasFma`
491/// - **ARM**: `HasNeon`, `HasSve`, `HasSve2`
492/// - **Generic**: `Has128BitSimd`, `Has256BitSimd`, `Has512BitSimd`
493///
494/// # Options
495///
496/// ## `inline_always`
497///
498/// Use `#[inline(always)]` instead of `#[inline]` for the inner function.
499/// This can improve performance by ensuring aggressive inlining, but requires
500/// nightly Rust with `#![feature(target_feature_inline_always)]` enabled in
501/// the crate using the macro.
502///
503/// ```ignore
504/// #![feature(target_feature_inline_always)]
505///
506/// #[arcane(inline_always)]
507/// fn fast_kernel(token: Avx2Token, data: &mut [f32]) {
508/// // Inner function will use #[inline(always)]
509/// }
510/// ```
511#[proc_macro_attribute]
512pub fn arcane(attr: TokenStream, item: TokenStream) -> TokenStream {
513 let args = parse_macro_input!(attr as ArcaneArgs);
514 let input_fn = parse_macro_input!(item as ItemFn);
515 arcane_impl(input_fn, "arcane", args)
516}
517
518/// Alias for [`arcane`] - mark a function as an arcane SIMD function.
519///
520/// See [`arcane`] for full documentation.
521#[proc_macro_attribute]
522pub fn simd_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
523 let args = parse_macro_input!(attr as ArcaneArgs);
524 let input_fn = parse_macro_input!(item as ItemFn);
525 arcane_impl(input_fn, "simd_fn", args)
526}