1use indexmap::{IndexMap, IndexSet};
8use rowan::TextRange;
9
10use super::Query;
11use crate::diagnostics::DiagnosticKind;
12use crate::parser::{Def, Expr};
13
14impl Query<'_> {
15 pub(super) fn validate_recursion(&mut self) {
16 let sccs = SccFinder::find(self);
17
18 for scc in sccs {
19 self.validate_scc(scc);
20 }
21 }
22
23 fn validate_scc(&mut self, scc: Vec<String>) {
24 let scc_set: IndexSet<&str> = scc.iter().map(|s| s.as_str()).collect();
25
26 let has_escape = scc.iter().any(|name| {
30 self.symbol_table
31 .get(name.as_str())
32 .map(|body| expr_has_escape(body, &scc_set))
33 .unwrap_or(true)
34 });
35
36 if !has_escape {
37 if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| {
40 q.find_ref_range(expr, target)
41 }) {
42 let chain = self.format_chain(raw_chain, false);
43 self.report_cycle(DiagnosticKind::RecursionNoEscape, &scc, chain);
44 }
45 return;
46 }
47
48 if let Some(raw_chain) = self.find_cycle(&scc, &scc_set, |q, expr, target| {
52 q.find_unguarded_ref_range(expr, target)
53 }) {
54 let chain = self.format_chain(raw_chain, true);
55 self.report_cycle(DiagnosticKind::DirectRecursion, &scc, chain);
56 }
57 }
58
59 fn find_cycle(
62 &self,
63 nodes: &[String],
64 domain: &IndexSet<&str>,
65 get_edge_location: impl Fn(&Query, &Expr, &str) -> Option<TextRange>,
66 ) -> Option<Vec<(TextRange, String)>> {
67 let mut adj = IndexMap::new();
68 for name in nodes {
69 if let Some(body) = self.symbol_table.get(name.as_str()) {
70 let neighbors = domain
71 .iter()
72 .filter_map(|target| {
73 get_edge_location(self, body, target)
74 .map(|range| (target.to_string(), range))
75 })
76 .collect::<Vec<_>>();
77 adj.insert(name.clone(), neighbors);
78 }
79 }
80
81 CycleFinder::find(nodes, &adj)
82 }
83
84 fn format_chain(
85 &self,
86 chain: Vec<(TextRange, String)>,
87 is_unguarded: bool,
88 ) -> Vec<(TextRange, String)> {
89 if chain.len() == 1 {
90 let (range, target) = &chain[0];
91 let msg = if is_unguarded {
92 "references itself".to_string()
93 } else {
94 format!("{} references itself", target)
95 };
96 return vec![(*range, msg)];
97 }
98
99 let len = chain.len();
100 chain
101 .into_iter()
102 .enumerate()
103 .map(|(i, (range, target))| {
104 let msg = if i == len - 1 {
105 format!("references {} (completing cycle)", target)
106 } else {
107 format!("references {}", target)
108 };
109 (range, msg)
110 })
111 .collect()
112 }
113
114 fn report_cycle(
115 &mut self,
116 kind: DiagnosticKind,
117 scc: &[String],
118 chain: Vec<(TextRange, String)>,
119 ) {
120 let primary_loc = chain
121 .first()
122 .map(|(r, _)| *r)
123 .unwrap_or_else(|| TextRange::empty(0.into()));
124
125 let related_def = if scc.len() > 1 {
126 self.find_def_info_containing(scc, primary_loc)
127 } else {
128 None
129 };
130
131 let mut builder = self.recursion_diagnostics.report(kind, primary_loc);
132
133 for (range, msg) in chain {
134 builder = builder.related_to(msg, range);
135 }
136
137 if let Some((msg, range)) = related_def {
138 builder = builder.related_to(msg, range);
139 }
140
141 builder.emit();
142 }
143
144 fn find_def_info_containing(
145 &self,
146 scc: &[String],
147 range: TextRange,
148 ) -> Option<(String, TextRange)> {
149 scc.iter()
150 .find(|name| {
151 self.symbol_table
152 .get(name.as_str())
153 .map(|body| body.text_range().contains_range(range))
154 .unwrap_or(false)
155 })
156 .and_then(|name| {
157 self.find_def_by_name(name).and_then(|def| {
158 def.name()
159 .map(|n| (format!("{} is defined here", name), n.text_range()))
160 })
161 })
162 }
163
164 fn find_def_by_name(&self, name: &str) -> Option<Def> {
165 self.ast
166 .defs()
167 .find(|d| d.name().map(|n| n.text() == name).unwrap_or(false))
168 }
169
170 fn find_ref_range(&self, expr: &Expr, target: &str) -> Option<TextRange> {
171 find_ref_in_expr(expr, target)
172 }
173
174 fn find_unguarded_ref_range(&self, expr: &Expr, target: &str) -> Option<TextRange> {
175 find_unguarded_ref_in_expr(expr, target)
176 }
177}
178
179struct CycleFinder<'a> {
180 adj: &'a IndexMap<String, Vec<(String, TextRange)>>,
181 visited: IndexSet<String>,
182 on_path: IndexMap<String, usize>,
183 path: Vec<String>,
184 edges: Vec<TextRange>,
185}
186
187impl<'a> CycleFinder<'a> {
188 fn find(
189 nodes: &[String],
190 adj: &'a IndexMap<String, Vec<(String, TextRange)>>,
191 ) -> Option<Vec<(TextRange, String)>> {
192 let mut finder = Self {
193 adj,
194 visited: IndexSet::new(),
195 on_path: IndexMap::new(),
196 path: Vec::new(),
197 edges: Vec::new(),
198 };
199
200 for start in nodes {
201 if let Some(chain) = finder.dfs(start) {
202 return Some(chain);
203 }
204 }
205 None
206 }
207
208 fn dfs(&mut self, current: &String) -> Option<Vec<(TextRange, String)>> {
209 if self.on_path.contains_key(current) {
210 return None;
211 }
212
213 if self.visited.contains(current) {
214 return None;
215 }
216
217 self.visited.insert(current.clone());
218 self.on_path.insert(current.clone(), self.path.len());
219 self.path.push(current.clone());
220
221 if let Some(neighbors) = self.adj.get(current) {
222 for (target, range) in neighbors {
223 if let Some(&start_index) = self.on_path.get(target) {
224 let mut chain = Vec::new();
229 for i in start_index..self.path.len() - 1 {
230 chain.push((self.edges[i], self.path[i + 1].clone()));
231 }
232 chain.push((*range, target.clone()));
233 return Some(chain);
234 }
235
236 self.edges.push(*range);
237 if let Some(chain) = self.dfs(target) {
238 return Some(chain);
239 }
240 self.edges.pop();
241 }
242 }
243
244 self.path.pop();
245 self.on_path.swap_remove(current);
246 None
247 }
248}
249
250struct SccFinder<'a, 'src> {
251 query: &'a Query<'src>,
252 index: usize,
253 stack: Vec<String>,
254 on_stack: IndexSet<String>,
255 indices: IndexMap<String, usize>,
256 lowlinks: IndexMap<String, usize>,
257 sccs: Vec<Vec<String>>,
258}
259
260impl<'a, 'src> SccFinder<'a, 'src> {
261 fn find(query: &'a Query<'src>) -> Vec<Vec<String>> {
262 let mut finder = Self {
263 query,
264 index: 0,
265 stack: Vec::new(),
266 on_stack: IndexSet::new(),
267 indices: IndexMap::new(),
268 lowlinks: IndexMap::new(),
269 sccs: Vec::new(),
270 };
271
272 for name in query.symbol_table.keys() {
273 if !finder.indices.contains_key(*name) {
274 finder.strongconnect(name);
275 }
276 }
277
278 finder
279 .sccs
280 .into_iter()
281 .filter(|scc| {
282 scc.len() > 1
283 || query
284 .symbol_table
285 .get(scc[0].as_str())
286 .map(|body| collect_refs(body).contains(scc[0].as_str()))
287 .unwrap_or(false)
288 })
289 .collect()
290 }
291
292 fn strongconnect(&mut self, name: &str) {
293 self.indices.insert(name.to_string(), self.index);
294 self.lowlinks.insert(name.to_string(), self.index);
295 self.index += 1;
296 self.stack.push(name.to_string());
297 self.on_stack.insert(name.to_string());
298
299 if let Some(body) = self.query.symbol_table.get(name) {
300 let refs = collect_refs(body);
301 for ref_name in refs {
302 if !self.query.symbol_table.contains_key(ref_name.as_str()) {
303 continue;
304 }
305
306 if !self.indices.contains_key(&ref_name) {
307 self.strongconnect(&ref_name);
308 let ref_lowlink = self.lowlinks[&ref_name];
309 let my_lowlink = self.lowlinks.get_mut(name).unwrap();
310 *my_lowlink = (*my_lowlink).min(ref_lowlink);
311 } else if self.on_stack.contains(&ref_name) {
312 let ref_index = self.indices[&ref_name];
313 let my_lowlink = self.lowlinks.get_mut(name).unwrap();
314 *my_lowlink = (*my_lowlink).min(ref_index);
315 }
316 }
317 }
318
319 if self.lowlinks[name] == self.indices[name] {
320 let mut scc = Vec::new();
321 loop {
322 let w = self.stack.pop().unwrap();
323 self.on_stack.swap_remove(&w);
324 scc.push(w.clone());
325 if w == name {
326 break;
327 }
328 }
329 self.sccs.push(scc);
330 }
331 }
332}
333
334fn expr_has_escape(expr: &Expr, scc: &IndexSet<&str>) -> bool {
335 match expr {
336 Expr::Ref(r) => {
337 let Some(name_token) = r.name() else {
338 return true;
339 };
340 !scc.contains(name_token.text())
341 }
342 Expr::NamedNode(node) => {
343 let children: Vec<_> = node.children().collect();
344 children.is_empty() || children.iter().all(|c| expr_has_escape(c, scc))
345 }
346 Expr::AltExpr(_) => expr.children().iter().any(|c| expr_has_escape(c, scc)),
347 Expr::SeqExpr(_) => expr.children().iter().all(|c| expr_has_escape(c, scc)),
348 Expr::QuantifiedExpr(q) => {
349 if q.is_optional() {
350 return true;
351 }
352 q.inner()
353 .map(|inner| expr_has_escape(&inner, scc))
354 .unwrap_or(true)
355 }
356 Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
357 expr.children().iter().all(|c| expr_has_escape(c, scc))
358 }
359 Expr::AnonymousNode(_) => true,
360 }
361}
362
363fn expr_guarantees_consumption(expr: &Expr) -> bool {
364 match expr {
365 Expr::NamedNode(_) | Expr::AnonymousNode(_) => true,
366 Expr::Ref(_) => false,
367 Expr::AltExpr(_) => expr.children().iter().all(expr_guarantees_consumption),
368 Expr::SeqExpr(_) => expr.children().iter().any(expr_guarantees_consumption),
369 Expr::QuantifiedExpr(q) => {
370 !q.is_optional()
371 && q.inner()
372 .map(|i| expr_guarantees_consumption(&i))
373 .unwrap_or(false)
374 }
375 Expr::CapturedExpr(_) | Expr::FieldExpr(_) => {
376 expr.children().iter().all(expr_guarantees_consumption)
377 }
378 }
379}
380
381fn collect_refs(expr: &Expr) -> IndexSet<String> {
382 let mut refs = IndexSet::new();
383 collect_refs_into(expr, &mut refs);
384 refs
385}
386
387fn collect_refs_into(expr: &Expr, refs: &mut IndexSet<String>) {
388 if let Expr::Ref(r) = expr
389 && let Some(name_token) = r.name()
390 {
391 refs.insert(name_token.text().to_string());
392 }
393
394 for child in expr.children() {
395 collect_refs_into(&child, refs);
396 }
397}
398
399fn find_ref_in_expr(expr: &Expr, target: &str) -> Option<TextRange> {
400 if let Expr::Ref(r) = expr {
401 let name_token = r.name()?;
402 if name_token.text() == target {
403 return Some(name_token.text_range());
404 }
405 }
406
407 expr.children()
408 .iter()
409 .find_map(|child| find_ref_in_expr(child, target))
410}
411
412fn find_unguarded_ref_in_expr(expr: &Expr, target: &str) -> Option<TextRange> {
413 match expr {
414 Expr::Ref(r) => r
415 .name()
416 .filter(|n| n.text() == target)
417 .map(|n| n.text_range()),
418 Expr::NamedNode(_) | Expr::AnonymousNode(_) => None,
419 Expr::AltExpr(_) => expr
420 .children()
421 .iter()
422 .find_map(|c| find_unguarded_ref_in_expr(c, target)),
423 Expr::SeqExpr(_) => {
424 for c in expr.children() {
425 if let Some(range) = find_unguarded_ref_in_expr(&c, target) {
426 return Some(range);
427 }
428 if expr_guarantees_consumption(&c) {
429 return None;
430 }
431 }
432 None
433 }
434 Expr::QuantifiedExpr(q) => q
435 .inner()
436 .and_then(|i| find_unguarded_ref_in_expr(&i, target)),
437 Expr::CapturedExpr(_) | Expr::FieldExpr(_) => expr
438 .children()
439 .iter()
440 .find_map(|c| find_unguarded_ref_in_expr(c, target)),
441 }
442}