Skip to main content

rbx_rsml/typechecker/
macro_check.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::{
4    lexer::Token,
5    parser::{AstErrors, Construct, Delimited, MacroBody, MacroBodyContent, Node, SelectorNode},
6    range_from_span::RangeFromSpan,
7};
8
9use crate::typechecker::{ReportTypeError, Typechecker, type_error::*};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum MacroReturnContext {
13    Construct,
14    Datatype,
15    Selector,
16}
17
18impl MacroReturnContext {
19    pub fn name(&self) -> &'static str {
20        match self {
21            Self::Construct => "Construct",
22            Self::Datatype => "Datatype",
23            Self::Selector => "Selector",
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct MacroDefinition<'a> {
30    pub arg_names: Vec<&'a str>,
31    pub body: Option<&'a MacroBodyContent<'a>>,
32    pub return_context: MacroReturnContext,
33}
34
35#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
36pub struct MacroKey<'a> {
37    pub name: &'a str,
38    pub arity: usize,
39}
40
41pub type MacroRegistry<'a> = HashMap<MacroKey<'a>, MacroDefinition<'a>>;
42
43pub fn collect_macro_def_arg_names<'a>(args: &Option<Delimited<'a>>) -> Vec<&'a str> {
44    let Some(args) = args else { return Vec::new() };
45    let Some(content) = &args.content else {
46        return Vec::new();
47    };
48    content
49        .iter()
50        .filter_map(|construct| {
51            if let Construct::Node { node } = construct {
52                if let Token::MacroArgIdentifier(Some(name)) = node.token.value() {
53                    return Some(*name);
54                }
55            }
56            None
57        })
58        .collect()
59}
60
61pub(super) fn count_macro_call_args(body: &Option<Delimited>) -> usize {
62    let Some(body) = body else { return 0 };
63    let Some(content) = &body.content else {
64        return 0;
65    };
66    if content.is_empty() {
67        return 0;
68    }
69    content
70        .iter()
71        .filter(|construct| {
72            matches!(
73                construct,
74                Construct::Node { node } if matches!(node.token.value(), Token::Comma)
75            )
76        })
77        .count()
78        + 1
79}
80
81pub fn macro_return_context(return_type: &Option<(Node, Option<Node>)>) -> MacroReturnContext {
82    if let Some((_, Some(ident))) = return_type {
83        match ident.token.value() {
84            Token::Identifier("Datatype") => MacroReturnContext::Datatype,
85            Token::Identifier("Selector") => MacroReturnContext::Selector,
86            _ => MacroReturnContext::Construct,
87        }
88    } else {
89        MacroReturnContext::Construct
90    }
91}
92
93impl<'a> Typechecker<'a> {
94    pub(super) fn typecheck_macro(
95        &self,
96        args: &Option<Delimited<'a>>,
97        body: &Option<MacroBody<'a>>,
98        ast_errors: &mut AstErrors,
99    ) {
100        let macro_args = collect_macro_arg_names(args);
101        let Some(body) = body else { return };
102
103        match &body.content {
104            MacroBodyContent::Construct(Some(content)) => {
105                self.typecheck_macro_body_content(content, &macro_args, ast_errors);
106            }
107            MacroBodyContent::Datatype(Some(content)) => {
108                self.validate_macro_arg_refs(content, Some(&macro_args), ast_errors);
109                self.validate_annotation(content, ast_errors);
110                if let Construct::MacroCall { name, body, .. } = content.as_ref() {
111                    self.validate_macro_call(name, body, MacroReturnContext::Datatype, ast_errors);
112                }
113            }
114            MacroBodyContent::Selector(Some(selectors)) => {
115                for selector in selectors {
116                    if let SelectorNode::MacroCall { name, body } = selector {
117                        self.validate_macro_call(
118                            name,
119                            body,
120                            MacroReturnContext::Selector,
121                            ast_errors,
122                        );
123                    }
124                }
125            }
126            _ => {}
127        }
128    }
129
130    fn typecheck_macro_body_content(
131        &self,
132        content: &Vec<Construct<'a>>,
133        macro_args: &HashSet<&str>,
134        ast_errors: &mut AstErrors,
135    ) {
136        for construct in content {
137            match construct {
138                Construct::Assignment { right, .. } => {
139                    if let Some(right) = right {
140                        self.validate_macro_arg_refs(right, Some(macro_args), ast_errors);
141                        self.validate_annotation(right, ast_errors);
142                        if let Construct::MacroCall { name, body, .. } = right.as_ref() {
143                            self.validate_macro_call(
144                                name,
145                                body,
146                                MacroReturnContext::Datatype,
147                                ast_errors,
148                            );
149                        }
150                    }
151                }
152
153                Construct::Rule { body, .. } => {
154                    if let Some(body) = body {
155                        if let Some(content) = &body.content {
156                            self.typecheck_macro_body_content(content, macro_args, ast_errors);
157                        }
158                    }
159                }
160
161                Construct::Tween { body, .. } => {
162                    if let Some(body) = body {
163                        self.validate_macro_arg_refs(body, Some(macro_args), ast_errors);
164                    }
165                }
166
167                Construct::MacroCall { name, body, .. } => {
168                    self.validate_macro_call(name, body, MacroReturnContext::Construct, ast_errors);
169                }
170
171                Construct::Macro { .. } => {
172                    ast_errors.report(
173                        TypeError::NotAllowedInContext {
174                            name: construct.name_plural(),
175                            context: "other macros",
176                        },
177                        self.range_from_span(construct.span()),
178                    );
179                }
180
181                Construct::Derive { .. } => {
182                    ast_errors.report(
183                        TypeError::NotAllowedInContext {
184                            name: construct.name_plural(),
185                            context: "non-global scopes",
186                        },
187                        self.range_from_span(construct.span()),
188                    );
189                }
190
191                _ => (),
192            }
193        }
194    }
195
196    pub(super) fn validate_macro_call(
197        &self,
198        name: &Node<'a>,
199        body: &Option<Delimited<'a>>,
200        expected_context: MacroReturnContext,
201        ast_errors: &mut AstErrors,
202    ) {
203        let Token::MacroCallIdentifier(Some(macro_name)) = name.token.value() else {
204            return;
205        };
206
207        let local_arities = self
208            .macro_registry
209            .keys()
210            .filter(|k| k.name == *macro_name)
211            .map(|k| k.arity);
212        let builtin_arities = crate::builtins::BUILTINS
213            .registry
214            .keys()
215            .filter(|k| k.name == *macro_name)
216            .map(|k| k.arity);
217
218        let mut expected_counts: Vec<usize> = local_arities.chain(builtin_arities).collect();
219
220        if expected_counts.is_empty() {
221            ast_errors.report(
222                TypeError::UndefinedMacro { name: macro_name },
223                self.range_from_span(name.token.span()),
224            );
225            return;
226        }
227
228        let call_arg_count = count_macro_call_args(body);
229        let key = MacroKey {
230            name: *macro_name,
231            arity: call_arg_count,
232        };
233
234        let matching_context = self
235            .macro_registry
236            .get(&key)
237            .map(|def| def.return_context)
238            .or_else(|| {
239                crate::builtins::BUILTINS
240                    .registry
241                    .get(&key)
242                    .map(|def| def.return_context)
243            });
244
245        let Some(matching_context) = matching_context else {
246            expected_counts.sort();
247            expected_counts.dedup();
248
249            ast_errors.report(
250                TypeError::WrongMacroArgCount {
251                    name: macro_name,
252                    expected: expected_counts,
253                    got: call_arg_count,
254                },
255                self.range_from_span(name.token.span()),
256            );
257            return;
258        };
259
260        if matching_context != expected_context {
261            ast_errors.report(
262                TypeError::WrongMacroContext {
263                    name: macro_name,
264                    expected: matching_context.name(),
265                    got: expected_context.name(),
266                },
267                self.range_from_span(name.token.span()),
268            );
269        }
270    }
271
272    pub(super) fn validate_macro_arg_refs(
273        &self,
274        construct: &Construct<'a>,
275        macro_args: Option<&HashSet<&str>>,
276        ast_errors: &mut AstErrors,
277    ) {
278        match construct {
279            Construct::Node { node } => {
280                if let Token::MacroArgIdentifier(name) = node.token.value() {
281                    let is_valid = match macro_args {
282                        Some(args) => name.is_some_and(|arg_name| args.contains(arg_name)),
283                        None => false,
284                    };
285
286                    if !is_valid {
287                        if let Some(arg_name) = name {
288                            ast_errors.report(
289                                TypeError::InvalidMacroArg {
290                                    msg: &format!(
291                                        "No macro argument named \"{}\" exists.",
292                                        arg_name
293                                    ),
294                                },
295                                self.range_from_span(node.token.span()),
296                            );
297                        } else {
298                            ast_errors.report(
299                                TypeError::InvalidMacroArg {
300                                    msg: "Missing macro argument name.",
301                                },
302                                self.range_from_span(node.token.span()),
303                            );
304                        }
305                    }
306                }
307            }
308
309            Construct::MathOperation { left, right, .. } => {
310                self.validate_macro_arg_refs(left, macro_args, ast_errors);
311                if let Some(right) = right {
312                    self.validate_macro_arg_refs(right, macro_args, ast_errors);
313                }
314            }
315
316            Construct::UnaryMinus { operand, .. } => {
317                self.validate_macro_arg_refs(operand, macro_args, ast_errors);
318            }
319
320            Construct::Table { body } => {
321                let Some(content) = &body.content else { return };
322                for item in content {
323                    self.validate_macro_arg_refs(item, macro_args, ast_errors);
324                }
325            }
326
327            Construct::AnnotatedTable { body, .. } => {
328                let Some(body) = body else { return };
329                let Some(content) = &body.content else { return };
330                for item in content {
331                    self.validate_macro_arg_refs(item, macro_args, ast_errors);
332                }
333            }
334
335            _ => (),
336        }
337    }
338
339    fn range_from_span(&self, span: (usize, usize)) -> crate::types::Range {
340        crate::types::Range::from_span(&self.parsed.rope, span)
341    }
342}
343
344fn collect_macro_arg_names<'a>(args: &Option<Delimited<'a>>) -> HashSet<&'a str> {
345    let mut names = HashSet::new();
346    if let Some(args) = args {
347        if let Some(content) = &args.content {
348            for construct in content {
349                if let Construct::Node { node } = construct {
350                    if let Token::MacroArgIdentifier(Some(name)) = node.token.value() {
351                        names.insert(*name);
352                    }
353                }
354            }
355        }
356    }
357    names
358}
359
360fn for_each_macro_call_in_body<'a, F>(body: &MacroBodyContent<'a>, cb: &mut F)
361where
362    F: FnMut(&'a str, usize, (usize, usize)),
363{
364    match body {
365        MacroBodyContent::Construct(Some(content)) => {
366            for construct in content {
367                visit_construct_for_calls(construct, cb);
368            }
369        }
370
371        MacroBodyContent::Datatype(Some(content)) => {
372            visit_construct_for_calls(content, cb);
373        }
374
375        MacroBodyContent::Selector(Some(selectors)) => {
376            visit_selectors_for_calls(selectors, cb);
377        }
378
379        _ => {}
380    }
381}
382
383fn visit_construct_for_calls<'a, F>(construct: &Construct<'a>, cb: &mut F)
384where
385    F: FnMut(&'a str, usize, (usize, usize)),
386{
387    match construct {
388        Construct::MacroCall { name, body, .. } => {
389            if let Token::MacroCallIdentifier(Some(n)) = name.token.value() {
390                cb(*n, count_macro_call_args(body), name.token.span());
391            }
392        }
393
394        Construct::Assignment { right, .. } => {
395            if let Some(right) = right {
396                visit_construct_for_calls(right, cb);
397            }
398        }
399
400        Construct::Rule { selectors, body } => {
401            if let Some(selectors) = selectors {
402                visit_selectors_for_calls(selectors, cb);
403            }
404
405            if let Some(body) = body {
406                if let Some(content) = &body.content {
407                    for inner in content {
408                        visit_construct_for_calls(inner, cb);
409                    }
410                }
411            }
412        }
413
414        _ => {}
415    }
416}
417
418fn visit_selectors_for_calls<'a, F>(selectors: &[SelectorNode<'a>], cb: &mut F)
419where
420    F: FnMut(&'a str, usize, (usize, usize)),
421{
422    for selector in selectors {
423        if let SelectorNode::MacroCall { name, body } = selector {
424            if let Token::MacroCallIdentifier(Some(n)) = name.token.value() {
425                cb(*n, count_macro_call_args(body), name.token.span());
426            }
427        }
428    }
429}
430
431enum DfsColor {
432    Gray,
433    Black,
434}
435
436impl<'a> Typechecker<'a> {
437    pub(super) fn detect_recursive_macro_calls(&self, ast_errors: &mut AstErrors) {
438        let mut color: HashMap<MacroKey<'a>, DfsColor> = HashMap::new();
439
440        let roots: Vec<MacroKey<'a>> = self.macro_registry.keys().copied().collect();
441        for root in roots {
442            if color.contains_key(&root) {
443                continue;
444            }
445
446            self.dfs_macro_cycle(root, &mut color, ast_errors);
447        }
448    }
449
450    fn dfs_macro_cycle(
451        &self,
452        key: MacroKey<'a>,
453        color: &mut HashMap<MacroKey<'a>, DfsColor>,
454        ast_errors: &mut AstErrors,
455    ) {
456        color.insert(key, DfsColor::Gray);
457
458        let Some(def) = self.macro_registry.get(&key) else {
459            color.insert(key, DfsColor::Black);
460            return;
461        };
462        let Some(body) = def.body else {
463            color.insert(key, DfsColor::Black);
464            return;
465        };
466
467        let mut calls: Vec<(&'a str, usize, (usize, usize))> = Vec::new();
468        for_each_macro_call_in_body(body, &mut |name, arity, span| {
469            calls.push((name, arity, span));
470        });
471
472        for (name, arity, span) in calls {
473            let callee = MacroKey { name, arity };
474
475            if !self.macro_registry.contains_key(&callee) {
476                continue;
477            }
478
479            match color.get(&callee) {
480                Some(DfsColor::Gray) => {
481                    ast_errors.report(TypeError::RecursiveMacroCall, self.range_from_span(span))
482                }
483                Some(DfsColor::Black) => {}
484                None => self.dfs_macro_cycle(callee, color, ast_errors),
485            }
486        }
487
488        color.insert(key, DfsColor::Black);
489    }
490}