use crate::ast::Expr;
use crate::error::{LambdustError, Result};
use std::collections::HashMap;
pub type MacroTransformer = fn(&[Expr]) -> Result<Expr>;
#[derive(Debug, Clone)]
pub struct Macro {
pub name: String,
pub transformer: MacroTransformer,
pub is_syntax_rules: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Pattern {
Literal(String),
Variable(String),
List(Vec<Pattern>),
Ellipsis(Box<Pattern>),
Dotted(Vec<Pattern>, Box<Pattern>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Template {
Literal(String),
Variable(String),
List(Vec<Template>),
Ellipsis(Box<Template>),
Dotted(Vec<Template>, Box<Template>),
}
#[derive(Debug, Clone)]
pub struct SyntaxRule {
pub pattern: Pattern,
pub template: Template,
}
#[derive(Debug, Clone)]
pub struct MacroExpander {
macros: HashMap<String, Macro>,
}
impl MacroExpander {
pub fn new() -> Self {
let mut expander = MacroExpander {
macros: HashMap::new(),
};
expander.add_builtin_macros();
expander
}
fn add_builtin_macros(&mut self) {
self.macros.insert("let".to_string(), Macro {
name: "let".to_string(),
transformer: expand_let,
is_syntax_rules: false,
});
self.macros.insert("let*".to_string(), Macro {
name: "let*".to_string(),
transformer: expand_let_star,
is_syntax_rules: false,
});
self.macros.insert("letrec".to_string(), Macro {
name: "letrec".to_string(),
transformer: expand_letrec,
is_syntax_rules: false,
});
self.macros.insert("cond".to_string(), Macro {
name: "cond".to_string(),
transformer: expand_cond,
is_syntax_rules: false,
});
self.macros.insert("case".to_string(), Macro {
name: "case".to_string(),
transformer: expand_case,
is_syntax_rules: false,
});
self.macros.insert("when".to_string(), Macro {
name: "when".to_string(),
transformer: expand_when,
is_syntax_rules: false,
});
self.macros.insert("unless".to_string(), Macro {
name: "unless".to_string(),
transformer: expand_unless,
is_syntax_rules: false,
});
}
pub fn is_macro_call(&self, expr: &Expr) -> bool {
match expr {
Expr::List(exprs) if !exprs.is_empty() => {
match &exprs[0] {
Expr::Variable(name) => self.macros.contains_key(name),
_ => false,
}
}
_ => false,
}
}
pub fn expand_macro(&self, expr: Expr) -> Result<Expr> {
match &expr {
Expr::List(exprs) if !exprs.is_empty() => {
match &exprs[0] {
Expr::Variable(name) => {
if let Some(macro_def) = self.macros.get(name) {
let args = &exprs[1..];
(macro_def.transformer)(args)
} else {
Ok(expr) }
}
_ => Ok(expr), }
}
_ => Ok(expr), }
}
pub fn expand_all(&self, expr: Expr) -> Result<Expr> {
match expr {
Expr::List(exprs) => {
let expanded = if self.is_macro_call(&Expr::List(exprs.clone())) {
self.expand_macro(Expr::List(exprs))?
} else {
Expr::List(exprs)
};
match expanded {
Expr::List(exprs) => {
let mut expanded_exprs = Vec::new();
for expr in exprs {
expanded_exprs.push(self.expand_all(expr)?);
}
Ok(Expr::List(expanded_exprs))
}
other => self.expand_all(other),
}
}
Expr::Quote(expr) => Ok(Expr::Quote(expr)), Expr::Quasiquote(expr) => {
Ok(Expr::Quasiquote(Box::new(self.expand_all(*expr)?)))
}
Expr::Unquote(expr) => Ok(Expr::Unquote(Box::new(self.expand_all(*expr)?))),
Expr::UnquoteSplicing(expr) => Ok(Expr::UnquoteSplicing(Box::new(self.expand_all(*expr)?))),
Expr::DottedList(exprs, tail) => {
let mut expanded_exprs = Vec::new();
for expr in exprs {
expanded_exprs.push(self.expand_all(expr)?);
}
let expanded_tail = self.expand_all(*tail)?;
Ok(Expr::DottedList(expanded_exprs, Box::new(expanded_tail)))
}
other => Ok(other), }
}
pub fn define_macro(&mut self, name: String, transformer: MacroTransformer) {
self.macros.insert(name.clone(), Macro {
name,
transformer,
is_syntax_rules: false,
});
}
}
impl Default for MacroExpander {
fn default() -> Self {
Self::new()
}
}
fn expand_let(args: &[Expr]) -> Result<Expr> {
if args.len() < 2 {
return Err(LambdustError::SyntaxError("let: too few arguments".to_string()));
}
let bindings = &args[0];
let body = &args[1..];
let binding_list = match bindings {
Expr::List(bindings) => bindings,
_ => return Err(LambdustError::SyntaxError("let: bindings must be a list".to_string())),
};
let mut vars = Vec::new();
let mut vals = Vec::new();
for binding in binding_list {
match binding {
Expr::List(parts) if parts.len() == 2 => {
match &parts[0] {
Expr::Variable(var) => {
vars.push(Expr::Variable(var.clone()));
vals.push(parts[1].clone());
}
_ => return Err(LambdustError::SyntaxError("let: binding variable must be a symbol".to_string())),
}
}
_ => return Err(LambdustError::SyntaxError("let: each binding must be (var val)".to_string())),
}
}
let lambda = Expr::List({
let mut lambda_expr = vec![
Expr::Variable("lambda".to_string()),
Expr::List(vars),
];
lambda_expr.extend(body.iter().cloned());
lambda_expr
});
let mut application = vec![lambda];
application.extend(vals);
Ok(Expr::List(application))
}
fn expand_let_star(args: &[Expr]) -> Result<Expr> {
if args.len() < 2 {
return Err(LambdustError::SyntaxError("let*: too few arguments".to_string()));
}
let bindings = &args[0];
let body = &args[1..];
let binding_list = match bindings {
Expr::List(bindings) => bindings,
_ => return Err(LambdustError::SyntaxError("let*: bindings must be a list".to_string())),
};
if binding_list.is_empty() {
return Ok(Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(body.iter().cloned());
begin_expr
}));
}
let mut result = Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(body.iter().cloned());
begin_expr
});
for binding in binding_list.iter().rev() {
result = Expr::List(vec![
Expr::Variable("let".to_string()),
Expr::List(vec![binding.clone()]),
result,
]);
}
Ok(result)
}
fn expand_letrec(args: &[Expr]) -> Result<Expr> {
if args.len() < 2 {
return Err(LambdustError::SyntaxError("letrec: too few arguments".to_string()));
}
let bindings = &args[0];
let body = &args[1..];
let binding_list = match bindings {
Expr::List(bindings) => bindings,
_ => return Err(LambdustError::SyntaxError("letrec: bindings must be a list".to_string())),
};
let mut vars = Vec::new();
let mut assignments = Vec::new();
let mut undefined_vals = Vec::new();
for binding in binding_list {
match binding {
Expr::List(parts) if parts.len() == 2 => {
match &parts[0] {
Expr::Variable(var) => {
vars.push(Expr::Variable(var.clone()));
assignments.push(Expr::List(vec![
Expr::Variable("set!".to_string()),
Expr::Variable(var.clone()),
parts[1].clone(),
]));
undefined_vals.push(Expr::Variable("#f".to_string())); }
_ => return Err(LambdustError::SyntaxError("letrec: binding variable must be a symbol".to_string())),
}
}
_ => return Err(LambdustError::SyntaxError("letrec: each binding must be (var val)".to_string())),
}
}
let mut lambda_body = assignments;
lambda_body.extend(body.iter().cloned());
let lambda = Expr::List({
let mut lambda_expr = vec![
Expr::Variable("lambda".to_string()),
Expr::List(vars),
];
lambda_expr.extend(lambda_body);
lambda_expr
});
let mut application = vec![lambda];
application.extend(undefined_vals);
Ok(Expr::List(application))
}
fn expand_cond(args: &[Expr]) -> Result<Expr> {
if args.is_empty() {
return Ok(Expr::Variable("#f".to_string())); }
expand_cond_clauses(args)
}
fn expand_cond_clauses(clauses: &[Expr]) -> Result<Expr> {
if clauses.is_empty() {
return Ok(Expr::Variable("#f".to_string()));
}
let clause = &clauses[0];
let rest = &clauses[1..];
match clause {
Expr::List(parts) if !parts.is_empty() => {
let test = &parts[0];
let exprs = &parts[1..];
if let Expr::Variable(name) = test {
if name == "else" {
if !rest.is_empty() {
return Err(LambdustError::SyntaxError("cond: else clause must be last".to_string()));
}
return if exprs.is_empty() {
Ok(Expr::Variable("#t".to_string()))
} else {
Ok(Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(exprs.iter().cloned());
begin_expr
}))
};
}
}
let then_expr = if exprs.is_empty() {
test.clone()
} else {
Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(exprs.iter().cloned());
begin_expr
})
};
let else_expr = expand_cond_clauses(rest)?;
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
test.clone(),
then_expr,
else_expr,
]))
}
_ => Err(LambdustError::SyntaxError("cond: clause must be a list".to_string())),
}
}
fn expand_case(args: &[Expr]) -> Result<Expr> {
if args.len() < 2 {
return Err(LambdustError::SyntaxError("case: too few arguments".to_string()));
}
let key = &args[0];
let clauses = &args[1..];
let key_var = "__case_key__";
let cond_clauses = expand_case_clauses(key_var, clauses)?;
Ok(Expr::List(vec![
Expr::Variable("let".to_string()),
Expr::List(vec![
Expr::List(vec![
Expr::Variable(key_var.to_string()),
key.clone(),
])
]),
cond_clauses,
]))
}
fn expand_case_clauses(key_var: &str, clauses: &[Expr]) -> Result<Expr> {
if clauses.is_empty() {
return Ok(Expr::Variable("#f".to_string()));
}
let clause = &clauses[0];
let rest = &clauses[1..];
match clause {
Expr::List(parts) if parts.len() >= 2 => {
let datum_list = &parts[0];
let exprs = &parts[1..];
if let Expr::Variable(name) = datum_list {
if name == "else" {
if !rest.is_empty() {
return Err(LambdustError::SyntaxError("case: else clause must be last".to_string()));
}
return Ok(Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(exprs.iter().cloned());
begin_expr
}));
}
}
let test = match datum_list {
Expr::List(datums) => {
let mut or_expr = vec![Expr::Variable("or".to_string())];
for datum in datums {
or_expr.push(Expr::List(vec![
Expr::Variable("eqv?".to_string()),
Expr::Variable(key_var.to_string()),
Expr::Quote(Box::new(datum.clone())),
]));
}
Expr::List(or_expr)
}
single_datum => {
Expr::List(vec![
Expr::Variable("eqv?".to_string()),
Expr::Variable(key_var.to_string()),
Expr::Quote(Box::new(single_datum.clone())),
])
}
};
let then_expr = Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(exprs.iter().cloned());
begin_expr
});
let else_expr = expand_case_clauses(key_var, rest)?;
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
test,
then_expr,
else_expr,
]))
}
_ => Err(LambdustError::SyntaxError("case: clause must be a list".to_string())),
}
}
fn expand_when(args: &[Expr]) -> Result<Expr> {
if args.is_empty() {
return Err(LambdustError::SyntaxError("when: too few arguments".to_string()));
}
let test = &args[0];
let body = &args[1..];
if body.is_empty() {
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
test.clone(),
Expr::Variable("#f".to_string()),
]))
} else {
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
test.clone(),
Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(body.iter().cloned());
begin_expr
}),
]))
}
}
fn expand_unless(args: &[Expr]) -> Result<Expr> {
if args.is_empty() {
return Err(LambdustError::SyntaxError("unless: too few arguments".to_string()));
}
let test = &args[0];
let body = &args[1..];
let negated_test = Expr::List(vec![
Expr::Variable("not".to_string()),
test.clone(),
]);
if body.is_empty() {
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
negated_test,
Expr::Variable("#f".to_string()),
]))
} else {
Ok(Expr::List(vec![
Expr::Variable("if".to_string()),
negated_test,
Expr::List({
let mut begin_expr = vec![Expr::Variable("begin".to_string())];
begin_expr.extend(body.iter().cloned());
begin_expr
}),
]))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexer::tokenize;
use crate::parser::parse;
fn parse_expr(input: &str) -> Expr {
let tokens = tokenize(input).unwrap();
parse(tokens).unwrap()
}
#[test]
fn test_expand_let() {
let expander = MacroExpander::new();
let expr = parse_expr("(let ((x 1) (y 2)) (+ x y))");
let expanded = expander.expand_macro(expr).unwrap();
match expanded {
Expr::List(exprs) => {
assert_eq!(exprs.len(), 3);
assert!(matches!(exprs[0], Expr::List(_))); }
_ => panic!("Expected list expression"),
}
}
#[test]
fn test_expand_cond() {
let expander = MacroExpander::new();
let expr = parse_expr("(cond ((< x 0) 'negative) ((> x 0) 'positive) (else 'zero))");
let expanded = expander.expand_macro(expr).unwrap();
match expanded {
Expr::List(exprs) => {
assert_eq!(exprs[0], Expr::Variable("if".to_string()));
}
_ => panic!("Expected if expression"),
}
}
#[test]
fn test_expand_when() {
let expander = MacroExpander::new();
let expr = parse_expr("(when (> x 0) (display x) (newline))");
let expanded = expander.expand_macro(expr).unwrap();
match expanded {
Expr::List(exprs) => {
assert_eq!(exprs[0], Expr::Variable("if".to_string()));
assert_eq!(exprs.len(), 3);
}
_ => panic!("Expected if expression"),
}
}
#[test]
fn test_is_macro_call() {
let expander = MacroExpander::new();
let let_expr = parse_expr("(let ((x 1)) x)");
let regular_expr = parse_expr("(+ 1 2)");
assert!(expander.is_macro_call(&let_expr));
assert!(!expander.is_macro_call(®ular_expr));
}
}