Skip to main content

graphcal_compiler/tir/dim_check/
helpers.rs

1use std::sync::Arc;
2
3use miette::NamedSource;
4
5use crate::registry::error::GraphcalError;
6use crate::registry::types::{Registry, TypeDef};
7use crate::syntax::dimension::Dimension;
8
9use super::{DeclaredType, InferredIndex, InferredStructType, InferredType};
10use crate::tir::typed::{ResolvedIndex, ResolvedTypeExpr};
11
12pub(super) fn is_bool_type(ty: &InferredType) -> bool {
13    match ty {
14        InferredType::Bool => true,
15        InferredType::Indexed { element, .. } => is_bool_type(element),
16        _ => false,
17    }
18}
19
20/// Check if a declared type matches an inferred type.
21///
22/// Under the n-variant-union model, the inferred type of a constructor
23/// expression is *already* the owning union — there is no per-variant
24/// type and therefore no widening/subtyping at the type level. Struct
25/// equality is by name and type-argument list only.
26pub(super) fn types_match(declared: &DeclaredType, inferred: &InferredType) -> bool {
27    match (declared, inferred) {
28        (DeclaredType::Scalar(d), InferredType::Scalar(i)) => d == i,
29        (DeclaredType::Bool, InferredType::Bool) => true,
30        (DeclaredType::Int, inferred) if inferred.is_int_like() => true,
31        (DeclaredType::Datetime(d), InferredType::Datetime(i)) => d == i,
32        (DeclaredType::IndexArg(d), InferredType::NamedIndex(i)) => i.matches_ref(d),
33        (DeclaredType::Struct(d, d_args), InferredType::Struct(i, i_args)) => {
34            i.matches_ref(d)
35                && d_args.len() == i_args.len()
36                && d_args
37                    .iter()
38                    .zip(i_args)
39                    .all(|(da, ia)| types_match(da, ia))
40        }
41        (
42            DeclaredType::Indexed {
43                element: d_elem,
44                index: d_idx,
45            },
46            InferredType::Indexed {
47                element: i_elem,
48                index: i_idx,
49            },
50        ) => i_idx.matches_ref(d_idx) && types_match(d_elem, i_elem),
51        _ => false,
52    }
53}
54
55/// Check if a resolved declaration type matches an inferred expression type,
56/// preserving canonical index identity when both sides carry it.
57pub(super) fn resolved_type_matches_inferred(
58    resolved: &ResolvedTypeExpr,
59    inferred: &InferredType,
60) -> bool {
61    match (resolved, inferred) {
62        (ResolvedTypeExpr::Dimensionless, InferredType::Scalar(d)) => d.is_dimensionless(),
63        (ResolvedTypeExpr::Bool, InferredType::Bool) => true,
64        (ResolvedTypeExpr::Int, inferred) => inferred.is_int_like(),
65        (ResolvedTypeExpr::Datetime(expected), InferredType::Datetime(actual)) => {
66            expected == actual
67        }
68        (ResolvedTypeExpr::Scalar(expected), InferredType::Scalar(actual)) => expected == actual,
69        (ResolvedTypeExpr::IndexArg(expected), InferredType::NamedIndex(actual)) => {
70            resolved_index_matches_inferred(expected, actual)
71        }
72        (ResolvedTypeExpr::Struct(expected, _), InferredType::Struct(actual, args)) => {
73            actual.matches_resolved(expected) && args.is_empty()
74        }
75        (
76            ResolvedTypeExpr::GenericStruct {
77                name, type_args, ..
78            },
79            InferredType::Struct(actual, actual_args),
80        ) => {
81            actual.matches_resolved(name)
82                && type_args.len() == actual_args.len()
83                && type_args
84                    .iter()
85                    .zip(actual_args)
86                    .all(|(expected, actual)| resolved_type_matches_inferred(expected, actual))
87        }
88        (ResolvedTypeExpr::Indexed { base, indexes }, _) => {
89            resolved_indexed_type_matches_inferred(base, indexes, inferred)
90        }
91        _ => false,
92    }
93}
94
95fn resolved_indexed_type_matches_inferred(
96    base: &ResolvedTypeExpr,
97    indexes: &[ResolvedIndex],
98    inferred: &InferredType,
99) -> bool {
100    let mut current = inferred;
101    for index in indexes {
102        let InferredType::Indexed {
103            element,
104            index: actual,
105        } = current
106        else {
107            return false;
108        };
109        if !resolved_index_matches_inferred(index, actual) {
110            return false;
111        }
112        current = element;
113    }
114    resolved_type_matches_inferred(base, current)
115}
116
117fn resolved_index_matches_inferred(index: &ResolvedIndex, actual: &InferredIndex) -> bool {
118    match index {
119        ResolvedIndex::Concrete(expected, _) => actual.matches_resolved(expected),
120        ResolvedIndex::NatExpr(form, _) => actual
121            .nat_range_form()
122            .is_some_and(|actual_form| actual_form == *form),
123        // An unbound generic index parameter never reaches this comparison:
124        // DAG declaration types and inline-DAG param types resolve with no
125        // generic params in scope, and HIR inference only constructs
126        // `InferredIndex` from concrete (resolved or Nat-range) identities —
127        // the syntax engine's leaf-name fallback that could fabricate a
128        // generic-named index is gone (#765). No display-name comparison can
129        // therefore be meaningful here.
130        ResolvedIndex::GenericParam(_, _) => false,
131    }
132}
133
134/// Format a declared type for display in diagnostics.
135pub(super) fn format_declared_type(dt: &DeclaredType, registry: &Registry) -> String {
136    dt.format(&registry.dimensions)
137}
138
139/// Look up the definition for an inferred struct identity.
140///
141/// Prefer canonical semantic TIR type definitions, then consult the leaf-keyed
142/// registry for boundary-created synthetic owners.
143pub(super) fn struct_type_def_for_inferred<'a>(
144    ty: &InferredStructType,
145    dag: Option<&'a crate::tir::typed::DagTIR>,
146    registry: &'a Registry,
147) -> Option<&'a TypeDef> {
148    dag.map(|dag| &dag.semantic.type_defs)
149        .and_then(|defs| defs.struct_types.get(ty.resolved()))
150        .or_else(|| registry.types.get_type(ty.name().as_str()))
151}
152
153/// Format an inferred type for display in diagnostics.
154#[must_use]
155pub fn format_inferred_type(it: &InferredType, registry: &Registry) -> String {
156    if let InferredType::Fin(bound) = it {
157        return format!("Fin({})", bound.format());
158    }
159    DeclaredType::from(it).format(&registry.dimensions)
160}
161
162impl From<&InferredType> for DeclaredType {
163    fn from(it: &InferredType) -> Self {
164        match it {
165            InferredType::Scalar(d) => Self::Scalar(d.clone()),
166            InferredType::Bool => Self::Bool,
167            InferredType::Int | InferredType::Fin(_) => Self::Int,
168            InferredType::Datetime(scale) => Self::Datetime(*scale),
169            InferredType::NamedIndex(index) => Self::IndexArg(index.type_ref().clone()),
170            InferredType::Struct(n, args) => {
171                Self::Struct(n.type_ref().clone(), args.iter().map(Self::from).collect())
172            }
173            InferredType::Indexed { element, index } => Self::Indexed {
174                element: Box::new(Self::from(element.as_ref())),
175                index: index.type_ref().clone(),
176            },
177        }
178    }
179}
180
181impl From<&DeclaredType> for InferredType {
182    fn from(dt: &DeclaredType) -> Self {
183        match dt {
184            DeclaredType::Scalar(d) => Self::Scalar(d.clone()),
185            DeclaredType::Bool => Self::Bool,
186            DeclaredType::Int => Self::Int,
187            DeclaredType::Datetime(scale) => Self::Datetime(*scale),
188            DeclaredType::IndexArg(index) => {
189                Self::NamedIndex(InferredIndex::from_ref(index.clone()))
190            }
191            DeclaredType::Struct(n, args) => Self::Struct(
192                InferredStructType::from_ref(n.clone()),
193                args.iter().map(Self::from).collect(),
194            ),
195            DeclaredType::Indexed { element, index } => Self::Indexed {
196                element: Box::new(Self::from(element.as_ref())),
197                index: InferredIndex::from_ref(index.clone()),
198            },
199        }
200    }
201}
202
203pub fn expect_scalar(
204    inferred: &InferredType,
205    registry: &Registry,
206    src: &NamedSource<Arc<String>>,
207    span: crate::syntax::span::Span,
208) -> Result<Dimension, GraphcalError> {
209    let found_kind = match inferred {
210        InferredType::Scalar(d) => return Ok(d.clone()),
211        InferredType::Bool => "a Bool value",
212        InferredType::Int | InferredType::Fin(_) => "an Int value",
213        InferredType::Datetime(_) => "a Datetime value",
214        InferredType::NamedIndex(_) => "a named-index loop variable",
215        InferredType::Struct(..) => "a struct",
216        InferredType::Indexed { .. } => "an indexed value",
217    };
218    Err(GraphcalError::DimensionMismatch {
219        expected: "scalar dimension".to_string(),
220        found: format_inferred_type(inferred, registry),
221        help: format!("expected a scalar value, not {found_kind}"),
222        src: src.clone(),
223        span: span.into(),
224    })
225}
226
227/// Build the Cartesian product of variant-key slices across multiple axes.
228pub(super) fn cartesian_product<T: Clone + Eq + std::hash::Hash>(
229    axes: &[Vec<T>],
230    current: &mut Vec<T>,
231    result: &mut std::collections::HashSet<Vec<T>>,
232) {
233    if current.len() == axes.len() {
234        result.insert(current.clone());
235        return;
236    }
237    let axis_idx = current.len();
238    for variant in &axes[axis_idx] {
239        current.push(variant.clone());
240        cartesian_product(axes, current, result);
241        current.pop();
242    }
243}