june_lang/
translator.rs

1use crate::ast;
2use crate::ast::TypedDef;
3use crate::builtins;
4use crate::types;
5use crate::types::Typed;
6use crate::wasm;
7use std::result;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum Error {}
12
13type Result<T> = result::Result<T, Error>;
14
15struct Translator {
16    module: wasm::Module,
17}
18
19impl Default for Translator {
20    fn default() -> Self {
21        let mut module = wasm::Module::default();
22        builtins::install_imports(&mut module);
23        Self { module }
24    }
25}
26
27fn find_main(defs: &[TypedDef]) -> Option<usize> {
28    defs.iter().find_map(|def| match def {
29        ast::TypedDef::FnDef(func) if func.name == "main" => {
30            Some(func.resolved_type.typ.index)
31        }
32        _ => None,
33    })
34}
35
36impl Translator {
37    fn new() -> Self {
38        Self { module: wasm::Module::default() }
39    }
40
41    fn int_literal(
42        &self,
43        lit: ast::IntLiteral,
44        instrs: &mut Vec<wasm::Instr>,
45    ) -> Result<()> {
46        instrs.push(wasm::Instr::Const(wasm::Const::I64(lit.value)));
47        Ok(())
48    }
49
50    fn binary(
51        &self,
52        binary: ast::TypedBinary,
53        instrs: &mut Vec<wasm::Instr>,
54    ) -> Result<()> {
55        self.expr(*binary.lhs, instrs)?;
56        self.expr(*binary.rhs, instrs)?;
57        // TODO: handle generic operations
58        match binary.op {
59            ast::BinaryOp::Add => instrs.push(wasm::Instr::AddI64),
60            ast::BinaryOp::Sub => instrs.push(wasm::Instr::SubI64),
61            ast::BinaryOp::Mul => todo!(),
62            ast::BinaryOp::Div => todo!(),
63        }
64        Ok(())
65    }
66
67    fn call(
68        &self,
69        call: ast::TypedCall,
70        instrs: &mut Vec<wasm::Instr>,
71    ) -> Result<()> {
72        // TODO: encode these invariants in the call ast node
73        let fn_type = call
74            .target
75            .as_ident()
76            .expect("typed call target must be an ident")
77            .resolution
78            .typ
79            .as_fn()
80            .expect("typed call target must have fn type");
81        assert_eq!(call.args.len(), fn_type.params.len());
82        for arg in call.args {
83            self.expr(arg, instrs)?;
84        }
85        instrs.push(wasm::Instr::Call(fn_type.index as u32));
86        Ok(())
87    }
88
89    fn ident(
90        &self,
91        ident: ast::TypedIdent,
92        instrs: &mut Vec<wasm::Instr>,
93    ) -> Result<()> {
94        match ident.resolution.reference {
95            types::Reference::External => todo!(),
96            types::Reference::Global { .. } => todo!(),
97            types::Reference::Stack { local_idx } => {
98                instrs.push(wasm::Instr::GetLocal(local_idx as u32));
99            }
100        }
101        Ok(())
102    }
103
104    fn expr(
105        &self,
106        expr: ast::TypedExpr,
107        instrs: &mut Vec<wasm::Instr>,
108    ) -> Result<()> {
109        match expr {
110            ast::Expr::Int(int) => self.int_literal(int, instrs),
111            ast::Expr::Binary(binary) => self.binary(binary, instrs),
112            ast::Expr::Call(call) => self.call(call, instrs),
113            ast::Expr::Ident(ident) => self.ident(ident, instrs),
114            _ => todo!(),
115        }
116    }
117
118    fn let_stmt(
119        &mut self,
120        lett: ast::TypedBinding,
121        body: &mut Vec<wasm::Instr>,
122    ) -> Result<()> {
123        self.expr(lett.expr, body)?;
124        body.push(wasm::Instr::SetLocal(lett.resolved_type.idx as u32));
125        Ok(())
126    }
127
128    fn stmt(
129        &mut self,
130        stmt: ast::TypedStmt,
131        body: &mut Vec<wasm::Instr>,
132    ) -> Result<()> {
133        match stmt {
134            ast::Stmt::Block(_) => todo!(),
135            ast::Stmt::Expr(expr) => {
136                if expr.typ() != types::Type::Void {
137                    body.push(wasm::Instr::Drop);
138                }
139                self.expr(expr, body)?;
140                Ok(())
141            }
142            ast::Stmt::Let(lett) => self.let_stmt(lett, body),
143        }
144    }
145
146    fn func(&mut self, func: ast::TypedFunc) -> Result<()> {
147        let typeidx = self.module.types.0.len() as u32;
148        let locals = func.resolved_type.locals;
149        self.module.types.0.push(func.resolved_type.into());
150        self.module.funcs.0.push(wasm::Func { typeidx });
151        let mut body = vec![];
152        for stmt in func.body.0 {
153            self.stmt(stmt, &mut body)?;
154        }
155        body.push(wasm::Instr::End);
156        // TODO: store the types of |locals|
157        self.module.code.0.push(wasm::Code {
158            locals: vec![wasm::ValType::NumType(wasm::NumType::I64); locals],
159            body,
160        });
161        Ok(())
162    }
163
164    fn def(&mut self, def: ast::TypedDef) -> Result<()> {
165        match def {
166            ast::Def::FnDef(func) => self.func(func),
167        }
168    }
169
170    fn program(&mut self, program: ast::TypedProgram) -> Result<()> {
171        // TODO: encode this in the ast
172        self.module.start.0 =
173            find_main(&program.defs).expect("main not found") as u32;
174        for def in program.defs {
175            self.def(def)?;
176        }
177        Ok(())
178    }
179
180    fn _foo(&mut self) {
181        // add
182        self.module.types.0.push(wasm::FuncType {
183            params: vec![
184                wasm::ValType::NumType(wasm::NumType::I64),
185                wasm::ValType::NumType(wasm::NumType::I64),
186            ],
187            results: vec![wasm::ValType::NumType(wasm::NumType::I64)],
188        });
189        self.module.funcs.0.push(wasm::Func { typeidx: 1 });
190        self.module.code.0.push(wasm::Code {
191            body: vec![
192                wasm::Instr::GetLocal(0),
193                wasm::Instr::GetLocal(1),
194                wasm::Instr::AddI64,
195                wasm::Instr::End,
196            ],
197            locals: vec![],
198        });
199
200        // main
201        self.module.types.0.push(wasm::FuncType { params: vec![], results: vec![] });
202        self.module.funcs.0.push(wasm::Func { typeidx: 2 });
203        self.module.code.0.push(wasm::Code {
204            body: vec![
205                wasm::Instr::Const(wasm::Const::I64(7)),
206                wasm::Instr::Const(wasm::Const::I64(14)),
207                wasm::Instr::Call(1),
208                wasm::Instr::Call(0),
209                wasm::Instr::End,
210            ],
211            locals: vec![],
212        });
213
214        self.module.start.0 = 2;
215    }
216
217    fn translate(mut self, program: ast::TypedProgram) -> Result<wasm::Module> {
218        self.program(program)?;
219        Ok(self.module)
220    }
221}
222
223pub fn translate(program: ast::TypedProgram) -> Result<wasm::Module> {
224    Translator::default().translate(program)
225}
226
227#[cfg(test)]
228mod test {
229    use super::*;
230    use crate::types;
231
232    fn translate_expr(expr: ast::TypedExpr) -> Vec<wasm::Instr> {
233        let mut instrs = Vec::new();
234        Translator::new().expr(expr, &mut instrs).unwrap();
235        instrs
236    }
237
238    #[derive(Debug, PartialEq)]
239    struct FuncDefSpec {
240        typ: wasm::FuncType,
241        func: wasm::Func,
242        code: wasm::Code,
243    }
244
245    #[test]
246    fn test_int_literal() {
247        let instrs = translate_expr(ast::TypedExpr::Int(ast::Literal::new(247)));
248        let expected = vec![wasm::Instr::Const(wasm::Const::I64(247))];
249        assert_eq!(instrs, expected);
250    }
251
252    #[test]
253    fn test_binary() {
254        let instrs = translate_expr(ast::TypedExpr::Binary(ast::TypedBinary {
255            op: ast::BinaryOp::Add,
256            lhs: Box::new(ast::TypedExpr::Int(ast::Literal::new(3))),
257            rhs: Box::new(ast::TypedExpr::Int(ast::Literal::new(16))),
258            cargo: types::Type::Int,
259        }));
260        let expected = vec![
261            wasm::Instr::Const(wasm::Const::I64(3)),
262            wasm::Instr::Const(wasm::Const::I64(16)),
263            wasm::Instr::AddI64,
264        ];
265        assert_eq!(instrs, expected);
266    }
267
268    #[test]
269    fn test_call() {
270        let instrs = translate_expr(ast::TypedExpr::Call(ast::TypedCall {
271            target: Box::new(ast::TypedExpr::Ident(ast::TypedIdent {
272                name: String::from("add"),
273                resolution: types::Resolution {
274                    typ: types::Type::Fn(types::FnType {
275                        index: 3,
276                        params: vec![types::Type::Int, types::Type::Int],
277                        ret: Some(Box::new(types::Type::Int)),
278                    }),
279                    reference: types::Reference::Global { idx: 17 },
280                },
281            })),
282            args: vec![
283                ast::TypedExpr::Int(ast::Literal::new(3)),
284                ast::TypedExpr::Int(ast::Literal::new(4)),
285            ],
286            resolved_type: types::Type::Int,
287        }));
288        let expected = vec![
289            wasm::Instr::Const(wasm::Const::I64(3)),
290            wasm::Instr::Const(wasm::Const::I64(4)),
291            wasm::Instr::Call(3),
292        ];
293        assert_eq!(instrs, expected);
294    }
295
296    #[test]
297    fn test_func() {
298        let mut translator = Translator::new();
299        translator
300            .def(ast::TypedDef::FnDef(ast::TypedFunc {
301                name: String::from("print_sum"),
302                params: vec![
303                    ast::Param {
304                        name: String::from("a"),
305                        typ: ast::TypeSpec::simple("int"),
306                        resolved_type: types::Type::Int,
307                    },
308                    ast::Param {
309                        name: String::from("b"),
310                        typ: ast::TypeSpec::simple("int"),
311                        resolved_type: types::Type::Int,
312                    },
313                ],
314                ret: None,
315                resolved_type: types::FnDef {
316                    typ: types::FnType {
317                        index: 1,
318                        params: vec![types::Type::Int, types::Type::Int],
319                        ret: None,
320                    },
321                    locals: 1,
322                },
323                body: ast::Block(vec![
324                    ast::Stmt::Let(ast::Binding::new(
325                        String::from("sum"),
326                        ast::TypeSpec::simple("int"),
327                        ast::Expr::Binary(ast::Binary {
328                            op: ast::BinaryOp::Add,
329                            lhs: Box::new(ast::Expr::Ident(ast::Ident {
330                                name: String::from("a"),
331                                resolution: types::Resolution {
332                                    typ: types::Type::Int,
333                                    reference: types::Reference::Stack {
334                                        local_idx: 0,
335                                    },
336                                },
337                            })),
338                            rhs: Box::new(ast::Expr::Ident(ast::Ident {
339                                name: String::from("b"),
340                                resolution: types::Resolution {
341                                    typ: types::Type::Int,
342                                    reference: types::Reference::Stack {
343                                        local_idx: 1,
344                                    },
345                                },
346                            })),
347                            cargo: types::Type::Int,
348                        }),
349                        types::LocalBinding { typ: types::Type::Int, idx: 2 },
350                    )),
351                    ast::Stmt::Expr(ast::Expr::Call(ast::Call {
352                        target: Box::new(ast::Expr::Ident(ast::Ident {
353                            name: String::from("println"),
354                            resolution: types::Resolution {
355                                reference: types::Reference::External,
356                                typ: types::Type::Fn(types::FnType {
357                                    index: 0,
358                                    params: vec![types::Type::Int],
359                                    ret: None,
360                                }),
361                            },
362                        })),
363                        args: vec![ast::Expr::Ident(ast::Ident {
364                            name: String::from("sum"),
365                            resolution: types::Resolution {
366                                typ: types::Type::Int,
367                                reference: types::Reference::Stack { local_idx: 2 },
368                            },
369                        })],
370                        resolved_type: types::Type::Void,
371                    })),
372                ]),
373            }))
374            .unwrap();
375        assert_eq!(
376            vec![wasm::FuncType {
377                params: vec![
378                    wasm::ValType::NumType(wasm::NumType::I64),
379                    wasm::ValType::NumType(wasm::NumType::I64)
380                ],
381                results: vec![]
382            }],
383            translator.module.types.0,
384        );
385        assert_eq!(vec![wasm::Func { typeidx: 0 }], translator.module.funcs.0);
386        assert_eq!(
387            &wasm::Code {
388                locals: vec![wasm::ValType::NumType(wasm::NumType::I64),],
389                body: vec![
390                    wasm::Instr::GetLocal(0),
391                    wasm::Instr::GetLocal(1),
392                    wasm::Instr::AddI64,
393                    wasm::Instr::SetLocal(2),
394                    wasm::Instr::GetLocal(2),
395                    wasm::Instr::Call(0),
396                    wasm::Instr::End,
397                ]
398            },
399            &translator.module.code.0[0]
400        );
401    }
402}