1use crate::ast;
2use crate::ast::TypedDef;
3use crate::builtins;
4use crate::types;
5use crate::types::Typed;
6use crate::wasm;
7use std::result;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum Error {}
12
13type Result<T> = result::Result<T, Error>;
14
15struct Translator {
16 module: wasm::Module,
17}
18
19impl Default for Translator {
20 fn default() -> Self {
21 let mut module = wasm::Module::default();
22 builtins::install_imports(&mut module);
23 Self { module }
24 }
25}
26
27fn find_main(defs: &[TypedDef]) -> Option<usize> {
28 defs.iter().find_map(|def| match def {
29 ast::TypedDef::FnDef(func) if func.name == "main" => {
30 Some(func.resolved_type.typ.index)
31 }
32 _ => None,
33 })
34}
35
36impl Translator {
37 fn new() -> Self {
38 Self { module: wasm::Module::default() }
39 }
40
41 fn int_literal(
42 &self,
43 lit: ast::IntLiteral,
44 instrs: &mut Vec<wasm::Instr>,
45 ) -> Result<()> {
46 instrs.push(wasm::Instr::Const(wasm::Const::I64(lit.value)));
47 Ok(())
48 }
49
50 fn binary(
51 &self,
52 binary: ast::TypedBinary,
53 instrs: &mut Vec<wasm::Instr>,
54 ) -> Result<()> {
55 self.expr(*binary.lhs, instrs)?;
56 self.expr(*binary.rhs, instrs)?;
57 match binary.op {
59 ast::BinaryOp::Add => instrs.push(wasm::Instr::AddI64),
60 ast::BinaryOp::Sub => instrs.push(wasm::Instr::SubI64),
61 ast::BinaryOp::Mul => todo!(),
62 ast::BinaryOp::Div => todo!(),
63 }
64 Ok(())
65 }
66
67 fn call(
68 &self,
69 call: ast::TypedCall,
70 instrs: &mut Vec<wasm::Instr>,
71 ) -> Result<()> {
72 let fn_type = call
74 .target
75 .as_ident()
76 .expect("typed call target must be an ident")
77 .resolution
78 .typ
79 .as_fn()
80 .expect("typed call target must have fn type");
81 assert_eq!(call.args.len(), fn_type.params.len());
82 for arg in call.args {
83 self.expr(arg, instrs)?;
84 }
85 instrs.push(wasm::Instr::Call(fn_type.index as u32));
86 Ok(())
87 }
88
89 fn ident(
90 &self,
91 ident: ast::TypedIdent,
92 instrs: &mut Vec<wasm::Instr>,
93 ) -> Result<()> {
94 match ident.resolution.reference {
95 types::Reference::External => todo!(),
96 types::Reference::Global { .. } => todo!(),
97 types::Reference::Stack { local_idx } => {
98 instrs.push(wasm::Instr::GetLocal(local_idx as u32));
99 }
100 }
101 Ok(())
102 }
103
104 fn expr(
105 &self,
106 expr: ast::TypedExpr,
107 instrs: &mut Vec<wasm::Instr>,
108 ) -> Result<()> {
109 match expr {
110 ast::Expr::Int(int) => self.int_literal(int, instrs),
111 ast::Expr::Binary(binary) => self.binary(binary, instrs),
112 ast::Expr::Call(call) => self.call(call, instrs),
113 ast::Expr::Ident(ident) => self.ident(ident, instrs),
114 _ => todo!(),
115 }
116 }
117
118 fn let_stmt(
119 &mut self,
120 lett: ast::TypedBinding,
121 body: &mut Vec<wasm::Instr>,
122 ) -> Result<()> {
123 self.expr(lett.expr, body)?;
124 body.push(wasm::Instr::SetLocal(lett.resolved_type.idx as u32));
125 Ok(())
126 }
127
128 fn stmt(
129 &mut self,
130 stmt: ast::TypedStmt,
131 body: &mut Vec<wasm::Instr>,
132 ) -> Result<()> {
133 match stmt {
134 ast::Stmt::Block(_) => todo!(),
135 ast::Stmt::Expr(expr) => {
136 if expr.typ() != types::Type::Void {
137 body.push(wasm::Instr::Drop);
138 }
139 self.expr(expr, body)?;
140 Ok(())
141 }
142 ast::Stmt::Let(lett) => self.let_stmt(lett, body),
143 }
144 }
145
146 fn func(&mut self, func: ast::TypedFunc) -> Result<()> {
147 let typeidx = self.module.types.0.len() as u32;
148 let locals = func.resolved_type.locals;
149 self.module.types.0.push(func.resolved_type.into());
150 self.module.funcs.0.push(wasm::Func { typeidx });
151 let mut body = vec![];
152 for stmt in func.body.0 {
153 self.stmt(stmt, &mut body)?;
154 }
155 body.push(wasm::Instr::End);
156 self.module.code.0.push(wasm::Code {
158 locals: vec![wasm::ValType::NumType(wasm::NumType::I64); locals],
159 body,
160 });
161 Ok(())
162 }
163
164 fn def(&mut self, def: ast::TypedDef) -> Result<()> {
165 match def {
166 ast::Def::FnDef(func) => self.func(func),
167 }
168 }
169
170 fn program(&mut self, program: ast::TypedProgram) -> Result<()> {
171 self.module.start.0 =
173 find_main(&program.defs).expect("main not found") as u32;
174 for def in program.defs {
175 self.def(def)?;
176 }
177 Ok(())
178 }
179
180 fn _foo(&mut self) {
181 self.module.types.0.push(wasm::FuncType {
183 params: vec![
184 wasm::ValType::NumType(wasm::NumType::I64),
185 wasm::ValType::NumType(wasm::NumType::I64),
186 ],
187 results: vec![wasm::ValType::NumType(wasm::NumType::I64)],
188 });
189 self.module.funcs.0.push(wasm::Func { typeidx: 1 });
190 self.module.code.0.push(wasm::Code {
191 body: vec![
192 wasm::Instr::GetLocal(0),
193 wasm::Instr::GetLocal(1),
194 wasm::Instr::AddI64,
195 wasm::Instr::End,
196 ],
197 locals: vec![],
198 });
199
200 self.module.types.0.push(wasm::FuncType { params: vec![], results: vec![] });
202 self.module.funcs.0.push(wasm::Func { typeidx: 2 });
203 self.module.code.0.push(wasm::Code {
204 body: vec![
205 wasm::Instr::Const(wasm::Const::I64(7)),
206 wasm::Instr::Const(wasm::Const::I64(14)),
207 wasm::Instr::Call(1),
208 wasm::Instr::Call(0),
209 wasm::Instr::End,
210 ],
211 locals: vec![],
212 });
213
214 self.module.start.0 = 2;
215 }
216
217 fn translate(mut self, program: ast::TypedProgram) -> Result<wasm::Module> {
218 self.program(program)?;
219 Ok(self.module)
220 }
221}
222
223pub fn translate(program: ast::TypedProgram) -> Result<wasm::Module> {
224 Translator::default().translate(program)
225}
226
227#[cfg(test)]
228mod test {
229 use super::*;
230 use crate::types;
231
232 fn translate_expr(expr: ast::TypedExpr) -> Vec<wasm::Instr> {
233 let mut instrs = Vec::new();
234 Translator::new().expr(expr, &mut instrs).unwrap();
235 instrs
236 }
237
238 #[derive(Debug, PartialEq)]
239 struct FuncDefSpec {
240 typ: wasm::FuncType,
241 func: wasm::Func,
242 code: wasm::Code,
243 }
244
245 #[test]
246 fn test_int_literal() {
247 let instrs = translate_expr(ast::TypedExpr::Int(ast::Literal::new(247)));
248 let expected = vec![wasm::Instr::Const(wasm::Const::I64(247))];
249 assert_eq!(instrs, expected);
250 }
251
252 #[test]
253 fn test_binary() {
254 let instrs = translate_expr(ast::TypedExpr::Binary(ast::TypedBinary {
255 op: ast::BinaryOp::Add,
256 lhs: Box::new(ast::TypedExpr::Int(ast::Literal::new(3))),
257 rhs: Box::new(ast::TypedExpr::Int(ast::Literal::new(16))),
258 cargo: types::Type::Int,
259 }));
260 let expected = vec![
261 wasm::Instr::Const(wasm::Const::I64(3)),
262 wasm::Instr::Const(wasm::Const::I64(16)),
263 wasm::Instr::AddI64,
264 ];
265 assert_eq!(instrs, expected);
266 }
267
268 #[test]
269 fn test_call() {
270 let instrs = translate_expr(ast::TypedExpr::Call(ast::TypedCall {
271 target: Box::new(ast::TypedExpr::Ident(ast::TypedIdent {
272 name: String::from("add"),
273 resolution: types::Resolution {
274 typ: types::Type::Fn(types::FnType {
275 index: 3,
276 params: vec![types::Type::Int, types::Type::Int],
277 ret: Some(Box::new(types::Type::Int)),
278 }),
279 reference: types::Reference::Global { idx: 17 },
280 },
281 })),
282 args: vec![
283 ast::TypedExpr::Int(ast::Literal::new(3)),
284 ast::TypedExpr::Int(ast::Literal::new(4)),
285 ],
286 resolved_type: types::Type::Int,
287 }));
288 let expected = vec![
289 wasm::Instr::Const(wasm::Const::I64(3)),
290 wasm::Instr::Const(wasm::Const::I64(4)),
291 wasm::Instr::Call(3),
292 ];
293 assert_eq!(instrs, expected);
294 }
295
296 #[test]
297 fn test_func() {
298 let mut translator = Translator::new();
299 translator
300 .def(ast::TypedDef::FnDef(ast::TypedFunc {
301 name: String::from("print_sum"),
302 params: vec![
303 ast::Param {
304 name: String::from("a"),
305 typ: ast::TypeSpec::simple("int"),
306 resolved_type: types::Type::Int,
307 },
308 ast::Param {
309 name: String::from("b"),
310 typ: ast::TypeSpec::simple("int"),
311 resolved_type: types::Type::Int,
312 },
313 ],
314 ret: None,
315 resolved_type: types::FnDef {
316 typ: types::FnType {
317 index: 1,
318 params: vec![types::Type::Int, types::Type::Int],
319 ret: None,
320 },
321 locals: 1,
322 },
323 body: ast::Block(vec![
324 ast::Stmt::Let(ast::Binding::new(
325 String::from("sum"),
326 ast::TypeSpec::simple("int"),
327 ast::Expr::Binary(ast::Binary {
328 op: ast::BinaryOp::Add,
329 lhs: Box::new(ast::Expr::Ident(ast::Ident {
330 name: String::from("a"),
331 resolution: types::Resolution {
332 typ: types::Type::Int,
333 reference: types::Reference::Stack {
334 local_idx: 0,
335 },
336 },
337 })),
338 rhs: Box::new(ast::Expr::Ident(ast::Ident {
339 name: String::from("b"),
340 resolution: types::Resolution {
341 typ: types::Type::Int,
342 reference: types::Reference::Stack {
343 local_idx: 1,
344 },
345 },
346 })),
347 cargo: types::Type::Int,
348 }),
349 types::LocalBinding { typ: types::Type::Int, idx: 2 },
350 )),
351 ast::Stmt::Expr(ast::Expr::Call(ast::Call {
352 target: Box::new(ast::Expr::Ident(ast::Ident {
353 name: String::from("println"),
354 resolution: types::Resolution {
355 reference: types::Reference::External,
356 typ: types::Type::Fn(types::FnType {
357 index: 0,
358 params: vec![types::Type::Int],
359 ret: None,
360 }),
361 },
362 })),
363 args: vec![ast::Expr::Ident(ast::Ident {
364 name: String::from("sum"),
365 resolution: types::Resolution {
366 typ: types::Type::Int,
367 reference: types::Reference::Stack { local_idx: 2 },
368 },
369 })],
370 resolved_type: types::Type::Void,
371 })),
372 ]),
373 }))
374 .unwrap();
375 assert_eq!(
376 vec![wasm::FuncType {
377 params: vec![
378 wasm::ValType::NumType(wasm::NumType::I64),
379 wasm::ValType::NumType(wasm::NumType::I64)
380 ],
381 results: vec![]
382 }],
383 translator.module.types.0,
384 );
385 assert_eq!(vec![wasm::Func { typeidx: 0 }], translator.module.funcs.0);
386 assert_eq!(
387 &wasm::Code {
388 locals: vec![wasm::ValType::NumType(wasm::NumType::I64),],
389 body: vec![
390 wasm::Instr::GetLocal(0),
391 wasm::Instr::GetLocal(1),
392 wasm::Instr::AddI64,
393 wasm::Instr::SetLocal(2),
394 wasm::Instr::GetLocal(2),
395 wasm::Instr::Call(0),
396 wasm::Instr::End,
397 ]
398 },
399 &translator.module.code.0[0]
400 );
401 }
402}