claw_resolver/
expression.rs

1use ast::ExpressionId;
2use claw_ast as ast;
3
4use crate::types::{ResolvedType, RESOLVED_BOOL};
5use crate::{FunctionResolver, ItemId, ResolverError};
6
7pub(crate) trait ResolveExpression {
8    /// Setup must
9    /// * Call `define_name` when introducing new names
10    /// * Call `use_name` when using a name
11    /// * Call `setup_child` on each expression that is a child of this one.
12    ///
13    /// Setup may
14    /// * Call `set_implied_type` if the type of an expression is known.
15    fn setup_resolve(
16        &self,
17        expression: ExpressionId,
18        resolver: &mut FunctionResolver,
19    ) -> Result<(), ResolverError> {
20        _ = (expression, resolver);
21        Ok(())
22    }
23
24    /// In a successful type resolution, this function will be called
25    /// exactly once when the type of this expression is known.
26    fn on_resolved(
27        &self,
28        rtype: ResolvedType,
29        expression: ExpressionId,
30        resolver: &mut FunctionResolver,
31    ) -> Result<(), ResolverError> {
32        _ = (rtype, expression, resolver);
33        Ok(())
34    }
35
36    /// In a successful type resolution, this function will be called
37    /// once for each child of this expression.
38    fn on_child_resolved(
39        &self,
40        rtype: ResolvedType,
41        expression: ExpressionId,
42        resolver: &mut FunctionResolver,
43    ) -> Result<(), ResolverError> {
44        _ = (rtype, expression, resolver);
45        Ok(())
46    }
47}
48
49macro_rules! gen_resolve_expression {
50    ([$( $expr_type:ident ),*]) => {
51        impl ResolveExpression for ast::Expression {
52            fn setup_resolve(
53                &self,
54                expression: ExpressionId,
55                resolver: &mut FunctionResolver,
56            ) -> Result<(), ResolverError> {
57                match self {
58                    $(ast::Expression::$expr_type(inner) => {
59                        let inner: &dyn ResolveExpression = inner;
60                        inner.setup_resolve(expression, resolver)
61                    },)*
62                }
63            }
64
65            fn on_resolved(&self,
66                rtype: ResolvedType,
67                expression: ExpressionId,
68                resolver: &mut FunctionResolver,
69            ) -> Result<(), ResolverError> {
70                match self {
71                    $(ast::Expression::$expr_type(inner) => inner.on_resolved(rtype, expression, resolver),)*
72                }
73            }
74
75            fn on_child_resolved(&self,
76                rtype: ResolvedType,
77                expression: ExpressionId,
78                resolver: &mut FunctionResolver,
79            ) -> Result<(), ResolverError> {
80                match self {
81                    $(ast::Expression::$expr_type(inner) => inner.on_child_resolved(rtype, expression, resolver),)*
82                }
83            }
84        }
85    }
86}
87
88gen_resolve_expression!([Identifier, Literal, Enum, Call, Unary, Binary]);
89
90impl ResolveExpression for ast::Identifier {
91    fn setup_resolve(
92        &self,
93        expression: ExpressionId,
94        resolver: &mut FunctionResolver,
95    ) -> Result<(), ResolverError> {
96        let item = resolver.use_name(self.ident)?;
97        match item {
98            ItemId::Global(global) => {
99                let global = resolver.component.globals.get(global).unwrap();
100                resolver.set_expr_type(expression, ResolvedType::Defined(global.type_id));
101            }
102            ItemId::Param(param) => {
103                let param_type = *resolver.params.get(param).unwrap();
104                resolver.set_expr_type(expression, ResolvedType::Defined(param_type));
105            }
106            ItemId::Local(local) => resolver.use_local(local, expression),
107            _ => {}
108        }
109        Ok(())
110    }
111
112    fn on_resolved(
113        &self,
114        rtype: ResolvedType,
115        _expression: ExpressionId,
116        resolver: &mut FunctionResolver,
117    ) -> Result<(), ResolverError> {
118        let item = resolver.lookup_name(self.ident)?;
119        match item {
120            ItemId::Local(local) => resolver.set_local_type(local, rtype),
121            _ => {}
122        }
123        Ok(())
124    }
125}
126
127impl ResolveExpression for ast::Literal {
128    fn setup_resolve(
129        &self,
130        expression: ExpressionId,
131        resolver: &mut FunctionResolver,
132    ) -> Result<(), ResolverError> {
133        match self {
134            ast::Literal::String(_) => {
135                resolver.set_expr_type(
136                    expression,
137                    ResolvedType::Primitive(ast::PrimitiveType::String),
138                );
139            }
140            _ => {}
141        }
142        Ok(())
143    }
144}
145
146impl ResolveExpression for ast::EnumLiteral {
147    fn setup_resolve(
148        &self,
149        expression: ExpressionId,
150        resolver: &mut FunctionResolver,
151    ) -> Result<(), ResolverError> {
152        let item = resolver.use_name(self.enum_name)?;
153        match item {
154            ItemId::Type(rtype) => {
155                resolver.set_expr_type(expression, rtype);
156            }
157            _ => panic!("Can only use literals for enums"),
158        };
159        Ok(())
160    }
161}
162
163impl ResolveExpression for ast::Call {
164    fn setup_resolve(
165        &self,
166        expression: ExpressionId,
167        resolver: &mut FunctionResolver,
168    ) -> Result<(), ResolverError> {
169        let item = resolver.use_name(self.ident)?;
170        let (params, results): (Vec<_>, _) = match item {
171            ItemId::ImportFunc(import_func) => {
172                let import_func = &resolver.imports.funcs[import_func];
173                let params = import_func.params.iter().map(|(_name, rtype)| *rtype);
174                let results = import_func.results.unwrap();
175                (params.collect(), results)
176            }
177            ItemId::Function(func) => {
178                let func = &resolver.component.functions[func];
179                let params = func
180                    .params
181                    .iter()
182                    .map(|(_name, type_id)| ResolvedType::Defined(*type_id));
183                let results = ResolvedType::Defined(*func.results.as_ref().unwrap());
184                (params.collect(), results)
185            }
186            _ => panic!("Can only call functions"),
187        };
188        assert_eq!(params.len(), self.args.len());
189        for (arg, rtype) in self.args.iter().copied().zip(params.into_iter()) {
190            resolver.setup_child_expression(expression, arg)?;
191            resolver.set_expr_type(arg, rtype);
192        }
193
194        resolver.set_expr_type(expression, results);
195
196        Ok(())
197    }
198}
199
200impl ResolveExpression for ast::UnaryExpression {
201    fn setup_resolve(
202        &self,
203        expression: ExpressionId,
204        resolver: &mut FunctionResolver,
205    ) -> Result<(), ResolverError> {
206        resolver.setup_child_expression(expression, self.inner)
207    }
208
209    fn on_resolved(
210        &self,
211        rtype: ResolvedType,
212        _expression: ExpressionId,
213        resolver: &mut FunctionResolver,
214    ) -> Result<(), ResolverError> {
215        resolver.set_expr_type(self.inner, rtype);
216        Ok(())
217    }
218
219    fn on_child_resolved(
220        &self,
221        rtype: ResolvedType,
222        expression: ExpressionId,
223        resolver: &mut FunctionResolver,
224    ) -> Result<(), ResolverError> {
225        resolver.set_expr_type(expression, rtype);
226        Ok(())
227    }
228}
229
230// Binary Operators
231
232impl ResolveExpression for ast::BinaryExpression {
233    fn setup_resolve(
234        &self,
235        expression: ExpressionId,
236        resolver: &mut FunctionResolver,
237    ) -> Result<(), ResolverError> {
238        if self.is_relation() {
239            resolver.set_expr_type(expression, RESOLVED_BOOL);
240        }
241        resolver.setup_child_expression(expression, self.left)?;
242        resolver.setup_child_expression(expression, self.right)?;
243        Ok(())
244    }
245
246    fn on_resolved(
247        &self,
248        rtype: ResolvedType,
249        _expression: ExpressionId,
250        resolver: &mut FunctionResolver,
251    ) -> Result<(), ResolverError> {
252        if !self.is_relation() {
253            resolver.set_expr_type(self.left, rtype);
254            resolver.set_expr_type(self.right, rtype);
255        }
256        Ok(())
257    }
258
259    fn on_child_resolved(
260        &self,
261        rtype: ResolvedType,
262        expression: ExpressionId,
263        resolver: &mut FunctionResolver,
264    ) -> Result<(), ResolverError> {
265        if !self.is_relation() {
266            resolver.set_expr_type(expression, rtype);
267        }
268
269        let left = resolver.expression_types.get(&self.left).copied();
270        let right = resolver.expression_types.get(&self.right).copied();
271
272        match (left, right) {
273            (Some(_left), Some(_right)) => {
274                // Both types known, do nothing
275            }
276            (Some(left), None) => {
277                resolver.set_expr_type(self.right, left);
278            }
279            (None, Some(right)) => {
280                resolver.set_expr_type(self.left, right);
281            }
282            (None, None) => {
283                // Neither types known... how did we get here?
284                unreachable!("If a child has been resolved, at least one child shouldn't be None")
285            }
286        }
287
288        Ok(())
289    }
290}