arithmetic_eval/executable/
registers.rs

1//! `Registers` for executing commands and closely related types.
2
3use hashbrown::HashMap;
4
5use crate::{
6    alloc::{vec, Box, Rc, String, ToOwned, Vec},
7    arith::OrdArithmetic,
8    error::{Backtrace, CodeInModule, EvalResult, TupleLenMismatchContext},
9    executable::command::{Atom, Command, CompiledExpr, FieldName, SpannedAtom, SpannedCommand},
10    CallContext, Environment, Error, ErrorKind, Function, InterpretedFn, ModuleId, SpannedValue,
11    Value,
12};
13use arithmetic_parser::{BinaryOp, LvalueLen, MaybeSpanned, StripCode, UnaryOp};
14
15/// Sequence of instructions that can be executed with the `Registers`.
16#[derive(Debug)]
17pub(crate) struct Executable<'a, T> {
18    id: Box<dyn ModuleId>,
19    commands: Vec<SpannedCommand<'a, T>>,
20    child_fns: Vec<Rc<ExecutableFn<'a, T>>>,
21    // Hint how many registers the executable requires.
22    register_capacity: usize,
23}
24
25impl<'a, T: Clone> Clone for Executable<'a, T> {
26    fn clone(&self) -> Self {
27        Self {
28            id: self.id.clone_boxed(),
29            commands: self.commands.clone(),
30            child_fns: self.child_fns.clone(),
31            register_capacity: self.register_capacity,
32        }
33    }
34}
35
36impl<T: 'static + Clone> StripCode for Executable<'_, T> {
37    type Stripped = Executable<'static, T>;
38
39    fn strip_code(self) -> Self::Stripped {
40        Executable {
41            id: self.id,
42            commands: self
43                .commands
44                .into_iter()
45                .map(|command| command.map_extra(StripCode::strip_code).strip_code())
46                .collect(),
47            child_fns: self
48                .child_fns
49                .into_iter()
50                .map(|function| Rc::new(function.to_stripped_code()))
51                .collect(),
52            register_capacity: self.register_capacity,
53        }
54    }
55}
56
57impl<'a, T> Executable<'a, T> {
58    pub fn new(id: Box<dyn ModuleId>) -> Self {
59        Self {
60            id,
61            commands: vec![],
62            child_fns: vec![],
63            register_capacity: 0,
64        }
65    }
66
67    pub fn id(&self) -> &dyn ModuleId {
68        self.id.as_ref()
69    }
70
71    fn create_error<U>(&self, span: &MaybeSpanned<'a, U>, err: ErrorKind) -> Error<'a> {
72        Error::new(self.id.as_ref(), span, err)
73    }
74
75    pub fn push_command(&mut self, command: impl Into<SpannedCommand<'a, T>>) {
76        self.commands.push(command.into());
77    }
78
79    pub fn push_child_fn(&mut self, child_fn: ExecutableFn<'a, T>) -> usize {
80        let fn_ptr = self.child_fns.len();
81        self.child_fns.push(Rc::new(child_fn));
82        fn_ptr
83    }
84
85    pub fn finalize_function(&mut self, register_count: usize) {
86        // We check number of arguments in `InterpretedFn::evaluate()` in order to provide
87        // a more precise error.
88        match &mut self.commands[0].extra {
89            Command::Destructure { unchecked, .. } => {
90                *unchecked = true;
91            }
92            _ => unreachable!(),
93        }
94        self.register_capacity = register_count;
95    }
96
97    pub fn finalize_block(&mut self, register_count: usize) {
98        self.register_capacity = register_count;
99    }
100}
101
102impl<'a, T: Clone> Executable<'a, T> {
103    pub fn call_function(
104        &self,
105        captures: Vec<Value<'a, T>>,
106        args: Vec<Value<'a, T>>,
107        ctx: &mut CallContext<'_, 'a, T>,
108    ) -> EvalResult<'a, T> {
109        let mut registers = captures;
110        registers.push(Value::Tuple(args));
111        let mut env = Registers {
112            registers,
113            ..Registers::new()
114        };
115        env.execute(self, ctx.arithmetic(), ctx.backtrace())
116    }
117}
118
119/// `Executable` together with function-specific info.
120#[derive(Debug)]
121pub(crate) struct ExecutableFn<'a, T> {
122    pub inner: Executable<'a, T>,
123    pub def_span: MaybeSpanned<'a>,
124    pub arg_count: LvalueLen,
125}
126
127impl<T: 'static + Clone> ExecutableFn<'_, T> {
128    pub fn to_stripped_code(&self) -> ExecutableFn<'static, T> {
129        ExecutableFn {
130            inner: self.inner.clone().strip_code(),
131            def_span: self.def_span.strip_code(),
132            arg_count: self.arg_count,
133        }
134    }
135}
136
137impl<T: 'static + Clone> StripCode for ExecutableFn<'_, T> {
138    type Stripped = ExecutableFn<'static, T>;
139
140    fn strip_code(self) -> Self::Stripped {
141        ExecutableFn {
142            inner: self.inner.strip_code(),
143            def_span: self.def_span.strip_code(),
144            arg_count: self.arg_count,
145        }
146    }
147}
148
149#[derive(Debug)]
150pub(crate) struct Registers<'a, T> {
151    // TODO: restore `SmallVec` wrapped into a covariant wrapper.
152    registers: Vec<Value<'a, T>>,
153    // Maps variables to registers. Variables are mapped only from the global scope;
154    // thus, we don't need to remove them on error in an inner scope.
155    // TODO: investigate using stack-hosted small strings for keys.
156    vars: HashMap<String, usize>,
157    // Marks the start of a first inner scope currently being evaluated. This is used
158    // to quickly remove registers from the inner scopes on error.
159    inner_scope_start: Option<usize>,
160}
161
162impl<T: Clone> Clone for Registers<'_, T> {
163    fn clone(&self) -> Self {
164        Self {
165            registers: self.registers.clone(),
166            vars: self.vars.clone(),
167            inner_scope_start: self.inner_scope_start,
168        }
169    }
170}
171
172impl<T: 'static + Clone> StripCode for Registers<'_, T> {
173    type Stripped = Registers<'static, T>;
174
175    fn strip_code(self) -> Self::Stripped {
176        Registers {
177            registers: self
178                .registers
179                .into_iter()
180                .map(StripCode::strip_code)
181                .collect(),
182            vars: self.vars,
183            inner_scope_start: self.inner_scope_start,
184        }
185    }
186}
187
188impl<'a, T> Registers<'a, T> {
189    pub fn new() -> Self {
190        Self {
191            registers: vec![],
192            vars: HashMap::new(),
193            inner_scope_start: None,
194        }
195    }
196
197    pub fn get_var(&self, name: &str) -> Option<&Value<'a, T>> {
198        let register = *self.vars.get(name)?;
199        Some(&self.registers[register])
200    }
201
202    pub fn variables(&self) -> impl Iterator<Item = (&str, &Value<'a, T>)> + '_ {
203        self.vars
204            .iter()
205            .map(move |(name, register)| (name.as_str(), &self.registers[*register]))
206    }
207
208    pub fn variables_map(&self) -> &HashMap<String, usize> {
209        &self.vars
210    }
211
212    pub fn register_count(&self) -> usize {
213        self.registers.len()
214    }
215
216    pub fn set_var(&mut self, name: &str, value: Value<'a, T>) {
217        let register = *self.vars.get(name).unwrap_or_else(|| {
218            panic!("Variable `{}` is not defined", name);
219        });
220        self.registers[register] = value;
221    }
222
223    /// Allocates a new register with the specified name if the name was not allocated previously.
224    pub fn insert_var(&mut self, name: &str, value: Value<'a, T>) -> bool {
225        if self.vars.contains_key(name) {
226            false
227        } else {
228            let register = self.registers.len();
229            self.registers.push(value);
230            self.vars.insert(name.to_owned(), register);
231
232            true
233        }
234    }
235}
236
237impl<'a, T: Clone> Registers<'a, T> {
238    /// Updates from the specified environment. Updates are performed in place.
239    pub fn update_from_env(&mut self, env: &Environment<'a, T>) {
240        for (var_name, register) in &self.vars {
241            if let Some(value) = env.get(var_name) {
242                self.registers[*register] = value.clone();
243            }
244        }
245    }
246
247    /// Updates environment from this instance.
248    pub fn update_env(&self, env: &mut Environment<'a, T>) {
249        for (var_name, register) in &self.vars {
250            let value = self.registers[*register].clone();
251            // ^-- We cannot move `value` from `registers` because multiple names may be pointing
252            // to the same register.
253
254            env.insert(var_name, value);
255        }
256    }
257
258    pub fn into_variables(self) -> impl Iterator<Item = (String, Value<'a, T>)> {
259        let registers = self.registers;
260        // Moving out of `registers` is not sound because of possible aliasing.
261        self.vars
262            .into_iter()
263            .map(move |(name, register)| (name, registers[register].clone()))
264    }
265}
266
267impl<'a, T: Clone> Registers<'a, T> {
268    pub fn execute(
269        &mut self,
270        executable: &Executable<'a, T>,
271        arithmetic: &dyn OrdArithmetic<T>,
272        backtrace: Option<&mut Backtrace<'a>>,
273    ) -> EvalResult<'a, T> {
274        self.execute_inner(executable, arithmetic, backtrace)
275            .map_err(|err| {
276                if let Some(scope_start) = self.inner_scope_start.take() {
277                    self.registers.truncate(scope_start);
278                }
279                err
280            })
281    }
282
283    fn execute_inner(
284        &mut self,
285        executable: &Executable<'a, T>,
286        arithmetic: &dyn OrdArithmetic<T>,
287        mut backtrace: Option<&mut Backtrace<'a>>,
288    ) -> EvalResult<'a, T> {
289        if let Some(additional_capacity) = executable
290            .register_capacity
291            .checked_sub(self.registers.len())
292        {
293            self.registers.reserve(additional_capacity);
294        }
295
296        for command in &executable.commands {
297            match &command.extra {
298                Command::Push(expr) => {
299                    let expr_span = command.with_no_extra();
300                    let expr_value = self.execute_expr(
301                        expr_span,
302                        expr,
303                        executable,
304                        arithmetic,
305                        backtrace.as_deref_mut(),
306                    )?;
307                    self.registers.push(expr_value);
308                }
309
310                Command::Copy {
311                    source,
312                    destination,
313                } => {
314                    self.registers[*destination] = self.registers[*source].clone();
315                }
316
317                Command::TruncateRegisters(size) => {
318                    self.registers.truncate(*size);
319                }
320
321                Command::Destructure {
322                    source,
323                    start_len,
324                    end_len,
325                    lvalue_len,
326                    unchecked,
327                } => {
328                    let source = self.registers[*source].clone();
329                    if let Value::Tuple(mut elements) = source {
330                        if !*unchecked && !lvalue_len.matches(elements.len()) {
331                            let err = ErrorKind::TupleLenMismatch {
332                                lhs: *lvalue_len,
333                                rhs: elements.len(),
334                                context: TupleLenMismatchContext::Assignment,
335                            };
336                            return Err(executable.create_error(command, err));
337                        }
338
339                        let mut tail = elements.split_off(*start_len);
340                        self.registers.extend(elements);
341                        let end = tail.split_off(tail.len() - *end_len);
342                        self.registers.push(Value::Tuple(tail));
343                        self.registers.extend(end);
344                    } else {
345                        let err = ErrorKind::CannotDestructure;
346                        return Err(executable.create_error(command, err));
347                    }
348                }
349
350                Command::Annotate { register, name } => {
351                    self.vars.insert(name.clone(), *register);
352                }
353
354                Command::StartInnerScope => {
355                    debug_assert!(self.inner_scope_start.is_none());
356                    self.inner_scope_start = Some(self.registers.len());
357                }
358                Command::EndInnerScope => {
359                    debug_assert!(self.inner_scope_start.is_some());
360                    self.inner_scope_start = None;
361                }
362            }
363        }
364
365        Ok(self.registers.pop().unwrap_or_else(Value::void))
366    }
367
368    fn execute_expr(
369        &self,
370        span: MaybeSpanned<'a>,
371        expr: &CompiledExpr<'a, T>,
372        executable: &Executable<'a, T>,
373        arithmetic: &dyn OrdArithmetic<T>,
374        backtrace: Option<&mut Backtrace<'a>>,
375    ) -> EvalResult<'a, T> {
376        match expr {
377            CompiledExpr::Atom(atom) => Ok(self.resolve_atom(atom)),
378
379            CompiledExpr::Tuple(atoms) => {
380                let values = atoms.iter().map(|atom| self.resolve_atom(atom)).collect();
381                Ok(Value::Tuple(values))
382            }
383            CompiledExpr::Object(fields) => {
384                let fields = fields
385                    .iter()
386                    .map(|(name, atom)| (name.clone(), self.resolve_atom(atom)));
387                Ok(Value::Object(fields.collect()))
388            }
389
390            CompiledExpr::Unary { op, inner } => {
391                let inner_value = self.resolve_atom(&inner.extra);
392                match op {
393                    UnaryOp::Neg => inner_value.try_neg(arithmetic),
394                    UnaryOp::Not => inner_value.try_not(),
395                    _ => unreachable!("Checked during compilation"),
396                }
397                .map_err(|err| executable.create_error(&span, err))
398            }
399
400            CompiledExpr::Binary { op, lhs, rhs } => {
401                self.execute_binary_expr(executable.id(), span, *op, lhs, rhs, arithmetic)
402            }
403
404            CompiledExpr::FieldAccess {
405                receiver,
406                field: FieldName::Index(index),
407            } => {
408                if let Value::Tuple(mut tuple) = self.resolve_atom(&receiver.extra) {
409                    let len = tuple.len();
410                    if *index >= len {
411                        Err(executable.create_error(
412                            &span,
413                            ErrorKind::IndexOutOfBounds { index: *index, len },
414                        ))
415                    } else {
416                        Ok(tuple.swap_remove(*index))
417                    }
418                } else {
419                    Err(executable.create_error(&span, ErrorKind::CannotIndex))
420                }
421            }
422
423            CompiledExpr::FieldAccess {
424                receiver,
425                field: FieldName::Name(name),
426            } => {
427                if let Value::Object(mut obj) = self.resolve_atom(&receiver.extra) {
428                    obj.remove(name).ok_or_else(|| {
429                        let err = ErrorKind::NoField {
430                            field: name.clone(),
431                            available_fields: obj.keys().cloned().collect(),
432                        };
433                        executable.create_error(&span, err)
434                    })
435                } else {
436                    Err(executable.create_error(&span, ErrorKind::CannotAccessFields))
437                }
438            }
439
440            CompiledExpr::Function {
441                name,
442                original_name,
443                args,
444            } => {
445                if let Value::Function(function) = self.resolve_atom(&name.extra) {
446                    let fn_name = original_name.as_deref().unwrap_or("(anonymous function)");
447                    let arg_values = args
448                        .iter()
449                        .map(|arg| arg.copy_with_extra(self.resolve_atom(&arg.extra)))
450                        .collect();
451                    Self::eval_function(
452                        &function,
453                        fn_name,
454                        executable.id.as_ref(),
455                        span,
456                        arg_values,
457                        arithmetic,
458                        backtrace,
459                    )
460                } else {
461                    Err(executable.create_error(&span, ErrorKind::CannotCall))
462                }
463            }
464
465            CompiledExpr::DefineFunction {
466                ptr,
467                captures,
468                capture_names,
469            } => {
470                let fn_executable = Rc::clone(&executable.child_fns[*ptr]);
471                let captured_values = captures
472                    .iter()
473                    .map(|capture| self.resolve_atom(&capture.extra))
474                    .collect();
475
476                let function =
477                    InterpretedFn::new(fn_executable, captured_values, capture_names.clone());
478                Ok(Value::interpreted_fn(function))
479            }
480        }
481    }
482
483    fn execute_binary_expr(
484        &self,
485        module_id: &dyn ModuleId,
486        span: MaybeSpanned<'a>,
487        op: BinaryOp,
488        lhs: &SpannedAtom<'a, T>,
489        rhs: &SpannedAtom<'a, T>,
490        arithmetic: &dyn OrdArithmetic<T>,
491    ) -> EvalResult<'a, T> {
492        let lhs_value = lhs.copy_with_extra(self.resolve_atom(&lhs.extra));
493        let rhs_value = rhs.copy_with_extra(self.resolve_atom(&rhs.extra));
494
495        match op {
496            BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
497                Value::try_binary_op(module_id, span, lhs_value, rhs_value, op, arithmetic)
498            }
499
500            BinaryOp::Eq | BinaryOp::NotEq => {
501                let is_eq = lhs_value
502                    .extra
503                    .eq_by_arithmetic(&rhs_value.extra, arithmetic);
504                Ok(Value::Bool(if op == BinaryOp::Eq { is_eq } else { !is_eq }))
505            }
506
507            BinaryOp::And => Value::try_and(module_id, &lhs_value, &rhs_value),
508            BinaryOp::Or => Value::try_or(module_id, &lhs_value, &rhs_value),
509
510            BinaryOp::Gt | BinaryOp::Lt | BinaryOp::Ge | BinaryOp::Le => {
511                Value::compare(module_id, &lhs_value, &rhs_value, op, arithmetic)
512            }
513
514            _ => unreachable!("Checked during compilation"),
515        }
516    }
517
518    fn eval_function(
519        function: &Function<'a, T>,
520        fn_name: &str,
521        module_id: &dyn ModuleId,
522        call_span: MaybeSpanned<'a>,
523        arg_values: Vec<SpannedValue<'a, T>>,
524        arithmetic: &dyn OrdArithmetic<T>,
525        mut backtrace: Option<&mut Backtrace<'a>>,
526    ) -> EvalResult<'a, T> {
527        let full_call_span = CodeInModule::new(module_id, call_span);
528        if let Some(backtrace) = backtrace.as_deref_mut() {
529            backtrace.push_call(fn_name, function.def_span(), full_call_span.clone());
530        }
531        let mut context = CallContext::new(full_call_span, backtrace.as_deref_mut(), arithmetic);
532
533        function.evaluate(arg_values, &mut context).map(|value| {
534            if let Some(backtrace) = backtrace {
535                backtrace.pop_call();
536            }
537            value
538        })
539    }
540
541    #[inline]
542    fn resolve_atom(&self, atom: &Atom<T>) -> Value<'a, T> {
543        match atom {
544            Atom::Register(index) => self.registers[*index].clone(),
545            Atom::Constant(value) => Value::Prim(value.clone()),
546            Atom::Void => Value::void(),
547        }
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use crate::{compiler::Compiler, executable::ModuleImports, WildcardId};
555    use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
556
557    #[test]
558    fn iterative_evaluation() {
559        let block = Untyped::<F32Grammar>::parse_statements("x").unwrap();
560        let (mut module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
561        assert_eq!(module.inner.register_capacity, 2);
562        assert_eq!(module.inner.commands.len(), 1); // push `x` from r0 to r1
563
564        let mut env = Registers::new();
565        env.insert_var("x", Value::Prim(5.0));
566        module.imports = ModuleImports { inner: env };
567        let value = module.run().unwrap();
568        assert_eq!(value, Value::Prim(5.0));
569    }
570}