plotnik_lib/query/
infer.rs

1//! AST-based type inference for Plotnik queries.
2//!
3//! Analyzes query AST to determine output types.
4//! Rules follow ADR-0009 (Type System).
5//!
6//! # Design
7//!
8//! Unlike graph-based inference which must reconstruct structure from CFG traversal,
9//! AST-based inference directly walks the tree structure:
10//! - Sequences → `SeqExpr`
11//! - Alternations → `AltExpr` with `.kind()` for tagged/untagged
12//! - Quantifiers → `QuantifiedExpr`
13//! - Captures → `CapturedExpr`
14//!
15//! This eliminates dry-run traversal, reconvergence detection, and scope stack management.
16
17use std::collections::{HashMap, HashSet};
18
19use indexmap::IndexMap;
20use rowan::TextRange;
21
22use crate::diagnostics::{DiagnosticKind, Diagnostics};
23use crate::ir::{TYPE_NODE, TYPE_STR, TYPE_VOID, TypeId, TypeKind};
24use crate::parser::ast::{self, AltKind, Expr};
25use crate::parser::token_src;
26
27use super::Query;
28
29/// Result of type inference.
30#[derive(Debug, Default)]
31pub struct TypeInferenceResult<'src> {
32    pub type_defs: Vec<InferredTypeDef<'src>>,
33    pub entrypoint_types: IndexMap<&'src str, TypeId>,
34    pub diagnostics: Diagnostics,
35    pub errors: Vec<UnificationError<'src>>,
36}
37
38/// Error when types cannot be unified in alternation branches.
39#[derive(Debug, Clone)]
40pub struct UnificationError<'src> {
41    pub field: &'src str,
42    pub definition: &'src str,
43    pub types_found: Vec<TypeDescription>,
44    pub spans: Vec<TextRange>,
45}
46
47/// Human-readable type description for error messages.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum TypeDescription {
50    Node,
51    String,
52    Struct(Vec<String>),
53}
54
55impl std::fmt::Display for TypeDescription {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            TypeDescription::Node => write!(f, "Node"),
59            TypeDescription::String => write!(f, "String"),
60            TypeDescription::Struct(fields) => {
61                write!(f, "Struct {{ {} }}", fields.join(", "))
62            }
63        }
64    }
65}
66
67/// An inferred type definition.
68#[derive(Debug, Clone)]
69pub struct InferredTypeDef<'src> {
70    pub kind: TypeKind,
71    pub name: Option<&'src str>,
72    pub members: Vec<InferredMember<'src>>,
73    pub inner_type: Option<TypeId>,
74}
75
76/// A field (for Record) or variant (for Enum).
77#[derive(Debug, Clone)]
78pub struct InferredMember<'src> {
79    pub name: &'src str,
80    pub ty: TypeId,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84enum Cardinality {
85    #[default]
86    One,
87    Optional,
88    Star,
89    Plus,
90}
91
92impl Cardinality {
93    /// Join cardinalities when merging alternation branches.
94    fn join(self, other: Cardinality) -> Cardinality {
95        use Cardinality::*;
96        match (self, other) {
97            (One, One) => One,
98            (One, Optional) | (Optional, One) | (Optional, Optional) => Optional,
99            (Plus, Plus) => Plus,
100            (One, Plus) | (Plus, One) => Plus,
101            _ => Star,
102        }
103    }
104
105    fn make_optional(self) -> Cardinality {
106        use Cardinality::*;
107        match self {
108            One => Optional,
109            Plus => Star,
110            x => x,
111        }
112    }
113
114    /// Multiply cardinalities (outer * inner).
115    fn multiply(self, inner: Cardinality) -> Cardinality {
116        use Cardinality::*;
117        match (self, inner) {
118            (One, x) => x,
119            (x, One) => x,
120            (Optional, Optional) => Optional,
121            (Plus, Plus) => Plus,
122            _ => Star,
123        }
124    }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
128enum TypeShape {
129    Primitive(TypeId),
130}
131
132impl TypeShape {
133    fn to_description(&self) -> TypeDescription {
134        match self {
135            TypeShape::Primitive(TYPE_NODE) => TypeDescription::Node,
136            TypeShape::Primitive(TYPE_STR) => TypeDescription::String,
137            TypeShape::Primitive(_) => TypeDescription::Node,
138        }
139    }
140}
141
142#[derive(Debug, Clone)]
143struct FieldInfo {
144    base_type: TypeId,
145    shape: TypeShape,
146    cardinality: Cardinality,
147    branch_count: usize,
148    spans: Vec<TextRange>,
149}
150
151#[derive(Debug, Clone, Default)]
152struct ScopeInfo<'src> {
153    fields: IndexMap<&'src str, FieldInfo>,
154    #[allow(dead_code)] // May be used for future enum variant tracking
155    variants: IndexMap<&'src str, ScopeInfo<'src>>,
156    #[allow(dead_code)]
157    has_variants: bool,
158}
159
160impl<'src> ScopeInfo<'src> {
161    fn add_field(
162        &mut self,
163        name: &'src str,
164        base_type: TypeId,
165        cardinality: Cardinality,
166        span: TextRange,
167    ) {
168        let shape = TypeShape::Primitive(base_type);
169        if let Some(existing) = self.fields.get_mut(name) {
170            existing.cardinality = existing.cardinality.join(cardinality);
171            existing.branch_count += 1;
172            existing.spans.push(span);
173        } else {
174            self.fields.insert(
175                name,
176                FieldInfo {
177                    base_type,
178                    shape,
179                    cardinality,
180                    branch_count: 1,
181                    spans: vec![span],
182                },
183            );
184        }
185    }
186
187    fn merge_from(&mut self, other: ScopeInfo<'src>) -> Vec<MergeError<'src>> {
188        let mut errors = Vec::new();
189
190        for (name, other_info) in other.fields {
191            if let Some(existing) = self.fields.get_mut(name) {
192                if existing.shape != other_info.shape {
193                    errors.push(MergeError {
194                        field: name,
195                        shapes: vec![existing.shape.clone(), other_info.shape.clone()],
196                        spans: existing
197                            .spans
198                            .iter()
199                            .chain(&other_info.spans)
200                            .cloned()
201                            .collect(),
202                    });
203                }
204                existing.cardinality = existing.cardinality.join(other_info.cardinality);
205                existing.branch_count += other_info.branch_count;
206                existing.spans.extend(other_info.spans);
207            } else {
208                self.fields.insert(name, other_info);
209            }
210        }
211
212        errors
213    }
214
215    fn apply_optionality(&mut self, total_branches: usize) {
216        for info in self.fields.values_mut() {
217            if info.branch_count < total_branches {
218                info.cardinality = info.cardinality.make_optional();
219            }
220        }
221    }
222
223    #[allow(dead_code)] // May be useful for future scope analysis
224    fn is_empty(&self) -> bool {
225        self.fields.is_empty() && self.variants.is_empty()
226    }
227}
228
229#[derive(Debug)]
230struct MergeError<'src> {
231    field: &'src str,
232    shapes: Vec<TypeShape>,
233    spans: Vec<TextRange>,
234}
235
236/// What an expression produces when evaluated.
237#[derive(Debug, Clone)]
238struct ExprResult {
239    /// Base type (before cardinality wrapping).
240    base_type: TypeId,
241    /// Cardinality modifier.
242    cardinality: Cardinality,
243    /// True if this result represents a meaningful type (not just default Node).
244    /// Used to distinguish QIS array results from simple uncaptured expressions.
245    is_meaningful: bool,
246}
247
248impl ExprResult {
249    fn node() -> Self {
250        Self {
251            base_type: TYPE_NODE,
252            cardinality: Cardinality::One,
253            is_meaningful: false,
254        }
255    }
256
257    fn void() -> Self {
258        Self {
259            base_type: TYPE_VOID,
260            cardinality: Cardinality::One,
261            is_meaningful: false,
262        }
263    }
264
265    fn meaningful(type_id: TypeId) -> Self {
266        Self {
267            base_type: type_id,
268            cardinality: Cardinality::One,
269            is_meaningful: true,
270        }
271    }
272
273    /// Type is known but doesn't contribute to definition result (e.g., opaque references).
274    fn opaque(type_id: TypeId) -> Self {
275        Self {
276            base_type: type_id,
277            cardinality: Cardinality::One,
278            is_meaningful: false,
279        }
280    }
281
282    fn with_cardinality(mut self, card: Cardinality) -> Self {
283        self.cardinality = card;
284        self
285    }
286}
287
288struct InferenceContext<'src> {
289    source: &'src str,
290    qis_triggers: HashSet<ast::QuantifiedExpr>,
291    type_defs: Vec<InferredTypeDef<'src>>,
292    next_type_id: TypeId,
293    diagnostics: Diagnostics,
294    errors: Vec<UnificationError<'src>>,
295    current_def_name: &'src str,
296    /// Map from definition name to its computed type.
297    definition_types: HashMap<&'src str, TypeId>,
298}
299
300impl<'src> InferenceContext<'src> {
301    fn new(source: &'src str, qis_triggers: HashSet<ast::QuantifiedExpr>) -> Self {
302        Self {
303            source,
304            qis_triggers,
305            type_defs: Vec::new(),
306            next_type_id: 3, // 0=void, 1=node, 2=str
307            diagnostics: Diagnostics::default(),
308            errors: Vec::new(),
309            current_def_name: "",
310            definition_types: HashMap::new(),
311        }
312    }
313
314    fn alloc_type_id(&mut self) -> TypeId {
315        let id = self.next_type_id;
316        self.next_type_id += 1;
317        id
318    }
319
320    fn infer_definition(&mut self, def_name: &'src str, body: &Expr) -> TypeId {
321        self.current_def_name = def_name;
322
323        let mut scope = ScopeInfo::default();
324        let mut merge_errors = Vec::new();
325
326        // Special case: tagged alternation at definition root creates enum
327        if let Expr::AltExpr(alt) = body
328            && alt.kind() == AltKind::Tagged
329        {
330            return self.infer_tagged_alternation_as_enum(def_name, alt, &mut merge_errors);
331        }
332
333        // General case: infer expression and collect captures into scope
334        let result = self.infer_expr(body, &mut scope, Cardinality::One, &mut merge_errors);
335
336        self.report_merge_errors(&merge_errors);
337
338        // Build result type from scope (Payload Rule from ADR-0009)
339        match scope.fields.len() {
340            0 => {
341                if result.is_meaningful {
342                    // QIS or other expressions that produce a meaningful type without populating scope
343                    result.base_type
344                } else {
345                    TYPE_VOID
346                }
347            }
348            1 => {
349                // Single capture at definition root: unwrap to capture's type
350                let (_, info) = scope.fields.iter().next().unwrap();
351                self.wrap_with_cardinality(info.base_type, info.cardinality)
352            }
353            _ => {
354                // Multiple captures: create struct
355                self.create_struct_type(def_name, &scope)
356            }
357        }
358    }
359
360    fn infer_expr(
361        &mut self,
362        expr: &Expr,
363        scope: &mut ScopeInfo<'src>,
364        outer_card: Cardinality,
365        errors: &mut Vec<MergeError<'src>>,
366    ) -> ExprResult {
367        match expr {
368            Expr::CapturedExpr(c) => self.infer_captured(c, scope, outer_card, errors),
369            Expr::QuantifiedExpr(q) => self.infer_quantified(q, scope, outer_card, errors),
370            Expr::SeqExpr(s) => self.infer_sequence(s, scope, outer_card, errors),
371            Expr::AltExpr(a) => self.infer_alternation(a, scope, outer_card, errors),
372            Expr::NamedNode(n) => self.infer_named_node(n, scope, outer_card, errors),
373            Expr::FieldExpr(f) => self.infer_field_expr(f, scope, outer_card, errors),
374            Expr::Ref(r) => self.infer_ref(r),
375            Expr::AnonymousNode(_) => ExprResult::node(),
376        }
377    }
378
379    fn infer_captured(
380        &mut self,
381        c: &ast::CapturedExpr,
382        scope: &mut ScopeInfo<'src>,
383        outer_card: Cardinality,
384        errors: &mut Vec<MergeError<'src>>,
385    ) -> ExprResult {
386        let capture_name = c.name().map(|t| token_src(&t, self.source)).unwrap_or("_");
387        let span = c.text_range();
388        let has_string_annotation = c
389            .type_annotation()
390            .and_then(|t| t.name())
391            .is_some_and(|n| n.text() == "string");
392
393        let Some(inner) = c.inner() else {
394            return ExprResult::node();
395        };
396
397        // Check if inner is a scope container (seq/alt)
398        let is_scope_container = matches!(inner, Expr::SeqExpr(_) | Expr::AltExpr(_));
399
400        if is_scope_container {
401            // Captured scope container: creates nested type
402            let nested_type = self.infer_captured_container(capture_name, &inner, errors);
403            let result = ExprResult::meaningful(nested_type);
404            let effective_card = outer_card.multiply(result.cardinality);
405            scope.add_field(capture_name, result.base_type, effective_card, span);
406            result
407        } else {
408            // Simple capture: just capture the result
409            let result = self.infer_expr(&inner, scope, outer_card, errors);
410            let base_type = if has_string_annotation {
411                TYPE_STR
412            } else {
413                result.base_type
414            };
415            let effective_card = outer_card.multiply(result.cardinality);
416            scope.add_field(capture_name, base_type, effective_card, span);
417            ExprResult::meaningful(base_type).with_cardinality(result.cardinality)
418        }
419    }
420
421    fn infer_captured_container(
422        &mut self,
423        _capture_name: &'src str,
424        inner: &Expr,
425        errors: &mut Vec<MergeError<'src>>,
426    ) -> TypeId {
427        match inner {
428            Expr::SeqExpr(s) => {
429                let mut nested_scope = ScopeInfo::default();
430                for child in s.children() {
431                    self.infer_expr(&child, &mut nested_scope, Cardinality::One, errors);
432                }
433                // Per ADR-0009 Payload Rule: 0 captures → Void
434                if nested_scope.is_empty() {
435                    return TYPE_VOID;
436                }
437                let type_name = self.generate_scope_name();
438                self.create_struct_type(type_name, &nested_scope)
439            }
440            Expr::AltExpr(a) => {
441                if a.kind() == AltKind::Tagged {
442                    // Captured tagged alternation → Enum
443                    let type_name = self.generate_scope_name();
444                    self.infer_tagged_alternation_as_enum(type_name, a, errors)
445                } else {
446                    // Captured untagged alternation → Struct with merged fields
447                    let mut nested_scope = ScopeInfo::default();
448                    self.infer_untagged_alternation(a, &mut nested_scope, Cardinality::One, errors);
449                    // Per ADR-0009 Payload Rule: 0 captures → Void
450                    if nested_scope.is_empty() {
451                        return TYPE_VOID;
452                    }
453                    let type_name = self.generate_scope_name();
454                    self.create_struct_type(type_name, &nested_scope)
455                }
456            }
457            _ => {
458                // Not a container - shouldn't reach here
459                TYPE_NODE
460            }
461        }
462    }
463
464    fn infer_quantified(
465        &mut self,
466        q: &ast::QuantifiedExpr,
467        scope: &mut ScopeInfo<'src>,
468        outer_card: Cardinality,
469        errors: &mut Vec<MergeError<'src>>,
470    ) -> ExprResult {
471        let Some(inner) = q.inner() else {
472            return ExprResult::node();
473        };
474
475        let quant_card = self.quantifier_cardinality(q);
476        let is_qis = self.qis_triggers.contains(q);
477
478        if is_qis {
479            // QIS: create implicit scope for multiple captures
480            let mut nested_scope = ScopeInfo::default();
481            self.infer_expr(&inner, &mut nested_scope, Cardinality::One, errors);
482
483            let element_type = if !nested_scope.fields.is_empty() {
484                let type_name = self.generate_scope_name();
485                self.create_struct_type(type_name, &nested_scope)
486            } else {
487                TYPE_NODE
488            };
489
490            // Wrap with array type - this is a meaningful result
491            let array_type = self.wrap_with_cardinality(element_type, quant_card);
492            ExprResult::meaningful(array_type)
493        } else {
494            // No QIS: captures propagate with multiplied cardinality
495            let combined_card = outer_card.multiply(quant_card);
496            let result = self.infer_expr(&inner, scope, combined_card, errors);
497            // Return result with quantifier's cardinality so captured quantifiers work correctly
498            ExprResult {
499                base_type: result.base_type,
500                cardinality: quant_card.multiply(result.cardinality),
501                is_meaningful: result.is_meaningful,
502            }
503        }
504    }
505
506    fn infer_sequence(
507        &mut self,
508        s: &ast::SeqExpr,
509        scope: &mut ScopeInfo<'src>,
510        outer_card: Cardinality,
511        errors: &mut Vec<MergeError<'src>>,
512    ) -> ExprResult {
513        // Uncaptured sequence: captures propagate to parent scope
514        let mut last_result = ExprResult::void();
515        for child in s.children() {
516            last_result = self.infer_expr(&child, scope, outer_card, errors);
517        }
518        last_result
519    }
520
521    fn infer_alternation(
522        &mut self,
523        a: &ast::AltExpr,
524        scope: &mut ScopeInfo<'src>,
525        outer_card: Cardinality,
526        errors: &mut Vec<MergeError<'src>>,
527    ) -> ExprResult {
528        // Uncaptured alternation (tagged or untagged): captures propagate with optionality
529        self.infer_untagged_alternation(a, scope, outer_card, errors)
530    }
531
532    fn infer_untagged_alternation(
533        &mut self,
534        a: &ast::AltExpr,
535        scope: &mut ScopeInfo<'src>,
536        outer_card: Cardinality,
537        errors: &mut Vec<MergeError<'src>>,
538    ) -> ExprResult {
539        let branches: Vec<_> = a.branches().collect();
540        let total_branches = branches.len();
541
542        if total_branches == 0 {
543            return ExprResult::void();
544        }
545
546        let mut merged_scope = ScopeInfo::default();
547
548        for branch in &branches {
549            let Some(body) = branch.body() else {
550                continue;
551            };
552            let mut branch_scope = ScopeInfo::default();
553            self.infer_expr(&body, &mut branch_scope, outer_card, errors);
554            errors.extend(merged_scope.merge_from(branch_scope));
555        }
556
557        // Apply optionality for fields not present in all branches
558        merged_scope.apply_optionality(total_branches);
559
560        // Merge into parent scope
561        errors.extend(scope.merge_from(merged_scope));
562
563        ExprResult::node()
564    }
565
566    fn infer_tagged_alternation_as_enum(
567        &mut self,
568        type_name: &'src str,
569        a: &ast::AltExpr,
570        errors: &mut Vec<MergeError<'src>>,
571    ) -> TypeId {
572        let mut variants = IndexMap::new();
573
574        for branch in a.branches() {
575            let tag = branch
576                .label()
577                .map(|t| token_src(&t, self.source))
578                .unwrap_or("_");
579            let Some(body) = branch.body() else {
580                variants.insert(tag, ScopeInfo::default());
581                continue;
582            };
583
584            let mut variant_scope = ScopeInfo::default();
585            self.infer_expr(&body, &mut variant_scope, Cardinality::One, errors);
586            variants.insert(tag, variant_scope);
587        }
588
589        self.create_enum_type_from_variants(type_name, &variants)
590    }
591
592    fn infer_named_node(
593        &mut self,
594        n: &ast::NamedNode,
595        scope: &mut ScopeInfo<'src>,
596        outer_card: Cardinality,
597        errors: &mut Vec<MergeError<'src>>,
598    ) -> ExprResult {
599        // Named nodes have children - recurse into them
600        for child in n.children() {
601            self.infer_expr(&child, scope, outer_card, errors);
602        }
603        ExprResult::node()
604    }
605
606    fn infer_field_expr(
607        &mut self,
608        f: &ast::FieldExpr,
609        scope: &mut ScopeInfo<'src>,
610        outer_card: Cardinality,
611        errors: &mut Vec<MergeError<'src>>,
612    ) -> ExprResult {
613        // Field constraint (name: expr) - just recurse
614        if let Some(value) = f.value() {
615            return self.infer_expr(&value, scope, outer_card, errors);
616        }
617        ExprResult::node()
618    }
619
620    fn infer_ref(&self, r: &ast::Ref) -> ExprResult {
621        // References are opaque - captures don't propagate from referenced definition.
622        // Return the type (for use when captured) but mark as not meaningful
623        // so uncaptured refs don't affect definition's result type.
624        let ref_name = r.name().map(|t| t.text().to_string());
625        if let Some(name) = ref_name
626            && let Some(&type_id) = self.definition_types.get(name.as_str())
627        {
628            return ExprResult::opaque(type_id);
629        }
630        ExprResult::node()
631    }
632
633    fn quantifier_cardinality(&self, q: &ast::QuantifiedExpr) -> Cardinality {
634        let Some(op) = q.operator() else {
635            return Cardinality::One;
636        };
637        use crate::parser::cst::SyntaxKind;
638        match op.kind() {
639            SyntaxKind::Star | SyntaxKind::StarQuestion => Cardinality::Star,
640            SyntaxKind::Plus | SyntaxKind::PlusQuestion => Cardinality::Plus,
641            SyntaxKind::Question | SyntaxKind::QuestionQuestion => Cardinality::Optional,
642            _ => Cardinality::One,
643        }
644    }
645
646    fn generate_scope_name(&self) -> &'src str {
647        let name = format!("{}Scope{}", self.current_def_name, self.next_type_id);
648        Box::leak(name.into_boxed_str())
649    }
650
651    fn create_struct_type(&mut self, name: &'src str, scope: &ScopeInfo<'src>) -> TypeId {
652        let members: Vec<_> = scope
653            .fields
654            .iter()
655            .map(|(field_name, info)| {
656                let member_type = self.wrap_with_cardinality(info.base_type, info.cardinality);
657                InferredMember {
658                    name: field_name,
659                    ty: member_type,
660                }
661            })
662            .collect();
663
664        let type_id = self.alloc_type_id();
665
666        self.type_defs.push(InferredTypeDef {
667            kind: TypeKind::Record,
668            name: Some(name),
669            members,
670            inner_type: None,
671        });
672
673        type_id
674    }
675
676    fn create_enum_type_from_variants(
677        &mut self,
678        name: &'src str,
679        variants: &IndexMap<&'src str, ScopeInfo<'src>>,
680    ) -> TypeId {
681        let mut members = Vec::new();
682
683        for (tag, variant_scope) in variants {
684            let variant_type = if variant_scope.fields.is_empty() {
685                TYPE_VOID
686            } else if variant_scope.fields.len() == 1 {
687                // Single-capture variant: flatten (ADR-0007)
688                let (_, info) = variant_scope.fields.iter().next().unwrap();
689                self.wrap_with_cardinality(info.base_type, info.cardinality)
690            } else {
691                let variant_name = self.generate_scope_name();
692                self.create_struct_type(variant_name, variant_scope)
693            };
694            members.push(InferredMember {
695                name: tag,
696                ty: variant_type,
697            });
698        }
699
700        let type_id = self.alloc_type_id();
701
702        self.type_defs.push(InferredTypeDef {
703            kind: TypeKind::Enum,
704            name: Some(name),
705            members,
706            inner_type: None,
707        });
708
709        type_id
710    }
711
712    fn wrap_with_cardinality(&mut self, base: TypeId, card: Cardinality) -> TypeId {
713        match card {
714            Cardinality::One => base,
715            Cardinality::Optional => {
716                let type_id = self.alloc_type_id();
717                self.type_defs.push(InferredTypeDef {
718                    kind: TypeKind::Optional,
719                    name: None,
720                    members: Vec::new(),
721                    inner_type: Some(base),
722                });
723                type_id
724            }
725            Cardinality::Star => {
726                let type_id = self.alloc_type_id();
727                self.type_defs.push(InferredTypeDef {
728                    kind: TypeKind::ArrayStar,
729                    name: None,
730                    members: Vec::new(),
731                    inner_type: Some(base),
732                });
733                type_id
734            }
735            Cardinality::Plus => {
736                let type_id = self.alloc_type_id();
737                self.type_defs.push(InferredTypeDef {
738                    kind: TypeKind::ArrayPlus,
739                    name: None,
740                    members: Vec::new(),
741                    inner_type: Some(base),
742                });
743                type_id
744            }
745        }
746    }
747
748    fn report_merge_errors(&mut self, merge_errors: &[MergeError<'src>]) {
749        for err in merge_errors {
750            let types_str = err
751                .shapes
752                .iter()
753                .map(|s| s.to_description().to_string())
754                .collect::<Vec<_>>()
755                .join(" vs ");
756
757            let primary_span = err.spans.first().copied().unwrap_or_default();
758            let mut builder = self
759                .diagnostics
760                .report(DiagnosticKind::IncompatibleTypes, primary_span)
761                .message(types_str);
762
763            for span in err.spans.iter().skip(1) {
764                builder = builder.related_to("also captured here", *span);
765            }
766            builder
767                .hint(format!(
768                    "capture `{}` has incompatible types across branches",
769                    err.field
770                ))
771                .emit();
772
773            self.errors.push(UnificationError {
774                field: err.field,
775                definition: self.current_def_name,
776                types_found: err.shapes.iter().map(|s| s.to_description()).collect(),
777                spans: err.spans.clone(),
778            });
779        }
780    }
781}
782
783impl<'a> Query<'a> {
784    /// Run type inference on the query AST.
785    pub(super) fn infer_types(&mut self) {
786        // Collect QIS triggers upfront to avoid borrowing issues
787        let qis_triggers: HashSet<_> = self.qis_triggers.keys().cloned().collect();
788        let sorted = self.topological_sort_definitions_ast();
789
790        let mut ctx = InferenceContext::new(self.source, qis_triggers);
791
792        // Process definitions in dependency order
793        for (name, body) in &sorted {
794            let type_id = ctx.infer_definition(name, body);
795            ctx.definition_types.insert(name, type_id);
796        }
797
798        // Preserve symbol table order for entrypoints
799        for (name, _) in &sorted {
800            if let Some(&type_id) = ctx.definition_types.get(name) {
801                self.type_info.entrypoint_types.insert(*name, type_id);
802            }
803        }
804        self.type_info.type_defs = ctx.type_defs;
805        self.type_info.diagnostics = ctx.diagnostics;
806        self.type_info.errors = ctx.errors;
807    }
808
809    /// Topologically sort definitions for processing order.
810    fn topological_sort_definitions_ast(&self) -> Vec<(&'a str, ast::Expr)> {
811        use std::collections::{HashSet, VecDeque};
812
813        let definitions: Vec<_> = self
814            .symbol_table
815            .iter()
816            .map(|(&name, body)| (name, body.clone()))
817            .collect();
818        let def_names: HashSet<&str> = definitions.iter().map(|(name, _)| *name).collect();
819
820        // Build dependency graph from AST references
821        let mut deps: HashMap<&str, Vec<&str>> = HashMap::new();
822        for (name, body) in &definitions {
823            let refs = Self::collect_ast_references(body, &def_names);
824            deps.insert(name, refs);
825        }
826
827        // Kahn's algorithm
828        let mut in_degree: HashMap<&str, usize> = HashMap::new();
829        for (name, _) in &definitions {
830            in_degree.insert(name, 0);
831        }
832        for refs in deps.values() {
833            for &dep in refs {
834                *in_degree.entry(dep).or_insert(0) += 1;
835            }
836        }
837
838        let mut zero_degree: Vec<&str> = in_degree
839            .iter()
840            .filter(|(_, deg)| **deg == 0)
841            .map(|(&name, _)| name)
842            .collect();
843        zero_degree.sort();
844        let mut queue: VecDeque<&str> = zero_degree.into_iter().collect();
845
846        let mut sorted_names = Vec::new();
847        while let Some(name) = queue.pop_front() {
848            sorted_names.push(name);
849            if let Some(refs) = deps.get(name) {
850                for &dep in refs {
851                    if let Some(deg) = in_degree.get_mut(dep) {
852                        *deg = deg.saturating_sub(1);
853                        if *deg == 0 {
854                            queue.push_back(dep);
855                        }
856                    }
857                }
858            }
859        }
860
861        // Reverse so dependencies come first
862        sorted_names.reverse();
863
864        // Add any remaining (cyclic) definitions
865        for (name, _) in &definitions {
866            if !sorted_names.contains(name) {
867                sorted_names.push(name);
868            }
869        }
870
871        // Build result with bodies
872        sorted_names
873            .into_iter()
874            .filter_map(|name| self.symbol_table.get(name).map(|body| (name, body.clone())))
875            .collect()
876    }
877
878    /// Collect references from an AST expression.
879    fn collect_ast_references<'b>(expr: &Expr, def_names: &HashSet<&'b str>) -> Vec<&'b str> {
880        let mut refs = Vec::new();
881        Self::collect_ast_references_impl(expr, def_names, &mut refs);
882        refs
883    }
884
885    fn collect_ast_references_impl<'b>(
886        expr: &Expr,
887        def_names: &HashSet<&'b str>,
888        refs: &mut Vec<&'b str>,
889    ) {
890        match expr {
891            Expr::Ref(r) => {
892                if let Some(name_token) = r.name() {
893                    let name = name_token.text();
894                    if def_names.contains(name) && !refs.contains(&name) {
895                        // Find the actual &'b str from the set
896                        if let Some(&found) = def_names.iter().find(|&&n| n == name) {
897                            refs.push(found);
898                        }
899                    }
900                }
901            }
902            _ => {
903                for child in expr.children() {
904                    Self::collect_ast_references_impl(&child, def_names, refs);
905                }
906            }
907        }
908    }
909}