plotnik_compiler/analyze/
recursion.rs

1//! Recursion validation for definitions.
2//!
3//! Validates that recursive definitions are well-formed:
4//! - Escapable: at least one non-recursive path exists
5//! - Guarded: every recursive cycle consumes input
6
7use indexmap::{IndexMap, IndexSet};
8use rowan::TextRange;
9
10use super::dependencies::{DependencyAnalysis, collect_refs};
11use super::symbol_table::SymbolTable;
12use super::visitor::{Visitor, walk_expr, walk_named_node};
13use crate::Diagnostics;
14use crate::diagnostics::DiagnosticKind;
15use crate::parser::{AnonymousNode, Def, Expr, NamedNode, Ref, Root, SeqExpr};
16use crate::query::SourceId;
17
18/// Validate recursion using the pre-computed dependency analysis.
19pub fn validate_recursion(
20    analysis: &DependencyAnalysis,
21    ast_map: &IndexMap<SourceId, Root>,
22    symbol_table: &SymbolTable,
23    diag: &mut Diagnostics,
24) {
25    let mut validator = RecursionValidator {
26        ast_map,
27        symbol_table,
28        diag,
29    };
30    validator.validate(&analysis.sccs);
31}
32
33struct RecursionValidator<'a, 'd> {
34    ast_map: &'a IndexMap<SourceId, Root>,
35    symbol_table: &'a SymbolTable,
36    diag: &'d mut Diagnostics,
37}
38
39impl<'a, 'd> RecursionValidator<'a, 'd> {
40    fn validate(&mut self, sccs: &[Vec<String>]) {
41        for scc in sccs {
42            self.validate_scc(scc);
43        }
44    }
45
46    fn validate_scc(&mut self, scc: &[String]) {
47        // Filter out trivial non-recursive components.
48        // A component is recursive if it has >1 node, or 1 node that references itself.
49        if scc.len() == 1 {
50            let name = &scc[0];
51            let body = self
52                .symbol_table
53                .get(name)
54                .expect("node in SCC must exist in symbol table");
55            if !collect_refs(body, self.symbol_table).contains(name.as_str()) {
56                return;
57            }
58        }
59
60        let scc_set: IndexSet<&str> = scc.iter().map(String::as_str).collect();
61
62        // 1. Check for infinite tree structure (Escape Analysis)
63        // A valid recursive definition must have a non-recursive path.
64        // If NO definition in the SCC has an escape path, the whole group is invalid.
65        let has_escape = scc
66            .iter()
67            .filter_map(|name| self.symbol_table.get(name))
68            .any(|body| expr_has_escape(body, &scc_set));
69
70        if !has_escape {
71            // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop
72            // because there are no escape paths.
73            if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| {
74                find_ref_range(expr, target)
75            }) {
76                let chain = self.format_chain(raw_chain, false);
77                self.report_cycle(DiagnosticKind::RecursionNoEscape, scc, chain);
78            }
79            return;
80        }
81
82        // 2. Check for infinite loops (Guarded Recursion Analysis)
83        // Even if there is an escape, every recursive cycle must consume input (be guarded).
84        // We look for a cycle composed entirely of unguarded references.
85        if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| {
86            find_unguarded_ref_range(expr, target)
87        }) {
88            let chain = self.format_chain(raw_chain, true);
89            self.report_cycle(DiagnosticKind::DirectRecursion, scc, chain);
90        }
91    }
92
93    /// Finds a cycle within the given set of nodes (SCC).
94    /// `get_edge_location` returns the location of a reference from `expr` to `target`.
95    fn find_cycle<'b>(
96        &self,
97        nodes: &'b [String],
98        domain: &IndexSet<&'b str>,
99        get_edge_location: impl Fn(&Self, SourceId, &Expr, &str) -> Option<TextRange>,
100    ) -> Option<Vec<(SourceId, TextRange, &'b str)>> {
101        let mut adj = IndexMap::new();
102        for name in nodes {
103            if let Some((source_id, body)) = self.symbol_table.get_full(name) {
104                let neighbors = domain
105                    .iter()
106                    .filter_map(|target| {
107                        get_edge_location(self, source_id, body, target)
108                            .map(|range| (*target, source_id, range))
109                    })
110                    .collect::<Vec<_>>();
111                adj.insert(name.as_str(), neighbors);
112            }
113        }
114
115        let node_strs: Vec<&str> = nodes.iter().map(String::as_str).collect();
116        CycleFinder::find(&node_strs, &adj)
117    }
118
119    fn format_chain(
120        &self,
121        raw_chain: Vec<(SourceId, TextRange, &str)>,
122        is_unguarded: bool,
123    ) -> Vec<(SourceId, TextRange, String)> {
124        if raw_chain.len() == 1 {
125            let (source_id, range, target) = &raw_chain[0];
126            let msg = if is_unguarded {
127                "references itself".to_string()
128            } else {
129                format!("{} references itself", target)
130            };
131            return vec![(*source_id, *range, msg)];
132        }
133
134        let len = raw_chain.len();
135        raw_chain
136            .into_iter()
137            .enumerate()
138            .map(|(i, (source_id, range, target))| {
139                let msg = if i == len - 1 {
140                    format!("references {} (completing cycle)", target)
141                } else {
142                    format!("references {}", target)
143                };
144                (source_id, range, msg)
145            })
146            .collect()
147    }
148
149    fn report_cycle(
150        &mut self,
151        kind: DiagnosticKind,
152        scc: &[String],
153        chain: Vec<(SourceId, TextRange, String)>,
154    ) {
155        let (primary_source, primary_loc) = chain
156            .first()
157            .map(|(s, r, _)| (*s, *r))
158            .unwrap_or_else(|| (SourceId::default(), TextRange::empty(0.into())));
159
160        let related_def = if scc.len() > 1 {
161            self.find_def_info_containing(scc, primary_loc)
162        } else {
163            None
164        };
165
166        let mut builder = self.diag.report(primary_source, kind, primary_loc);
167
168        for (source_id, range, msg) in chain {
169            builder = builder.related_to(source_id, range, msg);
170        }
171
172        if let Some((source_id, msg, range)) = related_def {
173            builder = builder.related_to(source_id, range, msg);
174        }
175
176        builder.emit();
177    }
178
179    fn find_def_info_containing(
180        &self,
181        scc: &[String],
182        range: TextRange,
183    ) -> Option<(SourceId, String, TextRange)> {
184        let name = scc.iter().find(|name| {
185            self.symbol_table
186                .get(name.as_str())
187                .is_some_and(|body| body.text_range().contains_range(range))
188        })?;
189        let (source_id, def) = self.find_def_by_name(name)?;
190        let n = def.name()?;
191        Some((
192            source_id,
193            format!("{} is defined here", name),
194            n.text_range(),
195        ))
196    }
197
198    fn find_def_by_name(&self, name: &str) -> Option<(SourceId, Def)> {
199        self.ast_map.iter().find_map(|(source_id, ast)| {
200            ast.defs()
201                .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false))
202                .map(|def| (*source_id, def))
203        })
204    }
205}
206
207struct CycleFinder<'a, 'q> {
208    adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>,
209    visited: IndexSet<&'q str>,
210    on_path: IndexMap<&'q str, usize>,
211    path: Vec<&'q str>,
212    edges: Vec<(SourceId, TextRange)>,
213}
214
215impl<'a, 'q> CycleFinder<'a, 'q> {
216    fn find(
217        nodes: &[&'q str],
218        adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>,
219    ) -> Option<Vec<(SourceId, TextRange, &'q str)>> {
220        let mut finder = Self {
221            adj,
222            visited: IndexSet::new(),
223            on_path: IndexMap::new(),
224            path: Vec::new(),
225            edges: Vec::new(),
226        };
227
228        for start in nodes {
229            if let Some(chain) = finder.dfs(start) {
230                return Some(chain);
231            }
232        }
233        None
234    }
235
236    fn dfs(&mut self, current: &'q str) -> Option<Vec<(SourceId, TextRange, &'q str)>> {
237        if self.on_path.contains_key(current) {
238            return None;
239        }
240
241        if self.visited.contains(current) {
242            return None;
243        }
244
245        self.visited.insert(current);
246        self.on_path.insert(current, self.path.len());
247        self.path.push(current);
248
249        if let Some(neighbors) = self.adj.get(current) {
250            for (target, source_id, range) in neighbors {
251                if let Some(&start_index) = self.on_path.get(target) {
252                    // Cycle detected!
253                    let mut chain = Vec::new();
254                    for i in start_index..self.path.len() - 1 {
255                        let (src, rng) = self.edges[i];
256                        chain.push((src, rng, self.path[i + 1]));
257                    }
258                    chain.push((*source_id, *range, *target));
259                    return Some(chain);
260                }
261
262                self.edges.push((*source_id, *range));
263                if let Some(chain) = self.dfs(target) {
264                    return Some(chain);
265                }
266                self.edges.pop();
267            }
268        }
269
270        self.path.pop();
271        self.on_path.swap_remove(current);
272        None
273    }
274}
275
276fn expr_has_escape(expr: &Expr, scc_names: &IndexSet<&str>) -> bool {
277    match expr {
278        Expr::Ref(r) => {
279            let Some(name_token) = r.name() else {
280                return true;
281            };
282            !scc_names.contains(name_token.text())
283        }
284        Expr::NamedNode(node) => {
285            let children: Vec<_> = node.children().collect();
286            children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc_names))
287        }
288        Expr::AltExpr(_) => expr
289            .children()
290            .iter()
291            .any(|c| expr_has_escape(c, scc_names)),
292        Expr::SeqExpr(_) => expr
293            .children()
294            .iter()
295            .all(|c| expr_has_escape(c, scc_names)),
296        Expr::QuantifiedExpr(q) => {
297            if q.is_optional() {
298                return true;
299            }
300            q.inner()
301                .map(|inner| expr_has_escape(&inner, scc_names))
302                .unwrap_or(true)
303        }
304        Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr
305            .children()
306            .iter()
307            .all(|c| expr_has_escape(c, scc_names)),
308        Expr::AnonymousNode(_) => true,
309    }
310}
311
312fn expr_guarantees_consumption(expr: &Expr) -> bool {
313    match expr {
314        Expr::NamedNode(_) | Expr::AnonymousNode(_) => true,
315        Expr::Ref(_) => false,
316        Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption),
317        Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption),
318        Expr::QuantifiedExpr(q) => {
319            !q.is_optional()
320                && q.inner()
321                    .map(|i| expr_guarantees_consumption(&i))
322                    .unwrap_or(false)
323        }
324        Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
325            expr.children().iter().all(expr_guarantees_consumption)
326        }
327    }
328}
329
330/// Whether to search for any reference or only unguarded ones.
331#[derive(Clone, Copy, PartialEq, Eq)]
332enum RefSearchMode {
333    /// Find any reference to the target.
334    Any,
335    /// Find only unguarded references (not inside a NamedNode/AnonymousNode).
336    Unguarded,
337}
338
339struct RefFinder<'a> {
340    target: &'a str,
341    found: Option<TextRange>,
342    mode: RefSearchMode,
343}
344
345impl Visitor for RefFinder<'_> {
346    fn visit_expr(&mut self, expr: &Expr) {
347        if self.found.is_some() {
348            return;
349        }
350        walk_expr(self, expr);
351    }
352
353    fn visit_named_node(&mut self, node: &NamedNode) {
354        if self.mode == RefSearchMode::Unguarded {
355            return; // Guarded: stop recursion
356        }
357        walk_named_node(self, node);
358    }
359
360    fn visit_anonymous_node(&mut self, _node: &AnonymousNode) {
361        // AnonymousNode has no child expressions, so nothing to walk.
362        // In Unguarded mode this also acts as a guard (stops recursion).
363    }
364
365    fn visit_ref(&mut self, r: &Ref) {
366        if self.found.is_some() {
367            return;
368        }
369        if let Some(name) = r.name()
370            && name.text() == self.target
371        {
372            self.found = Some(name.text_range());
373        }
374    }
375
376    fn visit_seq_expr(&mut self, seq: &SeqExpr) {
377        for child in seq.children() {
378            self.visit_expr(&child);
379            if self.found.is_some() {
380                return;
381            }
382            if self.mode == RefSearchMode::Unguarded && expr_guarantees_consumption(&child) {
383                return;
384            }
385        }
386    }
387}
388
389fn find_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
390    let mut visitor = RefFinder {
391        target,
392        found: None,
393        mode: RefSearchMode::Any,
394    };
395    visitor.visit_expr(expr);
396    visitor.found
397}
398
399fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
400    let mut visitor = RefFinder {
401        target,
402        found: None,
403        mode: RefSearchMode::Unguarded,
404    };
405    visitor.visit_expr(expr);
406    visitor.found
407}