claw_codegen/
expression.rs

1use ast::ExpressionId;
2use claw_ast as ast;
3use claw_resolver::{ItemId, ResolvedType};
4
5use crate::code::{CodeGenerator, ExpressionAllocator};
6use crate::types::{
7    Signedness, STRING_CONTENTS_ALIGNMENT, STRING_LENGTH_FIELD, STRING_OFFSET_FIELD,
8};
9use crate::GenerationError;
10
11use cranelift_entity::EntityRef;
12use wasm_encoder as enc;
13use wasm_encoder::Instruction;
14
15pub trait EncodeExpression {
16    fn alloc_expr_locals(
17        &self,
18        expression: ExpressionId,
19        allocator: &mut ExpressionAllocator,
20    ) -> Result<(), GenerationError>;
21
22    fn encode(
23        &self,
24        expression: ExpressionId,
25        code_gen: &mut CodeGenerator,
26    ) -> Result<(), GenerationError>;
27}
28
29impl EncodeExpression for ast::Expression {
30    fn alloc_expr_locals(
31        &self,
32        expression: ExpressionId,
33        allocator: &mut ExpressionAllocator,
34    ) -> Result<(), GenerationError> {
35        let expr: &dyn EncodeExpression = match self {
36            ast::Expression::Identifier(expr) => expr,
37            ast::Expression::Enum(expr) => expr,
38            ast::Expression::Literal(expr) => expr,
39            ast::Expression::Call(expr) => expr,
40            ast::Expression::Unary(expr) => expr,
41            ast::Expression::Binary(expr) => expr,
42        };
43        expr.alloc_expr_locals(expression, allocator)
44    }
45
46    fn encode(
47        &self,
48        expression: ExpressionId,
49        code_gen: &mut CodeGenerator,
50    ) -> Result<(), GenerationError> {
51        let expr: &dyn EncodeExpression = match self {
52            ast::Expression::Identifier(expr) => expr,
53            ast::Expression::Enum(expr) => expr,
54            ast::Expression::Literal(expr) => expr,
55            ast::Expression::Call(expr) => expr,
56            ast::Expression::Unary(expr) => expr,
57            ast::Expression::Binary(expr) => expr,
58        };
59        expr.encode(expression, code_gen)?;
60        Ok(())
61    }
62}
63
64impl EncodeExpression for ast::Identifier {
65    fn alloc_expr_locals(
66        &self,
67        expression: ExpressionId,
68        allocator: &mut ExpressionAllocator,
69    ) -> Result<(), GenerationError> {
70        allocator.alloc(expression)
71    }
72
73    fn encode(
74        &self,
75        expression: ExpressionId,
76        code_gen: &mut CodeGenerator,
77    ) -> Result<(), GenerationError> {
78        let fields = code_gen.fields(expression)?;
79        match code_gen.lookup_name(self.ident) {
80            ItemId::ImportFunc(_) => panic!("Cannot use imported function as value!!"),
81            ItemId::Type(_) => panic!("Cannot use type as value!!"),
82            ItemId::Global(global) => {
83                // TODO handle composite globals
84                let field = code_gen.one_field(expression)?;
85                code_gen.instruction(&Instruction::GlobalGet(global.index() as u32));
86                code_gen.write_expr_field(expression, &field);
87            }
88            ItemId::Param(param) => {
89                for field in fields.iter() {
90                    code_gen.read_param_field(param, field);
91                    code_gen.write_expr_field(expression, field);
92                }
93            }
94            ItemId::Local(local) => {
95                for field in fields.iter() {
96                    code_gen.read_local_field(local, field);
97                    code_gen.write_expr_field(expression, field);
98                }
99            }
100            ItemId::Function(_) => panic!("Cannot use function as value!!"),
101        }
102        Ok(())
103    }
104}
105
106impl EncodeExpression for ast::EnumLiteral {
107    fn alloc_expr_locals(
108        &self,
109        expression: ExpressionId,
110        allocator: &mut ExpressionAllocator,
111    ) -> Result<(), GenerationError> {
112        allocator.alloc(expression)
113    }
114
115    fn encode(
116        &self,
117        expression: ExpressionId,
118        code_gen: &mut CodeGenerator,
119    ) -> Result<(), GenerationError> {
120        match code_gen.lookup_name(self.enum_name) {
121            ItemId::Type(ResolvedType::Import(import_type)) => {
122                let import_type = code_gen.lookup_import_type(import_type);
123                match import_type {
124                    claw_resolver::ImportType::Enum(enum_type) => {
125                        let case_name = code_gen.lookup_name_str(self.case_name);
126                        // TODO nice error instead of unwrap
127                        let case_index =
128                            enum_type.cases.iter().position(|c| c == case_name).unwrap();
129                        code_gen.const_i32(case_index as i32);
130                        let field = code_gen.one_field(expression)?;
131                        code_gen.write_expr_field(expression, &field);
132                    }
133                }
134            }
135            _ => unreachable!(),
136        }
137        Ok(())
138    }
139}
140
141impl EncodeExpression for ast::Literal {
142    fn alloc_expr_locals(
143        &self,
144        expression: ExpressionId,
145        allocator: &mut ExpressionAllocator,
146    ) -> Result<(), GenerationError> {
147        allocator.alloc(expression)
148    }
149
150    fn encode(
151        &self,
152        expression: ExpressionId,
153        code_gen: &mut CodeGenerator,
154    ) -> Result<(), GenerationError> {
155        match self {
156            ast::Literal::String(string) => {
157                // Allocate string pointer
158                code_gen.const_i32(0);
159                code_gen.const_i32(0);
160                code_gen.const_i32(2i32.pow(STRING_CONTENTS_ALIGNMENT));
161                code_gen.const_i32(string.len() as i32);
162                code_gen.allocate();
163                code_gen.write_expr_field(expression, &STRING_OFFSET_FIELD);
164                // Store the string length
165                code_gen.const_i32(string.len() as i32);
166                code_gen.write_expr_field(expression, &STRING_LENGTH_FIELD);
167                // Copy in the data segment
168                let index = code_gen.encode_const_bytes(string.as_bytes());
169                code_gen.read_expr_field(expression, &STRING_OFFSET_FIELD);
170                code_gen.const_i32(0);
171                code_gen.read_expr_field(expression, &STRING_LENGTH_FIELD);
172                code_gen.instruction(&enc::Instruction::MemoryInit {
173                    mem: 0,
174                    data_index: index.into(),
175                })
176            }
177            ast::Literal::Integer(int) => {
178                let field = code_gen.one_field(expression)?;
179                code_gen.encode_const_int(*int, &field);
180                code_gen.write_expr_field(expression, &field);
181            }
182            ast::Literal::Float(float) => {
183                let field = code_gen.one_field(expression)?;
184                code_gen.encode_const_float(*float, &field);
185                code_gen.write_expr_field(expression, &field);
186            }
187        }
188        Ok(())
189    }
190}
191
192impl EncodeExpression for ast::Call {
193    fn alloc_expr_locals(
194        &self,
195        expression: ExpressionId,
196        allocator: &mut ExpressionAllocator,
197    ) -> Result<(), GenerationError> {
198        allocator.alloc(expression)?;
199        for arg in self.args.iter() {
200            allocator.alloc_child(*arg)?;
201        }
202        Ok(())
203    }
204
205    fn encode(
206        &self,
207        expression: ExpressionId,
208        code_gen: &mut CodeGenerator,
209    ) -> Result<(), GenerationError> {
210        for arg in self.args.iter() {
211            code_gen.encode_child(*arg)?;
212        }
213        let item = code_gen.lookup_name(self.ident);
214        code_gen.encode_call(item, &self.args, Some(expression))
215    }
216}
217
218impl EncodeExpression for ast::UnaryExpression {
219    fn alloc_expr_locals(
220        &self,
221        expression: ExpressionId,
222        allocator: &mut ExpressionAllocator,
223    ) -> Result<(), GenerationError> {
224        allocator.alloc(expression)?;
225        allocator.alloc_child(self.inner)
226    }
227
228    fn encode(
229        &self,
230        expression: ExpressionId,
231        code_gen: &mut CodeGenerator,
232    ) -> Result<(), GenerationError> {
233        code_gen.const_i32(0); // TODO support 64 bit ints
234        code_gen.encode_child(self.inner)?;
235        for field in code_gen.fields(self.inner)?.iter() {
236            code_gen.read_expr_field(self.inner, field);
237        }
238        code_gen.instruction(&enc::Instruction::I32Sub);
239        for field in code_gen.fields(expression)?.iter() {
240            code_gen.write_expr_field(expression, field);
241        }
242        Ok(())
243    }
244}
245
246impl EncodeExpression for ast::BinaryExpression {
247    fn alloc_expr_locals(
248        &self,
249        expression: ExpressionId,
250        allocator: &mut ExpressionAllocator,
251    ) -> Result<(), GenerationError> {
252        allocator.alloc(expression)?;
253        allocator.alloc_child(self.left)?;
254        allocator.alloc_child(self.right)?;
255        Ok(())
256    }
257
258    fn encode(
259        &self,
260        expression: ExpressionId,
261        code_gen: &mut CodeGenerator,
262    ) -> Result<(), GenerationError> {
263        code_gen.encode_child(self.left)?;
264        code_gen.encode_child(self.right)?;
265
266        let ptype = code_gen.get_ptype(expression)?;
267        if ptype == Some(ast::PrimitiveType::String) {
268            if self.op == ast::BinaryOp::Add {
269                encode_string_concatenation(expression, self.left, self.right, code_gen)
270            } else {
271                panic!("Strings can only be concatenated with '+'");
272            }
273        } else {
274            encode_binary_arithmetic(self.op, expression, self.left, self.right, code_gen)
275        }
276    }
277}
278
279fn encode_string_concatenation(
280    expression: ExpressionId,
281    left: ExpressionId,
282    right: ExpressionId,
283    code_gen: &mut CodeGenerator,
284) -> Result<(), GenerationError> {
285    // Compute new length
286    code_gen.read_expr_field(left, &STRING_LENGTH_FIELD);
287    code_gen.read_expr_field(right, &STRING_LENGTH_FIELD);
288    code_gen.instruction(&enc::Instruction::I32Add);
289    code_gen.write_expr_field(expression, &STRING_LENGTH_FIELD);
290    // Allocate new string
291    code_gen.const_i32(0);
292    code_gen.const_i32(0);
293    code_gen.const_i32(2i32.pow(STRING_CONTENTS_ALIGNMENT));
294    code_gen.read_expr_field(expression, &STRING_LENGTH_FIELD);
295    code_gen.allocate();
296    code_gen.write_expr_field(expression, &STRING_OFFSET_FIELD);
297    // Copy in the left string
298    code_gen.read_expr_field(expression, &STRING_OFFSET_FIELD);
299    code_gen.read_expr_field(left, &STRING_OFFSET_FIELD);
300    code_gen.read_expr_field(left, &STRING_LENGTH_FIELD);
301    code_gen.instruction(&enc::Instruction::MemoryCopy {
302        src_mem: 0,
303        dst_mem: 0,
304    });
305    // Copy in the right string
306    code_gen.read_expr_field(expression, &STRING_OFFSET_FIELD);
307    code_gen.read_expr_field(left, &STRING_LENGTH_FIELD);
308    code_gen.instruction(&enc::Instruction::I32Add);
309    code_gen.read_expr_field(right, &STRING_OFFSET_FIELD);
310    code_gen.read_expr_field(right, &STRING_LENGTH_FIELD);
311    code_gen.instruction(&enc::Instruction::MemoryCopy {
312        src_mem: 0,
313        dst_mem: 0,
314    });
315    Ok(())
316}
317
318const S: Signedness = Signedness::Signed;
319const U: Signedness = Signedness::Unsigned;
320
321fn encode_binary_arithmetic(
322    op: ast::BinaryOp,
323    expression: ExpressionId,
324    left: ExpressionId,
325    right: ExpressionId,
326    code_gen: &mut CodeGenerator,
327) -> Result<(), GenerationError> {
328    let left_field = code_gen.one_field(left)?;
329    let right_field = code_gen.one_field(right)?;
330    let field = code_gen.one_field(expression)?;
331
332    let valtype = left_field.stack_type;
333    let signedness = left_field.signedness;
334    let mask = left_field.arith_mask;
335
336    code_gen.read_expr_field(left, &left_field);
337    code_gen.read_expr_field(right, &right_field);
338
339    let instruction = match (op, valtype, signedness) {
340        // Multiply
341        (ast::BinaryOp::Multiply, enc::ValType::I32, _) => enc::Instruction::I32Mul,
342        (ast::BinaryOp::Multiply, enc::ValType::I64, _) => enc::Instruction::I64Mul,
343        (ast::BinaryOp::Multiply, enc::ValType::F32, _) => enc::Instruction::F32Mul,
344        (ast::BinaryOp::Multiply, enc::ValType::F64, _) => enc::Instruction::F64Mul,
345        // Divide
346        (ast::BinaryOp::Divide, enc::ValType::I32, S) => enc::Instruction::I32DivS,
347        (ast::BinaryOp::Divide, enc::ValType::I32, U) => enc::Instruction::I32DivU,
348        (ast::BinaryOp::Divide, enc::ValType::I64, S) => enc::Instruction::I64DivS,
349        (ast::BinaryOp::Divide, enc::ValType::I64, U) => enc::Instruction::I64DivU,
350        (ast::BinaryOp::Divide, enc::ValType::F32, _) => enc::Instruction::F32Div,
351        (ast::BinaryOp::Divide, enc::ValType::F64, _) => enc::Instruction::F64Div,
352        // Modulo
353        (ast::BinaryOp::Modulo, enc::ValType::I32, S) => enc::Instruction::I32RemS,
354        (ast::BinaryOp::Modulo, enc::ValType::I32, U) => enc::Instruction::I32RemU,
355        (ast::BinaryOp::Modulo, enc::ValType::I64, S) => enc::Instruction::I64RemS,
356        (ast::BinaryOp::Modulo, enc::ValType::I64, U) => enc::Instruction::I64RemU,
357        // Addition
358        (ast::BinaryOp::Add, enc::ValType::I32, _) => enc::Instruction::I32Add,
359        (ast::BinaryOp::Add, enc::ValType::I64, _) => enc::Instruction::I64Add,
360        (ast::BinaryOp::Add, enc::ValType::F32, _) => enc::Instruction::F32Add,
361        (ast::BinaryOp::Add, enc::ValType::F64, _) => enc::Instruction::F64Add,
362        // Subtraction
363        (ast::BinaryOp::Subtract, enc::ValType::I32, _) => enc::Instruction::I32Sub,
364        (ast::BinaryOp::Subtract, enc::ValType::I64, _) => enc::Instruction::I64Sub,
365        (ast::BinaryOp::Subtract, enc::ValType::F32, _) => enc::Instruction::F32Sub,
366        (ast::BinaryOp::Subtract, enc::ValType::F64, _) => enc::Instruction::F64Sub,
367        // Logical Bit Shifting
368        (ast::BinaryOp::BitShiftL, enc::ValType::I32, _) => enc::Instruction::I32Shl,
369        (ast::BinaryOp::BitShiftL, enc::ValType::I64, _) => enc::Instruction::I64Shl,
370        (ast::BinaryOp::BitShiftR, enc::ValType::I32, _) => enc::Instruction::I32ShrU,
371        (ast::BinaryOp::BitShiftR, enc::ValType::I64, _) => enc::Instruction::I64ShrU,
372        // Arithmetic Bit Shifting
373        (ast::BinaryOp::ArithShiftR, enc::ValType::I32, S) => enc::Instruction::I32ShrS,
374        (ast::BinaryOp::ArithShiftR, enc::ValType::I32, U) => enc::Instruction::I32ShrU,
375        (ast::BinaryOp::ArithShiftR, enc::ValType::I64, S) => enc::Instruction::I64ShrS,
376        (ast::BinaryOp::ArithShiftR, enc::ValType::I64, U) => enc::Instruction::I64ShrU,
377        // Less than
378        (ast::BinaryOp::LessThan, enc::ValType::I32, S) => enc::Instruction::I32LtS,
379        (ast::BinaryOp::LessThan, enc::ValType::I32, U) => enc::Instruction::I32LtU,
380        (ast::BinaryOp::LessThan, enc::ValType::I64, S) => enc::Instruction::I64LtS,
381        (ast::BinaryOp::LessThan, enc::ValType::I64, U) => enc::Instruction::I64LtU,
382        (ast::BinaryOp::LessThan, enc::ValType::F32, _) => enc::Instruction::F32Lt,
383        (ast::BinaryOp::LessThan, enc::ValType::F64, _) => enc::Instruction::F64Lt,
384        // Less than equal
385        (ast::BinaryOp::LessThanEqual, enc::ValType::I32, S) => enc::Instruction::I32LeS,
386        (ast::BinaryOp::LessThanEqual, enc::ValType::I32, U) => enc::Instruction::I32LeU,
387        (ast::BinaryOp::LessThanEqual, enc::ValType::I64, S) => enc::Instruction::I64LeS,
388        (ast::BinaryOp::LessThanEqual, enc::ValType::I64, U) => enc::Instruction::I64LeU,
389        (ast::BinaryOp::LessThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Le,
390        (ast::BinaryOp::LessThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Le,
391        // Greater than
392        (ast::BinaryOp::GreaterThan, enc::ValType::I32, S) => enc::Instruction::I32GtS,
393        (ast::BinaryOp::GreaterThan, enc::ValType::I32, U) => enc::Instruction::I32GtU,
394        (ast::BinaryOp::GreaterThan, enc::ValType::I64, S) => enc::Instruction::I64GtS,
395        (ast::BinaryOp::GreaterThan, enc::ValType::I64, U) => enc::Instruction::I64GtU,
396        (ast::BinaryOp::GreaterThan, enc::ValType::F32, _) => enc::Instruction::F32Gt,
397        (ast::BinaryOp::GreaterThan, enc::ValType::F64, _) => enc::Instruction::F64Gt,
398        // Greater than or equal
399        (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, S) => enc::Instruction::I32GeS,
400        (ast::BinaryOp::GreaterThanEqual, enc::ValType::I32, U) => enc::Instruction::I32GeU,
401        (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, S) => enc::Instruction::I64GeS,
402        (ast::BinaryOp::GreaterThanEqual, enc::ValType::I64, U) => enc::Instruction::I64GeU,
403        (ast::BinaryOp::GreaterThanEqual, enc::ValType::F32, _) => enc::Instruction::F32Ge,
404        (ast::BinaryOp::GreaterThanEqual, enc::ValType::F64, _) => enc::Instruction::F64Ge,
405        // Equal
406        (ast::BinaryOp::Equals, enc::ValType::I32, _) => enc::Instruction::I32Eq,
407        (ast::BinaryOp::Equals, enc::ValType::I64, _) => enc::Instruction::I64Eq,
408        (ast::BinaryOp::Equals, enc::ValType::F32, _) => enc::Instruction::F32Eq,
409        (ast::BinaryOp::Equals, enc::ValType::F64, _) => enc::Instruction::F64Eq,
410        // Not equal
411        (ast::BinaryOp::NotEquals, enc::ValType::I32, _) => enc::Instruction::I32Ne,
412        (ast::BinaryOp::NotEquals, enc::ValType::I64, _) => enc::Instruction::I64Ne,
413        (ast::BinaryOp::NotEquals, enc::ValType::F32, _) => enc::Instruction::F32Ne,
414        (ast::BinaryOp::NotEquals, enc::ValType::F64, _) => enc::Instruction::F64Ne,
415        // Bitwise and
416        (ast::BinaryOp::BitAnd, enc::ValType::I32, _) => enc::Instruction::I32And,
417        (ast::BinaryOp::BitAnd, enc::ValType::I64, _) => enc::Instruction::I64And,
418        // Bitwise xor
419        (ast::BinaryOp::BitXor, enc::ValType::I32, _) => enc::Instruction::I32Xor,
420        (ast::BinaryOp::BitXor, enc::ValType::I64, _) => enc::Instruction::I64Xor,
421        // Bitwise or
422        (ast::BinaryOp::BitOr, enc::ValType::I32, _) => enc::Instruction::I32Or,
423        (ast::BinaryOp::BitOr, enc::ValType::I64, _) => enc::Instruction::I64Or,
424        // Logical and/or
425        (ast::BinaryOp::LogicalAnd, enc::ValType::I32, _) => enc::Instruction::I32And,
426        (ast::BinaryOp::LogicalOr, enc::ValType::I32, _) => enc::Instruction::I32Or,
427        // Fallback
428        (operator, valtype, _) => panic!(
429            "Cannot apply binary operator {:?} to type {:?}",
430            operator, valtype
431        ),
432    };
433    code_gen.instruction(&instruction);
434
435    if let Some(mask) = mask {
436        code_gen.const_i32(mask);
437        code_gen.instruction(&enc::Instruction::I32And);
438    }
439
440    code_gen.write_expr_field(expression, &field);
441    Ok(())
442}