deskc_evalmir/
eval_mir.rs

1use std::collections::HashMap;
2
3use mir::mir::Mir;
4use mir::stmt::Stmt;
5use mir::ty::ConcEffect;
6use mir::BlockId;
7
8use crate::const_stmt;
9
10use crate::value::{Closure, FnRef, Value};
11use mir::{ty::ConcType, StmtBind, VarId};
12
13#[cfg_attr(feature = "withserde", derive(serde::Serialize, serde::Deserialize))]
14#[derive(Debug, Clone, PartialEq)]
15pub struct EvalMir {
16    pub mir: Mir,
17    pub registers: HashMap<VarId, Value>,
18    pub parameters: HashMap<ConcType, Value>,
19    pub captured: HashMap<ConcType, Value>,
20    pub pc_block: BlockId,
21    pub pc_stmt_idx: usize,
22    // Before handling apply stmt, save the var to here, and used when returned.
23    pub return_register: Option<VarId>,
24    pub handlers: HashMap<ConcEffect, Handler>,
25}
26
27#[cfg_attr(feature = "withserde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum Handler {
30    Handler(Closure),
31    Continuation(Vec<EvalMir>),
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum InnerOutput {
36    Return(Value),
37    Perform {
38        input: Value,
39        effect: ConcEffect,
40    },
41    RunOther {
42        fn_ref: FnRef,
43        parameters: HashMap<ConcType, Value>,
44    },
45    Running,
46}
47
48impl EvalMir {
49    pub fn eval_next(&mut self) -> InnerOutput {
50        let block = &self.mir.blocks[self.pc_block.0];
51        // if reach to terminator
52        if block.stmts.len() == self.pc_stmt_idx {
53            match &block.terminator {
54                mir::ATerminator::Return(var) => InnerOutput::Return(
55                    (&mut self.registers)
56                        .remove(var)
57                        .expect("return value should be exists"),
58                ),
59                mir::ATerminator::Match { var, cases } => {
60                    let value = self.load_value(var);
61                    if let Value::Variant { id, value: _ } = value {
62                        let case = cases.iter().find(|c| c.ty == *id).unwrap();
63                        self.pc_block = case.next;
64                        self.pc_stmt_idx = 0;
65                        InnerOutput::Running
66                    } else {
67                        panic!("should be variant")
68                    }
69                }
70                mir::ATerminator::Goto(next) => {
71                    self.pc_block = *next;
72                    self.pc_stmt_idx = 0;
73                    InnerOutput::Running
74                }
75            }
76        } else {
77            let StmtBind {
78                var: bind_var,
79                stmt,
80            } = &block.stmts[self.pc_stmt_idx];
81            let value = match stmt {
82                Stmt::Const(const_value) => const_stmt::eval(const_value),
83                Stmt::Tuple(values) => Value::Tuple(
84                    values
85                        .iter()
86                        .map(|var| self.load_value(var).clone())
87                        .collect(),
88                ),
89                Stmt::Array(_) => todo!(),
90                Stmt::Set(_) => todo!(),
91                Stmt::Fn(fn_ref) => {
92                    let fn_ref = match fn_ref {
93                        mir::stmt::FnRef::Link(_) => todo!(),
94                        mir::stmt::FnRef::Clojure {
95                            amir,
96                            captured,
97                            handlers,
98                        } => FnRef::Closure(Closure {
99                            mir: *amir,
100                            captured: captured
101                                .iter()
102                                .map(|var| {
103                                    (self.get_var_ty(var).clone(), self.load_value(var).clone())
104                                })
105                                .collect(),
106                            handlers: handlers
107                                .iter()
108                                .map(|(effect, handler)| {
109                                    (
110                                        effect.clone(),
111                                        if let Value::FnRef(FnRef::Closure(closure)) =
112                                            self.load_value(handler).clone()
113                                        {
114                                            Handler::Handler(closure)
115                                        } else {
116                                            panic!("handler must be FnRef::Closure")
117                                        },
118                                    )
119                                })
120                                .collect(),
121                        }),
122                    };
123                    Value::FnRef(fn_ref)
124                }
125                Stmt::Perform(var) => {
126                    // Save the return register to get result from continuation.
127                    self.return_register = Some(*bind_var);
128                    if let ConcType::Effectful {
129                        ty: output,
130                        effects: _,
131                    } = self.get_var_ty(bind_var)
132                    {
133                        let effect = ConcEffect {
134                            input: self.get_var_ty(var).clone(),
135                            output: *output.clone(),
136                        };
137                        // Increment pc before perform is important
138                        self.pc_stmt_idx += 1;
139                        return InnerOutput::Perform {
140                            input: self.load_value(var).clone(),
141                            effect,
142                        };
143                    } else {
144                        panic!("type should be effectful")
145                    }
146                }
147                Stmt::Apply {
148                    function,
149                    arguments,
150                } => {
151                    if let Value::FnRef(fn_ref) = self.registers.get(function).cloned().unwrap() {
152                        let mut parameters = HashMap::new();
153                        arguments.iter().for_each(|arg| {
154                            let ty = self.get_var_ty(arg).clone();
155                            let value = self.load_value(arg).clone();
156                            parameters.insert(ty, value);
157                        });
158                        // Save the return register.
159                        self.return_register = Some(*bind_var);
160                        // Increment pc before return output is important
161                        self.pc_stmt_idx += 1;
162                        return InnerOutput::RunOther { fn_ref, parameters };
163                    } else {
164                        panic!("fn_ref");
165                    }
166                }
167                Stmt::Op { op, operands } => self.eval_op(op, operands),
168                Stmt::Ref(_) => todo!(),
169                Stmt::RefMut(_) => todo!(),
170                Stmt::Index { tuple: _, index: _ } => todo!(),
171                // TODO remove old one because move
172                Stmt::Move(x) => self.load_value(x).clone(),
173                Stmt::Variant { id, value } => Value::Variant {
174                    id: *id,
175                    value: Box::new(self.load_value(value).clone()),
176                },
177                Stmt::Parameter => {
178                    let ty = &self.mir.vars.get(bind_var).ty;
179                    self.parameters
180                        .get(ty)
181                        .or_else(|| self.captured.get(ty))
182                        .unwrap_or_else(|| {
183                            panic!("parameter must be exist {:?} in {:?}", ty, self.parameters)
184                        })
185                        .clone()
186                }
187                Stmt::Recursion => Value::FnRef(FnRef::Recursion),
188                Stmt::Link(_link) => {
189                    todo!()
190                }
191            };
192            let var = *bind_var;
193            self.store_value(var, value);
194            self.pc_stmt_idx += 1;
195            InnerOutput::Running
196        }
197    }
198
199    // After perform, continue with this function.
200    pub fn eval_continue(&mut self, _output: Value) {
201        todo!()
202    }
203
204    // After call another mir, continue with this function.
205    pub fn return_or_continue_with_value(&mut self, ret: Value) {
206        let var = self.return_register.take().expect("needs return register");
207        self.store_value(var, ret);
208    }
209
210    pub fn load_value(&self, var: &VarId) -> &Value {
211        self.registers.get(var).unwrap()
212    }
213
214    pub fn store_value(&mut self, var: VarId, value: Value) {
215        self.registers.insert(var, value);
216    }
217
218    pub fn get_var_ty(&self, var: &VarId) -> &ConcType {
219        &self.mir.vars.get(var).ty
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use std::collections::HashMap;
226
227    use mir::{
228        mir::{BasicBlock, Var},
229        stmt::Stmt,
230        ATerminator, BlockId, Const, Scope, ScopeId, StmtBind, Vars,
231    };
232
233    use super::*;
234
235    #[test]
236    fn literal() {
237        let mir = Mir {
238            parameters: vec![],
239            output: ConcType::Number,
240            vars: Vars(vec![Var {
241                ty: ConcType::Number,
242                scope: ScopeId(0),
243            }]),
244            scopes: vec![Scope { super_scope: None }],
245            blocks: vec![BasicBlock {
246                stmts: vec![StmtBind {
247                    var: VarId(0),
248                    stmt: Stmt::Const(Const::Int(1)),
249                }],
250                terminator: ATerminator::Return(VarId(0)),
251            }],
252            captured: vec![],
253            links: vec![],
254        };
255
256        let mut eval = EvalMir {
257            mir,
258            pc_block: BlockId(0),
259            pc_stmt_idx: 0,
260            registers: HashMap::new(),
261            parameters: HashMap::new(),
262            captured: HashMap::new(),
263            return_register: None,
264            handlers: HashMap::new(),
265        };
266
267        assert_eq!(eval.eval_next(), InnerOutput::Running);
268        assert_eq!(eval.eval_next(), InnerOutput::Return(Value::Int(1)));
269    }
270
271    #[test]
272    fn builtin() {}
273}