use std::collections::HashSet;
use oxc::ast::ast::{
BindingPattern, CallExpression, Decorator, Expression, JSXExpressionContainer, Program,
PropertyDefinition, TSTypeReference, VariableDeclarationKind, VariableDeclarator,
};
use oxc::ast_visit::{Visit, walk};
use oxc::span::GetSpan;
use crate::ts_syn::abi::{Diagnostic, DiagnosticLevel, Patch, PatchCode, SpanIR};
use crate::ts_syn::declarative::{MacroArm, MacroDef, MacroMode};
use super::BuildMode;
use super::discovery::{DiscoveredMacro, MACRO_RULES_IDENT, find_macro_rules_import_span};
use super::expander::{ExpansionContext, expand_body_with_registry};
use super::matcher::{MatchError, match_invocation_against_arms};
use super::megamorph::{
self, MegamorphReport, Recommendation, ResolvedCallSite, extract_type_shape,
};
use super::registry::DeclarativeMacroRegistry;
pub struct ProcMacroFallback<'a> {
pub(crate) dispatcher: &'a crate::host::MacroDispatcher,
pub(crate) import_sources: &'a std::collections::HashMap<String, String>,
pub(crate) external_loader: Option<&'a crate::host::expand::ExternalMacroLoader>,
}
#[derive(Debug, Default, Clone)]
pub struct RewriteOutput {
pub patches: Vec<Patch>,
pub diagnostics: Vec<Diagnostic>,
}
pub fn rewrite(
program: &Program<'_>,
source: &str,
registry: &DeclarativeMacroRegistry,
discovered: &[DiscoveredMacro],
build_mode: BuildMode,
type_registry: Option<&crate::ts_syn::abi::ir::type_registry::TypeRegistry>,
proc_fallback: Option<ProcMacroFallback<'_>>,
) -> RewriteOutput {
let mut out = RewriteOutput::default();
for dm in discovered {
out.patches.push(Patch::Delete { span: dm.def_span });
}
if let Some(import_span) = find_macro_rules_import_span(program) {
out.patches.push(Patch::Delete { span: import_span });
}
let has_auto = registry.iter().any(|(_, def)| def.mode == MacroMode::Auto);
let run_analyzer =
has_auto && (matches!(build_mode, BuildMode::Prod) || build_mode.force_share());
let megamorph_report = if run_analyzer {
if type_registry.is_none() {
out.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Info,
message: "declarative macro analyzer running without a type registry; structural clustering disabled, falling back to the name-prefix heuristic. Pass `type_registry_json` in ExpandOptions to enable field-level fingerprinting.".to_string(),
span: None,
notes: vec![],
help: None,
});
}
let mut collector = CollectVisitor {
registry,
type_registry,
sites: Vec::new(),
};
collector.visit_program(program);
Some(megamorph::analyze(registry, &collector.sites, 4))
} else {
None
};
if let Some(report) = &megamorph_report {
let emit_telemetry = build_mode.analyzer_telemetry();
for (name, info) in &report.per_macro {
match &info.recommendation {
Recommendation::Cluster(clusters) => {
out.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Warning,
message: format!(
"macro `${}` is called with {} distinct argument shapes; shared runtime would be megamorphic. Partitioned into {} clusters. Use `mode: \"share-anyway\"` to silence.",
name,
info.distinct_shapes,
clusters.len()
),
span: None,
notes: vec![],
help: None,
});
}
Recommendation::ForceExpand => {
out.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Warning,
message: format!(
"macro `${}` has {} distinct argument shapes, clustered too coarsely to share. Falling back to inline expansion at every call site.",
name, info.distinct_shapes
),
span: None,
notes: vec![],
help: None,
});
}
Recommendation::Share => {}
}
if emit_telemetry {
let decision = match &info.recommendation {
Recommendation::Share => "Share (single helper)".to_string(),
Recommendation::Cluster(clusters) => {
format!("Cluster (into {} partitions)", clusters.len())
}
Recommendation::ForceExpand => {
"ForceExpand (inline at every call site)".to_string()
}
};
out.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Info,
message: format!(
"analyzer decision for macro `${}`: {} distinct argument shapes → {}",
name, info.distinct_shapes, decision
),
span: None,
notes: vec![],
help: None,
});
}
}
}
let mut visitor = RewriteVisitor {
registry,
source,
output: &mut out,
counter: 0,
build_mode,
emitted_runtimes: HashSet::new(),
megamorph_report: megamorph_report.as_ref(),
type_rewritten: HashSet::new(),
type_registry,
proc_dispatcher: proc_fallback.as_ref().map(|f| f.dispatcher),
import_sources: proc_fallback.as_ref().map(|f| f.import_sources),
external_loader: proc_fallback.as_ref().and_then(|f| f.external_loader),
};
visitor.visit_program(program);
out
}
struct CollectVisitor<'a> {
registry: &'a DeclarativeMacroRegistry,
type_registry: Option<&'a crate::ts_syn::abi::ir::type_registry::TypeRegistry>,
sites: Vec<ResolvedCallSite>,
}
impl<'a> Visit<'a> for CollectVisitor<'_> {
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if let Expression::Identifier(callee) = &call.callee
&& let Some(name) = callee.name.as_str().strip_prefix('$')
&& let Some(def) = self.registry.lookup_at(name, call.span.start + 1)
&& def.mode == MacroMode::Auto
{
let arg_shapes: Vec<super::megamorph::TypeShape> = call
.arguments
.iter()
.map(|arg| extract_type_shape(arg, self.type_registry))
.collect();
self.sites.push(ResolvedCallSite {
macro_name: name.to_string(),
call_span: SpanIR::new(call.span.start + 1, call.span.end + 1),
arg_shapes,
});
}
walk::walk_call_expression(self, call);
}
fn visit_variable_declarator(&mut self, decl: &VariableDeclarator<'a>) {
if is_macro_definition_declarator(decl) {
return;
}
walk::walk_variable_declarator(self, decl);
}
}
pub(super) struct RewriteVisitor<'a> {
registry: &'a DeclarativeMacroRegistry,
source: &'a str,
output: &'a mut RewriteOutput,
counter: u32,
build_mode: BuildMode,
emitted_runtimes: HashSet<(String, String)>,
megamorph_report: Option<&'a MegamorphReport>,
type_rewritten: HashSet<(u32, u32)>,
type_registry: Option<&'a crate::ts_syn::abi::ir::type_registry::TypeRegistry>,
pub(super) proc_dispatcher: Option<&'a crate::host::MacroDispatcher>,
pub(super) import_sources: Option<&'a std::collections::HashMap<String, String>>,
pub(super) external_loader: Option<&'a crate::host::expand::ExternalMacroLoader>,
}
impl RewriteVisitor<'_> {
fn next_id(&mut self) -> u32 {
self.counter += 1;
self.counter
}
}
fn unwrap_paren_and_casts<'b, 'a>(expr: &'b Expression<'a>) -> &'b Expression<'a> {
match expr {
Expression::ParenthesizedExpression(p) => unwrap_paren_and_casts(&p.expression),
Expression::TSAsExpression(t) => unwrap_paren_and_casts(&t.expression),
Expression::TSSatisfiesExpression(t) => unwrap_paren_and_casts(&t.expression),
Expression::TSNonNullExpression(t) => unwrap_paren_and_casts(&t.expression),
Expression::TSTypeAssertion(t) => unwrap_paren_and_casts(&t.expression),
other => other,
}
}
impl<'a> Visit<'a> for RewriteVisitor<'_> {
fn visit_variable_declarator(&mut self, decl: &VariableDeclarator<'a>) {
if is_macro_definition_declarator(decl) {
return;
}
walk::walk_variable_declarator(self, decl);
}
fn visit_expression_statement(&mut self, es: &oxc::ast::ast::ExpressionStatement<'a>) {
if let Expression::CallExpression(call) = unwrap_paren_and_casts(&es.expression)
&& try_rewrite_call(call, self, ExpansionContext::Statement)
{
return;
}
walk::walk_expression_statement(self, es);
}
fn visit_call_expression(&mut self, call: &CallExpression<'a>) {
if try_rewrite_call(call, self, ExpansionContext::Expression) {
return;
}
walk::walk_call_expression(self, call);
}
fn visit_ts_type_reference(&mut self, tr: &TSTypeReference<'a>) {
if super::type_walker::try_rewrite_type_ref(tr, self) {
return;
}
walk::walk_ts_type_reference(self, tr);
}
fn visit_jsx_expression_container(&mut self, node: &JSXExpressionContainer<'a>) {
walk::walk_jsx_expression_container(self, node);
}
fn visit_decorator(&mut self, d: &Decorator<'a>) {
walk::walk_decorator(self, d);
}
fn visit_property_definition(&mut self, p: &PropertyDefinition<'a>) {
walk::walk_property_definition(self, p);
}
}
fn is_macro_definition_declarator(d: &VariableDeclarator<'_>) -> bool {
if d.kind != VariableDeclarationKind::Const {
return false;
}
let BindingPattern::BindingIdentifier(bi) = &d.id else {
return false;
};
if !bi.name.as_str().starts_with('$') {
return false;
}
let Some(init) = &d.init else {
return false;
};
match init {
Expression::TaggedTemplateExpression(tagged) => {
let Expression::Identifier(id) = &tagged.tag else {
return false;
};
id.name.as_str() == MACRO_RULES_IDENT
}
Expression::CallExpression(call) => {
let Expression::Identifier(id) = &call.callee else {
return false;
};
id.name.as_str() == MACRO_RULES_IDENT
}
_ => false,
}
}
enum EmissionPlan<'a> {
InlineExpand { arms: &'a [MacroArm] },
ShareSingle { arms: &'a [MacroArm] },
ShareClustered {
arms: &'a [MacroArm],
clusters: &'a [super::megamorph::TypeCluster],
},
}
fn resolve_emission_strategy<'a>(
def: &'a MacroDef,
build_mode: BuildMode,
report: Option<&'a MegamorphReport>,
) -> EmissionPlan<'a> {
match def.mode {
MacroMode::ExpandOnly => EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
},
MacroMode::ShareOnly | MacroMode::ShareAnyway => {
match def.call_arms.as_deref() {
Some(call_arms) if def.runtime.is_some() => {
EmissionPlan::ShareSingle { arms: call_arms }
}
_ => EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
},
}
}
MacroMode::Auto => {
let is_share_path = matches!(build_mode, BuildMode::Prod) || build_mode.force_share();
if !is_share_path {
return EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
};
}
let info = report.and_then(|r| r.lookup(&def.name));
match info.map(|i| &i.recommendation) {
Some(Recommendation::Cluster(clusters)) => match def.call_arms.as_deref() {
Some(call_arms) if def.runtime.is_some() => EmissionPlan::ShareClustered {
arms: call_arms,
clusters: clusters.as_slice(),
},
_ => EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
},
},
Some(Recommendation::Share) | None => match def.call_arms.as_deref() {
Some(call_arms) if def.runtime.is_some() => {
EmissionPlan::ShareSingle { arms: call_arms }
}
_ => EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
},
},
Some(Recommendation::ForceExpand) => EmissionPlan::InlineExpand {
arms: def.arms.as_slice(),
},
}
}
}
}
fn resolve_cluster_id<'a>(
clusters: &'a [super::megamorph::TypeCluster],
arg_shapes: &[super::megamorph::TypeShape],
) -> Option<&'a str> {
for cluster in clusters {
if cluster
.shapes
.iter()
.any(|tuple| tuple.as_slice() == arg_shapes)
{
return Some(&cluster.id);
}
}
None
}
fn specialize_helper_name(
runtime_name_template: Option<&str>,
runtime_src: &str,
cluster_id: &str,
) -> Option<String> {
if let Some(template) = runtime_name_template {
return Some(template.replace("$__cluster__", cluster_id));
}
let bytes = runtime_src.as_bytes();
let needle = b"function ";
let mut i = 0;
while i + needle.len() < bytes.len() {
if &bytes[i..i + needle.len()] == needle {
let mut j = i + needle.len();
while j < bytes.len() && matches!(bytes[j], b' ' | b'\t') {
j += 1;
}
let start = j;
while j < bytes.len()
&& (bytes[j].is_ascii_alphanumeric() || bytes[j] == b'_' || bytes[j] == b'$')
{
j += 1;
}
if j > start {
let base = std::str::from_utf8(&bytes[start..j]).ok()?;
return Some(format!("{}__{}", base, cluster_id));
}
}
i += 1;
}
None
}
pub(super) fn try_rewrite_call(
call: &oxc::ast::ast::CallExpression<'_>,
visitor: &mut RewriteVisitor<'_>,
context: ExpansionContext,
) -> bool {
let Expression::Identifier(callee) = &call.callee else {
return false;
};
let callee_name = callee.name.as_str();
if !callee_name.starts_with('$') {
return false;
}
let name = &callee_name[1..];
let call_pos = call.span.start + 1;
let def_arc = visitor.registry.lookup_at(name, call_pos);
if def_arc.is_none() {
return try_dispatch_proc_call(call, callee_name, visitor, context);
}
let def_arc = def_arc.unwrap();
let def = def_arc.as_ref();
let plan = resolve_emission_strategy(def, visitor.build_mode, visitor.megamorph_report);
let (arms, cluster_id): (&[MacroArm], Option<String>) = match &plan {
EmissionPlan::InlineExpand { arms } => (*arms, None),
EmissionPlan::ShareSingle { arms } => (*arms, Some(String::new())),
EmissionPlan::ShareClustered { arms, clusters } => {
let arg_shapes: Vec<super::megamorph::TypeShape> = call
.arguments
.iter()
.map(|arg| extract_type_shape(arg, visitor.type_registry))
.collect();
let resolved = resolve_cluster_id(clusters, &arg_shapes).map(|s| s.to_string());
if resolved.is_none() {
visitor.output.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Warning,
message: format!(
"macro `${}` call site's argument shape did not match any cluster; falling back to a single shared helper. This is usually a sign of stale analysis data.",
name
),
span: Some(SpanIR::new(call.span.start + 1, call.span.end + 1)),
notes: vec![],
help: None,
});
(*arms, Some(String::new()))
} else {
(*arms, resolved)
}
}
};
if let Some(cluster_id_str) = cluster_id.as_deref()
&& let Some(runtime_src) = def.runtime.as_deref()
&& !matches!(plan, EmissionPlan::InlineExpand { .. })
{
let dedup_key = (name.to_string(), cluster_id_str.to_string());
if !visitor.emitted_runtimes.contains(&dedup_key) {
let runtime_emit = if cluster_id_str.is_empty() {
runtime_src.to_string()
} else {
let specialized = specialize_helper_name(
def.runtime_name_template.as_deref(),
runtime_src,
cluster_id_str,
);
if let (Some(template), Some(specialized_name)) =
(def.runtime_name_template.as_deref(), specialized.as_deref())
{
let base = template.replace("$__cluster__", "");
let base_trimmed = base.trim_matches('_');
if !base_trimmed.is_empty() {
runtime_src
.replace(base_trimmed, specialized_name)
.replace("$__cluster__", cluster_id_str)
} else {
runtime_src.replace("$__cluster__", cluster_id_str)
}
} else {
runtime_src.replace("$__cluster__", cluster_id_str)
}
};
visitor.output.patches.push(Patch::Insert {
at: SpanIR::new(1, 1),
code: PatchCode::Text(format!("{}\n", runtime_emit.trim())),
source_macro: Some(format_attribution(name, cluster_id_str)),
});
visitor.emitted_runtimes.insert(dedup_key);
}
}
let expander_cluster_id: Option<&str> = cluster_id
.as_deref()
.and_then(|s| if s.is_empty() { None } else { Some(s) });
match match_invocation_against_arms(arms, &call.arguments, visitor.source) {
Ok((arm_index, bindings)) => {
let arm = &arms[arm_index];
let expansion_id = visitor.next_id();
match expand_body_with_registry(
&arm.body,
&bindings,
expansion_id,
context,
0,
Some(visitor.registry),
expander_cluster_id,
) {
Ok(expanded) => {
let span = call.span;
let span_ir = SpanIR::new(span.start + 1, span.end + 1);
let cluster_attr = cluster_id.as_deref().unwrap_or("");
visitor.output.patches.push(Patch::Replace {
span: span_ir,
code: PatchCode::Text(expanded),
source_macro: Some(format_attribution(name, cluster_attr)),
});
true
}
Err(e) => {
let span = call.span();
visitor.output.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Error,
message: format!("error expanding macro `${}`: {}", name, e),
span: Some(SpanIR::new(span.start + 1, span.end + 1)),
notes: vec![],
help: None,
});
true
}
}
}
Err(match_err) => {
let span = call.span();
let help = match &match_err {
MatchError::NoArmMatched { tried } => {
if tried.is_empty() {
None
} else {
Some(format!("tried patterns: {}", tried.join(" | ")))
}
}
_ => None,
};
visitor.output.diagnostics.push(Diagnostic {
level: DiagnosticLevel::Error,
message: format!(
"macro `${}` invocation did not match any arm: {}",
name, match_err
),
span: Some(SpanIR::new(span.start + 1, span.end + 1)),
notes: vec![],
help,
});
true
}
}
}
pub(super) fn format_attribution(name: &str, cluster_id: &str) -> String {
if cluster_id.is_empty() {
format!("${}", name)
} else {
format!("${}@{}", name, cluster_id)
}
}
fn try_dispatch_proc_call(
call: &CallExpression<'_>,
callee_name: &str,
visitor: &mut RewriteVisitor<'_>,
_context: ExpansionContext,
) -> bool {
let dispatcher = match visitor.proc_dispatcher {
Some(d) => d,
None => return false,
};
let name_without_dollar = &callee_name[1..];
let module_path = if let Some(sources) = visitor.import_sources
&& let Some(mp) = sources.get(callee_name)
{
mp.clone()
} else if let Some(desc) = crate::host::derived::lookup_by_name(name_without_dollar)
&& desc.kind == crate::ts_syn::abi::MacroKind::Call
{
"@macro/derive".to_string()
} else {
return false;
};
let args_start = call.span.start as usize;
let args_end = call.span.end as usize;
let call_source = &visitor.source[args_start..args_end];
let args_source = if let Some(open) = call_source.find('(') {
let inner = &call_source[open + 1..];
if let Some(close) = inner.rfind(')') {
inner[..close].to_string()
} else {
inner.to_string()
}
} else {
String::new()
};
let call_span = SpanIR::new(call.span.start + 1, call.span.end + 1);
let ctx = crate::ts_syn::abi::MacroContextIR {
abi_version: 1,
macro_kind: crate::ts_syn::abi::MacroKind::Call,
macro_name: name_without_dollar.to_string(),
module_path,
decorator_span: call_span,
macro_name_span: None,
target_span: call_span,
file_name: String::new(),
target: crate::ts_syn::abi::TargetIR::Other,
target_source: args_source,
import_registry: crate::ts_syn::ImportRegistry::new(),
config: None,
type_registry: None,
resolved_fields: None,
};
let mut result = dispatcher.dispatch(ctx.clone());
let is_not_found = result.diagnostics.iter().any(|d| {
d.message.contains("Macro")
&& (d.message.contains("not found") || d.message.contains("is not a Macroforge"))
});
if is_not_found {
if let Some(loader) = visitor.external_loader {
match loader.run_macro(&ctx) {
Ok(external_result) => result = external_result,
Err(_) => return false,
}
} else {
return false;
}
}
if let Some(tokens) = &result.tokens {
visitor.output.patches.push(Patch::Replace {
span: call_span,
code: PatchCode::Text(tokens.clone()),
source_macro: Some(format!("${}", name_without_dollar)),
});
}
for patch in result.runtime_patches {
visitor.output.patches.push(patch);
}
for diag in result.diagnostics {
visitor.output.diagnostics.push(diag);
}
true
}
impl<'a> RewriteVisitor<'a> {
pub(super) fn registry(&self) -> &'a DeclarativeMacroRegistry {
self.registry
}
pub(super) fn source(&self) -> &'a str {
self.source
}
pub(super) fn output_mut(&mut self) -> &mut RewriteOutput {
self.output
}
pub(super) fn next_expansion_id(&mut self) -> u32 {
self.next_id()
}
pub(super) fn record_type_rewrite(&mut self, start: u32, end: u32) -> bool {
self.type_rewritten.insert((start, end))
}
}