use std::collections::HashMap;
use crate::ts_syn::declarative::{Body, BodyToken};
use super::hygiene;
use super::matcher::{Binding, match_invocation_against_arms};
use super::registry::DeclarativeMacroRegistry;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpansionContext {
Expression,
Statement,
Type,
}
#[derive(Debug, Clone)]
pub enum ExpandError {
UnboundName(String),
WrongBindingShape(String),
InconsistentSequenceLength(usize, usize),
UnanchoredRepetition,
RecursionLimit(u32),
UnknownMacroCall(String),
MalformedMacroCallArgs { callee: String, reason: String },
NestedMatchFailure { callee: String, tried: Vec<String> },
}
impl std::fmt::Display for ExpandError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExpandError::UnboundName(name) => write!(f, "unbound macro metavariable `${}`", name),
ExpandError::WrongBindingShape(name) => write!(
f,
"metavariable `${}` has the wrong binding shape (single vs sequence)",
name
),
ExpandError::InconsistentSequenceLength(a, b) => write!(
f,
"sequence bindings in the same repetition have different lengths ({} vs {})",
a, b
),
ExpandError::UnanchoredRepetition => write!(
f,
"repetition in body mentions no sequence-bound metavariable; cannot infer length"
),
ExpandError::RecursionLimit(limit) => write!(
f,
"macro expansion exceeded the recursion limit of {} levels — did a macro call itself?",
limit
),
ExpandError::UnknownMacroCall(name) => write!(
f,
"macro body calls unknown macro `${}` — not registered or out of scope",
name
),
ExpandError::MalformedMacroCallArgs { callee, reason } => write!(
f,
"macro body calls `${}` but its argument list failed to parse: {}",
callee, reason
),
ExpandError::NestedMatchFailure { callee, tried } => write!(
f,
"nested call to `${}` did not match any arm (tried: {})",
callee,
tried.join(" | ")
),
}
}
}
impl std::error::Error for ExpandError {}
pub const MAX_EXPANSION_DEPTH: u32 = 256;
pub fn expand_body(
body: &Body,
bindings: &HashMap<String, Binding>,
expansion_id: u32,
context: ExpansionContext,
depth: u32,
) -> Result<String, ExpandError> {
expand_body_with_registry(body, bindings, expansion_id, context, depth, None, None)
}
pub fn expand_body_with_registry(
body: &Body,
bindings: &HashMap<String, Binding>,
expansion_id: u32,
context: ExpansionContext,
depth: u32,
registry: Option<&DeclarativeMacroRegistry>,
cluster_id: Option<&str>,
) -> Result<String, ExpandError> {
if depth > MAX_EXPANSION_DEPTH {
return Err(ExpandError::RecursionLimit(MAX_EXPANSION_DEPTH));
}
let mut effective_bindings;
let bindings_ref: &HashMap<String, Binding> = if let Some(id) = cluster_id {
effective_bindings = bindings.clone();
effective_bindings.insert(
"__cluster__".to_string(),
Binding::Single(super::matcher::BoundFragment {
kind: crate::ts_syn::declarative::FragmentKind::Ident,
source: id.to_string(),
span: crate::ts_syn::abi::SpanIR::new(0, 0),
}),
);
&effective_bindings
} else {
bindings
};
let mut out = String::new();
render_tokens(
&body.0,
bindings_ref,
expansion_id,
depth,
registry,
&mut out,
)?;
let rewritten = rewrite_hygiene(out, expansion_id);
Ok(maybe_wrap_iife(rewritten, context))
}
fn render_tokens(
tokens: &[BodyToken],
bindings: &HashMap<String, Binding>,
expansion_id: u32,
depth: u32,
registry: Option<&DeclarativeMacroRegistry>,
out: &mut String,
) -> Result<(), ExpandError> {
if depth > MAX_EXPANSION_DEPTH {
return Err(ExpandError::RecursionLimit(MAX_EXPANSION_DEPTH));
}
for token in tokens {
match token {
BodyToken::Literal(s) => out.push_str(s),
BodyToken::Substitution(name) => {
let binding = bindings
.get(name)
.ok_or_else(|| ExpandError::UnboundName(name.clone()))?;
match binding {
Binding::Single(frag) => out.push_str(&frag.source),
Binding::Sequence(_) => {
return Err(ExpandError::WrongBindingShape(name.clone()));
}
}
}
BodyToken::MacroCall {
name: callee_name,
args,
} => {
expand_macro_call(
callee_name,
args,
bindings,
expansion_id,
depth,
registry,
out,
)?;
}
BodyToken::Repetition {
body,
separator,
kind: _,
} => {
expand_repetition(
body,
separator.as_deref(),
bindings,
expansion_id,
depth,
registry,
out,
)?;
}
}
}
Ok(())
}
fn expand_macro_call(
callee_name: &str,
args: &[BodyToken],
bindings: &HashMap<String, Binding>,
expansion_id: u32,
depth: u32,
registry: Option<&DeclarativeMacroRegistry>,
out: &mut String,
) -> Result<(), ExpandError> {
let Some(registry) = registry else {
return Err(ExpandError::UnknownMacroCall(callee_name.to_string()));
};
let Some(callee_def) = registry.lookup(callee_name).cloned() else {
return Err(ExpandError::UnknownMacroCall(callee_name.to_string()));
};
let mut rendered_args = String::new();
render_tokens(
args,
bindings,
expansion_id,
depth,
Some(registry),
&mut rendered_args,
)?;
let nested = match callee_def.kind {
crate::ts_syn::declarative::MacroKind::Value => expand_value_macro_call(
callee_name,
&callee_def,
&rendered_args,
expansion_id,
depth,
registry,
)?,
crate::ts_syn::declarative::MacroKind::Type => expand_type_macro_call(
callee_name,
&callee_def,
&rendered_args,
expansion_id,
depth,
registry,
)?,
};
out.push_str(&nested);
Ok(())
}
fn expand_value_macro_call(
callee_name: &str,
callee_def: &crate::ts_syn::declarative::MacroDef,
rendered_args: &str,
expansion_id: u32,
depth: u32,
registry: &DeclarativeMacroRegistry,
) -> Result<String, ExpandError> {
use oxc::allocator::Allocator;
use oxc::ast::ast::{Expression, Statement};
use oxc::parser::Parser;
use oxc::span::SourceType;
let wrapper_source = format!("__m4cr0f0rg3_dummy__({});", rendered_args.trim());
let allocator = Allocator::default();
let parsed = Parser::new(&allocator, &wrapper_source, SourceType::ts()).parse();
if !parsed.errors.is_empty() {
return Err(ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: parsed
.errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("; "),
});
}
let call = parsed.program.body.iter().find_map(|stmt| {
if let Statement::ExpressionStatement(es) = stmt
&& let Expression::CallExpression(call) = &es.expression
{
Some(call)
} else {
None
}
});
let Some(call) = call else {
return Err(ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: "wrapper did not produce a call expression".to_string(),
});
};
let (arm_index, callee_bindings) = match_invocation_against_arms(
&callee_def.arms,
&call.arguments,
&wrapper_source,
)
.map_err(|e| match e {
super::matcher::MatchError::NoArmMatched { tried } => ExpandError::NestedMatchFailure {
callee: callee_name.to_string(),
tried,
},
other => ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: other.to_string(),
},
})?;
expand_body_with_registry(
&callee_def.arms[arm_index].body,
&callee_bindings,
expansion_id.wrapping_add(depth + 1),
ExpansionContext::Statement,
depth + 1,
Some(registry),
None,
)
}
fn expand_type_macro_call(
callee_name: &str,
callee_def: &crate::ts_syn::declarative::MacroDef,
rendered_args: &str,
expansion_id: u32,
depth: u32,
registry: &DeclarativeMacroRegistry,
) -> Result<String, ExpandError> {
use oxc::allocator::Allocator;
use oxc::ast::ast::{Statement, TSType};
use oxc::parser::Parser;
use oxc::span::SourceType;
let wrapper_source = format!(
"type __m4cr0f0rg3_dummy__ = __m4cr0f0rg3_helper__<{}>;",
rendered_args.trim()
);
let allocator = Allocator::default();
let parsed = Parser::new(&allocator, &wrapper_source, SourceType::ts()).parse();
if !parsed.errors.is_empty() {
return Err(ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: parsed
.errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("; "),
});
}
let type_ref = parsed.program.body.iter().find_map(|stmt| {
if let Statement::TSTypeAliasDeclaration(alias) = stmt
&& let TSType::TSTypeReference(tr) = &alias.type_annotation
{
Some(tr)
} else {
None
}
});
let Some(type_ref) = type_ref else {
return Err(ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: "type-position wrapper did not produce a type reference".to_string(),
});
};
let Some(type_args) = type_ref.type_arguments.as_ref() else {
let arm_index = callee_def
.arms
.iter()
.position(|a| matches!(a.pattern, crate::ts_syn::declarative::Pattern::Empty))
.ok_or_else(|| ExpandError::NestedMatchFailure {
callee: callee_name.to_string(),
tried: vec!["()".to_string()],
})?;
return expand_body_with_registry(
&callee_def.arms[arm_index].body,
&HashMap::new(),
expansion_id.wrapping_add(depth + 1),
ExpansionContext::Type,
depth + 1,
Some(registry),
None,
);
};
let (arm_index, callee_bindings) = super::matcher::match_type_invocation_against_arms(
&callee_def.arms,
&type_args.params,
&wrapper_source,
)
.map_err(|e| match e {
super::matcher::MatchError::NoArmMatched { tried } => ExpandError::NestedMatchFailure {
callee: callee_name.to_string(),
tried,
},
other => ExpandError::MalformedMacroCallArgs {
callee: callee_name.to_string(),
reason: other.to_string(),
},
})?;
expand_body_with_registry(
&callee_def.arms[arm_index].body,
&callee_bindings,
expansion_id.wrapping_add(depth + 1),
ExpansionContext::Type,
depth + 1,
Some(registry),
None,
)
}
fn expand_repetition(
inner: &[BodyToken],
separator: Option<&str>,
outer_bindings: &HashMap<String, Binding>,
expansion_id: u32,
depth: u32,
registry: Option<&DeclarativeMacroRegistry>,
out: &mut String,
) -> Result<(), ExpandError> {
if depth > MAX_EXPANSION_DEPTH {
return Err(ExpandError::RecursionLimit(MAX_EXPANSION_DEPTH));
}
let names = collect_substitutions(inner);
let mut length: Option<usize> = None;
let mut sequence_names: Vec<&String> = Vec::new();
for name in &names {
if let Some(Binding::Sequence(frags)) = outer_bindings.get(*name) {
match length {
None => length = Some(frags.len()),
Some(prev) if prev != frags.len() => {
return Err(ExpandError::InconsistentSequenceLength(prev, frags.len()));
}
_ => {}
}
sequence_names.push(*name);
}
}
let Some(length) = length else {
return Err(ExpandError::UnanchoredRepetition);
};
for i in 0..length {
if i > 0
&& let Some(sep) = separator
{
out.push_str(sep);
}
let mut scope: HashMap<String, Binding> = HashMap::new();
for (name, binding) in outer_bindings {
match binding {
Binding::Single(_) => {
scope.insert(name.clone(), binding.clone());
}
Binding::Sequence(frags) => {
if sequence_names.contains(&name) {
scope.insert(name.clone(), Binding::Single(frags[i].clone()));
}
}
}
}
render_tokens(inner, &scope, expansion_id, depth, registry, out)?;
}
Ok(())
}
fn collect_substitutions(tokens: &[BodyToken]) -> Vec<&String> {
let mut names = Vec::new();
for token in tokens {
match token {
BodyToken::Substitution(name) => names.push(name),
BodyToken::MacroCall { args, .. } => {
names.extend(collect_substitutions(args));
}
BodyToken::Repetition { body, .. } => {
names.extend(collect_substitutions(body));
}
BodyToken::Literal(_) => {}
}
}
names
}
fn rewrite_hygiene(source: String, expansion_id: u32) -> String {
let declared = hygiene::collect_declared_underscore_names(&source);
let declared: std::collections::HashSet<String> = declared
.into_iter()
.filter(|n| n.starts_with("__") && !n.contains('$'))
.collect();
if declared.is_empty() {
return source;
}
let suffix = format!("${}", expansion_id);
hygiene::rewrite_identifiers(&source, &declared, &suffix)
}
fn maybe_wrap_iife(source: String, context: ExpansionContext) -> String {
match context {
ExpansionContext::Statement | ExpansionContext::Type => source,
ExpansionContext::Expression => {
let trimmed = source.trim();
if !(trimmed.starts_with('{') && trimmed.ends_with('}')) {
return source;
}
match rewrite_block_with_return(trimmed) {
Some(rewritten) => format!("(() => {})()", rewritten),
None => format!("(() => {})()", trimmed),
}
}
}
}
fn rewrite_block_with_return(block_source: &str) -> Option<String> {
use oxc::allocator::Allocator;
use oxc::ast::ast::Statement;
use oxc::parser::Parser;
use oxc::span::SourceType;
let allocator = Allocator::default();
let parsed = Parser::new(&allocator, block_source, SourceType::ts()).parse();
let parsed = if parsed.errors.is_empty() {
parsed
} else {
let tsx = Parser::new(&allocator, block_source, SourceType::tsx()).parse();
if !tsx.errors.is_empty() {
return None;
}
tsx
};
let stmts = &parsed.program.body;
if stmts.len() != 1 {
return None;
}
let Statement::BlockStatement(block) = &stmts[0] else {
return None;
};
let last = block.body.last()?;
let expr_stmt = match last {
Statement::ExpressionStatement(es) => es,
_ => return None,
};
let es_span = &expr_stmt.span;
let es_text = &block_source[es_span.start as usize..es_span.end as usize];
if es_text.trim_end().ends_with(';') {
return None;
}
let (before, after) = block_source.split_at(es_span.start as usize);
Some(format!("{}return {}", before, after))
}
#[cfg(test)]
mod iife_wrap_tests {
use super::{ExpansionContext, maybe_wrap_iife, rewrite_block_with_return};
fn wrap(src: &str) -> String {
maybe_wrap_iife(src.to_string(), ExpansionContext::Expression)
}
#[test]
fn non_block_source_passes_through() {
assert_eq!(wrap("1 + 2"), "1 + 2");
assert_eq!(wrap("foo(x)"), "foo(x)");
}
#[test]
fn statement_context_never_wraps() {
let out = maybe_wrap_iife("{ x + 1 }".into(), ExpansionContext::Statement);
assert_eq!(out, "{ x + 1 }");
}
#[test]
fn type_context_never_wraps() {
let out = maybe_wrap_iife("{ a: number }".into(), ExpansionContext::Type);
assert_eq!(out, "{ a: number }");
}
#[test]
fn trailing_expression_gets_return() {
let rewritten = rewrite_block_with_return("{ const __a = 10; __a + 1 }").unwrap();
assert_eq!(rewritten, "{ const __a = 10; return __a + 1 }");
}
#[test]
fn trailing_expression_with_semicolon_is_untouched() {
assert!(rewrite_block_with_return("{ const __a = 10; __a + 1; }").is_none());
}
#[test]
fn trailing_return_statement_is_untouched() {
assert!(rewrite_block_with_return("{ const __a = 10; return __a + 1 }").is_none());
}
#[test]
fn trailing_throw_is_untouched() {
assert!(rewrite_block_with_return("{ throw new Error('x') }").is_none());
}
#[test]
fn empty_block_is_untouched() {
assert!(rewrite_block_with_return("{}").is_none());
}
#[test]
fn block_with_only_declaration_is_untouched() {
assert!(rewrite_block_with_return("{ const x = 1; }").is_none());
}
#[test]
fn template_literal_with_semicolon_does_not_split_wrongly() {
let src = "{ const s = `a;${1};b`; s.length }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const s = `a;${1};b`; return s.length }");
}
#[test]
fn regex_literal_with_semicolon_is_respected() {
let src = "{ const r = /a;b/; r.source.length }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const r = /a;b/; return r.source.length }");
}
#[test]
fn comment_with_semicolon_is_ignored() {
let src = "{ const x = 1; /* ; ignored ; */ x + 1 }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const x = 1; /* ; ignored ; */ return x + 1 }");
}
#[test]
fn parenthesized_object_literal_as_trailing_value() {
let src = "{ const k = 'a'; ({ [k]: 1 }) }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const k = 'a'; return ({ [k]: 1 }) }");
}
#[test]
fn trailing_arrow_expression() {
let src = "{ const y = 2; () => y }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const y = 2; return () => y }");
}
#[test]
fn trailing_satisfies_expression() {
let src = "{ const x = 1; x satisfies number }";
let rewritten = rewrite_block_with_return(src).unwrap();
assert_eq!(rewritten, "{ const x = 1; return x satisfies number }");
}
#[test]
fn trailing_if_statement_is_untouched() {
assert!(rewrite_block_with_return("{ const x = 1; if (x > 0) { f(); } }").is_none());
}
#[test]
fn nested_block_with_trailing_expression() {
assert!(rewrite_block_with_return("{ { inner } }").is_none());
}
#[test]
fn malformed_block_falls_back_to_verbatim_wrap() {
let out = wrap("{ const x = ;; }");
assert!(out.starts_with("(() => {"));
assert!(out.ends_with(")()"));
}
}