plotnik_lib/query/
recursion.rs

1//! Escape path analysis for recursive definitions.
2//!
3//! Detects patterns that can never match because they require
4//! infinitely nested structures (recursion with no escape path),
5//! or infinite runtime loops where the cursor never advances (left recursion).
6
7use indexmap::{IndexMap, IndexSet};
8use rowan::TextRange;
9
10use super::Query;
11use crate::diagnostics::DiagnosticKind;
12use crate::parser::{Def, Expr};
13
14impl Query<'_> {
15    pub(super) fn validate_recursion(&mut self) {
16        let sccs = SccFinder::find(self);
17
18        for scc in sccs {
19            self.validate_scc(scc);
20        }
21    }
22
23    fn validate_scc(&mut self, scc: Vec<String>) {
24        let scc_set: IndexSet<&str> = scc.iter().map(|s| s.as_str()).collect();
25
26        // 1. Check for infinite tree structure (Escape Analysis)
27        // A valid recursive definition must have a non-recursive path.
28        // If NO definition in the SCC has an escape path, the whole group is invalid.
29        let has_escape = scc.iter().any(|name| {
30            self.symbol_table
31                .get(name.as_str())
32                .map(|body| expr_has_escape(body, &scc_set))
33                .unwrap_or(true)
34        });
35
36        if !has_escape {
37            // Find a cycle to report. Any cycle within the SCC is an infinite recursion loop
38            // because there are no escape paths.
39            if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| {
40                q.find_ref_range(expr, target)
41            }) {
42                let chain = self.format_chain(raw_chain, false);
43                self.report_cycle(DiagnosticKind::RecursionNoEscape, &scc, chain);
44            }
45            return;
46        }
47
48        // 2. Check for infinite loops (Guarded Recursion Analysis)
49        // Even if there is an escape, every recursive cycle must consume input (be guarded).
50        // We look for a cycle composed entirely of unguarded references.
51        if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| {
52            q.find_unguarded_ref_range(expr, target)
53        }) {
54            let chain = self.format_chain(raw_chain, true);
55            self.report_cycle(DiagnosticKind::DirectRecursion, &scc, chain);
56        }
57    }
58
59    /// Finds a cycle within the given set of nodes (SCC).
60    /// `get_edge_location` returns the location of a reference from `expr` to `target`.
61    fn find_cycle(
62        &self,
63        nodes: &[String],
64        domain: &IndexSet<&str>,
65        get_edge_location: impl Fn(&Query, &Expr, &str) -> Option<TextRange>,
66    ) -> Option<Vec<(TextRange, String)>> {
67        let mut adj = IndexMap::new();
68        for name in nodes {
69            if let Some(body) = self.symbol_table.get(name.as_str()) {
70                let neighbors = domain
71                    .iter()
72                    .filter_map(|target| {
73                        get_edge_location(self, body, target)
74                            .map(|range| (target.to_string(), range))
75                    })
76                    .collect::<Vec<_>>();
77                adj.insert(name.clone(), neighbors);
78            }
79        }
80
81        CycleFinder::find(nodes, &adj)
82    }
83
84    fn format_chain(
85        &self,
86        chain: Vec<(TextRange, String)>,
87        is_unguarded: bool,
88    ) -> Vec<(TextRange, String)> {
89        if chain.len() == 1 {
90            let (range, target) = &chain[0];
91            let msg = if is_unguarded {
92                "references itself".to_string()
93            } else {
94                format!("{} references itself", target)
95            };
96            return vec![(*range, msg)];
97        }
98
99        let len = chain.len();
100        chain
101            .into_iter()
102            .enumerate()
103            .map(|(i, (range, target))| {
104                let msg = if i == len - 1 {
105                    format!("references {} (completing cycle)", target)
106                } else {
107                    format!("references {}", target)
108                };
109                (range, msg)
110            })
111            .collect()
112    }
113
114    fn report_cycle(
115        &mut self,
116        kind: DiagnosticKind,
117        scc: &[String],
118        chain: Vec<(TextRange, String)>,
119    ) {
120        let primary_loc = chain
121            .first()
122            .map(|(r, _)| *r)
123            .unwrap_or_else(|| TextRange::empty(0.into()));
124
125        let related_def = if scc.len() > 1 {
126            self.find_def_info_containing(scc, primary_loc)
127        } else {
128            None
129        };
130
131        let mut builder = self.recursion_diagnostics.report(kind, primary_loc);
132
133        for (range, msg) in chain {
134            builder = builder.related_to(msg, range);
135        }
136
137        if let Some((msg, range)) = related_def {
138            builder = builder.related_to(msg, range);
139        }
140
141        builder.emit();
142    }
143
144    fn find_def_info_containing(
145        &self,
146        scc: &[String],
147        range: TextRange,
148    ) -> Option<(String, TextRange)> {
149        scc.iter()
150            .find(|name| {
151                self.symbol_table
152                    .get(name.as_str())
153                    .map(|body| body.text_range().contains_range(range))
154                    .unwrap_or(false)
155            })
156            .and_then(|name| {
157                self.find_def_by_name(name).and_then(|def| {
158                    def.name()
159                        .map(|n| (format!("{} is defined here", name), n.text_range()))
160                })
161            })
162    }
163
164    fn find_def_by_name(&self, name: &str) -> Option<Def> {
165        self.ast
166            .defs()
167            .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false))
168    }
169
170    fn find_ref_range(&self, expr: &Expr, target: &str) -> Option<TextRange> {
171        find_ref_in_expr(expr, target)
172    }
173
174    fn find_unguarded_ref_range(&self, expr: &Expr, target: &str) -> Option<TextRange> {
175        find_unguarded_ref_in_expr(expr, target)
176    }
177}
178
179struct CycleFinder<'a> {
180    adj: &'a IndexMap<String, Vec<(String, TextRange)>>,
181    visited: IndexSet<String>,
182    on_path: IndexMap<String, usize>,
183    path: Vec<String>,
184    edges: Vec<TextRange>,
185}
186
187impl<'a> CycleFinder<'a> {
188    fn find(
189        nodes: &[String],
190        adj: &'a IndexMap<String, Vec<(String, TextRange)>>,
191    ) -> Option<Vec<(TextRange, String)>> {
192        let mut finder = Self {
193            adj,
194            visited: IndexSet::new(),
195            on_path: IndexMap::new(),
196            path: Vec::new(),
197            edges: Vec::new(),
198        };
199
200        for start in nodes {
201            if let Some(chain) = finder.dfs(start) {
202                return Some(chain);
203            }
204        }
205        None
206    }
207
208    fn dfs(&mut self, current: &String) -> Option<Vec<(TextRange, String)>> {
209        if self.on_path.contains_key(current) {
210            return None;
211        }
212
213        if self.visited.contains(current) {
214            return None;
215        }
216
217        self.visited.insert(current.clone());
218        self.on_path.insert(current.clone(), self.path.len());
219        self.path.push(current.clone());
220
221        if let Some(neighbors) = self.adj.get(current) {
222            for (target, range) in neighbors {
223                if let Some(&start_index) = self.on_path.get(target) {
224                    // Cycle detected!
225                    // Path: path[start_index] ... path[last] (current)
226                    // Edges: edges[start_index] ... edges[last-1]
227                    // Closing edge: range
228                    let mut chain = Vec::new();
229                    for i in start_index..self.path.len() - 1 {
230                        chain.push((self.edges[i], self.path[i + 1].clone()));
231                    }
232                    chain.push((*range, target.clone()));
233                    return Some(chain);
234                }
235
236                self.edges.push(*range);
237                if let Some(chain) = self.dfs(target) {
238                    return Some(chain);
239                }
240                self.edges.pop();
241            }
242        }
243
244        self.path.pop();
245        self.on_path.swap_remove(current);
246        None
247    }
248}
249
250struct SccFinder<'a, 'src> {
251    query: &'a Query<'src>,
252    index: usize,
253    stack: Vec<String>,
254    on_stack: IndexSet<String>,
255    indices: IndexMap<String, usize>,
256    lowlinks: IndexMap<String, usize>,
257    sccs: Vec<Vec<String>>,
258}
259
260impl<'a, 'src> SccFinder<'a, 'src> {
261    fn find(query: &'a Query<'src>) -> Vec<Vec<String>> {
262        let mut finder = Self {
263            query,
264            index: 0,
265            stack: Vec::new(),
266            on_stack: IndexSet::new(),
267            indices: IndexMap::new(),
268            lowlinks: IndexMap::new(),
269            sccs: Vec::new(),
270        };
271
272        for name in query.symbol_table.keys() {
273            if !finder.indices.contains_key(*name) {
274                finder.strongconnect(name);
275            }
276        }
277
278        finder
279            .sccs
280            .into_iter()
281            .filter(|scc| {
282                scc.len() > 1
283                    || query
284                        .symbol_table
285                        .get(scc[0].as_str())
286                        .map(|body| collect_refs(body).contains(scc[0].as_str()))
287                        .unwrap_or(false)
288            })
289            .collect()
290    }
291
292    fn strongconnect(&mut self, name: &str) {
293        self.indices.insert(name.to_string(), self.index);
294        self.lowlinks.insert(name.to_string(), self.index);
295        self.index += 1;
296        self.stack.push(name.to_string());
297        self.on_stack.insert(name.to_string());
298
299        if let Some(body) = self.query.symbol_table.get(name) {
300            let refs = collect_refs(body);
301            for ref_name in refs {
302                if !self.query.symbol_table.contains_key(ref_name.as_str()) {
303                    continue;
304                }
305
306                if !self.indices.contains_key(&ref_name) {
307                    self.strongconnect(&ref_name);
308                    let ref_lowlink = self.lowlinks[&ref_name];
309                    let my_lowlink = self.lowlinks.get_mut(name).unwrap();
310                    *my_lowlink = (*my_lowlink).min(ref_lowlink);
311                } else if self.on_stack.contains(&ref_name) {
312                    let ref_index = self.indices[&ref_name];
313                    let my_lowlink = self.lowlinks.get_mut(name).unwrap();
314                    *my_lowlink = (*my_lowlink).min(ref_index);
315                }
316            }
317        }
318
319        if self.lowlinks[name] == self.indices[name] {
320            let mut scc = Vec::new();
321            loop {
322                let w = self.stack.pop().unwrap();
323                self.on_stack.swap_remove(&w);
324                scc.push(w.clone());
325                if w == name {
326                    break;
327                }
328            }
329            self.sccs.push(scc);
330        }
331    }
332}
333
334fn expr_has_escape(expr: &Expr, scc: &IndexSet<&str>) -> bool {
335    match expr {
336        Expr::Ref(r) => {
337            let Some(name_token) = r.name() else {
338                return true;
339            };
340            !scc.contains(name_token.text())
341        }
342        Expr::NamedNode(node) => {
343            let children: Vec<_> = node.children().collect();
344            children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc))
345        }
346        Expr::AltExpr(_) => expr.children().iter().any(|c| expr_has_escape(c, scc)),
347        Expr::SeqExpr(_) => expr.children().iter().all(|c| expr_has_escape(c, scc)),
348        Expr::QuantifiedExpr(q) => {
349            if q.is_optional() {
350                return true;
351            }
352            q.inner()
353                .map(|inner| expr_has_escape(&inner, scc))
354                .unwrap_or(true)
355        }
356        Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
357            expr.children().iter().all(|c| expr_has_escape(c, scc))
358        }
359        Expr::AnonymousNode(_) => true,
360    }
361}
362
363fn expr_guarantees_consumption(expr: &Expr) -> bool {
364    match expr {
365        Expr::NamedNode(_) | Expr::AnonymousNode(_) => true,
366        Expr::Ref(_) => false,
367        Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption),
368        Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption),
369        Expr::QuantifiedExpr(q) => {
370            !q.is_optional()
371                && q.inner()
372                    .map(|i| expr_guarantees_consumption(&i))
373                    .unwrap_or(false)
374        }
375        Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
376            expr.children().iter().all(expr_guarantees_consumption)
377        }
378    }
379}
380
381fn collect_refs(expr: &Expr) -> IndexSet<String> {
382    let mut refs = IndexSet::new();
383    collect_refs_into(expr, &mut refs);
384    refs
385}
386
387fn collect_refs_into(expr: &Expr, refs: &mut IndexSet<String>) {
388    if let Expr::Ref(r) = expr
389        && let Some(name_token) = r.name()
390    {
391        refs.insert(name_token.text().to_string());
392    }
393
394    for child in expr.children() {
395        collect_refs_into(&child, refs);
396    }
397}
398
399fn find_ref_in_expr(expr: &Expr, target: &str) -> Option<TextRange> {
400    if let Expr::Ref(r) = expr {
401        let name_token = r.name()?;
402        if name_token.text() == target {
403            return Some(name_token.text_range());
404        }
405    }
406
407    expr.children()
408        .iter()
409        .find_map(|child| find_ref_in_expr(child, target))
410}
411
412fn find_unguarded_ref_in_expr(expr: &Expr, target: &str) -> Option<TextRange> {
413    match expr {
414        Expr::Ref(r) => r
415            .name()
416            .filter(|n| n.text() == target)
417            .map(|n| n.text_range()),
418        Expr::NamedNode(_) | Expr::AnonymousNode(_) => None,
419        Expr::AltExpr(_) => expr
420            .children()
421            .iter()
422            .find_map(|c| find_unguarded_ref_in_expr(c, target)),
423        Expr::SeqExpr(_) => {
424            for c in expr.children() {
425                if let Some(range) = find_unguarded_ref_in_expr(&c, target) {
426                    return Some(range);
427                }
428                if expr_guarantees_consumption(&c) {
429                    return None;
430                }
431            }
432            None
433        }
434        Expr::QuantifiedExpr(q) => q
435            .inner()
436            .and_then(|i| find_unguarded_ref_in_expr(&i, target)),
437        Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr
438            .children()
439            .iter()
440            .find_map(|c| find_unguarded_ref_in_expr(c, target)),
441    }
442}