1use std::collections::{HashMap, HashSet};
2
3use crate::{
4 lexer::Token,
5 parser::{AstErrors, Construct, Delimited, MacroBody, MacroBodyContent, Node, SelectorNode},
6 range_from_span::RangeFromSpan,
7};
8
9use crate::typechecker::{ReportTypeError, Typechecker, type_error::*};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MacroReturnContext {
13 Construct,
14 Datatype,
15 Selector,
16}
17
18impl MacroReturnContext {
19 pub fn name(&self) -> &'static str {
20 match self {
21 Self::Construct => "Construct",
22 Self::Datatype => "Datatype",
23 Self::Selector => "Selector",
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct MacroDefinition<'a> {
30 pub arg_names: Vec<&'a str>,
31 pub body: Option<&'a MacroBodyContent<'a>>,
32 pub return_context: MacroReturnContext,
33}
34
35#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
36pub struct MacroKey<'a> {
37 pub name: &'a str,
38 pub arity: usize,
39}
40
41pub type MacroRegistry<'a> = HashMap<MacroKey<'a>, MacroDefinition<'a>>;
42
43pub fn collect_macro_def_arg_names<'a>(args: &Option<Delimited<'a>>) -> Vec<&'a str> {
44 let Some(args) = args else { return Vec::new() };
45 let Some(content) = &args.content else {
46 return Vec::new();
47 };
48 content
49 .iter()
50 .filter_map(|construct| {
51 if let Construct::Node { node } = construct {
52 if let Token::MacroArgIdentifier(Some(name)) = node.token.value() {
53 return Some(*name);
54 }
55 }
56 None
57 })
58 .collect()
59}
60
61pub(super) fn count_macro_call_args(body: &Option<Delimited>) -> usize {
62 let Some(body) = body else { return 0 };
63 let Some(content) = &body.content else {
64 return 0;
65 };
66 if content.is_empty() {
67 return 0;
68 }
69 content
70 .iter()
71 .filter(|construct| {
72 matches!(
73 construct,
74 Construct::Node { node } if matches!(node.token.value(), Token::Comma)
75 )
76 })
77 .count()
78 + 1
79}
80
81pub fn macro_return_context(return_type: &Option<(Node, Option<Node>)>) -> MacroReturnContext {
82 if let Some((_, Some(ident))) = return_type {
83 match ident.token.value() {
84 Token::Identifier("Datatype") => MacroReturnContext::Datatype,
85 Token::Identifier("Selector") => MacroReturnContext::Selector,
86 _ => MacroReturnContext::Construct,
87 }
88 } else {
89 MacroReturnContext::Construct
90 }
91}
92
93impl<'a> Typechecker<'a> {
94 pub(super) fn typecheck_macro(
95 &self,
96 args: &Option<Delimited<'a>>,
97 body: &Option<MacroBody<'a>>,
98 ast_errors: &mut AstErrors,
99 ) {
100 let macro_args = collect_macro_arg_names(args);
101 let Some(body) = body else { return };
102
103 match &body.content {
104 MacroBodyContent::Construct(Some(content)) => {
105 self.typecheck_macro_body_content(content, ¯o_args, ast_errors);
106 }
107 MacroBodyContent::Datatype(Some(content)) => {
108 self.validate_macro_arg_refs(content, Some(¯o_args), ast_errors);
109 self.validate_annotation(content, ast_errors);
110 if let Construct::MacroCall { name, body, .. } = content.as_ref() {
111 self.validate_macro_call(name, body, MacroReturnContext::Datatype, ast_errors);
112 }
113 }
114 MacroBodyContent::Selector(Some(selectors)) => {
115 for selector in selectors {
116 if let SelectorNode::MacroCall { name, body } = selector {
117 self.validate_macro_call(
118 name,
119 body,
120 MacroReturnContext::Selector,
121 ast_errors,
122 );
123 }
124 }
125 }
126 _ => {}
127 }
128 }
129
130 fn typecheck_macro_body_content(
131 &self,
132 content: &Vec<Construct<'a>>,
133 macro_args: &HashSet<&str>,
134 ast_errors: &mut AstErrors,
135 ) {
136 for construct in content {
137 match construct {
138 Construct::Assignment { right, .. } => {
139 if let Some(right) = right {
140 self.validate_macro_arg_refs(right, Some(macro_args), ast_errors);
141 self.validate_annotation(right, ast_errors);
142 if let Construct::MacroCall { name, body, .. } = right.as_ref() {
143 self.validate_macro_call(
144 name,
145 body,
146 MacroReturnContext::Datatype,
147 ast_errors,
148 );
149 }
150 }
151 }
152
153 Construct::Rule { body, .. } => {
154 if let Some(body) = body {
155 if let Some(content) = &body.content {
156 self.typecheck_macro_body_content(content, macro_args, ast_errors);
157 }
158 }
159 }
160
161 Construct::Tween { body, .. } => {
162 if let Some(body) = body {
163 self.validate_macro_arg_refs(body, Some(macro_args), ast_errors);
164 }
165 }
166
167 Construct::MacroCall { name, body, .. } => {
168 self.validate_macro_call(name, body, MacroReturnContext::Construct, ast_errors);
169 }
170
171 Construct::Macro { .. } => {
172 ast_errors.report(
173 TypeError::NotAllowedInContext {
174 name: construct.name_plural(),
175 context: "other macros",
176 },
177 self.range_from_span(construct.span()),
178 );
179 }
180
181 Construct::Derive { .. } => {
182 ast_errors.report(
183 TypeError::NotAllowedInContext {
184 name: construct.name_plural(),
185 context: "non-global scopes",
186 },
187 self.range_from_span(construct.span()),
188 );
189 }
190
191 _ => (),
192 }
193 }
194 }
195
196 pub(super) fn validate_macro_call(
197 &self,
198 name: &Node<'a>,
199 body: &Option<Delimited<'a>>,
200 expected_context: MacroReturnContext,
201 ast_errors: &mut AstErrors,
202 ) {
203 let Token::MacroCallIdentifier(Some(macro_name)) = name.token.value() else {
204 return;
205 };
206
207 let local_arities = self
208 .macro_registry
209 .keys()
210 .filter(|k| k.name == *macro_name)
211 .map(|k| k.arity);
212 let builtin_arities = crate::builtins::BUILTINS
213 .registry
214 .keys()
215 .filter(|k| k.name == *macro_name)
216 .map(|k| k.arity);
217
218 let mut expected_counts: Vec<usize> = local_arities.chain(builtin_arities).collect();
219
220 if expected_counts.is_empty() {
221 ast_errors.report(
222 TypeError::UndefinedMacro { name: macro_name },
223 self.range_from_span(name.token.span()),
224 );
225 return;
226 }
227
228 let call_arg_count = count_macro_call_args(body);
229 let key = MacroKey {
230 name: *macro_name,
231 arity: call_arg_count,
232 };
233
234 let matching_context = self
235 .macro_registry
236 .get(&key)
237 .map(|def| def.return_context)
238 .or_else(|| {
239 crate::builtins::BUILTINS
240 .registry
241 .get(&key)
242 .map(|def| def.return_context)
243 });
244
245 let Some(matching_context) = matching_context else {
246 expected_counts.sort();
247 expected_counts.dedup();
248
249 ast_errors.report(
250 TypeError::WrongMacroArgCount {
251 name: macro_name,
252 expected: expected_counts,
253 got: call_arg_count,
254 },
255 self.range_from_span(name.token.span()),
256 );
257 return;
258 };
259
260 if matching_context != expected_context {
261 ast_errors.report(
262 TypeError::WrongMacroContext {
263 name: macro_name,
264 expected: matching_context.name(),
265 got: expected_context.name(),
266 },
267 self.range_from_span(name.token.span()),
268 );
269 }
270 }
271
272 pub(super) fn validate_macro_arg_refs(
273 &self,
274 construct: &Construct<'a>,
275 macro_args: Option<&HashSet<&str>>,
276 ast_errors: &mut AstErrors,
277 ) {
278 match construct {
279 Construct::Node { node } => {
280 if let Token::MacroArgIdentifier(name) = node.token.value() {
281 let is_valid = match macro_args {
282 Some(args) => name.is_some_and(|arg_name| args.contains(arg_name)),
283 None => false,
284 };
285
286 if !is_valid {
287 if let Some(arg_name) = name {
288 ast_errors.report(
289 TypeError::InvalidMacroArg {
290 msg: &format!(
291 "No macro argument named \"{}\" exists.",
292 arg_name
293 ),
294 },
295 self.range_from_span(node.token.span()),
296 );
297 } else {
298 ast_errors.report(
299 TypeError::InvalidMacroArg {
300 msg: "Missing macro argument name.",
301 },
302 self.range_from_span(node.token.span()),
303 );
304 }
305 }
306 }
307 }
308
309 Construct::MathOperation { left, right, .. } => {
310 self.validate_macro_arg_refs(left, macro_args, ast_errors);
311 if let Some(right) = right {
312 self.validate_macro_arg_refs(right, macro_args, ast_errors);
313 }
314 }
315
316 Construct::UnaryMinus { operand, .. } => {
317 self.validate_macro_arg_refs(operand, macro_args, ast_errors);
318 }
319
320 Construct::Table { body } => {
321 let Some(content) = &body.content else { return };
322 for item in content {
323 self.validate_macro_arg_refs(item, macro_args, ast_errors);
324 }
325 }
326
327 Construct::AnnotatedTable { body, .. } => {
328 let Some(body) = body else { return };
329 let Some(content) = &body.content else { return };
330 for item in content {
331 self.validate_macro_arg_refs(item, macro_args, ast_errors);
332 }
333 }
334
335 _ => (),
336 }
337 }
338
339 fn range_from_span(&self, span: (usize, usize)) -> crate::types::Range {
340 crate::types::Range::from_span(&self.parsed.rope, span)
341 }
342}
343
344fn collect_macro_arg_names<'a>(args: &Option<Delimited<'a>>) -> HashSet<&'a str> {
345 let mut names = HashSet::new();
346 if let Some(args) = args {
347 if let Some(content) = &args.content {
348 for construct in content {
349 if let Construct::Node { node } = construct {
350 if let Token::MacroArgIdentifier(Some(name)) = node.token.value() {
351 names.insert(*name);
352 }
353 }
354 }
355 }
356 }
357 names
358}
359
360fn for_each_macro_call_in_body<'a, F>(body: &MacroBodyContent<'a>, cb: &mut F)
361where
362 F: FnMut(&'a str, usize, (usize, usize)),
363{
364 match body {
365 MacroBodyContent::Construct(Some(content)) => {
366 for construct in content {
367 visit_construct_for_calls(construct, cb);
368 }
369 }
370
371 MacroBodyContent::Datatype(Some(content)) => {
372 visit_construct_for_calls(content, cb);
373 }
374
375 MacroBodyContent::Selector(Some(selectors)) => {
376 visit_selectors_for_calls(selectors, cb);
377 }
378
379 _ => {}
380 }
381}
382
383fn visit_construct_for_calls<'a, F>(construct: &Construct<'a>, cb: &mut F)
384where
385 F: FnMut(&'a str, usize, (usize, usize)),
386{
387 match construct {
388 Construct::MacroCall { name, body, .. } => {
389 if let Token::MacroCallIdentifier(Some(n)) = name.token.value() {
390 cb(*n, count_macro_call_args(body), name.token.span());
391 }
392 }
393
394 Construct::Assignment { right, .. } => {
395 if let Some(right) = right {
396 visit_construct_for_calls(right, cb);
397 }
398 }
399
400 Construct::Rule { selectors, body } => {
401 if let Some(selectors) = selectors {
402 visit_selectors_for_calls(selectors, cb);
403 }
404
405 if let Some(body) = body {
406 if let Some(content) = &body.content {
407 for inner in content {
408 visit_construct_for_calls(inner, cb);
409 }
410 }
411 }
412 }
413
414 _ => {}
415 }
416}
417
418fn visit_selectors_for_calls<'a, F>(selectors: &[SelectorNode<'a>], cb: &mut F)
419where
420 F: FnMut(&'a str, usize, (usize, usize)),
421{
422 for selector in selectors {
423 if let SelectorNode::MacroCall { name, body } = selector {
424 if let Token::MacroCallIdentifier(Some(n)) = name.token.value() {
425 cb(*n, count_macro_call_args(body), name.token.span());
426 }
427 }
428 }
429}
430
431enum DfsColor {
432 Gray,
433 Black,
434}
435
436impl<'a> Typechecker<'a> {
437 pub(super) fn detect_recursive_macro_calls(&self, ast_errors: &mut AstErrors) {
438 let mut color: HashMap<MacroKey<'a>, DfsColor> = HashMap::new();
439
440 let roots: Vec<MacroKey<'a>> = self.macro_registry.keys().copied().collect();
441 for root in roots {
442 if color.contains_key(&root) {
443 continue;
444 }
445
446 self.dfs_macro_cycle(root, &mut color, ast_errors);
447 }
448 }
449
450 fn dfs_macro_cycle(
451 &self,
452 key: MacroKey<'a>,
453 color: &mut HashMap<MacroKey<'a>, DfsColor>,
454 ast_errors: &mut AstErrors,
455 ) {
456 color.insert(key, DfsColor::Gray);
457
458 let Some(def) = self.macro_registry.get(&key) else {
459 color.insert(key, DfsColor::Black);
460 return;
461 };
462 let Some(body) = def.body else {
463 color.insert(key, DfsColor::Black);
464 return;
465 };
466
467 let mut calls: Vec<(&'a str, usize, (usize, usize))> = Vec::new();
468 for_each_macro_call_in_body(body, &mut |name, arity, span| {
469 calls.push((name, arity, span));
470 });
471
472 for (name, arity, span) in calls {
473 let callee = MacroKey { name, arity };
474
475 if !self.macro_registry.contains_key(&callee) {
476 continue;
477 }
478
479 match color.get(&callee) {
480 Some(DfsColor::Gray) => {
481 ast_errors.report(TypeError::RecursiveMacroCall, self.range_from_span(span))
482 }
483 Some(DfsColor::Black) => {}
484 None => self.dfs_macro_cycle(callee, color, ast_errors),
485 }
486 }
487
488 color.insert(key, DfsColor::Black);
489 }
490}