1use hashbrown::HashMap;
4
5use crate::{
6 alloc::{vec, Box, Rc, String, ToOwned, Vec},
7 arith::OrdArithmetic,
8 error::{Backtrace, CodeInModule, EvalResult, TupleLenMismatchContext},
9 executable::command::{Atom, Command, CompiledExpr, FieldName, SpannedAtom, SpannedCommand},
10 CallContext, Environment, Error, ErrorKind, Function, InterpretedFn, ModuleId, SpannedValue,
11 Value,
12};
13use arithmetic_parser::{BinaryOp, LvalueLen, MaybeSpanned, StripCode, UnaryOp};
14
15#[derive(Debug)]
17pub(crate) struct Executable<'a, T> {
18 id: Box<dyn ModuleId>,
19 commands: Vec<SpannedCommand<'a, T>>,
20 child_fns: Vec<Rc<ExecutableFn<'a, T>>>,
21 register_capacity: usize,
23}
24
25impl<'a, T: Clone> Clone for Executable<'a, T> {
26 fn clone(&self) -> Self {
27 Self {
28 id: self.id.clone_boxed(),
29 commands: self.commands.clone(),
30 child_fns: self.child_fns.clone(),
31 register_capacity: self.register_capacity,
32 }
33 }
34}
35
36impl<T: 'static + Clone> StripCode for Executable<'_, T> {
37 type Stripped = Executable<'static, T>;
38
39 fn strip_code(self) -> Self::Stripped {
40 Executable {
41 id: self.id,
42 commands: self
43 .commands
44 .into_iter()
45 .map(|command| command.map_extra(StripCode::strip_code).strip_code())
46 .collect(),
47 child_fns: self
48 .child_fns
49 .into_iter()
50 .map(|function| Rc::new(function.to_stripped_code()))
51 .collect(),
52 register_capacity: self.register_capacity,
53 }
54 }
55}
56
57impl<'a, T> Executable<'a, T> {
58 pub fn new(id: Box<dyn ModuleId>) -> Self {
59 Self {
60 id,
61 commands: vec![],
62 child_fns: vec![],
63 register_capacity: 0,
64 }
65 }
66
67 pub fn id(&self) -> &dyn ModuleId {
68 self.id.as_ref()
69 }
70
71 fn create_error<U>(&self, span: &MaybeSpanned<'a, U>, err: ErrorKind) -> Error<'a> {
72 Error::new(self.id.as_ref(), span, err)
73 }
74
75 pub fn push_command(&mut self, command: impl Into<SpannedCommand<'a, T>>) {
76 self.commands.push(command.into());
77 }
78
79 pub fn push_child_fn(&mut self, child_fn: ExecutableFn<'a, T>) -> usize {
80 let fn_ptr = self.child_fns.len();
81 self.child_fns.push(Rc::new(child_fn));
82 fn_ptr
83 }
84
85 pub fn finalize_function(&mut self, register_count: usize) {
86 match &mut self.commands[0].extra {
89 Command::Destructure { unchecked, .. } => {
90 *unchecked = true;
91 }
92 _ => unreachable!(),
93 }
94 self.register_capacity = register_count;
95 }
96
97 pub fn finalize_block(&mut self, register_count: usize) {
98 self.register_capacity = register_count;
99 }
100}
101
102impl<'a, T: Clone> Executable<'a, T> {
103 pub fn call_function(
104 &self,
105 captures: Vec<Value<'a, T>>,
106 args: Vec<Value<'a, T>>,
107 ctx: &mut CallContext<'_, 'a, T>,
108 ) -> EvalResult<'a, T> {
109 let mut registers = captures;
110 registers.push(Value::Tuple(args));
111 let mut env = Registers {
112 registers,
113 ..Registers::new()
114 };
115 env.execute(self, ctx.arithmetic(), ctx.backtrace())
116 }
117}
118
119#[derive(Debug)]
121pub(crate) struct ExecutableFn<'a, T> {
122 pub inner: Executable<'a, T>,
123 pub def_span: MaybeSpanned<'a>,
124 pub arg_count: LvalueLen,
125}
126
127impl<T: 'static + Clone> ExecutableFn<'_, T> {
128 pub fn to_stripped_code(&self) -> ExecutableFn<'static, T> {
129 ExecutableFn {
130 inner: self.inner.clone().strip_code(),
131 def_span: self.def_span.strip_code(),
132 arg_count: self.arg_count,
133 }
134 }
135}
136
137impl<T: 'static + Clone> StripCode for ExecutableFn<'_, T> {
138 type Stripped = ExecutableFn<'static, T>;
139
140 fn strip_code(self) -> Self::Stripped {
141 ExecutableFn {
142 inner: self.inner.strip_code(),
143 def_span: self.def_span.strip_code(),
144 arg_count: self.arg_count,
145 }
146 }
147}
148
149#[derive(Debug)]
150pub(crate) struct Registers<'a, T> {
151 registers: Vec<Value<'a, T>>,
153 vars: HashMap<String, usize>,
157 inner_scope_start: Option<usize>,
160}
161
162impl<T: Clone> Clone for Registers<'_, T> {
163 fn clone(&self) -> Self {
164 Self {
165 registers: self.registers.clone(),
166 vars: self.vars.clone(),
167 inner_scope_start: self.inner_scope_start,
168 }
169 }
170}
171
172impl<T: 'static + Clone> StripCode for Registers<'_, T> {
173 type Stripped = Registers<'static, T>;
174
175 fn strip_code(self) -> Self::Stripped {
176 Registers {
177 registers: self
178 .registers
179 .into_iter()
180 .map(StripCode::strip_code)
181 .collect(),
182 vars: self.vars,
183 inner_scope_start: self.inner_scope_start,
184 }
185 }
186}
187
188impl<'a, T> Registers<'a, T> {
189 pub fn new() -> Self {
190 Self {
191 registers: vec![],
192 vars: HashMap::new(),
193 inner_scope_start: None,
194 }
195 }
196
197 pub fn get_var(&self, name: &str) -> Option<&Value<'a, T>> {
198 let register = *self.vars.get(name)?;
199 Some(&self.registers[register])
200 }
201
202 pub fn variables(&self) -> impl Iterator<Item = (&str, &Value<'a, T>)> + '_ {
203 self.vars
204 .iter()
205 .map(move |(name, register)| (name.as_str(), &self.registers[*register]))
206 }
207
208 pub fn variables_map(&self) -> &HashMap<String, usize> {
209 &self.vars
210 }
211
212 pub fn register_count(&self) -> usize {
213 self.registers.len()
214 }
215
216 pub fn set_var(&mut self, name: &str, value: Value<'a, T>) {
217 let register = *self.vars.get(name).unwrap_or_else(|| {
218 panic!("Variable `{}` is not defined", name);
219 });
220 self.registers[register] = value;
221 }
222
223 pub fn insert_var(&mut self, name: &str, value: Value<'a, T>) -> bool {
225 if self.vars.contains_key(name) {
226 false
227 } else {
228 let register = self.registers.len();
229 self.registers.push(value);
230 self.vars.insert(name.to_owned(), register);
231
232 true
233 }
234 }
235}
236
237impl<'a, T: Clone> Registers<'a, T> {
238 pub fn update_from_env(&mut self, env: &Environment<'a, T>) {
240 for (var_name, register) in &self.vars {
241 if let Some(value) = env.get(var_name) {
242 self.registers[*register] = value.clone();
243 }
244 }
245 }
246
247 pub fn update_env(&self, env: &mut Environment<'a, T>) {
249 for (var_name, register) in &self.vars {
250 let value = self.registers[*register].clone();
251 env.insert(var_name, value);
255 }
256 }
257
258 pub fn into_variables(self) -> impl Iterator<Item = (String, Value<'a, T>)> {
259 let registers = self.registers;
260 self.vars
262 .into_iter()
263 .map(move |(name, register)| (name, registers[register].clone()))
264 }
265}
266
267impl<'a, T: Clone> Registers<'a, T> {
268 pub fn execute(
269 &mut self,
270 executable: &Executable<'a, T>,
271 arithmetic: &dyn OrdArithmetic<T>,
272 backtrace: Option<&mut Backtrace<'a>>,
273 ) -> EvalResult<'a, T> {
274 self.execute_inner(executable, arithmetic, backtrace)
275 .map_err(|err| {
276 if let Some(scope_start) = self.inner_scope_start.take() {
277 self.registers.truncate(scope_start);
278 }
279 err
280 })
281 }
282
283 fn execute_inner(
284 &mut self,
285 executable: &Executable<'a, T>,
286 arithmetic: &dyn OrdArithmetic<T>,
287 mut backtrace: Option<&mut Backtrace<'a>>,
288 ) -> EvalResult<'a, T> {
289 if let Some(additional_capacity) = executable
290 .register_capacity
291 .checked_sub(self.registers.len())
292 {
293 self.registers.reserve(additional_capacity);
294 }
295
296 for command in &executable.commands {
297 match &command.extra {
298 Command::Push(expr) => {
299 let expr_span = command.with_no_extra();
300 let expr_value = self.execute_expr(
301 expr_span,
302 expr,
303 executable,
304 arithmetic,
305 backtrace.as_deref_mut(),
306 )?;
307 self.registers.push(expr_value);
308 }
309
310 Command::Copy {
311 source,
312 destination,
313 } => {
314 self.registers[*destination] = self.registers[*source].clone();
315 }
316
317 Command::TruncateRegisters(size) => {
318 self.registers.truncate(*size);
319 }
320
321 Command::Destructure {
322 source,
323 start_len,
324 end_len,
325 lvalue_len,
326 unchecked,
327 } => {
328 let source = self.registers[*source].clone();
329 if let Value::Tuple(mut elements) = source {
330 if !*unchecked && !lvalue_len.matches(elements.len()) {
331 let err = ErrorKind::TupleLenMismatch {
332 lhs: *lvalue_len,
333 rhs: elements.len(),
334 context: TupleLenMismatchContext::Assignment,
335 };
336 return Err(executable.create_error(command, err));
337 }
338
339 let mut tail = elements.split_off(*start_len);
340 self.registers.extend(elements);
341 let end = tail.split_off(tail.len() - *end_len);
342 self.registers.push(Value::Tuple(tail));
343 self.registers.extend(end);
344 } else {
345 let err = ErrorKind::CannotDestructure;
346 return Err(executable.create_error(command, err));
347 }
348 }
349
350 Command::Annotate { register, name } => {
351 self.vars.insert(name.clone(), *register);
352 }
353
354 Command::StartInnerScope => {
355 debug_assert!(self.inner_scope_start.is_none());
356 self.inner_scope_start = Some(self.registers.len());
357 }
358 Command::EndInnerScope => {
359 debug_assert!(self.inner_scope_start.is_some());
360 self.inner_scope_start = None;
361 }
362 }
363 }
364
365 Ok(self.registers.pop().unwrap_or_else(Value::void))
366 }
367
368 fn execute_expr(
369 &self,
370 span: MaybeSpanned<'a>,
371 expr: &CompiledExpr<'a, T>,
372 executable: &Executable<'a, T>,
373 arithmetic: &dyn OrdArithmetic<T>,
374 backtrace: Option<&mut Backtrace<'a>>,
375 ) -> EvalResult<'a, T> {
376 match expr {
377 CompiledExpr::Atom(atom) => Ok(self.resolve_atom(atom)),
378
379 CompiledExpr::Tuple(atoms) => {
380 let values = atoms.iter().map(|atom| self.resolve_atom(atom)).collect();
381 Ok(Value::Tuple(values))
382 }
383 CompiledExpr::Object(fields) => {
384 let fields = fields
385 .iter()
386 .map(|(name, atom)| (name.clone(), self.resolve_atom(atom)));
387 Ok(Value::Object(fields.collect()))
388 }
389
390 CompiledExpr::Unary { op, inner } => {
391 let inner_value = self.resolve_atom(&inner.extra);
392 match op {
393 UnaryOp::Neg => inner_value.try_neg(arithmetic),
394 UnaryOp::Not => inner_value.try_not(),
395 _ => unreachable!("Checked during compilation"),
396 }
397 .map_err(|err| executable.create_error(&span, err))
398 }
399
400 CompiledExpr::Binary { op, lhs, rhs } => {
401 self.execute_binary_expr(executable.id(), span, *op, lhs, rhs, arithmetic)
402 }
403
404 CompiledExpr::FieldAccess {
405 receiver,
406 field: FieldName::Index(index),
407 } => {
408 if let Value::Tuple(mut tuple) = self.resolve_atom(&receiver.extra) {
409 let len = tuple.len();
410 if *index >= len {
411 Err(executable.create_error(
412 &span,
413 ErrorKind::IndexOutOfBounds { index: *index, len },
414 ))
415 } else {
416 Ok(tuple.swap_remove(*index))
417 }
418 } else {
419 Err(executable.create_error(&span, ErrorKind::CannotIndex))
420 }
421 }
422
423 CompiledExpr::FieldAccess {
424 receiver,
425 field: FieldName::Name(name),
426 } => {
427 if let Value::Object(mut obj) = self.resolve_atom(&receiver.extra) {
428 obj.remove(name).ok_or_else(|| {
429 let err = ErrorKind::NoField {
430 field: name.clone(),
431 available_fields: obj.keys().cloned().collect(),
432 };
433 executable.create_error(&span, err)
434 })
435 } else {
436 Err(executable.create_error(&span, ErrorKind::CannotAccessFields))
437 }
438 }
439
440 CompiledExpr::Function {
441 name,
442 original_name,
443 args,
444 } => {
445 if let Value::Function(function) = self.resolve_atom(&name.extra) {
446 let fn_name = original_name.as_deref().unwrap_or("(anonymous function)");
447 let arg_values = args
448 .iter()
449 .map(|arg| arg.copy_with_extra(self.resolve_atom(&arg.extra)))
450 .collect();
451 Self::eval_function(
452 &function,
453 fn_name,
454 executable.id.as_ref(),
455 span,
456 arg_values,
457 arithmetic,
458 backtrace,
459 )
460 } else {
461 Err(executable.create_error(&span, ErrorKind::CannotCall))
462 }
463 }
464
465 CompiledExpr::DefineFunction {
466 ptr,
467 captures,
468 capture_names,
469 } => {
470 let fn_executable = Rc::clone(&executable.child_fns[*ptr]);
471 let captured_values = captures
472 .iter()
473 .map(|capture| self.resolve_atom(&capture.extra))
474 .collect();
475
476 let function =
477 InterpretedFn::new(fn_executable, captured_values, capture_names.clone());
478 Ok(Value::interpreted_fn(function))
479 }
480 }
481 }
482
483 fn execute_binary_expr(
484 &self,
485 module_id: &dyn ModuleId,
486 span: MaybeSpanned<'a>,
487 op: BinaryOp,
488 lhs: &SpannedAtom<'a, T>,
489 rhs: &SpannedAtom<'a, T>,
490 arithmetic: &dyn OrdArithmetic<T>,
491 ) -> EvalResult<'a, T> {
492 let lhs_value = lhs.copy_with_extra(self.resolve_atom(&lhs.extra));
493 let rhs_value = rhs.copy_with_extra(self.resolve_atom(&rhs.extra));
494
495 match op {
496 BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
497 Value::try_binary_op(module_id, span, lhs_value, rhs_value, op, arithmetic)
498 }
499
500 BinaryOp::Eq | BinaryOp::NotEq => {
501 let is_eq = lhs_value
502 .extra
503 .eq_by_arithmetic(&rhs_value.extra, arithmetic);
504 Ok(Value::Bool(if op == BinaryOp::Eq { is_eq } else { !is_eq }))
505 }
506
507 BinaryOp::And => Value::try_and(module_id, &lhs_value, &rhs_value),
508 BinaryOp::Or => Value::try_or(module_id, &lhs_value, &rhs_value),
509
510 BinaryOp::Gt | BinaryOp::Lt | BinaryOp::Ge | BinaryOp::Le => {
511 Value::compare(module_id, &lhs_value, &rhs_value, op, arithmetic)
512 }
513
514 _ => unreachable!("Checked during compilation"),
515 }
516 }
517
518 fn eval_function(
519 function: &Function<'a, T>,
520 fn_name: &str,
521 module_id: &dyn ModuleId,
522 call_span: MaybeSpanned<'a>,
523 arg_values: Vec<SpannedValue<'a, T>>,
524 arithmetic: &dyn OrdArithmetic<T>,
525 mut backtrace: Option<&mut Backtrace<'a>>,
526 ) -> EvalResult<'a, T> {
527 let full_call_span = CodeInModule::new(module_id, call_span);
528 if let Some(backtrace) = backtrace.as_deref_mut() {
529 backtrace.push_call(fn_name, function.def_span(), full_call_span.clone());
530 }
531 let mut context = CallContext::new(full_call_span, backtrace.as_deref_mut(), arithmetic);
532
533 function.evaluate(arg_values, &mut context).map(|value| {
534 if let Some(backtrace) = backtrace {
535 backtrace.pop_call();
536 }
537 value
538 })
539 }
540
541 #[inline]
542 fn resolve_atom(&self, atom: &Atom<T>) -> Value<'a, T> {
543 match atom {
544 Atom::Register(index) => self.registers[*index].clone(),
545 Atom::Constant(value) => Value::Prim(value.clone()),
546 Atom::Void => Value::void(),
547 }
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use crate::{compiler::Compiler, executable::ModuleImports, WildcardId};
555 use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
556
557 #[test]
558 fn iterative_evaluation() {
559 let block = Untyped::<F32Grammar>::parse_statements("x").unwrap();
560 let (mut module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
561 assert_eq!(module.inner.register_capacity, 2);
562 assert_eq!(module.inner.commands.len(), 1); let mut env = Registers::new();
565 env.insert_var("x", Value::Prim(5.0));
566 module.imports = ModuleImports { inner: env };
567 let value = module.run().unwrap();
568 assert_eq!(value, Value::Prim(5.0));
569 }
570}