1use crate::{parser::Expr, Pos};
2use alloc::{
3 collections::{btree_map::Entry, BTreeMap, LinkedList},
4 string::String,
5};
6
7#[derive(Debug)]
8pub struct MacroErr {
9 pub pos: Pos,
10 pub msg: &'static str,
11}
12
13pub fn match_pattern(e1: &Expr, e2: &Expr, ctx: &mut BTreeMap<String, LinkedList<Expr>>) -> bool {
15 match (e1, e2) {
16 (Expr::ID(left, _), _) => {
17 if let Some('$') = left.chars().next() {
18 let entry = ctx.entry(left.clone());
20 match entry {
21 Entry::Vacant(ent) => {
22 let mut list = LinkedList::new();
23 list.push_back(e2.clone());
24 ent.insert(list);
25
26 true
27 }
28 Entry::Occupied(ent) => {
29 let exprs = ent.get();
30 if exprs.len() != 1 {
31 false
32 } else {
33 eq_expr(exprs.front().unwrap(), e2)
34 }
35 }
36 }
37 } else if left == "_" {
38 true
39 } else {
40 matches!(e2, Expr::ID(right, _) if left == right)
41 }
42 }
43 (Expr::Bool(left, _), Expr::Bool(right, _)) => left == right,
44 (Expr::Char(left, _), Expr::Char(right, _)) => left == right,
45 (Expr::Num(left, _), Expr::Num(right, _)) => left == right,
46 (Expr::Str(left, _), Expr::Str(right, _)) => left == right,
47 (Expr::Tuple(left, _), Expr::Tuple(right, _)) => match_list(left, right, ctx),
48 (Expr::Apply(left, _), Expr::Apply(right, _)) => match_list(left, right, ctx),
49 (Expr::List(left, _), Expr::List(right, _)) => match_list(left, right, ctx),
50 _ => false,
51 }
52}
53
54pub fn match_list(
55 left: &LinkedList<Expr>,
56 right: &LinkedList<Expr>,
57 ctx: &mut BTreeMap<String, LinkedList<Expr>>,
58) -> bool {
59 let mut prev = None;
60 let mut it_left = left.iter();
61 let mut it_right = right.iter();
62
63 loop {
64 match (it_left.next(), it_right.next()) {
65 (Some(e1), Some(e2)) => {
66 if let Expr::ID(id, _) = e1 {
67 if id == "..." {
68 if let Some(key) = &prev {
69 let Some(exprs) = ctx.get_mut(key) else {
70 return false;
71 };
72 exprs.push_back(e2.clone());
73 break;
74 }
75 } else {
76 prev = Some(id.clone());
77 }
78 }
79
80 if !match_pattern(e1, e2, ctx) {
81 return false;
82 }
83 }
84 (Some(e1), None) => {
85 if let Expr::ID(id, _) = e1 {
86 return id == "...";
87 } else {
88 return false;
89 }
90 }
91 (None, Some(_)) => return false,
92 _ => return true,
93 }
94 }
95
96 let key = prev.unwrap();
97 let exprs = ctx.get_mut(&key).unwrap();
98 for expr in it_right {
99 exprs.push_back(expr.clone());
100 }
101
102 true
103}
104
105fn eq_expr(e1: &Expr, e2: &Expr) -> bool {
106 match (e1, e2) {
107 (Expr::ID(left, _), Expr::ID(right, _)) => left == right,
108 (Expr::Bool(left, _), Expr::Bool(right, _)) => left == right,
109 (Expr::Char(left, _), Expr::Char(right, _)) => left == right,
110 (Expr::Num(left, _), Expr::Num(right, _)) => left == right,
111 (Expr::Str(left, _), Expr::Str(right, _)) => left == right,
112 (Expr::Tuple(left, _), Expr::Tuple(right, _)) => eq_exprs(left, right),
113 (Expr::Apply(left, _), Expr::Apply(right, _)) => eq_exprs(left, right),
114 (Expr::List(left, _), Expr::List(right, _)) => eq_exprs(left, right),
115 _ => false,
116 }
117}
118
119fn eq_exprs(es1: &LinkedList<Expr>, es2: &LinkedList<Expr>) -> bool {
120 if es1.len() != es2.len() {
121 return false;
122 }
123
124 es1.iter().zip(es2.iter()).all(|(e1, e2)| eq_expr(e1, e2))
125}
126
127pub(crate) fn process_macros(exprs: &mut LinkedList<Expr>) -> Result<Macros, MacroErr> {
128 let macros = parse_macros(exprs)?;
129
130 for expr in exprs.iter_mut() {
131 apply_macros(¯os, expr)?;
132 }
133
134 Ok(macros)
135}
136
137pub(crate) fn apply(expr: &mut Expr, macros: &Macros) -> Result<(), MacroErr> {
138 apply_macros(macros, expr)
139}
140
141fn apply_macros(macros: &Macros, expr: &mut Expr) -> Result<(), MacroErr> {
142 if let Expr::Apply(exprs, _) = expr {
143 if let Some(Expr::ID(id, _)) = exprs.front() {
144 if id == "macro" {
145 return Ok(());
146 }
147 }
148 }
149
150 apply_macros_recursively(macros, expr, 0)
151}
152
153fn apply_macros_expr(
154 pos: Pos,
155 macros: &Macros,
156 expr: &Expr,
157 count: u8,
158) -> Result<Option<Expr>, MacroErr> {
159 if count == 0xff {
160 return Err(MacroErr {
161 pos,
162 msg: "too deep macro",
163 });
164 }
165
166 for (_, rules) in macros.iter() {
167 let mut ctx = BTreeMap::new();
168
169 for rule in rules.iter() {
170 if match_pattern(&rule.pattern, expr, &mut ctx) {
171 let expr = expand(pos, &rule.template, &ctx).pop_front().unwrap();
172
173 if let Some(e) = apply_macros_expr(pos, macros, &expr, count + 1)? {
174 return Ok(Some(e));
175 } else {
176 return Ok(Some(expr));
177 }
178 }
179 }
180 }
181
182 Ok(None)
183}
184
185fn apply_macros_recursively(macros: &Macros, expr: &mut Expr, count: u8) -> Result<(), MacroErr> {
186 if count == 0xff {
187 panic!("{}: too deep macro", expr.get_pos());
188 }
189
190 if let Some(e) = apply_macros_expr(expr.get_pos(), macros, expr, count)? {
191 *expr = e;
192 }
193
194 match expr {
195 Expr::Apply(exprs, _) | Expr::List(exprs, _) | Expr::Tuple(exprs, _) => {
196 for expr in exprs {
197 apply_macros_recursively(macros, expr, count + 1)?;
198 }
199 }
200 _ => (),
201 }
202
203 Ok(())
204}
205
206pub(crate) type Macros = BTreeMap<String, LinkedList<MacroRule>>;
207
208#[derive(Debug)]
209pub(crate) struct MacroRule {
210 pattern: Expr,
211 template: Expr,
212}
213
214fn parse_macros(exprs: &LinkedList<Expr>) -> Result<Macros, MacroErr> {
215 let mut result = BTreeMap::new();
216
217 for e in exprs.iter() {
218 if let Expr::Apply(es, _) = e {
219 let mut it = es.iter();
220
221 let Some(front) = it.next() else {
222 continue;
223 };
224
225 if let Expr::ID(id_macro, _) = front {
226 if id_macro == "macro" {
227 let id = it.next();
228 let Some(Expr::ID(id, _)) = id else {
229 return Err(MacroErr {
230 pos: e.get_pos(),
231 msg: "invalid macro",
232 });
233 };
234
235 let mut rules = LinkedList::new();
236 for rule in it {
237 let Expr::Apply(rule_exprs, _) = rule else {
238 return Err(MacroErr {
239 pos: rule.get_pos(),
240 msg: "invalid macro rule",
241 });
242 };
243
244 if rule_exprs.len() != 2 {
245 return Err(MacroErr {
246 pos: rule.get_pos(),
247 msg: "the number of arguments of a macro rule is not 2",
248 });
249 }
250
251 let mut rule_it = rule_exprs.iter();
252
253 let mut pattern = rule_it.next().unwrap().clone();
254 if let Expr::Apply(arguments, _) = &mut pattern {
255 if let Some(Expr::ID(front, _)) = arguments.front_mut() {
256 if front == "_" {
257 *front = id.clone();
258 } else if front != id {
259 return Err(MacroErr {
260 pos: pattern.get_pos(),
261 msg: "invalid macro pattern",
262 });
263 }
264 }
265
266 let template = rule_it.next().unwrap().clone();
267
268 rules.push_back(MacroRule { pattern, template });
269 } else {
270 return Err(MacroErr {
271 pos: pattern.get_pos(),
272 msg: "invalid macro pattern",
273 });
274 };
275 }
276
277 if let Entry::Vacant(entry) = result.entry(id.clone()) {
278 entry.insert(rules);
279 } else {
280 return Err(MacroErr {
281 pos: e.get_pos(),
282 msg: "multiply defined",
283 });
284 }
285 }
286 }
287 }
288 }
289
290 Ok(result)
291}
292
293fn expand(pos: Pos, template: &Expr, ctx: &BTreeMap<String, LinkedList<Expr>>) -> LinkedList<Expr> {
294 match template {
295 Expr::ID(id, _) => {
296 if let Some(exprs) = ctx.get(id) {
297 let expr = exprs.front().unwrap();
298 let mut result = LinkedList::new();
299 result.push_back(expr.clone());
300 result
301 } else {
302 let mut result: LinkedList<Expr> = LinkedList::new();
303 result.push_back(template.clone());
304 result
305 }
306 }
307 Expr::Apply(templates, _) => {
308 let exprs = expand_list(pos, templates, ctx);
309 let mut result = LinkedList::new();
310
311 result.push_back(Expr::Apply(exprs, pos));
314 result
315 }
316 Expr::List(templates, _) => {
317 let exprs = expand_list(pos, templates, ctx);
318 let mut result = LinkedList::new();
319 result.push_back(Expr::List(exprs, pos));
320 result
321 }
322 Expr::Tuple(templates, _) => {
323 let exprs = expand_list(pos, templates, ctx);
324 let mut result = LinkedList::new();
325 result.push_back(Expr::Tuple(exprs, pos));
326 result
327 }
328 expr => {
329 let mut result = LinkedList::new();
330 result.push_back(expr.clone());
331 result
332 }
333 }
334}
335
336fn expand_list(
337 pos: Pos,
338 templates: &LinkedList<Expr>,
339 ctx: &BTreeMap<String, LinkedList<Expr>>,
340) -> LinkedList<Expr> {
341 let mut result = LinkedList::new();
342
343 let mut prev = None;
344
345 for template in templates {
346 if let Expr::ID(id, _) = template {
347 if id == "..." {
348 if let Some(p) = &prev {
349 if let Some(exprs) = ctx.get(p) {
350 let mut it = exprs.iter();
351 let _ = it.next();
352
353 for expr in it {
354 result.push_back(expr.clone());
355 }
356 } else {
357 prev = None;
358 }
359 } else {
360 prev = None;
361 }
362
363 continue;
364 } else {
365 prev = Some(id.clone());
366 }
367 } else {
368 prev = None;
369 }
370
371 let mut exprs = expand(pos, template, ctx);
372 result.append(&mut exprs);
373 }
374
375 result
376}