1use ast::ExpressionId;
2use claw_ast as ast;
3
4use crate::types::{ResolvedType, RESOLVED_BOOL};
5use crate::{FunctionResolver, ItemId, ResolverError};
6
7pub(crate) trait ResolveExpression {
8 fn setup_resolve(
16 &self,
17 expression: ExpressionId,
18 resolver: &mut FunctionResolver,
19 ) -> Result<(), ResolverError> {
20 _ = (expression, resolver);
21 Ok(())
22 }
23
24 fn on_resolved(
27 &self,
28 rtype: ResolvedType,
29 expression: ExpressionId,
30 resolver: &mut FunctionResolver,
31 ) -> Result<(), ResolverError> {
32 _ = (rtype, expression, resolver);
33 Ok(())
34 }
35
36 fn on_child_resolved(
39 &self,
40 rtype: ResolvedType,
41 expression: ExpressionId,
42 resolver: &mut FunctionResolver,
43 ) -> Result<(), ResolverError> {
44 _ = (rtype, expression, resolver);
45 Ok(())
46 }
47}
48
49macro_rules! gen_resolve_expression {
50 ([$( $expr_type:ident ),*]) => {
51 impl ResolveExpression for ast::Expression {
52 fn setup_resolve(
53 &self,
54 expression: ExpressionId,
55 resolver: &mut FunctionResolver,
56 ) -> Result<(), ResolverError> {
57 match self {
58 $(ast::Expression::$expr_type(inner) => {
59 let inner: &dyn ResolveExpression = inner;
60 inner.setup_resolve(expression, resolver)
61 },)*
62 }
63 }
64
65 fn on_resolved(&self,
66 rtype: ResolvedType,
67 expression: ExpressionId,
68 resolver: &mut FunctionResolver,
69 ) -> Result<(), ResolverError> {
70 match self {
71 $(ast::Expression::$expr_type(inner) => inner.on_resolved(rtype, expression, resolver),)*
72 }
73 }
74
75 fn on_child_resolved(&self,
76 rtype: ResolvedType,
77 expression: ExpressionId,
78 resolver: &mut FunctionResolver,
79 ) -> Result<(), ResolverError> {
80 match self {
81 $(ast::Expression::$expr_type(inner) => inner.on_child_resolved(rtype, expression, resolver),)*
82 }
83 }
84 }
85 }
86}
87
88gen_resolve_expression!([Identifier, Literal, Enum, Call, Unary, Binary]);
89
90impl ResolveExpression for ast::Identifier {
91 fn setup_resolve(
92 &self,
93 expression: ExpressionId,
94 resolver: &mut FunctionResolver,
95 ) -> Result<(), ResolverError> {
96 let item = resolver.use_name(self.ident)?;
97 match item {
98 ItemId::Global(global) => {
99 let global = resolver.component.globals.get(global).unwrap();
100 resolver.set_expr_type(expression, ResolvedType::Defined(global.type_id));
101 }
102 ItemId::Param(param) => {
103 let param_type = *resolver.params.get(param).unwrap();
104 resolver.set_expr_type(expression, ResolvedType::Defined(param_type));
105 }
106 ItemId::Local(local) => resolver.use_local(local, expression),
107 _ => {}
108 }
109 Ok(())
110 }
111
112 fn on_resolved(
113 &self,
114 rtype: ResolvedType,
115 _expression: ExpressionId,
116 resolver: &mut FunctionResolver,
117 ) -> Result<(), ResolverError> {
118 let item = resolver.lookup_name(self.ident)?;
119 match item {
120 ItemId::Local(local) => resolver.set_local_type(local, rtype),
121 _ => {}
122 }
123 Ok(())
124 }
125}
126
127impl ResolveExpression for ast::Literal {
128 fn setup_resolve(
129 &self,
130 expression: ExpressionId,
131 resolver: &mut FunctionResolver,
132 ) -> Result<(), ResolverError> {
133 match self {
134 ast::Literal::String(_) => {
135 resolver.set_expr_type(
136 expression,
137 ResolvedType::Primitive(ast::PrimitiveType::String),
138 );
139 }
140 _ => {}
141 }
142 Ok(())
143 }
144}
145
146impl ResolveExpression for ast::EnumLiteral {
147 fn setup_resolve(
148 &self,
149 expression: ExpressionId,
150 resolver: &mut FunctionResolver,
151 ) -> Result<(), ResolverError> {
152 let item = resolver.use_name(self.enum_name)?;
153 match item {
154 ItemId::Type(rtype) => {
155 resolver.set_expr_type(expression, rtype);
156 }
157 _ => panic!("Can only use literals for enums"),
158 };
159 Ok(())
160 }
161}
162
163impl ResolveExpression for ast::Call {
164 fn setup_resolve(
165 &self,
166 expression: ExpressionId,
167 resolver: &mut FunctionResolver,
168 ) -> Result<(), ResolverError> {
169 let item = resolver.use_name(self.ident)?;
170 let (params, results): (Vec<_>, _) = match item {
171 ItemId::ImportFunc(import_func) => {
172 let import_func = &resolver.imports.funcs[import_func];
173 let params = import_func.params.iter().map(|(_name, rtype)| *rtype);
174 let results = import_func.results.unwrap();
175 (params.collect(), results)
176 }
177 ItemId::Function(func) => {
178 let func = &resolver.component.functions[func];
179 let params = func
180 .params
181 .iter()
182 .map(|(_name, type_id)| ResolvedType::Defined(*type_id));
183 let results = ResolvedType::Defined(*func.results.as_ref().unwrap());
184 (params.collect(), results)
185 }
186 _ => panic!("Can only call functions"),
187 };
188 assert_eq!(params.len(), self.args.len());
189 for (arg, rtype) in self.args.iter().copied().zip(params.into_iter()) {
190 resolver.setup_child_expression(expression, arg)?;
191 resolver.set_expr_type(arg, rtype);
192 }
193
194 resolver.set_expr_type(expression, results);
195
196 Ok(())
197 }
198}
199
200impl ResolveExpression for ast::UnaryExpression {
201 fn setup_resolve(
202 &self,
203 expression: ExpressionId,
204 resolver: &mut FunctionResolver,
205 ) -> Result<(), ResolverError> {
206 resolver.setup_child_expression(expression, self.inner)
207 }
208
209 fn on_resolved(
210 &self,
211 rtype: ResolvedType,
212 _expression: ExpressionId,
213 resolver: &mut FunctionResolver,
214 ) -> Result<(), ResolverError> {
215 resolver.set_expr_type(self.inner, rtype);
216 Ok(())
217 }
218
219 fn on_child_resolved(
220 &self,
221 rtype: ResolvedType,
222 expression: ExpressionId,
223 resolver: &mut FunctionResolver,
224 ) -> Result<(), ResolverError> {
225 resolver.set_expr_type(expression, rtype);
226 Ok(())
227 }
228}
229
230impl ResolveExpression for ast::BinaryExpression {
233 fn setup_resolve(
234 &self,
235 expression: ExpressionId,
236 resolver: &mut FunctionResolver,
237 ) -> Result<(), ResolverError> {
238 if self.is_relation() {
239 resolver.set_expr_type(expression, RESOLVED_BOOL);
240 }
241 resolver.setup_child_expression(expression, self.left)?;
242 resolver.setup_child_expression(expression, self.right)?;
243 Ok(())
244 }
245
246 fn on_resolved(
247 &self,
248 rtype: ResolvedType,
249 _expression: ExpressionId,
250 resolver: &mut FunctionResolver,
251 ) -> Result<(), ResolverError> {
252 if !self.is_relation() {
253 resolver.set_expr_type(self.left, rtype);
254 resolver.set_expr_type(self.right, rtype);
255 }
256 Ok(())
257 }
258
259 fn on_child_resolved(
260 &self,
261 rtype: ResolvedType,
262 expression: ExpressionId,
263 resolver: &mut FunctionResolver,
264 ) -> Result<(), ResolverError> {
265 if !self.is_relation() {
266 resolver.set_expr_type(expression, rtype);
267 }
268
269 let left = resolver.expression_types.get(&self.left).copied();
270 let right = resolver.expression_types.get(&self.right).copied();
271
272 match (left, right) {
273 (Some(_left), Some(_right)) => {
274 }
276 (Some(left), None) => {
277 resolver.set_expr_type(self.right, left);
278 }
279 (None, Some(right)) => {
280 resolver.set_expr_type(self.left, right);
281 }
282 (None, None) => {
283 unreachable!("If a child has been resolved, at least one child shouldn't be None")
285 }
286 }
287
288 Ok(())
289 }
290}