1use indexmap::IndexSet;
9use plotnik_langs::{Lang, NodeFieldId, NodeTypeId};
10use rowan::TextRange;
11
12use crate::diagnostics::DiagnosticKind;
13use crate::parser::ast::{self, Expr, NamedNode};
14use crate::parser::cst::{SyntaxKind, SyntaxToken};
15
16use super::Query;
17
18fn edit_distance(a: &str, b: &str) -> usize {
20 let a_len = a.chars().count();
21 let b_len = b.chars().count();
22
23 if a_len == 0 {
24 return b_len;
25 }
26 if b_len == 0 {
27 return a_len;
28 }
29
30 let mut prev: Vec<usize> = (0..=b_len).collect();
31 let mut curr = vec![0; b_len + 1];
32
33 for (i, ca) in a.chars().enumerate() {
34 curr[0] = i + 1;
35 for (j, cb) in b.chars().enumerate() {
36 let cost = if ca == cb { 0 } else { 1 };
37 curr[j + 1] = (prev[j + 1] + 1).min(curr[j] + 1).min(prev[j] + cost);
38 }
39 std::mem::swap(&mut prev, &mut curr);
40 }
41
42 prev[b_len]
43}
44
45fn find_similar<'a>(name: &str, candidates: &[&'a str], max_distance: usize) -> Option<&'a str> {
47 candidates
48 .iter()
49 .map(|&c| (c, edit_distance(name, c)))
50 .filter(|(_, d)| *d <= max_distance)
51 .min_by_key(|(_, d)| *d)
52 .map(|(c, _)| c)
53}
54
55#[allow(dead_code)]
57fn is_subtype_of(lang: &Lang, child: NodeTypeId, supertype: NodeTypeId) -> bool {
58 let subtypes = lang.subtypes(supertype);
59 for &subtype in subtypes {
60 if subtype == child {
61 return true;
62 }
63 if lang.is_supertype(subtype) && is_subtype_of(lang, child, subtype) {
64 return true;
65 }
66 }
67 false
68}
69
70#[allow(dead_code)]
72fn is_valid_child_expanded(lang: &Lang, parent: NodeTypeId, child: NodeTypeId) -> bool {
73 let valid_types = lang.valid_child_types(parent);
74 for &allowed in valid_types {
75 if allowed == child {
76 return true;
77 }
78 if lang.is_supertype(allowed) && is_subtype_of(lang, child, allowed) {
79 return true;
80 }
81 }
82 false
83}
84
85#[allow(dead_code)]
87fn is_valid_field_type_expanded(
88 lang: &Lang,
89 parent: NodeTypeId,
90 field: NodeFieldId,
91 child: NodeTypeId,
92) -> bool {
93 if lang.is_valid_field_type(parent, field, child) {
94 return true;
95 }
96 let valid_types = lang.valid_field_types(parent, field);
97 for &allowed in valid_types {
98 if lang.is_supertype(allowed) && is_subtype_of(lang, child, allowed) {
99 return true;
100 }
101 }
102 false
103}
104
105#[allow(dead_code)]
107fn format_list(items: &[&str], max_items: usize) -> String {
108 if items.is_empty() {
109 return String::new();
110 }
111 if items.len() <= max_items {
112 items
113 .iter()
114 .map(|s| format!("`{}`", s))
115 .collect::<Vec<_>>()
116 .join(", ")
117 } else {
118 let shown: Vec<_> = items[..max_items]
119 .iter()
120 .map(|s| format!("`{}`", s))
121 .collect();
122 format!(
123 "{}, ... ({} more)",
124 shown.join(", "),
125 items.len() - max_items
126 )
127 }
128}
129
130#[allow(dead_code)]
132#[derive(Clone, Copy)]
133struct ValidationContext<'a> {
134 parent_id: NodeTypeId,
136 parent_name: &'a str,
138 parent_range: TextRange,
140 field: Option<FieldContext<'a>>,
142}
143
144#[allow(dead_code)]
145#[derive(Clone, Copy)]
146struct FieldContext<'a> {
147 name: &'a str,
148 id: NodeFieldId,
149 range: TextRange,
150}
151
152impl<'a> Query<'a> {
153 pub fn link(&mut self, lang: &Lang) {
157 self.resolve_node_types(lang);
158 self.resolve_fields(lang);
159 self.validate_structure(lang);
160 }
161
162 fn resolve_node_types(&mut self, lang: &Lang) {
163 let defs: Vec<_> = self.ast.defs().collect();
164 for def in defs {
165 let Some(body) = def.body() else { continue };
166 self.collect_node_types(&body, lang);
167 }
168 }
169
170 fn collect_node_types(&mut self, expr: &Expr, lang: &Lang) {
171 match expr {
172 Expr::NamedNode(node) => {
173 self.resolve_named_node(node, lang);
174 for child in node.children() {
175 self.collect_node_types(&child, lang);
176 }
177 }
178 Expr::AnonymousNode(anon) => {
179 if anon.is_any() {
180 return;
181 }
182 let Some(value_token) = anon.value() else {
183 return;
184 };
185 let value = value_token.text();
186 if self.node_type_ids.contains_key(value) {
187 return;
188 }
189 let resolved = lang.resolve_anonymous_node(value);
190 self.node_type_ids.insert(
191 &self.source[text_range_to_usize(value_token.text_range())],
192 resolved,
193 );
194 if resolved.is_none() {
195 self.link_diagnostics
196 .report(DiagnosticKind::UnknownNodeType, value_token.text_range())
197 .message(value)
198 .emit();
199 }
200 }
201 Expr::AltExpr(alt) => {
202 for branch in alt.branches() {
203 let Some(body) = branch.body() else { continue };
204 self.collect_node_types(&body, lang);
205 }
206 }
207 Expr::SeqExpr(seq) => {
208 for child in seq.children() {
209 self.collect_node_types(&child, lang);
210 }
211 }
212 Expr::CapturedExpr(cap) => {
213 let Some(inner) = cap.inner() else { return };
214 self.collect_node_types(&inner, lang);
215 }
216 Expr::QuantifiedExpr(q) => {
217 let Some(inner) = q.inner() else { return };
218 self.collect_node_types(&inner, lang);
219 }
220 Expr::FieldExpr(f) => {
221 let Some(value) = f.value() else { return };
222 self.collect_node_types(&value, lang);
223 }
224 Expr::Ref(_) => {}
225 }
226 }
227
228 fn resolve_named_node(&mut self, node: &NamedNode, lang: &Lang) {
229 if node.is_any() {
230 return;
231 }
232 let Some(type_token) = node.node_type() else {
233 return;
234 };
235 if matches!(
236 type_token.kind(),
237 SyntaxKind::KwError | SyntaxKind::KwMissing
238 ) {
239 return;
240 }
241 let type_name = type_token.text();
242 if self.node_type_ids.contains_key(type_name) {
243 return;
244 }
245 let resolved = lang.resolve_named_node(type_name);
246 self.node_type_ids.insert(
247 &self.source[text_range_to_usize(type_token.text_range())],
248 resolved,
249 );
250 if resolved.is_none() {
251 let all_types = lang.all_named_node_kinds();
252 let max_dist = (type_name.len() / 3).clamp(2, 4);
253 let suggestion = find_similar(type_name, &all_types, max_dist);
254
255 let mut builder = self
256 .link_diagnostics
257 .report(DiagnosticKind::UnknownNodeType, type_token.text_range())
258 .message(type_name);
259
260 if let Some(similar) = suggestion {
261 builder = builder.hint(format!("did you mean `{}`?", similar));
262 }
263 builder.emit();
264 }
265 }
266
267 fn resolve_fields(&mut self, lang: &Lang) {
268 let defs: Vec<_> = self.ast.defs().collect();
269 for def in defs {
270 let Some(body) = def.body() else { continue };
271 self.collect_fields(&body, lang);
272 }
273 }
274
275 fn collect_fields(&mut self, expr: &Expr, lang: &Lang) {
276 match expr {
277 Expr::NamedNode(node) => {
278 for child in node.children() {
279 self.collect_fields(&child, lang);
280 }
281 for child in node.as_cst().children() {
282 if let Some(neg) = ast::NegatedField::cast(child) {
283 self.resolve_field_by_token(neg.name(), lang);
284 }
285 }
286 }
287 Expr::AltExpr(alt) => {
288 for branch in alt.branches() {
289 let Some(body) = branch.body() else { continue };
290 self.collect_fields(&body, lang);
291 }
292 }
293 Expr::SeqExpr(seq) => {
294 for child in seq.children() {
295 self.collect_fields(&child, lang);
296 }
297 }
298 Expr::CapturedExpr(cap) => {
299 let Some(inner) = cap.inner() else { return };
300 self.collect_fields(&inner, lang);
301 }
302 Expr::QuantifiedExpr(q) => {
303 let Some(inner) = q.inner() else { return };
304 self.collect_fields(&inner, lang);
305 }
306 Expr::FieldExpr(f) => {
307 self.resolve_field_by_token(f.name(), lang);
308 let Some(value) = f.value() else { return };
309 self.collect_fields(&value, lang);
310 }
311 Expr::AnonymousNode(_) | Expr::Ref(_) => {}
312 }
313 }
314
315 fn resolve_field_by_token(&mut self, name_token: Option<SyntaxToken>, lang: &Lang) {
316 let Some(name_token) = name_token else {
317 return;
318 };
319 let field_name = name_token.text();
320 if self.node_field_ids.contains_key(field_name) {
321 return;
322 }
323 let resolved = lang.resolve_field(field_name);
324 self.node_field_ids.insert(
325 &self.source[text_range_to_usize(name_token.text_range())],
326 resolved,
327 );
328 if resolved.is_some() {
329 return;
330 }
331 let all_fields = lang.all_field_names();
332 let max_dist = (field_name.len() / 3).clamp(2, 4);
333 let suggestion = find_similar(field_name, &all_fields, max_dist);
334
335 let mut builder = self
336 .link_diagnostics
337 .report(DiagnosticKind::UnknownField, name_token.text_range())
338 .message(field_name);
339
340 if let Some(similar) = suggestion {
341 builder = builder.hint(format!("did you mean `{}`?", similar));
342 }
343 builder.emit();
344 }
345
346 fn validate_structure(&mut self, lang: &Lang) {
347 let defs: Vec<_> = self.ast.defs().collect();
348 for def in defs {
349 let Some(body) = def.body() else { continue };
350 let mut visited = IndexSet::new();
351 self.validate_expr_structure(&body, None, lang, &mut visited);
352 }
353 }
354
355 fn validate_expr_structure(
356 &mut self,
357 expr: &Expr,
358 ctx: Option<ValidationContext<'a>>,
359 lang: &Lang,
360 visited: &mut IndexSet<String>,
361 ) {
362 match expr {
363 Expr::NamedNode(node) => {
364 if let Some(ref ctx) = ctx {
366 self.validate_terminal_type(expr, ctx, lang, visited);
367 }
368
369 let child_ctx = self.make_node_context(node, lang);
371
372 for child in node.children() {
373 match &child {
374 Expr::FieldExpr(f) => {
375 self.validate_field_expr(f, child_ctx.as_ref(), lang, visited);
377 }
378 _ => {
379 if let Some(ctx) = child_ctx {
381 self.validate_non_field_children(&child, &ctx, lang, visited);
382 }
383 self.validate_expr_structure(&child, child_ctx, lang, visited);
384 }
385 }
386 }
387
388 if let Some(ctx) = child_ctx {
390 for child in node.as_cst().children() {
391 if let Some(neg) = ast::NegatedField::cast(child) {
392 self.validate_negated_field(&neg, &ctx, lang);
393 }
394 }
395 }
396 }
397 Expr::AnonymousNode(_) => {
398 if let Some(ref ctx) = ctx {
400 self.validate_terminal_type(expr, ctx, lang, visited);
401 }
402 }
403 Expr::FieldExpr(f) => {
404 self.validate_field_expr(f, ctx.as_ref(), lang, visited);
406 }
407 Expr::AltExpr(alt) => {
408 for branch in alt.branches() {
409 let Some(body) = branch.body() else { continue };
410 self.validate_expr_structure(&body, ctx, lang, visited);
411 }
412 }
413 Expr::SeqExpr(seq) => {
414 for child in seq.children() {
415 self.validate_expr_structure(&child, ctx, lang, visited);
416 }
417 }
418 Expr::CapturedExpr(cap) => {
419 let Some(inner) = cap.inner() else { return };
420 self.validate_expr_structure(&inner, ctx, lang, visited);
421 }
422 Expr::QuantifiedExpr(q) => {
423 let Some(inner) = q.inner() else { return };
424 self.validate_expr_structure(&inner, ctx, lang, visited);
425 }
426 Expr::Ref(r) => {
427 let Some(name_token) = r.name() else { return };
428 let name = name_token.text();
429 if !visited.insert(name.to_string()) {
430 return;
431 }
432 let Some(body) = self.symbol_table.get(name).cloned() else {
433 visited.swap_remove(name);
434 return;
435 };
436 self.validate_expr_structure(&body, ctx, lang, visited);
437 visited.swap_remove(name);
438 }
439 }
440 }
441
442 fn make_node_context(&self, node: &NamedNode, lang: &Lang) -> Option<ValidationContext<'a>> {
444 if node.is_any() {
445 return None;
446 }
447 let type_token = node.node_type()?;
448 if matches!(
449 type_token.kind(),
450 SyntaxKind::KwError | SyntaxKind::KwMissing
451 ) {
452 return None;
453 }
454 let type_name = type_token.text();
455 let parent_id = self.node_type_ids.get(type_name).copied().flatten()?;
456 let parent_name = lang.node_type_name(parent_id)?;
457 Some(ValidationContext {
458 parent_id,
459 parent_name,
460 parent_range: type_token.text_range(),
461 field: None,
462 })
463 }
464
465 fn validate_field_expr(
467 &mut self,
468 field: &ast::FieldExpr,
469 ctx: Option<&ValidationContext<'a>>,
470 lang: &Lang,
471 visited: &mut IndexSet<String>,
472 ) {
473 let Some(name_token) = field.name() else {
474 return;
475 };
476 let field_name = name_token.text();
477
478 let Some(field_id) = self.node_field_ids.get(field_name).copied().flatten() else {
479 return;
480 };
481
482 let Some(ctx) = ctx else {
483 return;
484 };
485
486 if !lang.has_field(ctx.parent_id, field_id) {
488 self.emit_field_not_on_node(
489 name_token.text_range(),
490 field_name,
491 ctx.parent_id,
492 ctx.parent_range,
493 lang,
494 );
495 return;
496 }
497
498 let Some(value) = field.value() else {
499 return;
500 };
501
502 let field_ctx = ValidationContext {
504 parent_id: ctx.parent_id,
505 parent_name: ctx.parent_name,
506 parent_range: ctx.parent_range,
507 field: Some(FieldContext {
508 name: &self.source[text_range_to_usize(name_token.text_range())],
509 id: field_id,
510 range: name_token.text_range(),
511 }),
512 };
513
514 self.validate_expr_structure(&value, Some(field_ctx), lang, visited);
517 }
518
519 #[cfg(feature = "unstable-child-type-validation")]
521 fn validate_non_field_children(
522 &mut self,
523 expr: &Expr,
524 ctx: &ValidationContext<'a>,
525 lang: &Lang,
526 visited: &mut IndexSet<String>,
527 ) {
528 let terminals = self.collect_terminal_types(expr, visited);
530
531 let valid_types = lang.valid_child_types(ctx.parent_id);
533 let parent_only_fields = valid_types.is_empty();
534
535 for (child_id, child_name, child_range) in terminals {
536 if parent_only_fields {
537 self.link_diagnostics
538 .report(DiagnosticKind::InvalidChildType, child_range)
539 .message(child_name)
540 .related_to(
541 format!("`{}` only accepts children via fields", ctx.parent_name),
542 ctx.parent_range,
543 )
544 .emit();
545 continue;
546 }
547
548 if is_valid_child_expanded(lang, ctx.parent_id, child_id) {
549 continue;
550 }
551
552 let valid_names: Vec<&str> = valid_types
553 .iter()
554 .filter_map(|&id| lang.node_type_name(id))
555 .collect();
556
557 let mut builder = self
558 .link_diagnostics
559 .report(DiagnosticKind::InvalidChildType, child_range)
560 .message(child_name)
561 .related_to(format!("inside `{}`", ctx.parent_name), ctx.parent_range);
562
563 if !valid_names.is_empty() {
564 builder = builder.hint(format!(
565 "valid children for `{}`: {}",
566 ctx.parent_name,
567 format_list(&valid_names, 5)
568 ));
569 }
570 builder.emit();
571 }
572 }
573
574 #[cfg(not(feature = "unstable-child-type-validation"))]
575 fn validate_non_field_children(
576 &mut self,
577 _expr: &Expr,
578 _ctx: &ValidationContext<'a>,
579 _lang: &Lang,
580 _visited: &mut IndexSet<String>,
581 ) {
582 }
583
584 #[cfg(feature = "unstable-child-type-validation")]
586 fn validate_terminal_type(
587 &mut self,
588 expr: &Expr,
589 ctx: &ValidationContext<'a>,
590 lang: &Lang,
591 visited: &mut IndexSet<String>,
592 ) {
593 if let Expr::Ref(r) = expr {
595 let Some(name_token) = r.name() else { return };
596 let name = name_token.text();
597 if !visited.insert(name.to_string()) {
598 return;
599 }
600 let Some(body) = self.symbol_table.get(name).cloned() else {
601 visited.swap_remove(name);
602 return;
603 };
604 self.validate_terminal_type(&body, ctx, lang, visited);
605 visited.swap_remove(name);
606 return;
607 }
608
609 let Some((child_id, child_name, child_range)) = self.get_terminal_type_info(expr) else {
610 return;
611 };
612
613 if let Some(ref field) = ctx.field {
614 if is_valid_field_type_expanded(lang, ctx.parent_id, field.id, child_id) {
616 return;
617 }
618
619 let valid_types = lang.valid_field_types(ctx.parent_id, field.id);
620 let valid_names: Vec<&str> = valid_types
621 .iter()
622 .filter_map(|&id| lang.node_type_name(id))
623 .collect();
624
625 let mut builder = self
626 .link_diagnostics
627 .report(DiagnosticKind::InvalidFieldChildType, child_range)
628 .message(child_name)
629 .related_to(
630 format!("field `{}` on `{}`", field.name, ctx.parent_name),
631 field.range,
632 );
633
634 if !valid_names.is_empty() {
635 builder = builder.hint(format!(
636 "valid types for `{}`: {}",
637 field.name,
638 format_list(&valid_names, 5)
639 ));
640 }
641 builder.emit();
642 }
643 }
645
646 #[cfg(not(feature = "unstable-child-type-validation"))]
647 fn validate_terminal_type(
648 &mut self,
649 _expr: &Expr,
650 _ctx: &ValidationContext<'a>,
651 _lang: &Lang,
652 _visited: &mut IndexSet<String>,
653 ) {
654 }
655
656 #[allow(dead_code)]
658 fn collect_terminal_types(
659 &self,
660 expr: &Expr,
661 visited: &mut IndexSet<String>,
662 ) -> Vec<(NodeTypeId, &'a str, TextRange)> {
663 let mut result = Vec::new();
664 self.collect_terminal_types_impl(expr, &mut result, visited);
665 result
666 }
667
668 #[allow(dead_code)]
669 fn collect_terminal_types_impl(
670 &self,
671 expr: &Expr,
672 result: &mut Vec<(NodeTypeId, &'a str, TextRange)>,
673 visited: &mut IndexSet<String>,
674 ) {
675 match expr {
676 Expr::NamedNode(_) | Expr::AnonymousNode(_) => {
677 if let Some(info) = self.get_terminal_type_info(expr) {
678 result.push(info);
679 }
680 }
681 Expr::AltExpr(alt) => {
682 for branch in alt.branches() {
683 if let Some(body) = branch.body() {
684 self.collect_terminal_types_impl(&body, result, visited);
685 }
686 }
687 }
688 Expr::SeqExpr(seq) => {
689 for child in seq.children() {
690 self.collect_terminal_types_impl(&child, result, visited);
691 }
692 }
693 Expr::CapturedExpr(cap) => {
694 if let Some(inner) = cap.inner() {
695 self.collect_terminal_types_impl(&inner, result, visited);
696 }
697 }
698 Expr::QuantifiedExpr(q) => {
699 if let Some(inner) = q.inner() {
700 self.collect_terminal_types_impl(&inner, result, visited);
701 }
702 }
703 Expr::Ref(r) => {
704 let Some(name_token) = r.name() else { return };
705 let name = name_token.text();
706 if !visited.insert(name.to_string()) {
707 return;
708 }
709 let Some(body) = self.symbol_table.get(name) else {
710 visited.swap_remove(name);
711 return;
712 };
713 self.collect_terminal_types_impl(body, result, visited);
714 visited.swap_remove(name);
715 }
716 Expr::FieldExpr(_) => {
717 }
719 }
720 }
721
722 #[allow(dead_code)]
724 fn get_terminal_type_info(&self, expr: &Expr) -> Option<(NodeTypeId, &'a str, TextRange)> {
725 match expr {
726 Expr::NamedNode(node) => {
727 if node.is_any() {
728 return None;
729 }
730 let type_token = node.node_type()?;
731 if matches!(
732 type_token.kind(),
733 SyntaxKind::KwError | SyntaxKind::KwMissing
734 ) {
735 return None;
736 }
737 let type_name = type_token.text();
738 let type_id = self.node_type_ids.get(type_name).copied().flatten()?;
739 let name = &self.source[text_range_to_usize(type_token.text_range())];
740 Some((type_id, name, type_token.text_range()))
741 }
742 Expr::AnonymousNode(anon) => {
743 if anon.is_any() {
744 return None;
745 }
746 let value_token = anon.value()?;
747 let value = &self.source[text_range_to_usize(value_token.text_range())];
748 let type_id = self.node_type_ids.get(value).copied().flatten()?;
749 Some((type_id, value, value_token.text_range()))
750 }
751 _ => None,
752 }
753 }
754
755 fn validate_negated_field(
756 &mut self,
757 neg: &ast::NegatedField,
758 ctx: &ValidationContext<'a>,
759 lang: &Lang,
760 ) {
761 let Some(name_token) = neg.name() else {
762 return;
763 };
764 let field_name = name_token.text();
765
766 let Some(field_id) = self.node_field_ids.get(field_name).copied().flatten() else {
767 return;
768 };
769
770 if lang.has_field(ctx.parent_id, field_id) {
771 return;
772 }
773 self.emit_field_not_on_node(
774 name_token.text_range(),
775 field_name,
776 ctx.parent_id,
777 ctx.parent_range,
778 lang,
779 );
780 }
781
782 fn emit_field_not_on_node(
783 &mut self,
784 range: TextRange,
785 field_name: &str,
786 parent_id: NodeTypeId,
787 parent_range: TextRange,
788 lang: &Lang,
789 ) {
790 let valid_fields = lang.fields_for_node_type(parent_id);
791 let parent_name = lang.node_type_name(parent_id).unwrap_or("(unknown)");
792
793 let mut builder = self
794 .link_diagnostics
795 .report(DiagnosticKind::FieldNotOnNodeType, range)
796 .message(field_name)
797 .related_to(format!("on `{}`", parent_name), parent_range);
798
799 if valid_fields.is_empty() {
800 builder = builder.hint(format!("`{}` has no fields", parent_name));
801 } else {
802 let max_dist = (field_name.len() / 3).clamp(2, 4);
803 if let Some(similar) = find_similar(field_name, &valid_fields, max_dist) {
804 builder = builder.hint(format!("did you mean `{}`?", similar));
805 }
806 builder = builder.hint(format!(
807 "valid fields for `{}`: {}",
808 parent_name,
809 format_list(&valid_fields, 5)
810 ));
811 }
812 builder.emit();
813 }
814}
815
816fn text_range_to_usize(range: TextRange) -> std::ops::Range<usize> {
817 let start: usize = range.start().into();
818 let end: usize = range.end().into();
819 start..end
820}