blisp/
macro.rs

1use crate::{parser::Expr, Pos};
2use alloc::{
3    collections::{btree_map::Entry, BTreeMap, LinkedList},
4    string::String,
5};
6
7#[derive(Debug)]
8pub struct MacroErr {
9    pub pos: Pos,
10    pub msg: &'static str,
11}
12
13/// `e1` is a pattern and `e2` is an expression to be matched.
14pub fn match_pattern(e1: &Expr, e2: &Expr, ctx: &mut BTreeMap<String, LinkedList<Expr>>) -> bool {
15    match (e1, e2) {
16        (Expr::ID(left, _), _) => {
17            if let Some('$') = left.chars().next() {
18                // If `e1` is `$id`, then a map from `$id` to `e1` is added to `ctx`.
19                let entry = ctx.entry(left.clone());
20                match entry {
21                    Entry::Vacant(ent) => {
22                        let mut list = LinkedList::new();
23                        list.push_back(e2.clone());
24                        ent.insert(list);
25
26                        true
27                    }
28                    Entry::Occupied(ent) => {
29                        let exprs = ent.get();
30                        if exprs.len() != 1 {
31                            false
32                        } else {
33                            eq_expr(exprs.front().unwrap(), e2)
34                        }
35                    }
36                }
37            } else if left == "_" {
38                true
39            } else {
40                matches!(e2, Expr::ID(right, _) if left == right)
41            }
42        }
43        (Expr::Bool(left, _), Expr::Bool(right, _)) => left == right,
44        (Expr::Char(left, _), Expr::Char(right, _)) => left == right,
45        (Expr::Num(left, _), Expr::Num(right, _)) => left == right,
46        (Expr::Str(left, _), Expr::Str(right, _)) => left == right,
47        (Expr::Tuple(left, _), Expr::Tuple(right, _)) => match_list(left, right, ctx),
48        (Expr::Apply(left, _), Expr::Apply(right, _)) => match_list(left, right, ctx),
49        (Expr::List(left, _), Expr::List(right, _)) => match_list(left, right, ctx),
50        _ => false,
51    }
52}
53
54pub fn match_list(
55    left: &LinkedList<Expr>,
56    right: &LinkedList<Expr>,
57    ctx: &mut BTreeMap<String, LinkedList<Expr>>,
58) -> bool {
59    let mut prev = None;
60    let mut it_left = left.iter();
61    let mut it_right = right.iter();
62
63    loop {
64        match (it_left.next(), it_right.next()) {
65            (Some(e1), Some(e2)) => {
66                if let Expr::ID(id, _) = e1 {
67                    if id == "..." {
68                        if let Some(key) = &prev {
69                            let Some(exprs) = ctx.get_mut(key) else {
70                                return false;
71                            };
72                            exprs.push_back(e2.clone());
73                            break;
74                        }
75                    } else {
76                        prev = Some(id.clone());
77                    }
78                }
79
80                if !match_pattern(e1, e2, ctx) {
81                    return false;
82                }
83            }
84            (Some(e1), None) => {
85                if let Expr::ID(id, _) = e1 {
86                    return id == "...";
87                } else {
88                    return false;
89                }
90            }
91            (None, Some(_)) => return false,
92            _ => return true,
93        }
94    }
95
96    let key = prev.unwrap();
97    let exprs = ctx.get_mut(&key).unwrap();
98    for expr in it_right {
99        exprs.push_back(expr.clone());
100    }
101
102    true
103}
104
105fn eq_expr(e1: &Expr, e2: &Expr) -> bool {
106    match (e1, e2) {
107        (Expr::ID(left, _), Expr::ID(right, _)) => left == right,
108        (Expr::Bool(left, _), Expr::Bool(right, _)) => left == right,
109        (Expr::Char(left, _), Expr::Char(right, _)) => left == right,
110        (Expr::Num(left, _), Expr::Num(right, _)) => left == right,
111        (Expr::Str(left, _), Expr::Str(right, _)) => left == right,
112        (Expr::Tuple(left, _), Expr::Tuple(right, _)) => eq_exprs(left, right),
113        (Expr::Apply(left, _), Expr::Apply(right, _)) => eq_exprs(left, right),
114        (Expr::List(left, _), Expr::List(right, _)) => eq_exprs(left, right),
115        _ => false,
116    }
117}
118
119fn eq_exprs(es1: &LinkedList<Expr>, es2: &LinkedList<Expr>) -> bool {
120    if es1.len() != es2.len() {
121        return false;
122    }
123
124    es1.iter().zip(es2.iter()).all(|(e1, e2)| eq_expr(e1, e2))
125}
126
127pub(crate) fn process_macros(exprs: &mut LinkedList<Expr>) -> Result<Macros, MacroErr> {
128    let macros = parse_macros(exprs)?;
129
130    for expr in exprs.iter_mut() {
131        apply_macros(&macros, expr)?;
132    }
133
134    Ok(macros)
135}
136
137pub(crate) fn apply(expr: &mut Expr, macros: &Macros) -> Result<(), MacroErr> {
138    apply_macros(macros, expr)
139}
140
141fn apply_macros(macros: &Macros, expr: &mut Expr) -> Result<(), MacroErr> {
142    if let Expr::Apply(exprs, _) = expr {
143        if let Some(Expr::ID(id, _)) = exprs.front() {
144            if id == "macro" {
145                return Ok(());
146            }
147        }
148    }
149
150    apply_macros_recursively(macros, expr, 0)
151}
152
153fn apply_macros_expr(
154    pos: Pos,
155    macros: &Macros,
156    expr: &Expr,
157    count: u8,
158) -> Result<Option<Expr>, MacroErr> {
159    if count == 0xff {
160        return Err(MacroErr {
161            pos,
162            msg: "too deep macro",
163        });
164    }
165
166    for (_, rules) in macros.iter() {
167        let mut ctx = BTreeMap::new();
168
169        for rule in rules.iter() {
170            if match_pattern(&rule.pattern, expr, &mut ctx) {
171                let expr = expand(pos, &rule.template, &ctx).pop_front().unwrap();
172
173                if let Some(e) = apply_macros_expr(pos, macros, &expr, count + 1)? {
174                    return Ok(Some(e));
175                } else {
176                    return Ok(Some(expr));
177                }
178            }
179        }
180    }
181
182    Ok(None)
183}
184
185fn apply_macros_recursively(macros: &Macros, expr: &mut Expr, count: u8) -> Result<(), MacroErr> {
186    if count == 0xff {
187        panic!("{}: too deep macro", expr.get_pos());
188    }
189
190    if let Some(e) = apply_macros_expr(expr.get_pos(), macros, expr, count)? {
191        *expr = e;
192    }
193
194    match expr {
195        Expr::Apply(exprs, _) | Expr::List(exprs, _) | Expr::Tuple(exprs, _) => {
196            for expr in exprs {
197                apply_macros_recursively(macros, expr, count + 1)?;
198            }
199        }
200        _ => (),
201    }
202
203    Ok(())
204}
205
206pub(crate) type Macros = BTreeMap<String, LinkedList<MacroRule>>;
207
208#[derive(Debug)]
209pub(crate) struct MacroRule {
210    pattern: Expr,
211    template: Expr,
212}
213
214fn parse_macros(exprs: &LinkedList<Expr>) -> Result<Macros, MacroErr> {
215    let mut result = BTreeMap::new();
216
217    for e in exprs.iter() {
218        if let Expr::Apply(es, _) = e {
219            let mut it = es.iter();
220
221            let Some(front) = it.next() else {
222                continue;
223            };
224
225            if let Expr::ID(id_macro, _) = front {
226                if id_macro == "macro" {
227                    let id = it.next();
228                    let Some(Expr::ID(id, _)) = id else {
229                        return Err(MacroErr {
230                            pos: e.get_pos(),
231                            msg: "invalid macro",
232                        });
233                    };
234
235                    let mut rules = LinkedList::new();
236                    for rule in it {
237                        let Expr::Apply(rule_exprs, _) = rule else {
238                            return Err(MacroErr {
239                                pos: rule.get_pos(),
240                                msg: "invalid macro rule",
241                            });
242                        };
243
244                        if rule_exprs.len() != 2 {
245                            return Err(MacroErr {
246                                pos: rule.get_pos(),
247                                msg: "the number of arguments of a macro rule is not 2",
248                            });
249                        }
250
251                        let mut rule_it = rule_exprs.iter();
252
253                        let mut pattern = rule_it.next().unwrap().clone();
254                        if let Expr::Apply(arguments, _) = &mut pattern {
255                            if let Some(Expr::ID(front, _)) = arguments.front_mut() {
256                                if front == "_" {
257                                    *front = id.clone();
258                                } else if front != id {
259                                    return Err(MacroErr {
260                                        pos: pattern.get_pos(),
261                                        msg: "invalid macro pattern",
262                                    });
263                                }
264                            }
265
266                            let template = rule_it.next().unwrap().clone();
267
268                            rules.push_back(MacroRule { pattern, template });
269                        } else {
270                            return Err(MacroErr {
271                                pos: pattern.get_pos(),
272                                msg: "invalid macro pattern",
273                            });
274                        };
275                    }
276
277                    if let Entry::Vacant(entry) = result.entry(id.clone()) {
278                        entry.insert(rules);
279                    } else {
280                        return Err(MacroErr {
281                            pos: e.get_pos(),
282                            msg: "multiply defined",
283                        });
284                    }
285                }
286            }
287        }
288    }
289
290    Ok(result)
291}
292
293fn expand(pos: Pos, template: &Expr, ctx: &BTreeMap<String, LinkedList<Expr>>) -> LinkedList<Expr> {
294    match template {
295        Expr::ID(id, _) => {
296            if let Some(exprs) = ctx.get(id) {
297                let expr = exprs.front().unwrap();
298                let mut result = LinkedList::new();
299                result.push_back(expr.clone());
300                result
301            } else {
302                let mut result: LinkedList<Expr> = LinkedList::new();
303                result.push_back(template.clone());
304                result
305            }
306        }
307        Expr::Apply(templates, _) => {
308            let exprs = expand_list(pos, templates, ctx);
309            let mut result = LinkedList::new();
310
311            // TODO: rename variables
312
313            result.push_back(Expr::Apply(exprs, pos));
314            result
315        }
316        Expr::List(templates, _) => {
317            let exprs = expand_list(pos, templates, ctx);
318            let mut result = LinkedList::new();
319            result.push_back(Expr::List(exprs, pos));
320            result
321        }
322        Expr::Tuple(templates, _) => {
323            let exprs = expand_list(pos, templates, ctx);
324            let mut result = LinkedList::new();
325            result.push_back(Expr::Tuple(exprs, pos));
326            result
327        }
328        expr => {
329            let mut result = LinkedList::new();
330            result.push_back(expr.clone());
331            result
332        }
333    }
334}
335
336fn expand_list(
337    pos: Pos,
338    templates: &LinkedList<Expr>,
339    ctx: &BTreeMap<String, LinkedList<Expr>>,
340) -> LinkedList<Expr> {
341    let mut result = LinkedList::new();
342
343    let mut prev = None;
344
345    for template in templates {
346        if let Expr::ID(id, _) = template {
347            if id == "..." {
348                if let Some(p) = &prev {
349                    if let Some(exprs) = ctx.get(p) {
350                        let mut it = exprs.iter();
351                        let _ = it.next();
352
353                        for expr in it {
354                            result.push_back(expr.clone());
355                        }
356                    } else {
357                        prev = None;
358                    }
359                } else {
360                    prev = None;
361                }
362
363                continue;
364            } else {
365                prev = Some(id.clone());
366            }
367        } else {
368            prev = None;
369        }
370
371        let mut exprs = expand(pos, template, ctx);
372        result.append(&mut exprs);
373    }
374
375    result
376}