june_lang/
analyzer.rs

1// TODO: write failure test cases
2use crate::ast::*;
3use crate::builtins;
4use crate::symbol_table::*;
5use crate::types::*;
6use std::result;
7use thiserror::Error;
8
9#[derive(Debug, Error, PartialEq)]
10pub enum Error {
11    #[error("undefined: {0}")]
12    Undefined(String),
13    #[error("wrong number of arguments: want {want}, got {got}")]
14    Arity { want: usize, got: usize },
15    #[error("type mismatch: want {want:?}, got {got:?}")]
16    TypeMismatch { want: Type, got: Type },
17    #[error("not callable: {0:?}")]
18    InvalidCallable(Type),
19    #[error("unknown type: {0}")]
20    UnknownType(String),
21    #[error("invalid types for operator {op:?}: {lhs:?}, {rhs:?}")]
22    InvalidOpTypes { op: BinaryOp, lhs: Type, rhs: Type },
23    #[error("entrypoint `main` undefined")]
24    NoMain,
25}
26
27type Result<T> = result::Result<T, Error>;
28
29fn check_op(op: BinaryOp, lhs: Type, rhs: Type) -> Result<Type> {
30    use BinaryOp::*;
31    use Type::*;
32    match (op, &lhs, &rhs) {
33        (Add, Int, Int) => Ok(Int),
34        (Add, Str, Str) => Ok(Str),
35        _ => Err(Error::InvalidOpTypes { op, lhs, rhs }),
36    }
37}
38
39fn check((want, got): (&Type, &TypedExpr)) -> Result<()> {
40    let got = got.typ();
41    if want != &got {
42        Err(Error::TypeMismatch { want: want.clone(), got })
43    } else {
44        Ok(())
45    }
46}
47
48fn check_all(want: &[Type], got: &[TypedExpr]) -> Result<()> {
49    if want.len() != got.len() {
50        Err(Error::Arity { want: want.len(), got: got.len() })
51    } else {
52        want.iter().zip(got).try_for_each(check)
53    }
54}
55
56pub struct Analyzer {
57    ctx: SymbolTable,
58}
59
60impl Default for Analyzer {
61    fn default() -> Analyzer {
62        let mut ctx = SymbolTable::default();
63        builtins::install_symbols(&mut ctx);
64        Analyzer::with_context(ctx)
65    }
66}
67
68impl Analyzer {
69    fn with_context(ctx: SymbolTable) -> Analyzer {
70        Analyzer { ctx }
71    }
72
73    fn try_map<T, U, F>(&mut self, ts: Vec<T>, f: F) -> Result<Vec<U>>
74    where
75        F: Fn(&mut Self, T) -> Result<U>,
76    {
77        ts.into_iter().map(|t| f(self, t)).collect::<Result<_>>()
78    }
79
80    fn param(&mut self, param: Param) -> Result<TypedParam> {
81        let (name, typ) = (param.name, param.typ);
82        let resolved_type = self.resolve_type(&typ)?;
83        Ok(Param { name, typ, resolved_type })
84    }
85
86    fn func(&mut self, f: Func) -> Result<TypedFunc> {
87        self.ctx.enter_function();
88        let params = self.try_map(f.params, |a, param| a.param(param))?;
89        self.ctx.push_frame();
90        for param in &params {
91            self.ctx.def_local(&param.name, param.resolved_type.clone());
92        }
93        let body = self.block(f.body)?;
94        let ret = f.ret.clone().map(|typ| self.resolve_type(&typ)).transpose()?;
95        // TODO: look for return statements when we handle return types
96        let resolved_type = FnDef {
97            typ: self.ctx.def_fn(params.iter().map(|p| p.typ()).collect(), ret),
98            locals: self.ctx.num_locals().unwrap() - params.len(),
99        };
100        let func = Func { name: f.name, params, body, ret: f.ret, resolved_type };
101        self.ctx.pop_frame();
102        self.ctx.exit_function();
103        Ok(func)
104    }
105
106    fn call(&mut self, call: Call) -> Result<TypedCall> {
107        let target = Box::new(self.expr(*call.target)?);
108        if let Type::Fn(f) = target.typ() {
109            let args = self.try_map(call.args, |a, arg| a.expr(arg))?;
110            check_all(&f.params, &args)?;
111            Ok(Call { target, args, resolved_type: f.ret.into() })
112        } else {
113            Err(Error::InvalidCallable(target.typ()))
114        }
115    }
116
117    fn ident(&mut self, ident: Ident) -> Result<TypedExpr> {
118        let name = ident.name;
119        let resolution =
120            self.ctx.get(&name).ok_or(Error::Undefined(name.clone()))?;
121        Ok(Expr::Ident(Ident { name, resolution }))
122    }
123
124    fn binary(&mut self, expr: Binary) -> Result<TypedExpr> {
125        let lhs = Box::new(self.expr(*expr.lhs)?);
126        let rhs = Box::new(self.expr(*expr.rhs)?);
127        let op = expr.op;
128        let cargo = check_op(op, lhs.typ(), rhs.typ())?;
129        Ok(Expr::Binary(Binary { op, lhs, rhs, cargo }))
130    }
131
132    pub fn expr(&mut self, expr: Expr) -> Result<TypedExpr> {
133        match expr {
134            Expr::Call(call) => Ok(Expr::Call(self.call(call)?)),
135            Expr::Int(prim) => Ok(Expr::Int(prim)),
136            Expr::Str(prim) => Ok(Expr::Str(prim)),
137            Expr::Ident(prim) => self.ident(prim),
138            Expr::Binary(bin) => self.binary(bin),
139        }
140    }
141
142    fn resolve_type(&self, typ: &TypeSpec) -> Result<Type> {
143        // TODO: handle more complex types
144        match typ {
145            TypeSpec::Void => Ok(Type::Void),
146            TypeSpec::Simple(typ) if "int" == typ => Ok(Type::Int),
147            TypeSpec::Simple(typ) if "str" == typ => Ok(Type::Str),
148            TypeSpec::Simple(typ) => Err(Error::UnknownType(typ.into())),
149        }
150    }
151
152    fn let_stmt(&mut self, stmt: Binding) -> Result<TypedStmt> {
153        let typ = self.resolve_type(&stmt.typ)?;
154        let expr = self.expr(stmt.expr)?;
155        check((&typ, &expr))?;
156        let idx = self.ctx.def_local(&stmt.name, typ.clone());
157        Ok(Stmt::Let(Binding::new(
158            stmt.name,
159            stmt.typ,
160            expr,
161            LocalBinding { typ, idx },
162        )))
163    }
164
165    fn stmt(&mut self, stmt: Stmt) -> Result<TypedStmt> {
166        match stmt {
167            Stmt::Expr(expr) => Ok(Stmt::Expr(self.expr(expr)?)),
168            Stmt::Let(stmt) => self.let_stmt(stmt),
169            Stmt::Block(block) => Ok(Stmt::Block(self.block(block)?)),
170        }
171    }
172
173    fn block(&mut self, Block(stmts): Block) -> Result<TypedBlock> {
174        self.ctx.push_frame();
175        let stmts = self.try_map(stmts, |a, stmt| a.stmt(stmt))?;
176        self.ctx.pop_frame();
177        Ok(Block(stmts))
178    }
179
180    fn def(&mut self, def: Def) -> Result<TypedDef> {
181        match def {
182            Def::FnDef(f) => {
183                let func = self.func(f)?;
184                let typ = func.resolved_type.clone();
185                self.ctx.def_global(&func.name, Type::Fn(typ.typ));
186                Ok(Def::FnDef(func))
187            }
188        }
189    }
190
191    fn program(&mut self, Program { defs, .. }: Program) -> Result<TypedProgram> {
192        let defs = self.try_map(defs, |a, def| a.def(def))?;
193        if let Some(main_def) =
194            defs.iter().position(|d| matches!(d, Def::FnDef(f) if &f.name == "main"))
195        {
196            Ok(Program { main_def, defs })
197        } else {
198            Err(Error::NoMain)
199        }
200    }
201}
202
203pub fn analyze(prog: Program) -> Result<TypedProgram> {
204    Analyzer::default().program(prog)
205}
206
207#[cfg(test)]
208mod test {
209    use super::*;
210    use crate::parser;
211    use crate::scanner;
212    use pretty_assertions::assert_eq;
213
214    fn parse(input: &[u8]) -> parser::Parser<&[u8]> {
215        parser::Parser::new(scanner::scan(input))
216    }
217
218    fn with_locals<S: ToString, T: IntoIterator<Item = (S, Type)>>(
219        locals: T,
220    ) -> Analyzer {
221        let mut ctx = SymbolTable::default();
222        ctx.enter_function();
223        ctx.push_frame();
224        locals.into_iter().for_each(|(name, typ)| {
225            ctx.def_local(name.to_string(), typ);
226        });
227        Analyzer::with_context(ctx)
228    }
229
230    #[test]
231    fn test_hello() {
232        let input = b"
233            fn greet(name: str) {
234                println(name);
235            }
236
237            fn main() {
238                greet(\"the pope\");
239            }
240        ";
241        let program = parse(input).program().unwrap();
242        let expected = Program {
243            main_def: 1,
244            defs: vec![
245                Def::FnDef(Func {
246                    name: String::from("greet"),
247                    params: vec![Param {
248                        name: String::from("name"),
249                        typ: TypeSpec::simple("str"),
250                        resolved_type: Type::Str,
251                    }],
252                    ret: None,
253                    body: Block(vec![Stmt::Expr(Expr::Call(Call {
254                        target: Box::new(Expr::Ident(Ident {
255                            name: String::from("println"),
256                            resolution: Resolution {
257                                reference: Reference::Global { idx: 0 },
258                                typ: Type::Fn(FnType {
259                                    index: 0,
260                                    params: vec![Type::Str],
261                                    ret: None,
262                                }),
263                            },
264                        })),
265                        args: vec![Expr::Ident(Ident {
266                            name: String::from("name"),
267                            resolution: Resolution {
268                                reference: Reference::Stack { local_idx: 0 },
269                                typ: Type::Str,
270                            },
271                        })],
272                        resolved_type: Type::Void,
273                    }))]),
274                    resolved_type: FnDef {
275                        typ: FnType { index: 1, params: vec![Type::Str], ret: None },
276                        locals: 0,
277                    },
278                }),
279                Def::FnDef(Func {
280                    name: String::from("main"),
281                    params: vec![],
282                    ret: None,
283                    body: Block(vec![Stmt::Expr(Expr::Call(Call {
284                        target: Box::new(Expr::Ident(Ident {
285                            name: String::from("greet"),
286                            resolution: Resolution {
287                                reference: Reference::Global { idx: 1 },
288                                typ: Type::Fn(FnType {
289                                    index: 1,
290                                    params: vec![Type::Str],
291                                    ret: None,
292                                }),
293                            },
294                        })),
295                        args: vec![Expr::Str(Literal::new("the pope"))],
296                        resolved_type: Type::Void,
297                    }))]),
298                    resolved_type: FnDef {
299                        typ: FnType { index: 2, params: vec![], ret: None },
300                        locals: 0,
301                    },
302                }),
303            ],
304        };
305        let mut ctx = SymbolTable::default();
306        let println = ctx.def_fn(vec![Type::Str], None);
307        ctx.def_global("println", Type::Fn(println));
308        let actual = Analyzer::with_context(ctx).program(program).unwrap();
309        assert_eq!(expected, actual);
310    }
311
312    #[test]
313    fn test_func() {
314        let mut ctx = SymbolTable::default();
315        let itoa = ctx.def_fn(vec![Type::Int], Some(Type::Str));
316        ctx.def_global("itoa", Type::Fn(itoa));
317        let join = ctx.def_fn(vec![Type::Str, Type::Str], Some(Type::Str));
318        ctx.def_global("join", Type::Fn(join));
319        let println = ctx.def_fn(vec![Type::Str], None);
320        ctx.def_global("println", Type::Fn(println));
321        let input = b"
322            fn greet(name: str, age: int) {
323                let age_str: str = itoa(age);
324                let greeting: str = join(name, age_str);
325                println(greeting);
326            }
327        ";
328        let expected = Func {
329            name: String::from("greet"),
330            params: vec![
331                Param {
332                    name: String::from("name"),
333                    typ: TypeSpec::simple("str"),
334                    resolved_type: Type::Str,
335                },
336                Param {
337                    name: String::from("age"),
338                    typ: TypeSpec::simple("int"),
339                    resolved_type: Type::Int,
340                },
341            ],
342            ret: None,
343            body: Block(vec![
344                Stmt::Let(Binding {
345                    name: String::from("age_str"),
346                    typ: TypeSpec::simple("str"),
347                    expr: Expr::Call(Call {
348                        target: Box::new(Expr::Ident(Ident {
349                            name: String::from("itoa"),
350                            resolution: Resolution {
351                                reference: Reference::Global { idx: 0 },
352                                typ: Type::Fn(FnType {
353                                    index: 0,
354                                    params: vec![Type::Int],
355                                    ret: Some(Box::new(Type::Str)),
356                                }),
357                            },
358                        })),
359                        args: vec![Expr::Ident(Ident {
360                            name: String::from("age"),
361                            resolution: Resolution {
362                                reference: Reference::Stack { local_idx: 1 },
363                                typ: Type::Int,
364                            },
365                        })],
366                        resolved_type: Type::Str,
367                    }),
368                    resolved_type: LocalBinding { typ: Type::Str, idx: 2 },
369                }),
370                Stmt::Let(Binding {
371                    name: String::from("greeting"),
372                    typ: TypeSpec::simple("str"),
373                    expr: Expr::Call(Call {
374                        target: Box::new(Expr::Ident(Ident {
375                            name: String::from("join"),
376                            resolution: Resolution {
377                                reference: Reference::Global { idx: 1 },
378                                typ: Type::Fn(FnType {
379                                    index: 1,
380                                    params: vec![Type::Str, Type::Str],
381                                    ret: Some(Box::new(Type::Str)),
382                                }),
383                            },
384                        })),
385                        args: vec![
386                            Expr::Ident(Ident {
387                                name: String::from("name"),
388                                resolution: Resolution {
389                                    reference: Reference::Stack { local_idx: 0 },
390                                    typ: Type::Str,
391                                },
392                            }),
393                            Expr::Ident(Ident {
394                                name: String::from("age_str"),
395                                resolution: Resolution {
396                                    reference: Reference::Stack { local_idx: 2 },
397                                    typ: Type::Str,
398                                },
399                            }),
400                        ],
401                        resolved_type: Type::Str,
402                    }),
403                    resolved_type: LocalBinding { typ: Type::Str, idx: 3 },
404                }),
405                Stmt::Expr(Expr::Call(Call {
406                    target: Box::new(Expr::Ident(Ident {
407                        name: String::from("println"),
408                        resolution: Resolution {
409                            reference: Reference::Global { idx: 2 },
410                            typ: Type::Fn(FnType {
411                                index: 2,
412                                params: vec![Type::Str],
413                                ret: None,
414                            }),
415                        },
416                    })),
417                    args: vec![Expr::Ident(Ident {
418                        name: String::from("greeting"),
419                        resolution: Resolution {
420                            reference: Reference::Stack { local_idx: 3 },
421                            typ: Type::Str,
422                        },
423                    })],
424                    resolved_type: Type::Void,
425                })),
426            ]),
427            resolved_type: FnDef {
428                typ: FnType {
429                    index: 3,
430                    params: vec![Type::Str, Type::Int],
431                    ret: None,
432                },
433                locals: 2,
434            },
435        };
436        let func = parse(input).fn_expr().unwrap();
437        let actual = Analyzer::with_context(ctx).func(func).unwrap();
438        assert_eq!(expected, actual);
439    }
440
441    #[test]
442    fn test_binary() {
443        let input: Vec<(Analyzer, &[u8])> = vec![
444            (Analyzer::default(), b"14 + 7"),
445            (Analyzer::default(), b"\"a\" + \"b\""),
446            (with_locals(vec![("x", Type::Int)]), b"x + 7"),
447            (with_locals(vec![("x", Type::Str)]), b"x + \"s\""),
448            (with_locals(vec![("x", Type::Str)]), b"x + 7"),
449        ];
450        let expected = vec![
451            Ok(Type::Int),
452            Ok(Type::Str),
453            Ok(Type::Int),
454            Ok(Type::Str),
455            Err(Error::InvalidOpTypes {
456                op: BinaryOp::Add,
457                lhs: Type::Str,
458                rhs: Type::Int,
459            }),
460        ];
461        let actual: Vec<Result<Type>> = input
462            .into_iter()
463            .map(|(mut a, s)| a.expr(parse(s).expr().unwrap()))
464            .map(|e| e.map(|te| te.typ()))
465            .collect();
466        assert_eq!(expected, actual);
467    }
468
469    #[test]
470    fn test_block() {
471        let input = b"{
472            let x: int = 7;
473            let y: int = x;
474            {
475                let z: int = y;
476                let y: int = x;
477                let w: int = y;
478                {
479                    let x: int = 7;
480                }
481                x;
482            }
483            y;
484        }";
485        let expected = Block(vec![
486            Stmt::Let(Binding {
487                name: String::from("x"),
488                typ: TypeSpec::Simple(String::from("int")),
489                expr: Expr::Int(Literal { value: 7 }),
490                resolved_type: LocalBinding { typ: Type::Int, idx: 0 },
491            }),
492            Stmt::Let(Binding {
493                name: String::from("y"),
494                typ: TypeSpec::Simple(String::from("int")),
495                expr: Expr::Ident(Ident {
496                    name: String::from("x"),
497                    resolution: Resolution {
498                        typ: Type::Int,
499                        reference: Reference::Stack { local_idx: 0 },
500                    },
501                }),
502                resolved_type: LocalBinding { typ: Type::Int, idx: 1 },
503            }),
504            Stmt::Block(Block(vec![
505                Stmt::Let(Binding {
506                    name: String::from("z"),
507                    typ: TypeSpec::Simple(String::from("int")),
508                    expr: Expr::Ident(Ident {
509                        name: String::from("y"),
510                        resolution: Resolution {
511                            typ: Type::Int,
512                            reference: Reference::Stack { local_idx: 1 },
513                        },
514                    }),
515                    resolved_type: LocalBinding { typ: Type::Int, idx: 2 },
516                }),
517                Stmt::Let(Binding {
518                    name: String::from("y"),
519                    typ: TypeSpec::Simple(String::from("int")),
520                    expr: Expr::Ident(Ident {
521                        name: String::from("x"),
522                        resolution: Resolution {
523                            typ: Type::Int,
524                            reference: Reference::Stack { local_idx: 0 },
525                        },
526                    }),
527                    resolved_type: LocalBinding { typ: Type::Int, idx: 3 },
528                }),
529                Stmt::Let(Binding {
530                    name: String::from("w"),
531                    typ: TypeSpec::Simple(String::from("int")),
532                    expr: Expr::Ident(Ident {
533                        name: String::from("y"),
534                        resolution: Resolution {
535                            typ: Type::Int,
536                            reference: Reference::Stack { local_idx: 3 },
537                        },
538                    }),
539                    resolved_type: LocalBinding { typ: Type::Int, idx: 4 },
540                }),
541                Stmt::Block(Block(vec![Stmt::Let(Binding {
542                    name: String::from("x"),
543                    typ: TypeSpec::Simple(String::from("int")),
544                    expr: Expr::Int(Literal { value: 7 }),
545                    resolved_type: LocalBinding { typ: Type::Int, idx: 5 },
546                })])),
547                Stmt::Expr(Expr::Ident(Ident {
548                    name: String::from("x"),
549                    resolution: Resolution {
550                        typ: Type::Int,
551                        reference: Reference::Stack { local_idx: 0 },
552                    },
553                })),
554            ])),
555            Stmt::Expr(Expr::Ident(Ident {
556                name: String::from("y"),
557                resolution: Resolution {
558                    typ: Type::Int,
559                    reference: Reference::Stack { local_idx: 1 },
560                },
561            })),
562        ]);
563        let block = parse(input).block().unwrap();
564        let mut analyzer = Analyzer::default();
565        analyzer.ctx.enter_function();
566        let actual = analyzer.block(block).unwrap();
567        assert_eq!(expected, actual);
568    }
569
570    fn analyze_exprs(inputs: &[&[u8]], ctx: SymbolTable) -> Vec<Result<Type>> {
571        let mut analyzer = Analyzer::with_context(ctx);
572        inputs
573            .iter()
574            .map(|b| parse(b).expr().unwrap())
575            .map(|e| analyzer.expr(e).map(|e| e.typ()))
576            .collect()
577    }
578
579    #[test]
580    fn test_call() {
581        let mut ctx = SymbolTable::default();
582        ctx.enter_function();
583        ctx.push_frame();
584        ctx.def_local(
585            String::from("println"),
586            Type::Fn(FnType {
587                index: 0,
588                params: vec![Type::Int, Type::Str],
589                ret: None,
590            }),
591        );
592        let inputs: &[&[u8]] = &[
593            b"println(\"foo\")",
594            b"println(27, 34)",
595            b"println(27, \"foo\")",
596        ];
597        let expected = vec![
598            Err(Error::Arity { want: 2, got: 1 }),
599            Err(Error::TypeMismatch { want: Type::Str, got: Type::Int }),
600            Ok(Type::Void),
601        ];
602        let actual: Vec<Result<Type>> = analyze_exprs(inputs, ctx);
603        assert_eq!(expected, actual);
604    }
605
606    #[test]
607    fn test_ident() {
608        let mut ctx = SymbolTable::default();
609        ctx.enter_function();
610        ctx.push_frame();
611        ctx.def_local(String::from("foo"), Type::Int);
612        ctx.def_local(String::from("bar"), Type::Str);
613        let inputs: &[&[u8]] = &[b"foo", b"bar", b"baz"];
614        let expected = vec![
615            Ok(Type::Int),
616            Ok(Type::Str),
617            Err(Error::Undefined(String::from("baz"))),
618        ];
619        let actual: Vec<Result<Type>> = analyze_exprs(inputs, ctx);
620        assert_eq!(expected, actual);
621    }
622
623    #[test]
624    fn test_literal() {
625        let inputs: &[&[u8]] = &[b"27", b"\"hello, world\""];
626        let expected = vec![Ok(Type::Int), Ok(Type::Str)];
627        let actual = analyze_exprs(inputs, SymbolTable::default());
628        assert_eq!(expected, actual);
629    }
630
631    #[test]
632    fn test_no_main() {
633        let input = b"fn foo() {}";
634        let ast = parse(input).program().unwrap();
635        let actual = analyze(ast);
636        assert_eq!(Err(Error::NoMain), actual);
637    }
638}