plotnik_lib/query/
shapes.rs

1//! Shape cardinality analysis for query expressions.
2//!
3//! Determines whether an expression matches a single node position (`One`)
4//! or multiple sequential positions (`Many`). Used to validate field constraints:
5//! `field: expr` requires `expr` to have `ShapeCardinality::One`.
6//!
7//! `Invalid` marks nodes where cardinality cannot be determined (error nodes,
8//! undefined refs, etc.).
9
10use super::Query;
11use crate::diagnostics::DiagnosticKind;
12use crate::parser::{Expr, Ref, SeqExpr};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum ShapeCardinality {
16    One,
17    Many,
18    Invalid,
19}
20
21impl Query<'_> {
22    pub(super) fn infer_shapes(&mut self) {
23        let bodies: Vec<_> = self.ast.defs().filter_map(|d| d.body()).collect();
24
25        for body in &bodies {
26            self.compute_all_cardinalities(body);
27        }
28
29        for body in &bodies {
30            self.validate_shapes(body);
31        }
32    }
33
34    fn compute_all_cardinalities(&mut self, expr: &Expr) {
35        self.get_or_compute(expr);
36
37        for child in expr.children() {
38            self.compute_all_cardinalities(&child);
39        }
40    }
41
42    fn get_or_compute(&mut self, expr: &Expr) -> ShapeCardinality {
43        if let Some(&c) = self.shape_cardinality_table.get(expr) {
44            return c;
45        }
46        // Insert sentinel to break cycles (e.g., `Foo = (Foo)`)
47        self.shape_cardinality_table
48            .insert(expr.clone(), ShapeCardinality::Invalid);
49        let c = self.compute_single(expr);
50        self.shape_cardinality_table.insert(expr.clone(), c);
51        c
52    }
53
54    fn compute_single(&mut self, expr: &Expr) -> ShapeCardinality {
55        match expr {
56            Expr::NamedNode(_) | Expr::AnonymousNode(_) | Expr::FieldExpr(_) | Expr::AltExpr(_) => {
57                ShapeCardinality::One
58            }
59
60            Expr::SeqExpr(seq) => self.seq_cardinality(seq),
61
62            Expr::CapturedExpr(cap) => {
63                let Some(inner) = cap.inner() else {
64                    return ShapeCardinality::Invalid;
65                };
66                self.get_or_compute(&inner)
67            }
68
69            Expr::QuantifiedExpr(q) => {
70                let Some(inner) = q.inner() else {
71                    return ShapeCardinality::Invalid;
72                };
73                self.get_or_compute(&inner)
74            }
75
76            Expr::Ref(r) => self.ref_cardinality(r),
77        }
78    }
79
80    fn seq_cardinality(&mut self, seq: &SeqExpr) -> ShapeCardinality {
81        let children: Vec<_> = seq.children().collect();
82
83        match children.len() {
84            0 => ShapeCardinality::One,
85            1 => self.get_or_compute(&children[0]),
86            _ => ShapeCardinality::Many,
87        }
88    }
89
90    fn ref_cardinality(&mut self, r: &Ref) -> ShapeCardinality {
91        let name_tok = r.name().expect(
92            "shape_cardinalities: Ref without name token \
93             (parser only creates Ref for PascalCase Id)",
94        );
95        let name = name_tok.text();
96
97        let Some(body) = self.symbol_table.get(name).cloned() else {
98            return ShapeCardinality::Invalid;
99        };
100
101        self.get_or_compute(&body)
102    }
103
104    fn validate_shapes(&mut self, expr: &Expr) {
105        if let Expr::FieldExpr(field) = expr
106            && let Some(value) = field.value()
107        {
108            let card = self
109                .shape_cardinality_table
110                .get(&value)
111                .copied()
112                .unwrap_or(ShapeCardinality::One);
113
114            if card == ShapeCardinality::Many {
115                let field_name = field
116                    .name()
117                    .map(|t| t.text().to_string())
118                    .unwrap_or_else(|| "field".to_string());
119
120                self.shapes_diagnostics
121                    .report(DiagnosticKind::FieldSequenceValue, value.text_range())
122                    .message(field_name)
123                    .emit();
124            }
125        }
126
127        for child in expr.children() {
128            self.validate_shapes(&child);
129        }
130    }
131}