1use crate::ast::*;
3use crate::builtins;
4use crate::symbol_table::*;
5use crate::types::*;
6use std::result;
7use thiserror::Error;
8
9#[derive(Debug, Error, PartialEq)]
10pub enum Error {
11 #[error("undefined: {0}")]
12 Undefined(String),
13 #[error("wrong number of arguments: want {want}, got {got}")]
14 Arity { want: usize, got: usize },
15 #[error("type mismatch: want {want:?}, got {got:?}")]
16 TypeMismatch { want: Type, got: Type },
17 #[error("not callable: {0:?}")]
18 InvalidCallable(Type),
19 #[error("unknown type: {0}")]
20 UnknownType(String),
21 #[error("invalid types for operator {op:?}: {lhs:?}, {rhs:?}")]
22 InvalidOpTypes { op: BinaryOp, lhs: Type, rhs: Type },
23 #[error("entrypoint `main` undefined")]
24 NoMain,
25}
26
27type Result<T> = result::Result<T, Error>;
28
29fn check_op(op: BinaryOp, lhs: Type, rhs: Type) -> Result<Type> {
30 use BinaryOp::*;
31 use Type::*;
32 match (op, &lhs, &rhs) {
33 (Add, Int, Int) => Ok(Int),
34 (Add, Str, Str) => Ok(Str),
35 _ => Err(Error::InvalidOpTypes { op, lhs, rhs }),
36 }
37}
38
39fn check((want, got): (&Type, &TypedExpr)) -> Result<()> {
40 let got = got.typ();
41 if want != &got {
42 Err(Error::TypeMismatch { want: want.clone(), got })
43 } else {
44 Ok(())
45 }
46}
47
48fn check_all(want: &[Type], got: &[TypedExpr]) -> Result<()> {
49 if want.len() != got.len() {
50 Err(Error::Arity { want: want.len(), got: got.len() })
51 } else {
52 want.iter().zip(got).try_for_each(check)
53 }
54}
55
56pub struct Analyzer {
57 ctx: SymbolTable,
58}
59
60impl Default for Analyzer {
61 fn default() -> Analyzer {
62 let mut ctx = SymbolTable::default();
63 builtins::install_symbols(&mut ctx);
64 Analyzer::with_context(ctx)
65 }
66}
67
68impl Analyzer {
69 fn with_context(ctx: SymbolTable) -> Analyzer {
70 Analyzer { ctx }
71 }
72
73 fn try_map<T, U, F>(&mut self, ts: Vec<T>, f: F) -> Result<Vec<U>>
74 where
75 F: Fn(&mut Self, T) -> Result<U>,
76 {
77 ts.into_iter().map(|t| f(self, t)).collect::<Result<_>>()
78 }
79
80 fn param(&mut self, param: Param) -> Result<TypedParam> {
81 let (name, typ) = (param.name, param.typ);
82 let resolved_type = self.resolve_type(&typ)?;
83 Ok(Param { name, typ, resolved_type })
84 }
85
86 fn func(&mut self, f: Func) -> Result<TypedFunc> {
87 self.ctx.enter_function();
88 let params = self.try_map(f.params, |a, param| a.param(param))?;
89 self.ctx.push_frame();
90 for param in ¶ms {
91 self.ctx.def_local(¶m.name, param.resolved_type.clone());
92 }
93 let body = self.block(f.body)?;
94 let ret = f.ret.clone().map(|typ| self.resolve_type(&typ)).transpose()?;
95 let resolved_type = FnDef {
97 typ: self.ctx.def_fn(params.iter().map(|p| p.typ()).collect(), ret),
98 locals: self.ctx.num_locals().unwrap() - params.len(),
99 };
100 let func = Func { name: f.name, params, body, ret: f.ret, resolved_type };
101 self.ctx.pop_frame();
102 self.ctx.exit_function();
103 Ok(func)
104 }
105
106 fn call(&mut self, call: Call) -> Result<TypedCall> {
107 let target = Box::new(self.expr(*call.target)?);
108 if let Type::Fn(f) = target.typ() {
109 let args = self.try_map(call.args, |a, arg| a.expr(arg))?;
110 check_all(&f.params, &args)?;
111 Ok(Call { target, args, resolved_type: f.ret.into() })
112 } else {
113 Err(Error::InvalidCallable(target.typ()))
114 }
115 }
116
117 fn ident(&mut self, ident: Ident) -> Result<TypedExpr> {
118 let name = ident.name;
119 let resolution =
120 self.ctx.get(&name).ok_or(Error::Undefined(name.clone()))?;
121 Ok(Expr::Ident(Ident { name, resolution }))
122 }
123
124 fn binary(&mut self, expr: Binary) -> Result<TypedExpr> {
125 let lhs = Box::new(self.expr(*expr.lhs)?);
126 let rhs = Box::new(self.expr(*expr.rhs)?);
127 let op = expr.op;
128 let cargo = check_op(op, lhs.typ(), rhs.typ())?;
129 Ok(Expr::Binary(Binary { op, lhs, rhs, cargo }))
130 }
131
132 pub fn expr(&mut self, expr: Expr) -> Result<TypedExpr> {
133 match expr {
134 Expr::Call(call) => Ok(Expr::Call(self.call(call)?)),
135 Expr::Int(prim) => Ok(Expr::Int(prim)),
136 Expr::Str(prim) => Ok(Expr::Str(prim)),
137 Expr::Ident(prim) => self.ident(prim),
138 Expr::Binary(bin) => self.binary(bin),
139 }
140 }
141
142 fn resolve_type(&self, typ: &TypeSpec) -> Result<Type> {
143 match typ {
145 TypeSpec::Void => Ok(Type::Void),
146 TypeSpec::Simple(typ) if "int" == typ => Ok(Type::Int),
147 TypeSpec::Simple(typ) if "str" == typ => Ok(Type::Str),
148 TypeSpec::Simple(typ) => Err(Error::UnknownType(typ.into())),
149 }
150 }
151
152 fn let_stmt(&mut self, stmt: Binding) -> Result<TypedStmt> {
153 let typ = self.resolve_type(&stmt.typ)?;
154 let expr = self.expr(stmt.expr)?;
155 check((&typ, &expr))?;
156 let idx = self.ctx.def_local(&stmt.name, typ.clone());
157 Ok(Stmt::Let(Binding::new(
158 stmt.name,
159 stmt.typ,
160 expr,
161 LocalBinding { typ, idx },
162 )))
163 }
164
165 fn stmt(&mut self, stmt: Stmt) -> Result<TypedStmt> {
166 match stmt {
167 Stmt::Expr(expr) => Ok(Stmt::Expr(self.expr(expr)?)),
168 Stmt::Let(stmt) => self.let_stmt(stmt),
169 Stmt::Block(block) => Ok(Stmt::Block(self.block(block)?)),
170 }
171 }
172
173 fn block(&mut self, Block(stmts): Block) -> Result<TypedBlock> {
174 self.ctx.push_frame();
175 let stmts = self.try_map(stmts, |a, stmt| a.stmt(stmt))?;
176 self.ctx.pop_frame();
177 Ok(Block(stmts))
178 }
179
180 fn def(&mut self, def: Def) -> Result<TypedDef> {
181 match def {
182 Def::FnDef(f) => {
183 let func = self.func(f)?;
184 let typ = func.resolved_type.clone();
185 self.ctx.def_global(&func.name, Type::Fn(typ.typ));
186 Ok(Def::FnDef(func))
187 }
188 }
189 }
190
191 fn program(&mut self, Program { defs, .. }: Program) -> Result<TypedProgram> {
192 let defs = self.try_map(defs, |a, def| a.def(def))?;
193 if let Some(main_def) =
194 defs.iter().position(|d| matches!(d, Def::FnDef(f) if &f.name == "main"))
195 {
196 Ok(Program { main_def, defs })
197 } else {
198 Err(Error::NoMain)
199 }
200 }
201}
202
203pub fn analyze(prog: Program) -> Result<TypedProgram> {
204 Analyzer::default().program(prog)
205}
206
207#[cfg(test)]
208mod test {
209 use super::*;
210 use crate::parser;
211 use crate::scanner;
212 use pretty_assertions::assert_eq;
213
214 fn parse(input: &[u8]) -> parser::Parser<&[u8]> {
215 parser::Parser::new(scanner::scan(input))
216 }
217
218 fn with_locals<S: ToString, T: IntoIterator<Item = (S, Type)>>(
219 locals: T,
220 ) -> Analyzer {
221 let mut ctx = SymbolTable::default();
222 ctx.enter_function();
223 ctx.push_frame();
224 locals.into_iter().for_each(|(name, typ)| {
225 ctx.def_local(name.to_string(), typ);
226 });
227 Analyzer::with_context(ctx)
228 }
229
230 #[test]
231 fn test_hello() {
232 let input = b"
233 fn greet(name: str) {
234 println(name);
235 }
236
237 fn main() {
238 greet(\"the pope\");
239 }
240 ";
241 let program = parse(input).program().unwrap();
242 let expected = Program {
243 main_def: 1,
244 defs: vec![
245 Def::FnDef(Func {
246 name: String::from("greet"),
247 params: vec![Param {
248 name: String::from("name"),
249 typ: TypeSpec::simple("str"),
250 resolved_type: Type::Str,
251 }],
252 ret: None,
253 body: Block(vec![Stmt::Expr(Expr::Call(Call {
254 target: Box::new(Expr::Ident(Ident {
255 name: String::from("println"),
256 resolution: Resolution {
257 reference: Reference::Global { idx: 0 },
258 typ: Type::Fn(FnType {
259 index: 0,
260 params: vec![Type::Str],
261 ret: None,
262 }),
263 },
264 })),
265 args: vec![Expr::Ident(Ident {
266 name: String::from("name"),
267 resolution: Resolution {
268 reference: Reference::Stack { local_idx: 0 },
269 typ: Type::Str,
270 },
271 })],
272 resolved_type: Type::Void,
273 }))]),
274 resolved_type: FnDef {
275 typ: FnType { index: 1, params: vec![Type::Str], ret: None },
276 locals: 0,
277 },
278 }),
279 Def::FnDef(Func {
280 name: String::from("main"),
281 params: vec![],
282 ret: None,
283 body: Block(vec![Stmt::Expr(Expr::Call(Call {
284 target: Box::new(Expr::Ident(Ident {
285 name: String::from("greet"),
286 resolution: Resolution {
287 reference: Reference::Global { idx: 1 },
288 typ: Type::Fn(FnType {
289 index: 1,
290 params: vec![Type::Str],
291 ret: None,
292 }),
293 },
294 })),
295 args: vec![Expr::Str(Literal::new("the pope"))],
296 resolved_type: Type::Void,
297 }))]),
298 resolved_type: FnDef {
299 typ: FnType { index: 2, params: vec![], ret: None },
300 locals: 0,
301 },
302 }),
303 ],
304 };
305 let mut ctx = SymbolTable::default();
306 let println = ctx.def_fn(vec![Type::Str], None);
307 ctx.def_global("println", Type::Fn(println));
308 let actual = Analyzer::with_context(ctx).program(program).unwrap();
309 assert_eq!(expected, actual);
310 }
311
312 #[test]
313 fn test_func() {
314 let mut ctx = SymbolTable::default();
315 let itoa = ctx.def_fn(vec![Type::Int], Some(Type::Str));
316 ctx.def_global("itoa", Type::Fn(itoa));
317 let join = ctx.def_fn(vec![Type::Str, Type::Str], Some(Type::Str));
318 ctx.def_global("join", Type::Fn(join));
319 let println = ctx.def_fn(vec![Type::Str], None);
320 ctx.def_global("println", Type::Fn(println));
321 let input = b"
322 fn greet(name: str, age: int) {
323 let age_str: str = itoa(age);
324 let greeting: str = join(name, age_str);
325 println(greeting);
326 }
327 ";
328 let expected = Func {
329 name: String::from("greet"),
330 params: vec![
331 Param {
332 name: String::from("name"),
333 typ: TypeSpec::simple("str"),
334 resolved_type: Type::Str,
335 },
336 Param {
337 name: String::from("age"),
338 typ: TypeSpec::simple("int"),
339 resolved_type: Type::Int,
340 },
341 ],
342 ret: None,
343 body: Block(vec![
344 Stmt::Let(Binding {
345 name: String::from("age_str"),
346 typ: TypeSpec::simple("str"),
347 expr: Expr::Call(Call {
348 target: Box::new(Expr::Ident(Ident {
349 name: String::from("itoa"),
350 resolution: Resolution {
351 reference: Reference::Global { idx: 0 },
352 typ: Type::Fn(FnType {
353 index: 0,
354 params: vec![Type::Int],
355 ret: Some(Box::new(Type::Str)),
356 }),
357 },
358 })),
359 args: vec![Expr::Ident(Ident {
360 name: String::from("age"),
361 resolution: Resolution {
362 reference: Reference::Stack { local_idx: 1 },
363 typ: Type::Int,
364 },
365 })],
366 resolved_type: Type::Str,
367 }),
368 resolved_type: LocalBinding { typ: Type::Str, idx: 2 },
369 }),
370 Stmt::Let(Binding {
371 name: String::from("greeting"),
372 typ: TypeSpec::simple("str"),
373 expr: Expr::Call(Call {
374 target: Box::new(Expr::Ident(Ident {
375 name: String::from("join"),
376 resolution: Resolution {
377 reference: Reference::Global { idx: 1 },
378 typ: Type::Fn(FnType {
379 index: 1,
380 params: vec![Type::Str, Type::Str],
381 ret: Some(Box::new(Type::Str)),
382 }),
383 },
384 })),
385 args: vec![
386 Expr::Ident(Ident {
387 name: String::from("name"),
388 resolution: Resolution {
389 reference: Reference::Stack { local_idx: 0 },
390 typ: Type::Str,
391 },
392 }),
393 Expr::Ident(Ident {
394 name: String::from("age_str"),
395 resolution: Resolution {
396 reference: Reference::Stack { local_idx: 2 },
397 typ: Type::Str,
398 },
399 }),
400 ],
401 resolved_type: Type::Str,
402 }),
403 resolved_type: LocalBinding { typ: Type::Str, idx: 3 },
404 }),
405 Stmt::Expr(Expr::Call(Call {
406 target: Box::new(Expr::Ident(Ident {
407 name: String::from("println"),
408 resolution: Resolution {
409 reference: Reference::Global { idx: 2 },
410 typ: Type::Fn(FnType {
411 index: 2,
412 params: vec![Type::Str],
413 ret: None,
414 }),
415 },
416 })),
417 args: vec![Expr::Ident(Ident {
418 name: String::from("greeting"),
419 resolution: Resolution {
420 reference: Reference::Stack { local_idx: 3 },
421 typ: Type::Str,
422 },
423 })],
424 resolved_type: Type::Void,
425 })),
426 ]),
427 resolved_type: FnDef {
428 typ: FnType {
429 index: 3,
430 params: vec![Type::Str, Type::Int],
431 ret: None,
432 },
433 locals: 2,
434 },
435 };
436 let func = parse(input).fn_expr().unwrap();
437 let actual = Analyzer::with_context(ctx).func(func).unwrap();
438 assert_eq!(expected, actual);
439 }
440
441 #[test]
442 fn test_binary() {
443 let input: Vec<(Analyzer, &[u8])> = vec![
444 (Analyzer::default(), b"14 + 7"),
445 (Analyzer::default(), b"\"a\" + \"b\""),
446 (with_locals(vec![("x", Type::Int)]), b"x + 7"),
447 (with_locals(vec![("x", Type::Str)]), b"x + \"s\""),
448 (with_locals(vec![("x", Type::Str)]), b"x + 7"),
449 ];
450 let expected = vec![
451 Ok(Type::Int),
452 Ok(Type::Str),
453 Ok(Type::Int),
454 Ok(Type::Str),
455 Err(Error::InvalidOpTypes {
456 op: BinaryOp::Add,
457 lhs: Type::Str,
458 rhs: Type::Int,
459 }),
460 ];
461 let actual: Vec<Result<Type>> = input
462 .into_iter()
463 .map(|(mut a, s)| a.expr(parse(s).expr().unwrap()))
464 .map(|e| e.map(|te| te.typ()))
465 .collect();
466 assert_eq!(expected, actual);
467 }
468
469 #[test]
470 fn test_block() {
471 let input = b"{
472 let x: int = 7;
473 let y: int = x;
474 {
475 let z: int = y;
476 let y: int = x;
477 let w: int = y;
478 {
479 let x: int = 7;
480 }
481 x;
482 }
483 y;
484 }";
485 let expected = Block(vec![
486 Stmt::Let(Binding {
487 name: String::from("x"),
488 typ: TypeSpec::Simple(String::from("int")),
489 expr: Expr::Int(Literal { value: 7 }),
490 resolved_type: LocalBinding { typ: Type::Int, idx: 0 },
491 }),
492 Stmt::Let(Binding {
493 name: String::from("y"),
494 typ: TypeSpec::Simple(String::from("int")),
495 expr: Expr::Ident(Ident {
496 name: String::from("x"),
497 resolution: Resolution {
498 typ: Type::Int,
499 reference: Reference::Stack { local_idx: 0 },
500 },
501 }),
502 resolved_type: LocalBinding { typ: Type::Int, idx: 1 },
503 }),
504 Stmt::Block(Block(vec![
505 Stmt::Let(Binding {
506 name: String::from("z"),
507 typ: TypeSpec::Simple(String::from("int")),
508 expr: Expr::Ident(Ident {
509 name: String::from("y"),
510 resolution: Resolution {
511 typ: Type::Int,
512 reference: Reference::Stack { local_idx: 1 },
513 },
514 }),
515 resolved_type: LocalBinding { typ: Type::Int, idx: 2 },
516 }),
517 Stmt::Let(Binding {
518 name: String::from("y"),
519 typ: TypeSpec::Simple(String::from("int")),
520 expr: Expr::Ident(Ident {
521 name: String::from("x"),
522 resolution: Resolution {
523 typ: Type::Int,
524 reference: Reference::Stack { local_idx: 0 },
525 },
526 }),
527 resolved_type: LocalBinding { typ: Type::Int, idx: 3 },
528 }),
529 Stmt::Let(Binding {
530 name: String::from("w"),
531 typ: TypeSpec::Simple(String::from("int")),
532 expr: Expr::Ident(Ident {
533 name: String::from("y"),
534 resolution: Resolution {
535 typ: Type::Int,
536 reference: Reference::Stack { local_idx: 3 },
537 },
538 }),
539 resolved_type: LocalBinding { typ: Type::Int, idx: 4 },
540 }),
541 Stmt::Block(Block(vec![Stmt::Let(Binding {
542 name: String::from("x"),
543 typ: TypeSpec::Simple(String::from("int")),
544 expr: Expr::Int(Literal { value: 7 }),
545 resolved_type: LocalBinding { typ: Type::Int, idx: 5 },
546 })])),
547 Stmt::Expr(Expr::Ident(Ident {
548 name: String::from("x"),
549 resolution: Resolution {
550 typ: Type::Int,
551 reference: Reference::Stack { local_idx: 0 },
552 },
553 })),
554 ])),
555 Stmt::Expr(Expr::Ident(Ident {
556 name: String::from("y"),
557 resolution: Resolution {
558 typ: Type::Int,
559 reference: Reference::Stack { local_idx: 1 },
560 },
561 })),
562 ]);
563 let block = parse(input).block().unwrap();
564 let mut analyzer = Analyzer::default();
565 analyzer.ctx.enter_function();
566 let actual = analyzer.block(block).unwrap();
567 assert_eq!(expected, actual);
568 }
569
570 fn analyze_exprs(inputs: &[&[u8]], ctx: SymbolTable) -> Vec<Result<Type>> {
571 let mut analyzer = Analyzer::with_context(ctx);
572 inputs
573 .iter()
574 .map(|b| parse(b).expr().unwrap())
575 .map(|e| analyzer.expr(e).map(|e| e.typ()))
576 .collect()
577 }
578
579 #[test]
580 fn test_call() {
581 let mut ctx = SymbolTable::default();
582 ctx.enter_function();
583 ctx.push_frame();
584 ctx.def_local(
585 String::from("println"),
586 Type::Fn(FnType {
587 index: 0,
588 params: vec![Type::Int, Type::Str],
589 ret: None,
590 }),
591 );
592 let inputs: &[&[u8]] = &[
593 b"println(\"foo\")",
594 b"println(27, 34)",
595 b"println(27, \"foo\")",
596 ];
597 let expected = vec![
598 Err(Error::Arity { want: 2, got: 1 }),
599 Err(Error::TypeMismatch { want: Type::Str, got: Type::Int }),
600 Ok(Type::Void),
601 ];
602 let actual: Vec<Result<Type>> = analyze_exprs(inputs, ctx);
603 assert_eq!(expected, actual);
604 }
605
606 #[test]
607 fn test_ident() {
608 let mut ctx = SymbolTable::default();
609 ctx.enter_function();
610 ctx.push_frame();
611 ctx.def_local(String::from("foo"), Type::Int);
612 ctx.def_local(String::from("bar"), Type::Str);
613 let inputs: &[&[u8]] = &[b"foo", b"bar", b"baz"];
614 let expected = vec![
615 Ok(Type::Int),
616 Ok(Type::Str),
617 Err(Error::Undefined(String::from("baz"))),
618 ];
619 let actual: Vec<Result<Type>> = analyze_exprs(inputs, ctx);
620 assert_eq!(expected, actual);
621 }
622
623 #[test]
624 fn test_literal() {
625 let inputs: &[&[u8]] = &[b"27", b"\"hello, world\""];
626 let expected = vec![Ok(Type::Int), Ok(Type::Str)];
627 let actual = analyze_exprs(inputs, SymbolTable::default());
628 assert_eq!(expected, actual);
629 }
630
631 #[test]
632 fn test_no_main() {
633 let input = b"fn foo() {}";
634 let ast = parse(input).program().unwrap();
635 let actual = analyze(ast);
636 assert_eq!(Err(Error::NoMain), actual);
637 }
638}