use proc_macro2::{Delimiter, Ident, Spacing, TokenStream, TokenTree};
use quote::{format_ident, quote};
use crate::common::suffix_path;
use crate::incant::IncantInput;
use crate::tiers::{self, DEFAULT_TIER_NAMES, ResolvedTier};
#[derive(Clone)]
pub(crate) struct CallerContext {
pub tier_suffix: String,
pub target_arch: Option<&'static str>,
pub token_ident: Ident,
}
pub(crate) fn rewrite_incant_in_body(body: TokenStream, ctx: &CallerContext) -> TokenStream {
let tokens: Vec<TokenTree> = body.into_iter().collect();
let mut result = Vec::new();
let mut i = 0;
while i < tokens.len() {
if is_ident(&tokens[i], "fn") {
result.push(tokens[i].clone());
i += 1;
while i < tokens.len() {
let is_body =
matches!(&tokens[i], TokenTree::Group(g) if g.delimiter() == Delimiter::Brace);
result.push(tokens[i].clone());
i += 1;
if is_body {
break;
}
}
continue;
}
if is_ident(&tokens[i], "incant")
&& i + 2 < tokens.len()
&& is_punct(&tokens[i + 1], '!')
&& let Some(TokenTree::Group(group)) = tokens.get(i + 2)
&& group.delimiter() == Delimiter::Parenthesis
{
let inner = group.stream();
if let Ok(input) = syn::parse2::<IncantInput>(inner) {
let rewritten = rewrite_single_incant(&input, ctx);
result.extend(rewritten);
i += 3; continue;
}
}
if let TokenTree::Group(group) = &tokens[i] {
let inner = rewrite_incant_in_body(group.stream(), ctx);
let mut new_group = proc_macro2::Group::new(group.delimiter(), inner);
new_group.set_span(group.span());
result.push(TokenTree::Group(new_group));
i += 1;
continue;
}
result.push(tokens[i].clone());
i += 1;
}
result.into_iter().collect()
}
fn rewrite_single_incant(input: &IncantInput, ctx: &CallerContext) -> TokenStream {
if input.with_token.is_some() {
return reconstruct_incant(input);
}
let func_path = &input.func_path;
let args = &input.args;
let tier_names: Vec<String> = match &input.tiers {
Some((names, _)) => names.clone(),
None => DEFAULT_TIER_NAMES.iter().map(|s| s.to_string()).collect(),
};
let error_span = proc_macro2::Span::call_site();
let tiers = match tiers::resolve_tiers(&tier_names, error_span, true) {
Ok(t) => t,
Err(_) => return reconstruct_incant(input), };
let mut upgrade_tiers: Vec<&ResolvedTier> = Vec::new();
let mut direct_tier: Option<&ResolvedTier> = None;
for rt in &tiers {
if rt.name == "scalar" || rt.name == "default" {
continue; }
if rt.target_arch != ctx.target_arch {
continue;
}
if rt.suffix == ctx.tier_suffix {
if direct_tier.is_none() {
direct_tier = Some(rt);
}
} else if crate::generated::can_downgrade_tier(&ctx.tier_suffix, rt.suffix) {
if direct_tier.is_none() {
direct_tier = Some(rt);
}
} else {
upgrade_tiers.push(rt);
}
}
upgrade_tiers.sort_by_key(|rt| core::cmp::Reverse(rt.priority));
let token_ident = &ctx.token_ident;
let mut upgrade_arms = Vec::new();
for rt in &upgrade_tiers {
let fn_suffixed = suffix_path(func_path, rt.suffix);
let token_path: syn::Path = syn::parse_str(rt.token_path).unwrap();
let token_expr = quote! { __t };
let caller_ident = token_ident.to_string();
let call_args =
crate::common::build_call_args_with_ident(args, &token_expr, Some(&caller_ident));
let check = quote! {
if let Some(__t) = #token_path::summon() {
break '__incant_rewrite #fn_suffixed(#call_args);
}
};
if let Some(feat) = &rt.feature_gate {
let allow_attr = if rt.allow_unexpected_cfg {
quote! { #[allow(unexpected_cfgs)] }
} else {
quote! {}
};
upgrade_arms.push(quote! {
#allow_attr
#[cfg(feature = #feat)]
{ #check }
});
} else {
upgrade_arms.push(check);
}
}
let fallback_call = if let Some(rt) = direct_tier {
let fn_suffixed = suffix_path(func_path, rt.suffix);
let token_expr = if rt.suffix == ctx.tier_suffix {
quote! { #token_ident }
} else {
let downgrade_method = format_ident!("{}", rt.suffix);
quote! { #token_ident.#downgrade_method() }
};
let caller_ident = token_ident.to_string();
let call_args =
crate::common::build_call_args_with_ident(args, &token_expr, Some(&caller_ident));
quote! { #fn_suffixed(#call_args) }
} else {
let has_default = tiers.iter().any(|t| t.name == "default");
if has_default {
let fn_default = suffix_path(func_path, "default");
let caller_ident = token_ident.to_string();
let default_args: Vec<&syn::Expr> = args
.iter()
.filter(|a| {
!crate::common::is_bare_ident_pub(a, "Token")
&& !crate::common::is_bare_ident_pub(a, &caller_ident)
})
.collect();
quote! { #fn_default(#(#default_args),*) }
} else {
let fn_scalar = suffix_path(func_path, "scalar");
let scalar_args = crate::common::build_scalar_call_args(args);
quote! { #fn_scalar(#scalar_args) }
}
};
if upgrade_arms.is_empty() {
fallback_call
} else {
quote! {
'__incant_rewrite: {
use archmage::SimdToken;
#(#upgrade_arms)*
#fallback_call
}
}
}
}
fn reconstruct_incant(input: &IncantInput) -> TokenStream {
let func_path = &input.func_path;
let args = &input.args;
let with_part = input.with_token.as_ref().map(|t| quote! { with #t });
let tier_part = input.tiers.as_ref().map(|(names, _)| {
let tier_strs: Vec<_> = names
.iter()
.map(|n| {
let ident = format_ident!("{}", n);
quote! { #ident }
})
.collect();
quote! { , [#(#tier_strs),*] }
});
quote! {
archmage::incant!(#func_path(#(#args),*) #with_part #tier_part)
}
}
fn is_ident(tt: &TokenTree, name: &str) -> bool {
matches!(tt, TokenTree::Ident(id) if *id == name)
}
fn is_punct(tt: &TokenTree, ch: char) -> bool {
matches!(tt, TokenTree::Punct(p) if p.as_char() == ch && p.spacing() == Spacing::Alone)
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
fn make_ctx(tier: &str, _priority: u32, arch: Option<&'static str>) -> CallerContext {
CallerContext {
tier_suffix: tier.to_string(),
target_arch: arch,
token_ident: format_ident!("__token"),
}
}
#[test]
fn skips_inner_fn_items() {
let body = quote! {
fn inner() {
incant!(process(data))
}
incant!(outer_call(data))
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("incant ! (process (data))"),
"inner fn incant! should be preserved, got: {result}"
);
assert!(
result.contains("outer_call_v3"),
"outer incant! should be rewritten, got: {result}"
);
}
#[test]
fn exact_tier_match_no_summon() {
let body = quote! {
let x = incant!(process(data), [v3, scalar]);
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("process_v3"),
"should call process_v3, got: {result}"
);
assert!(
result.contains("__token"),
"should pass token, got: {result}"
);
assert!(
!result.contains("summon"),
"should not summon for exact match, got: {result}"
);
}
#[test]
fn upgrade_attempt_with_summon() {
let body = quote! {
let x = incant!(process(data), [v4, v3, scalar]);
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("summon"),
"should summon for v4 upgrade, got: {result}"
);
assert!(
result.contains("process_v4"),
"should try process_v4, got: {result}"
);
assert!(
result.contains("process_v3"),
"should fall back to process_v3, got: {result}"
);
assert!(
result.contains("__token"),
"should pass token for v3, got: {result}"
);
}
#[test]
fn upgrade_with_feature_gate() {
let body = quote! {
let x = incant!(process(data), [v4(cfg(avx512)), v3, scalar]);
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("avx512"),
"v4 upgrade should be gated on avx512, got: {result}"
);
assert!(
result.contains("summon"),
"should summon for v4 upgrade, got: {result}"
);
assert!(
result.contains("process_v3"),
"should fall back to process_v3, got: {result}"
);
}
#[test]
fn scalar_fallback_when_no_matching_tier() {
let body = quote! {
let x = incant!(process(data), [neon, scalar]);
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("process_scalar"),
"should fall through to scalar, got: {result}"
);
assert!(
result.contains("ScalarToken"),
"should use ScalarToken, got: {result}"
);
}
#[test]
fn passthrough_not_rewritten() {
let body = quote! {
let x = incant!(process(data) with token);
};
let ctx = make_ctx("v3", 30, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("incant"),
"passthrough should be preserved, got: {result}"
);
}
#[test]
fn downgrade_uses_method() {
let body = quote! {
let x = incant!(process(data), [v3, scalar]);
};
let ctx = make_ctx("v4", 40, Some("x86_64"));
let result = rewrite_incant_in_body(body, &ctx).to_string();
assert!(
result.contains("process_v3"),
"should call process_v3, got: {result}"
);
assert!(
result.contains("__token . v3 ()"),
"should downgrade token, got: {result}"
);
assert!(
!result.contains("summon"),
"should not summon for downgrade, got: {result}"
);
}
}