1use indexmap::{IndexMap, IndexSet};
8use rowan::TextRange;
9
10use super::dependencies::{DependencyAnalysis, collect_refs};
11use super::symbol_table::SymbolTable;
12use super::visitor::{Visitor, walk_expr, walk_named_node};
13use crate::Diagnostics;
14use crate::diagnostics::DiagnosticKind;
15use crate::parser::{AnonymousNode, Def, Expr, NamedNode, Ref, Root, SeqExpr};
16use crate::query::SourceId;
17
18pub fn validate_recursion(
20 analysis: &DependencyAnalysis,
21 ast_map: &IndexMap<SourceId, Root>,
22 symbol_table: &SymbolTable,
23 diag: &mut Diagnostics,
24) {
25 let mut validator = RecursionValidator {
26 ast_map,
27 symbol_table,
28 diag,
29 };
30 validator.validate(&analysis.sccs);
31}
32
33struct RecursionValidator<'a, 'd> {
34 ast_map: &'a IndexMap<SourceId, Root>,
35 symbol_table: &'a SymbolTable,
36 diag: &'d mut Diagnostics,
37}
38
39impl<'a, 'd> RecursionValidator<'a, 'd> {
40 fn validate(&mut self, sccs: &[Vec<String>]) {
41 for scc in sccs {
42 self.validate_scc(scc);
43 }
44 }
45
46 fn validate_scc(&mut self, scc: &[String]) {
47 if scc.len() == 1 {
50 let name = &scc[0];
51 let body = self
52 .symbol_table
53 .get(name)
54 .expect("node in SCC must exist in symbol table");
55 if !collect_refs(body, self.symbol_table).contains(name.as_str()) {
56 return;
57 }
58 }
59
60 let scc_set: IndexSet<&str> = scc.iter().map(String::as_str).collect();
61
62 let has_escape = scc
66 .iter()
67 .filter_map(|name| self.symbol_table.get(name))
68 .any(|body| expr_has_escape(body, &scc_set));
69
70 if !has_escape {
71 if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| {
74 find_ref_range(expr, target)
75 }) {
76 let chain = self.format_chain(raw_chain, false);
77 self.report_cycle(DiagnosticKind::RecursionNoEscape, scc, chain);
78 }
79 return;
80 }
81
82 if let Some(raw_chain) = self.find_cycle(scc, &scc_set, |_, _, expr, target| {
86 find_unguarded_ref_range(expr, target)
87 }) {
88 let chain = self.format_chain(raw_chain, true);
89 self.report_cycle(DiagnosticKind::DirectRecursion, scc, chain);
90 }
91 }
92
93 fn find_cycle<'b>(
96 &self,
97 nodes: &'b [String],
98 domain: &IndexSet<&'b str>,
99 get_edge_location: impl Fn(&Self, SourceId, &Expr, &str) -> Option<TextRange>,
100 ) -> Option<Vec<(SourceId, TextRange, &'b str)>> {
101 let mut adj = IndexMap::new();
102 for name in nodes {
103 if let Some((source_id, body)) = self.symbol_table.get_full(name) {
104 let neighbors = domain
105 .iter()
106 .filter_map(|target| {
107 get_edge_location(self, source_id, body, target)
108 .map(|range| (*target, source_id, range))
109 })
110 .collect::<Vec<_>>();
111 adj.insert(name.as_str(), neighbors);
112 }
113 }
114
115 let node_strs: Vec<&str> = nodes.iter().map(String::as_str).collect();
116 CycleFinder::find(&node_strs, &adj)
117 }
118
119 fn format_chain(
120 &self,
121 raw_chain: Vec<(SourceId, TextRange, &str)>,
122 is_unguarded: bool,
123 ) -> Vec<(SourceId, TextRange, String)> {
124 if raw_chain.len() == 1 {
125 let (source_id, range, target) = &raw_chain[0];
126 let msg = if is_unguarded {
127 "references itself".to_string()
128 } else {
129 format!("{} references itself", target)
130 };
131 return vec![(*source_id, *range, msg)];
132 }
133
134 let len = raw_chain.len();
135 raw_chain
136 .into_iter()
137 .enumerate()
138 .map(|(i, (source_id, range, target))| {
139 let msg = if i == len - 1 {
140 format!("references {} (completing cycle)", target)
141 } else {
142 format!("references {}", target)
143 };
144 (source_id, range, msg)
145 })
146 .collect()
147 }
148
149 fn report_cycle(
150 &mut self,
151 kind: DiagnosticKind,
152 scc: &[String],
153 chain: Vec<(SourceId, TextRange, String)>,
154 ) {
155 let (primary_source, primary_loc) = chain
156 .first()
157 .map(|(s, r, _)| (*s, *r))
158 .unwrap_or_else(|| (SourceId::default(), TextRange::empty(0.into())));
159
160 let related_def = if scc.len() > 1 {
161 self.find_def_info_containing(scc, primary_loc)
162 } else {
163 None
164 };
165
166 let mut builder = self.diag.report(primary_source, kind, primary_loc);
167
168 for (source_id, range, msg) in chain {
169 builder = builder.related_to(source_id, range, msg);
170 }
171
172 if let Some((source_id, msg, range)) = related_def {
173 builder = builder.related_to(source_id, range, msg);
174 }
175
176 builder.emit();
177 }
178
179 fn find_def_info_containing(
180 &self,
181 scc: &[String],
182 range: TextRange,
183 ) -> Option<(SourceId, String, TextRange)> {
184 let name = scc.iter().find(|name| {
185 self.symbol_table
186 .get(name.as_str())
187 .is_some_and(|body| body.text_range().contains_range(range))
188 })?;
189 let (source_id, def) = self.find_def_by_name(name)?;
190 let n = def.name()?;
191 Some((
192 source_id,
193 format!("{} is defined here", name),
194 n.text_range(),
195 ))
196 }
197
198 fn find_def_by_name(&self, name: &str) -> Option<(SourceId, Def)> {
199 self.ast_map.iter().find_map(|(source_id, ast)| {
200 ast.defs()
201 .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false))
202 .map(|def| (*source_id, def))
203 })
204 }
205}
206
207struct CycleFinder<'a, 'q> {
208 adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>,
209 visited: IndexSet<&'q str>,
210 on_path: IndexMap<&'q str, usize>,
211 path: Vec<&'q str>,
212 edges: Vec<(SourceId, TextRange)>,
213}
214
215impl<'a, 'q> CycleFinder<'a, 'q> {
216 fn find(
217 nodes: &[&'q str],
218 adj: &'a IndexMap<&'q str, Vec<(&'q str, SourceId, TextRange)>>,
219 ) -> Option<Vec<(SourceId, TextRange, &'q str)>> {
220 let mut finder = Self {
221 adj,
222 visited: IndexSet::new(),
223 on_path: IndexMap::new(),
224 path: Vec::new(),
225 edges: Vec::new(),
226 };
227
228 for start in nodes {
229 if let Some(chain) = finder.dfs(start) {
230 return Some(chain);
231 }
232 }
233 None
234 }
235
236 fn dfs(&mut self, current: &'q str) -> Option<Vec<(SourceId, TextRange, &'q str)>> {
237 if self.on_path.contains_key(current) {
238 return None;
239 }
240
241 if self.visited.contains(current) {
242 return None;
243 }
244
245 self.visited.insert(current);
246 self.on_path.insert(current, self.path.len());
247 self.path.push(current);
248
249 if let Some(neighbors) = self.adj.get(current) {
250 for (target, source_id, range) in neighbors {
251 if let Some(&start_index) = self.on_path.get(target) {
252 let mut chain = Vec::new();
254 for i in start_index..self.path.len() - 1 {
255 let (src, rng) = self.edges[i];
256 chain.push((src, rng, self.path[i + 1]));
257 }
258 chain.push((*source_id, *range, *target));
259 return Some(chain);
260 }
261
262 self.edges.push((*source_id, *range));
263 if let Some(chain) = self.dfs(target) {
264 return Some(chain);
265 }
266 self.edges.pop();
267 }
268 }
269
270 self.path.pop();
271 self.on_path.swap_remove(current);
272 None
273 }
274}
275
276fn expr_has_escape(expr: &Expr, scc_names: &IndexSet<&str>) -> bool {
277 match expr {
278 Expr::Ref(r) => {
279 let Some(name_token) = r.name() else {
280 return true;
281 };
282 !scc_names.contains(name_token.text())
283 }
284 Expr::NamedNode(node) => {
285 let children: Vec<_> = node.children().collect();
286 children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc_names))
287 }
288 Expr::AltExpr(_) => expr
289 .children()
290 .iter()
291 .any(|c| expr_has_escape(c, scc_names)),
292 Expr::SeqExpr(_) => expr
293 .children()
294 .iter()
295 .all(|c| expr_has_escape(c, scc_names)),
296 Expr::QuantifiedExpr(q) => {
297 if q.is_optional() {
298 return true;
299 }
300 q.inner()
301 .map(|inner| expr_has_escape(&inner, scc_names))
302 .unwrap_or(true)
303 }
304 Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr
305 .children()
306 .iter()
307 .all(|c| expr_has_escape(c, scc_names)),
308 Expr::AnonymousNode(_) => true,
309 }
310}
311
312fn expr_guarantees_consumption(expr: &Expr) -> bool {
313 match expr {
314 Expr::NamedNode(_) | Expr::AnonymousNode(_) => true,
315 Expr::Ref(_) => false,
316 Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption),
317 Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption),
318 Expr::QuantifiedExpr(q) => {
319 !q.is_optional()
320 && q.inner()
321 .map(|i| expr_guarantees_consumption(&i))
322 .unwrap_or(false)
323 }
324 Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
325 expr.children().iter().all(expr_guarantees_consumption)
326 }
327 }
328}
329
330#[derive(Clone, Copy, PartialEq, Eq)]
332enum RefSearchMode {
333 Any,
335 Unguarded,
337}
338
339struct RefFinder<'a> {
340 target: &'a str,
341 found: Option<TextRange>,
342 mode: RefSearchMode,
343}
344
345impl Visitor for RefFinder<'_> {
346 fn visit_expr(&mut self, expr: &Expr) {
347 if self.found.is_some() {
348 return;
349 }
350 walk_expr(self, expr);
351 }
352
353 fn visit_named_node(&mut self, node: &NamedNode) {
354 if self.mode == RefSearchMode::Unguarded {
355 return; }
357 walk_named_node(self, node);
358 }
359
360 fn visit_anonymous_node(&mut self, _node: &AnonymousNode) {
361 }
364
365 fn visit_ref(&mut self, r: &Ref) {
366 if self.found.is_some() {
367 return;
368 }
369 if let Some(name) = r.name()
370 && name.text() == self.target
371 {
372 self.found = Some(name.text_range());
373 }
374 }
375
376 fn visit_seq_expr(&mut self, seq: &SeqExpr) {
377 for child in seq.children() {
378 self.visit_expr(&child);
379 if self.found.is_some() {
380 return;
381 }
382 if self.mode == RefSearchMode::Unguarded && expr_guarantees_consumption(&child) {
383 return;
384 }
385 }
386 }
387}
388
389fn find_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
390 let mut visitor = RefFinder {
391 target,
392 found: None,
393 mode: RefSearchMode::Any,
394 };
395 visitor.visit_expr(expr);
396 visitor.found
397}
398
399fn find_unguarded_ref_range(expr: &Expr, target: &str) -> Option<TextRange> {
400 let mut visitor = RefFinder {
401 target,
402 found: None,
403 mode: RefSearchMode::Unguarded,
404 };
405 visitor.visit_expr(expr);
406 visitor.found
407}