1use std::collections::HashMap;
2
3use mir::mir::Mir;
4use mir::stmt::Stmt;
5use mir::ty::ConcEffect;
6use mir::BlockId;
7
8use crate::const_stmt;
9
10use crate::value::{Closure, FnRef, Value};
11use mir::{ty::ConcType, StmtBind, VarId};
12
13#[cfg_attr(feature = "withserde", derive(serde::Serialize, serde::Deserialize))]
14#[derive(Debug, Clone, PartialEq)]
15pub struct EvalMir {
16 pub mir: Mir,
17 pub registers: HashMap<VarId, Value>,
18 pub parameters: HashMap<ConcType, Value>,
19 pub captured: HashMap<ConcType, Value>,
20 pub pc_block: BlockId,
21 pub pc_stmt_idx: usize,
22 pub return_register: Option<VarId>,
24 pub handlers: HashMap<ConcEffect, Handler>,
25}
26
27#[cfg_attr(feature = "withserde", derive(serde::Serialize, serde::Deserialize))]
28#[derive(Debug, Clone, PartialEq)]
29pub enum Handler {
30 Handler(Closure),
31 Continuation(Vec<EvalMir>),
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum InnerOutput {
36 Return(Value),
37 Perform {
38 input: Value,
39 effect: ConcEffect,
40 },
41 RunOther {
42 fn_ref: FnRef,
43 parameters: HashMap<ConcType, Value>,
44 },
45 Running,
46}
47
48impl EvalMir {
49 pub fn eval_next(&mut self) -> InnerOutput {
50 let block = &self.mir.blocks[self.pc_block.0];
51 if block.stmts.len() == self.pc_stmt_idx {
53 match &block.terminator {
54 mir::ATerminator::Return(var) => InnerOutput::Return(
55 (&mut self.registers)
56 .remove(var)
57 .expect("return value should be exists"),
58 ),
59 mir::ATerminator::Match { var, cases } => {
60 let value = self.load_value(var);
61 if let Value::Variant { id, value: _ } = value {
62 let case = cases.iter().find(|c| c.ty == *id).unwrap();
63 self.pc_block = case.next;
64 self.pc_stmt_idx = 0;
65 InnerOutput::Running
66 } else {
67 panic!("should be variant")
68 }
69 }
70 mir::ATerminator::Goto(next) => {
71 self.pc_block = *next;
72 self.pc_stmt_idx = 0;
73 InnerOutput::Running
74 }
75 }
76 } else {
77 let StmtBind {
78 var: bind_var,
79 stmt,
80 } = &block.stmts[self.pc_stmt_idx];
81 let value = match stmt {
82 Stmt::Const(const_value) => const_stmt::eval(const_value),
83 Stmt::Tuple(values) => Value::Tuple(
84 values
85 .iter()
86 .map(|var| self.load_value(var).clone())
87 .collect(),
88 ),
89 Stmt::Array(_) => todo!(),
90 Stmt::Set(_) => todo!(),
91 Stmt::Fn(fn_ref) => {
92 let fn_ref = match fn_ref {
93 mir::stmt::FnRef::Link(_) => todo!(),
94 mir::stmt::FnRef::Clojure {
95 amir,
96 captured,
97 handlers,
98 } => FnRef::Closure(Closure {
99 mir: *amir,
100 captured: captured
101 .iter()
102 .map(|var| {
103 (self.get_var_ty(var).clone(), self.load_value(var).clone())
104 })
105 .collect(),
106 handlers: handlers
107 .iter()
108 .map(|(effect, handler)| {
109 (
110 effect.clone(),
111 if let Value::FnRef(FnRef::Closure(closure)) =
112 self.load_value(handler).clone()
113 {
114 Handler::Handler(closure)
115 } else {
116 panic!("handler must be FnRef::Closure")
117 },
118 )
119 })
120 .collect(),
121 }),
122 };
123 Value::FnRef(fn_ref)
124 }
125 Stmt::Perform(var) => {
126 self.return_register = Some(*bind_var);
128 if let ConcType::Effectful {
129 ty: output,
130 effects: _,
131 } = self.get_var_ty(bind_var)
132 {
133 let effect = ConcEffect {
134 input: self.get_var_ty(var).clone(),
135 output: *output.clone(),
136 };
137 self.pc_stmt_idx += 1;
139 return InnerOutput::Perform {
140 input: self.load_value(var).clone(),
141 effect,
142 };
143 } else {
144 panic!("type should be effectful")
145 }
146 }
147 Stmt::Apply {
148 function,
149 arguments,
150 } => {
151 if let Value::FnRef(fn_ref) = self.registers.get(function).cloned().unwrap() {
152 let mut parameters = HashMap::new();
153 arguments.iter().for_each(|arg| {
154 let ty = self.get_var_ty(arg).clone();
155 let value = self.load_value(arg).clone();
156 parameters.insert(ty, value);
157 });
158 self.return_register = Some(*bind_var);
160 self.pc_stmt_idx += 1;
162 return InnerOutput::RunOther { fn_ref, parameters };
163 } else {
164 panic!("fn_ref");
165 }
166 }
167 Stmt::Op { op, operands } => self.eval_op(op, operands),
168 Stmt::Ref(_) => todo!(),
169 Stmt::RefMut(_) => todo!(),
170 Stmt::Index { tuple: _, index: _ } => todo!(),
171 Stmt::Move(x) => self.load_value(x).clone(),
173 Stmt::Variant { id, value } => Value::Variant {
174 id: *id,
175 value: Box::new(self.load_value(value).clone()),
176 },
177 Stmt::Parameter => {
178 let ty = &self.mir.vars.get(bind_var).ty;
179 self.parameters
180 .get(ty)
181 .or_else(|| self.captured.get(ty))
182 .unwrap_or_else(|| {
183 panic!("parameter must be exist {:?} in {:?}", ty, self.parameters)
184 })
185 .clone()
186 }
187 Stmt::Recursion => Value::FnRef(FnRef::Recursion),
188 Stmt::Link(_link) => {
189 todo!()
190 }
191 };
192 let var = *bind_var;
193 self.store_value(var, value);
194 self.pc_stmt_idx += 1;
195 InnerOutput::Running
196 }
197 }
198
199 pub fn eval_continue(&mut self, _output: Value) {
201 todo!()
202 }
203
204 pub fn return_or_continue_with_value(&mut self, ret: Value) {
206 let var = self.return_register.take().expect("needs return register");
207 self.store_value(var, ret);
208 }
209
210 pub fn load_value(&self, var: &VarId) -> &Value {
211 self.registers.get(var).unwrap()
212 }
213
214 pub fn store_value(&mut self, var: VarId, value: Value) {
215 self.registers.insert(var, value);
216 }
217
218 pub fn get_var_ty(&self, var: &VarId) -> &ConcType {
219 &self.mir.vars.get(var).ty
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use std::collections::HashMap;
226
227 use mir::{
228 mir::{BasicBlock, Var},
229 stmt::Stmt,
230 ATerminator, BlockId, Const, Scope, ScopeId, StmtBind, Vars,
231 };
232
233 use super::*;
234
235 #[test]
236 fn literal() {
237 let mir = Mir {
238 parameters: vec![],
239 output: ConcType::Number,
240 vars: Vars(vec![Var {
241 ty: ConcType::Number,
242 scope: ScopeId(0),
243 }]),
244 scopes: vec![Scope { super_scope: None }],
245 blocks: vec![BasicBlock {
246 stmts: vec![StmtBind {
247 var: VarId(0),
248 stmt: Stmt::Const(Const::Int(1)),
249 }],
250 terminator: ATerminator::Return(VarId(0)),
251 }],
252 captured: vec![],
253 links: vec![],
254 };
255
256 let mut eval = EvalMir {
257 mir,
258 pc_block: BlockId(0),
259 pc_stmt_idx: 0,
260 registers: HashMap::new(),
261 parameters: HashMap::new(),
262 captured: HashMap::new(),
263 return_register: None,
264 handlers: HashMap::new(),
265 };
266
267 assert_eq!(eval.eval_next(), InnerOutput::Running);
268 assert_eq!(eval.eval_next(), InnerOutput::Return(Value::Int(1)));
269 }
270
271 #[test]
272 fn builtin() {}
273}