mini_lang/
ir.rs

1use crate::{parser, MiniError, MiniResult};
2use std::collections::HashMap;
3
4pub use parser::Operator;
5
6/// List of define functions, variables, and expressions to print.
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct Program {
9    pub funcs: Vec<Expr>,
10    pub vars: Vec<Expr>,
11    pub prints: Vec<Expr>,
12}
13
14/// The expression tree.
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub enum Expr {
17    /// The literal value.
18    Value(i32),
19    /// The variable's scope depth (specified by `expr.circulate`), and index (in `program.vars`).
20    Variable(usize, usize),
21    /// The operator, and left side value, and right side value.
22    Operation(Operator, Box<Expr>, Box<Expr>),
23    /// The function index (in `program.funcs`) and list of arguments (number integrity is already verified.)
24    FuncCall(usize, Vec<Expr>),
25    /// The condition, the expression evaluated if condition is true, and false.
26    If(Box<Expr>, Box<Expr>, Box<Expr>),
27}
28
29impl Expr {
30    fn from_ast(
31        e: parser::Expr,
32        ns_vars: &HashMap<String, usize>,
33        ns_funcs: &HashMap<String, (usize, usize)>,
34    ) -> MiniResult<Self> {
35        Ok(match e {
36            parser::Expr::Value(v) => Self::Value(v),
37            parser::Expr::Variable(s) => {
38                Self::Variable(0, *ns_vars.get(&s).ok_or("Using undefined variable.")?)
39            }
40            parser::Expr::Operation(op, lhs, rhs) => Self::Operation(
41                op,
42                Box::new(Self::from_ast(*lhs, ns_vars, ns_funcs)?),
43                Box::new(Self::from_ast(*rhs, ns_vars, ns_funcs)?),
44            ),
45            parser::Expr::FuncCall(s, e) => {
46                let (id, args) = *ns_funcs.get(&s).ok_or("Using undefined function.")?;
47                if e.len() != args {
48                    return Err(MiniError::from("Illegal arguments."));
49                }
50                Self::FuncCall(
51                    id,
52                    e.into_iter()
53                        .map(|e| Self::from_ast(e, ns_vars, ns_funcs))
54                        .collect::<Result<Vec<_>, _>>()?,
55                )
56            }
57            parser::Expr::If(c, t, f) => Self::If(
58                Box::new(Self::from_ast(*c, ns_vars, ns_funcs)?),
59                Box::new(Self::from_ast(*t, ns_vars, ns_funcs)?),
60                Box::new(Self::from_ast(*f, ns_vars, ns_funcs)?),
61            ),
62        })
63    }
64
65    /// Circulate the variable's scope depth recursively.
66    pub fn circulate(self, depth: usize) -> Self {
67        match self {
68            Self::Value(v) => Self::Value(v),
69            Self::Variable(_, id) => Self::Variable(depth, id),
70            Self::Operation(op, lhs, rhs) => Self::Operation(
71                op,
72                Box::new((*lhs).circulate(depth)),
73                Box::new((*rhs).circulate(depth)),
74            ),
75            Self::FuncCall(id, args) => {
76                Self::FuncCall(id, args.into_iter().map(|e| e.circulate(depth)).collect())
77            }
78            Self::If(c, t, f) => Self::If(
79                Box::new((*c).circulate(depth)),
80                Box::new((*t).circulate(depth)),
81                Box::new((*f).circulate(depth)),
82            ),
83        }
84    }
85}
86
87pub fn compile(ast: parser::Ast) -> MiniResult<Program> {
88    let mut vars = Vec::new();
89    let mut ns_vars = HashMap::new();
90    let mut funcs = Vec::new();
91    let mut ns_funcs = HashMap::new();
92    let mut prints = Vec::new();
93    for stmt in ast {
94        match stmt {
95            parser::Stmt::Binding(v, e) => {
96                let id = vars.len();
97                vars.push(Expr::from_ast(e, &ns_vars, &ns_funcs)?);
98                ns_vars.insert(v, id);
99            }
100            parser::Stmt::Print(e) => {
101                prints.push(Expr::from_ast(e, &ns_vars, &ns_funcs)?);
102            }
103            parser::Stmt::Define(f, a, e) => {
104                let id = vars.len();
105                let args = a.len();
106                ns_funcs.insert(f, (id, args));
107                let local_vars = a.into_iter().enumerate().map(|(i, s)| (s, i)).collect();
108                funcs.push(Expr::from_ast(e, &local_vars, &ns_funcs)?);
109            }
110        }
111    }
112
113    Ok(Program {
114        vars,
115        funcs,
116        prints,
117    })
118}