graphcal_compiler/tir/dim_check/
helpers.rs1use std::sync::Arc;
2
3use miette::NamedSource;
4
5use crate::registry::error::GraphcalError;
6use crate::registry::types::{Registry, TypeDef};
7use crate::syntax::dimension::Dimension;
8
9use super::{DeclaredType, InferredIndex, InferredStructType, InferredType};
10use crate::tir::typed::{ResolvedIndex, ResolvedTypeExpr};
11
12pub(super) fn is_bool_type(ty: &InferredType) -> bool {
13 match ty {
14 InferredType::Bool => true,
15 InferredType::Indexed { element, .. } => is_bool_type(element),
16 _ => false,
17 }
18}
19
20pub(super) fn types_match(declared: &DeclaredType, inferred: &InferredType) -> bool {
27 match (declared, inferred) {
28 (DeclaredType::Scalar(d), InferredType::Scalar(i)) => d == i,
29 (DeclaredType::Bool, InferredType::Bool) => true,
30 (DeclaredType::Int, inferred) if inferred.is_int_like() => true,
31 (DeclaredType::Datetime(d), InferredType::Datetime(i)) => d == i,
32 (DeclaredType::IndexArg(d), InferredType::NamedIndex(i)) => i.matches_ref(d),
33 (DeclaredType::Struct(d, d_args), InferredType::Struct(i, i_args)) => {
34 i.matches_ref(d)
35 && d_args.len() == i_args.len()
36 && d_args
37 .iter()
38 .zip(i_args)
39 .all(|(da, ia)| types_match(da, ia))
40 }
41 (
42 DeclaredType::Indexed {
43 element: d_elem,
44 index: d_idx,
45 },
46 InferredType::Indexed {
47 element: i_elem,
48 index: i_idx,
49 },
50 ) => i_idx.matches_ref(d_idx) && types_match(d_elem, i_elem),
51 _ => false,
52 }
53}
54
55pub(super) fn resolved_type_matches_inferred(
58 resolved: &ResolvedTypeExpr,
59 inferred: &InferredType,
60) -> bool {
61 match (resolved, inferred) {
62 (ResolvedTypeExpr::Dimensionless, InferredType::Scalar(d)) => d.is_dimensionless(),
63 (ResolvedTypeExpr::Bool, InferredType::Bool) => true,
64 (ResolvedTypeExpr::Int, inferred) => inferred.is_int_like(),
65 (ResolvedTypeExpr::Datetime(expected), InferredType::Datetime(actual)) => {
66 expected == actual
67 }
68 (ResolvedTypeExpr::Scalar(expected), InferredType::Scalar(actual)) => expected == actual,
69 (ResolvedTypeExpr::IndexArg(expected), InferredType::NamedIndex(actual)) => {
70 resolved_index_matches_inferred(expected, actual)
71 }
72 (ResolvedTypeExpr::Struct(expected, _), InferredType::Struct(actual, args)) => {
73 actual.matches_resolved(expected) && args.is_empty()
74 }
75 (
76 ResolvedTypeExpr::GenericStruct {
77 name, type_args, ..
78 },
79 InferredType::Struct(actual, actual_args),
80 ) => {
81 actual.matches_resolved(name)
82 && type_args.len() == actual_args.len()
83 && type_args
84 .iter()
85 .zip(actual_args)
86 .all(|(expected, actual)| resolved_type_matches_inferred(expected, actual))
87 }
88 (ResolvedTypeExpr::Indexed { base, indexes }, _) => {
89 resolved_indexed_type_matches_inferred(base, indexes, inferred)
90 }
91 _ => false,
92 }
93}
94
95fn resolved_indexed_type_matches_inferred(
96 base: &ResolvedTypeExpr,
97 indexes: &[ResolvedIndex],
98 inferred: &InferredType,
99) -> bool {
100 let mut current = inferred;
101 for index in indexes {
102 let InferredType::Indexed {
103 element,
104 index: actual,
105 } = current
106 else {
107 return false;
108 };
109 if !resolved_index_matches_inferred(index, actual) {
110 return false;
111 }
112 current = element;
113 }
114 resolved_type_matches_inferred(base, current)
115}
116
117fn resolved_index_matches_inferred(index: &ResolvedIndex, actual: &InferredIndex) -> bool {
118 match index {
119 ResolvedIndex::Concrete(expected, _) => actual.matches_resolved(expected),
120 ResolvedIndex::NatExpr(form, _) => actual
121 .nat_range_form()
122 .is_some_and(|actual_form| actual_form == *form),
123 ResolvedIndex::GenericParam(_, _) => false,
131 }
132}
133
134pub(super) fn format_declared_type(dt: &DeclaredType, registry: &Registry) -> String {
136 dt.format(®istry.dimensions)
137}
138
139pub(super) fn struct_type_def_for_inferred<'a>(
144 ty: &InferredStructType,
145 dag: Option<&'a crate::tir::typed::DagTIR>,
146 registry: &'a Registry,
147) -> Option<&'a TypeDef> {
148 dag.map(|dag| &dag.semantic.type_defs)
149 .and_then(|defs| defs.struct_types.get(ty.resolved()))
150 .or_else(|| registry.types.get_type(ty.name().as_str()))
151}
152
153#[must_use]
155pub fn format_inferred_type(it: &InferredType, registry: &Registry) -> String {
156 if let InferredType::Fin(bound) = it {
157 return format!("Fin({})", bound.format());
158 }
159 DeclaredType::from(it).format(®istry.dimensions)
160}
161
162impl From<&InferredType> for DeclaredType {
163 fn from(it: &InferredType) -> Self {
164 match it {
165 InferredType::Scalar(d) => Self::Scalar(d.clone()),
166 InferredType::Bool => Self::Bool,
167 InferredType::Int | InferredType::Fin(_) => Self::Int,
168 InferredType::Datetime(scale) => Self::Datetime(*scale),
169 InferredType::NamedIndex(index) => Self::IndexArg(index.type_ref().clone()),
170 InferredType::Struct(n, args) => {
171 Self::Struct(n.type_ref().clone(), args.iter().map(Self::from).collect())
172 }
173 InferredType::Indexed { element, index } => Self::Indexed {
174 element: Box::new(Self::from(element.as_ref())),
175 index: index.type_ref().clone(),
176 },
177 }
178 }
179}
180
181impl From<&DeclaredType> for InferredType {
182 fn from(dt: &DeclaredType) -> Self {
183 match dt {
184 DeclaredType::Scalar(d) => Self::Scalar(d.clone()),
185 DeclaredType::Bool => Self::Bool,
186 DeclaredType::Int => Self::Int,
187 DeclaredType::Datetime(scale) => Self::Datetime(*scale),
188 DeclaredType::IndexArg(index) => {
189 Self::NamedIndex(InferredIndex::from_ref(index.clone()))
190 }
191 DeclaredType::Struct(n, args) => Self::Struct(
192 InferredStructType::from_ref(n.clone()),
193 args.iter().map(Self::from).collect(),
194 ),
195 DeclaredType::Indexed { element, index } => Self::Indexed {
196 element: Box::new(Self::from(element.as_ref())),
197 index: InferredIndex::from_ref(index.clone()),
198 },
199 }
200 }
201}
202
203pub fn expect_scalar(
204 inferred: &InferredType,
205 registry: &Registry,
206 src: &NamedSource<Arc<String>>,
207 span: crate::syntax::span::Span,
208) -> Result<Dimension, GraphcalError> {
209 let found_kind = match inferred {
210 InferredType::Scalar(d) => return Ok(d.clone()),
211 InferredType::Bool => "a Bool value",
212 InferredType::Int | InferredType::Fin(_) => "an Int value",
213 InferredType::Datetime(_) => "a Datetime value",
214 InferredType::NamedIndex(_) => "a named-index loop variable",
215 InferredType::Struct(..) => "a struct",
216 InferredType::Indexed { .. } => "an indexed value",
217 };
218 Err(GraphcalError::DimensionMismatch {
219 expected: "scalar dimension".to_string(),
220 found: format_inferred_type(inferred, registry),
221 help: format!("expected a scalar value, not {found_kind}"),
222 src: src.clone(),
223 span: span.into(),
224 })
225}
226
227pub(super) fn cartesian_product<T: Clone + Eq + std::hash::Hash>(
229 axes: &[Vec<T>],
230 current: &mut Vec<T>,
231 result: &mut std::collections::HashSet<Vec<T>>,
232) {
233 if current.len() == axes.len() {
234 result.insert(current.clone());
235 return;
236 }
237 let axis_idx = current.len();
238 for variant in &axes[axis_idx] {
239 current.push(variant.clone());
240 cartesian_product(axes, current, result);
241 current.pop();
242 }
243}