use crate::frontend::error::SyntaxError;
use crate::frontend::span::Span;
use crate::frontend::token::{Token, TokenInfo};
use std::collections::HashMap;
const MAX_EXPANSION_DEPTH: u32 = 200;
pub struct MacroCtx<'r> {
pub(crate) gensym_counter: &'r mut u64,
#[allow(dead_code)]
pub(crate) registry: Option<&'r MacroRegistry>,
pub line: u32,
pub span: Span,
}
impl<'r> MacroCtx<'r> {
pub fn gensym(&mut self, prefix: &str) -> Box<str> {
*self.gensym_counter = self.gensym_counter.wrapping_add(1);
let n = *self.gensym_counter;
let p = if prefix.is_empty() { "g" } else { prefix };
format!("__lm_{n}_{p}").into_boxed_str()
}
}
pub trait Macro {
fn expand(
&self,
args: &[Vec<TokenInfo>],
ctx: &mut MacroCtx<'_>,
) -> Result<Vec<TokenInfo>, SyntaxError>;
}
pub struct MacroRegistry {
macros: HashMap<Box<str>, Box<dyn Macro>>,
pub(crate) gensym_counter: u64,
}
impl Default for MacroRegistry {
fn default() -> Self {
Self::new()
}
}
impl MacroRegistry {
pub fn new() -> Self {
MacroRegistry {
macros: HashMap::new(),
gensym_counter: 0,
}
}
pub fn with_builtins() -> Self {
let mut r = MacroRegistry::new();
r.register("quote", Box::new(builtins::QuoteMacro));
r.register("unquote", Box::new(builtins::UnquoteMacro));
r.register("if", Box::new(builtins::IfMacro));
r.register("gensym", Box::new(builtins::GensymMacro));
r
}
pub fn register(&mut self, name: &str, m: Box<dyn Macro>) {
self.macros.insert(name.into(), m);
}
pub fn get(&self, name: &str) -> Option<&dyn Macro> {
self.macros.get(name).map(|b| b.as_ref())
}
pub fn clear(&mut self) {
self.macros.clear();
}
pub fn expand(&mut self, input: Vec<TokenInfo>) -> Result<Vec<TokenInfo>, SyntaxError> {
let mut counter = self.gensym_counter;
let out = expand_stream(input, self, &mut counter, 0)?;
self.gensym_counter = counter;
Ok(out)
}
}
fn keyword_name(t: &Token) -> Option<&'static str> {
Some(match t {
Token::And => "and",
Token::Break => "break",
Token::Do => "do",
Token::Else => "else",
Token::Elseif => "elseif",
Token::End => "end",
Token::False => "false",
Token::For => "for",
Token::Function => "function",
Token::Global => "global",
Token::Goto => "goto",
Token::If => "if",
Token::In => "in",
Token::Local => "local",
Token::Nil => "nil",
Token::Not => "not",
Token::Or => "or",
Token::Repeat => "repeat",
Token::Return => "return",
Token::Then => "then",
Token::True => "true",
Token::Until => "until",
Token::While => "while",
_ => return None,
})
}
fn expand_stream(
input: Vec<TokenInfo>,
registry: &MacroRegistry,
gensym_counter: &mut u64,
depth: u32,
) -> Result<Vec<TokenInfo>, SyntaxError> {
if depth > MAX_EXPANSION_DEPTH {
let line = input.first().map(|t| t.line).unwrap_or(1);
return Err(SyntaxError::new(
line,
b"macro expansion depth exceeded (200) near '@'".to_vec(),
));
}
let mut out: Vec<TokenInfo> = Vec::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
match &input[i].tok {
Token::At => {
let inv_line = input[i].line;
let inv_start = input[i].span;
let name_idx = i + 1;
let name = match input.get(name_idx).map(|t| &t.tok) {
Some(Token::Name(n)) => n.clone(),
Some(other) => {
if let Some(kw) = keyword_name(other) {
kw.into()
} else {
return Err(SyntaxError::new(
inv_line,
b"macro name expected after '@'".to_vec(),
));
}
}
None => {
return Err(SyntaxError::new(
inv_line,
b"macro name expected after '@'".to_vec(),
));
}
};
let mut cursor = name_idx + 1;
let (raw_args, after) = collect_macro_args(&input, cursor, inv_line)?;
cursor = after;
let mut expanded_args: Vec<Vec<TokenInfo>> = Vec::with_capacity(raw_args.len());
for a in raw_args {
expanded_args.push(expand_stream(a, registry, gensym_counter, depth + 1)?);
}
let macro_impl = registry.get(&name).ok_or_else(|| {
SyntaxError::new(inv_line, format!("unknown macro '@{name}'").into_bytes())
})?;
let end_span = if cursor > 0 && cursor <= input.len() {
input[cursor - 1].span
} else {
inv_start
};
let full_span = Span::new(inv_start.start as usize, end_span.end as usize);
let mut ctx = MacroCtx {
gensym_counter,
registry: Some(registry),
line: inv_line,
span: full_span,
};
let mut expanded = macro_impl.expand(&expanded_args, &mut ctx)?;
expanded = expand_stream(expanded, registry, gensym_counter, depth + 1)?;
out.extend(expanded);
i = cursor;
}
Token::MacroBraceOpen => {
let block_line = input[i].line;
let (body, after) = collect_quote_block(&input, i, block_line)?;
let span = Span::new(
input[i].span.start as usize,
input[after - 1].span.end as usize,
);
let body_expanded = expand_stream(body, registry, gensym_counter, depth + 1)?;
out.push(TokenInfo {
tok: Token::MacroQuote(body_expanded.into_boxed_slice()),
span,
line: block_line,
});
i = after;
}
Token::MacroBraceClose => {
return Err(SyntaxError::new(
input[i].line,
b"unexpected '}@' (no matching '@{')".to_vec(),
));
}
Token::MacroQuote(_) => {
return Err(SyntaxError::new(
input[i].line,
b"stray macro-quote token left in stream (forgot '@unquote'?)".to_vec(),
));
}
_ => {
out.push(input[i].clone());
i += 1;
}
}
}
Ok(out)
}
fn collect_macro_args(
input: &[TokenInfo],
start: usize,
inv_line: u32,
) -> Result<(Vec<Vec<TokenInfo>>, usize), SyntaxError> {
if start >= input.len() {
return Ok((Vec::new(), start));
}
match &input[start].tok {
Token::LParen => collect_paren_args(input, start, inv_line),
Token::LBrace => {
let (body, after) = collect_brace_body(input, start, inv_line)?;
Ok((vec![body], after))
}
Token::MacroBraceOpen => {
let (body, after) = collect_quote_block(input, start, inv_line)?;
Ok((vec![body], after))
}
_ => {
Ok((Vec::new(), start))
}
}
}
fn collect_paren_args(
input: &[TokenInfo],
lparen_idx: usize,
inv_line: u32,
) -> Result<(Vec<Vec<TokenInfo>>, usize), SyntaxError> {
debug_assert!(matches!(input[lparen_idx].tok, Token::LParen));
let mut depth_paren = 1u32;
let mut depth_brace = 0u32;
let mut depth_bracket = 0u32;
let mut depth_quote = 0u32;
let mut args: Vec<Vec<TokenInfo>> = Vec::new();
let mut cur: Vec<TokenInfo> = Vec::new();
let mut i = lparen_idx + 1;
while i < input.len() {
match &input[i].tok {
Token::LParen => {
depth_paren += 1;
cur.push(input[i].clone());
}
Token::RParen => {
depth_paren -= 1;
if depth_paren == 0 && depth_brace == 0 && depth_bracket == 0 && depth_quote == 0 {
if !cur.is_empty() || !args.is_empty() {
args.push(std::mem::take(&mut cur));
}
return Ok((args, i + 1));
}
cur.push(input[i].clone());
}
Token::LBrace => {
depth_brace += 1;
cur.push(input[i].clone());
}
Token::RBrace => {
if depth_brace == 0 {
return Err(SyntaxError::new(
input[i].line,
b"unexpected '}' inside macro arg list".to_vec(),
));
}
depth_brace -= 1;
cur.push(input[i].clone());
}
Token::LBracket => {
depth_bracket += 1;
cur.push(input[i].clone());
}
Token::RBracket => {
depth_bracket = depth_bracket.saturating_sub(1);
cur.push(input[i].clone());
}
Token::MacroBraceOpen => {
depth_quote += 1;
cur.push(input[i].clone());
}
Token::MacroBraceClose => {
if depth_quote == 0 {
return Err(SyntaxError::new(
input[i].line,
b"unexpected '}@' inside macro arg list".to_vec(),
));
}
depth_quote -= 1;
cur.push(input[i].clone());
}
Token::Comma
if depth_paren == 1
&& depth_brace == 0
&& depth_bracket == 0
&& depth_quote == 0 =>
{
args.push(std::mem::take(&mut cur));
}
_ => cur.push(input[i].clone()),
}
i += 1;
}
Err(SyntaxError::new(
inv_line,
b"unterminated macro arg list (missing ')')".to_vec(),
))
}
fn collect_brace_body(
input: &[TokenInfo],
lbrace_idx: usize,
inv_line: u32,
) -> Result<(Vec<TokenInfo>, usize), SyntaxError> {
debug_assert!(matches!(input[lbrace_idx].tok, Token::LBrace));
let mut depth = 1u32;
let mut body: Vec<TokenInfo> = Vec::new();
let mut i = lbrace_idx + 1;
while i < input.len() {
match &input[i].tok {
Token::LBrace => {
depth += 1;
body.push(input[i].clone());
}
Token::RBrace => {
depth -= 1;
if depth == 0 {
return Ok((body, i + 1));
}
body.push(input[i].clone());
}
_ => body.push(input[i].clone()),
}
i += 1;
}
Err(SyntaxError::new(
inv_line,
b"unterminated macro brace body (missing '}')".to_vec(),
))
}
fn collect_quote_block(
input: &[TokenInfo],
open_idx: usize,
inv_line: u32,
) -> Result<(Vec<TokenInfo>, usize), SyntaxError> {
debug_assert!(matches!(input[open_idx].tok, Token::MacroBraceOpen));
let mut depth = 1u32;
let mut body: Vec<TokenInfo> = Vec::new();
let mut i = open_idx + 1;
while i < input.len() {
match &input[i].tok {
Token::MacroBraceOpen => {
depth += 1;
body.push(input[i].clone());
}
Token::MacroBraceClose => {
depth -= 1;
if depth == 0 {
return Ok((body, i + 1));
}
body.push(input[i].clone());
}
_ => body.push(input[i].clone()),
}
i += 1;
}
Err(SyntaxError::new(
inv_line,
b"unterminated quote block (missing '}@')".to_vec(),
))
}
mod builtins {
use super::*;
pub(super) struct QuoteMacro;
impl Macro for QuoteMacro {
fn expand(
&self,
args: &[Vec<TokenInfo>],
ctx: &mut MacroCtx<'_>,
) -> Result<Vec<TokenInfo>, SyntaxError> {
if args.len() != 1 {
return Err(SyntaxError::new(
ctx.line,
format!(
"@quote expects exactly one brace body, got {} args",
args.len()
)
.into_bytes(),
));
}
Ok(args[0].clone())
}
}
pub(super) struct UnquoteMacro;
impl Macro for UnquoteMacro {
fn expand(
&self,
args: &[Vec<TokenInfo>],
ctx: &mut MacroCtx<'_>,
) -> Result<Vec<TokenInfo>, SyntaxError> {
if args.len() != 1 {
return Err(SyntaxError::new(
ctx.line,
format!("@unquote expects 1 arg, got {}", args.len()).into_bytes(),
));
}
let a = &args[0];
if a.len() == 1 {
if let Token::MacroQuote(body) = &a[0].tok {
return Ok(body.to_vec());
}
}
Ok(a.clone())
}
}
pub(super) struct IfMacro;
impl Macro for IfMacro {
fn expand(
&self,
args: &[Vec<TokenInfo>],
ctx: &mut MacroCtx<'_>,
) -> Result<Vec<TokenInfo>, SyntaxError> {
if args.len() < 2 || args.len() > 3 {
return Err(SyntaxError::new(
ctx.line,
format!("@if expects (cond, then[, else]) — got {} args", args.len())
.into_bytes(),
));
}
let cond_truthy = eval_const_cond(&args[0], ctx.line)?;
let chosen = if cond_truthy {
&args[1]
} else if args.len() == 3 {
&args[2]
} else {
&EMPTY_ARM
};
if chosen.len() == 1 {
if let Token::MacroQuote(body) = &chosen[0].tok {
return Ok(body.to_vec());
}
}
Ok(chosen.clone())
}
}
static EMPTY_ARM: Vec<TokenInfo> = Vec::new();
fn eval_const_cond(tokens: &[TokenInfo], line: u32) -> Result<bool, SyntaxError> {
if tokens.is_empty() {
return Err(SyntaxError::new(line, b"@if: empty condition".to_vec()));
}
if tokens.len() == 1 {
return match &tokens[0].tok {
Token::True => Ok(true),
Token::False => Ok(false),
Token::Int(i) => Ok(*i != 0),
Token::Nil => Ok(false),
_ => Err(SyntaxError::new(
line,
b"@if: cond must be true/false/integer/literal-eq".to_vec(),
)),
};
}
if tokens.len() == 3 {
let op = &tokens[1].tok;
let eq = matches!(op, Token::Eq);
let ne = matches!(op, Token::Ne);
if eq || ne {
let l = literal_eq(&tokens[0].tok, &tokens[2].tok, line)?;
return Ok(if eq { l } else { !l });
}
}
Err(SyntaxError::new(
line,
b"@if: unsupported condition shape (use true/false/int/lit==lit)".to_vec(),
))
}
fn literal_eq(a: &Token, b: &Token, line: u32) -> Result<bool, SyntaxError> {
Ok(match (a, b) {
(Token::Int(x), Token::Int(y)) => x == y,
(Token::Float(x), Token::Float(y)) => x == y,
(Token::Int(x), Token::Float(y)) | (Token::Float(y), Token::Int(x)) => {
(*x as f64) == *y
}
(Token::Str(x), Token::Str(y)) => x == y,
(Token::True, Token::True)
| (Token::False, Token::False)
| (Token::Nil, Token::Nil) => true,
(Token::True, Token::False) | (Token::False, Token::True) => false,
_ => {
return Err(SyntaxError::new(
line,
b"@if: only int/float/string/bool/nil literals comparable".to_vec(),
));
}
})
}
pub(super) struct GensymMacro;
impl Macro for GensymMacro {
fn expand(
&self,
args: &[Vec<TokenInfo>],
ctx: &mut MacroCtx<'_>,
) -> Result<Vec<TokenInfo>, SyntaxError> {
let prefix = if args.is_empty() {
String::new()
} else if args.len() == 1 && args[0].len() == 1 {
match &args[0][0].tok {
Token::Str(bytes) => String::from_utf8_lossy(bytes).into_owned(),
Token::Name(n) => n.to_string(),
_ => {
return Err(SyntaxError::new(
ctx.line,
b"@gensym: prefix must be a string literal or name".to_vec(),
));
}
}
} else {
return Err(SyntaxError::new(
ctx.line,
b"@gensym: expected 0 or 1 args".to_vec(),
));
};
let name = ctx.gensym(&prefix);
Ok(vec![TokenInfo {
tok: Token::Name(name),
span: ctx.span,
line: ctx.line,
}])
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::frontend::lexer::Lexer;
use crate::version::LuaVersion;
fn lex(src: &str, v: LuaVersion) -> Vec<TokenInfo> {
let mut lex = Lexer::new(src.as_bytes(), v);
let mut out = Vec::new();
loop {
let t = lex.next_token().expect("lex");
let eof = matches!(t.tok, Token::Eof);
if eof {
break;
}
out.push(t);
}
out
}
#[test]
fn gensym_is_unique() {
let mut r = MacroRegistry::with_builtins();
let toks = lex("local a = @gensym local b = @gensym", LuaVersion::MacroLua);
let out = r.expand(toks).unwrap();
let gensyms: Vec<String> = out
.iter()
.filter_map(|t| {
if let Token::Name(n) = &t.tok {
if n.starts_with("__lm_") {
Some(n.to_string())
} else {
None
}
} else {
None
}
})
.collect();
assert_eq!(gensyms.len(), 2, "expected 2 gensyms, got {gensyms:?}");
assert_ne!(gensyms[0], gensyms[1], "gensyms must be unique");
}
#[test]
fn unknown_macro_errors() {
let mut r = MacroRegistry::with_builtins();
let toks = lex("@nope(1)", LuaVersion::MacroLua);
let err = r.expand(toks).unwrap_err();
assert!(
String::from_utf8_lossy(&err.msg).contains("unknown macro"),
"got: {}",
err.msg_str()
);
}
#[test]
fn quote_splices_body() {
let mut r = MacroRegistry::with_builtins();
let toks = lex("local x = @quote{ 42 }", LuaVersion::MacroLua);
let out = r.expand(toks).unwrap();
let has_42 = out.iter().any(|t| matches!(t.tok, Token::Int(42)));
assert!(has_42, "@quote{{42}} should splice Int(42); got {out:?}");
assert!(
out.iter().all(|t| !matches!(
t.tok,
Token::At | Token::MacroBraceOpen | Token::MacroBraceClose
)),
"expander left @-tokens: {out:?}"
);
}
}