arithmetic_eval/compiler/
mod.rs

1//! Transformation of AST output by the parser into non-recursive format.
2
3use hashbrown::HashMap;
4
5use crate::{
6    alloc::{Box, String, ToOwned},
7    executable::{Atom, Command, CompiledExpr, Executable, ExecutableModule, FieldName, Registers},
8    Error, ErrorKind, ModuleId, Value,
9};
10use arithmetic_parser::{
11    grammars::Grammar, BinaryOp, Block, Destructure, FnDefinition, InputSpan, Lvalue,
12    ObjectDestructure, Spanned, SpannedLvalue, UnaryOp,
13};
14
15mod captures;
16mod expr;
17
18use self::captures::{CapturesExtractor, CompilerExtTarget};
19
20pub(crate) type ImportSpans<'a> = HashMap<String, Spanned<'a>>;
21
22#[derive(Debug)]
23pub(crate) struct Compiler {
24    /// Mapping between registers and named variables.
25    vars_to_registers: HashMap<String, usize>,
26    scope_depth: usize,
27    register_count: usize,
28    module_id: Box<dyn ModuleId>,
29}
30
31impl Compiler {
32    fn new(module_id: Box<dyn ModuleId>) -> Self {
33        Self {
34            vars_to_registers: HashMap::new(),
35            scope_depth: 0,
36            register_count: 0,
37            module_id,
38        }
39    }
40
41    fn from_env<T>(module_id: Box<dyn ModuleId>, env: &Registers<'_, T>) -> Self {
42        Self {
43            vars_to_registers: env.variables_map().clone(),
44            register_count: env.register_count(),
45            scope_depth: 0,
46            module_id,
47        }
48    }
49
50    /// Backups this instance. This effectively clones all fields.
51    fn backup(&mut self) -> Self {
52        Self {
53            vars_to_registers: self.vars_to_registers.clone(),
54            scope_depth: self.scope_depth,
55            register_count: self.register_count,
56            module_id: self.module_id.clone_boxed(),
57        }
58    }
59
60    fn create_error<'a, T>(&self, span: &Spanned<'a, T>, err: ErrorKind) -> Error<'a> {
61        Error::new(self.module_id.as_ref(), span, err)
62    }
63
64    fn check_unary_op<'a>(&self, op: &Spanned<'a, UnaryOp>) -> Result<UnaryOp, Error<'a>> {
65        match op.extra {
66            UnaryOp::Neg | UnaryOp::Not => Ok(op.extra),
67            _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
68        }
69    }
70
71    fn check_binary_op<'a>(&self, op: &Spanned<'a, BinaryOp>) -> Result<BinaryOp, Error<'a>> {
72        match op.extra {
73            BinaryOp::Add
74            | BinaryOp::Sub
75            | BinaryOp::Mul
76            | BinaryOp::Div
77            | BinaryOp::Power
78            | BinaryOp::And
79            | BinaryOp::Or
80            | BinaryOp::Eq
81            | BinaryOp::NotEq
82            | BinaryOp::Gt
83            | BinaryOp::Ge
84            | BinaryOp::Lt
85            | BinaryOp::Le => Ok(op.extra),
86
87            _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
88        }
89    }
90
91    fn get_var(&self, name: &str) -> usize {
92        *self
93            .vars_to_registers
94            .get(name)
95            .expect("Captures must created during module compilation")
96    }
97
98    fn push_assignment<'a, T, U>(
99        &mut self,
100        executable: &mut Executable<'a, T>,
101        rhs: CompiledExpr<'a, T>,
102        rhs_span: &Spanned<'a, U>,
103    ) -> usize {
104        let register = self.register_count;
105        let command = Command::Push(rhs);
106        executable.push_command(rhs_span.copy_with_extra(command));
107        self.register_count += 1;
108        register
109    }
110
111    pub fn compile_module<'a, Id: ModuleId, T: Grammar<'a>>(
112        module_id: Id,
113        block: &Block<'a, T>,
114    ) -> Result<(ExecutableModule<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
115        let module_id = Box::new(module_id) as Box<dyn ModuleId>;
116        let (captures, import_spans) = Self::extract_captures(module_id.clone_boxed(), block)?;
117        let mut compiler = Self::from_env(module_id.clone_boxed(), &captures);
118
119        let mut executable = Executable::new(module_id);
120        let empty_span = InputSpan::new("");
121        let last_atom = compiler
122            .compile_block_inner(&mut executable, block)?
123            .map_or(Atom::Void, |spanned| spanned.extra);
124        // Push the last variable to a register to be popped during execution.
125        compiler.push_assignment(
126            &mut executable,
127            CompiledExpr::Atom(last_atom),
128            &empty_span.into(),
129        );
130
131        executable.finalize_block(compiler.register_count);
132        let module = ExecutableModule::from_parts(executable, captures);
133        Ok((module, import_spans))
134    }
135
136    fn extract_captures<'a, T: Grammar<'a>>(
137        module_id: Box<dyn ModuleId>,
138        block: &Block<'a, T>,
139    ) -> Result<(Registers<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
140        let mut extractor = CapturesExtractor::new(module_id);
141        extractor.eval_block(&block)?;
142
143        let mut captures = Registers::new();
144        for &var_name in extractor.captures.keys() {
145            captures.insert_var(var_name, Value::void());
146        }
147
148        let import_spans = extractor
149            .captures
150            .into_iter()
151            .map(|(var_name, var_span)| (var_name.to_owned(), var_span))
152            .collect();
153
154        Ok((captures, import_spans))
155    }
156
157    fn assign<'a, T, Ty>(
158        &mut self,
159        executable: &mut Executable<'a, T>,
160        lhs: &SpannedLvalue<'a, Ty>,
161        rhs_register: usize,
162    ) -> Result<(), Error<'a>> {
163        match &lhs.extra {
164            Lvalue::Variable { .. } => {
165                self.insert_var(executable, lhs.with_no_extra(), rhs_register);
166            }
167
168            Lvalue::Tuple(destructure) => {
169                let span = lhs.with_no_extra();
170                self.destructure(executable, destructure, span, rhs_register)?;
171            }
172
173            Lvalue::Object(destructure) => {
174                let span = lhs.with_no_extra();
175                self.destructure_object(executable, destructure, span, rhs_register)?;
176            }
177
178            _ => {
179                let err = ErrorKind::unsupported(lhs.extra.ty());
180                return Err(self.create_error(lhs, err));
181            }
182        }
183
184        Ok(())
185    }
186
187    fn insert_var<'a, T>(
188        &mut self,
189        executable: &mut Executable<'a, T>,
190        var_span: Spanned<'a>,
191        register: usize,
192    ) {
193        let var_name = *var_span.fragment();
194        if var_name != "_" {
195            self.vars_to_registers.insert(var_name.to_owned(), register);
196
197            // It does not make sense to annotate vars in the inner scopes, since
198            // they cannot be accessed externally.
199            if self.scope_depth == 0 {
200                let command = Command::Annotate {
201                    register,
202                    name: var_name.to_owned(),
203                };
204                executable.push_command(var_span.copy_with_extra(command));
205            }
206        }
207    }
208
209    fn destructure<'a, T, Ty>(
210        &mut self,
211        executable: &mut Executable<'a, T>,
212        destructure: &Destructure<'a, Ty>,
213        span: Spanned<'a>,
214        rhs_register: usize,
215    ) -> Result<(), Error<'a>> {
216        let command = Command::Destructure {
217            source: rhs_register,
218            start_len: destructure.start.len(),
219            end_len: destructure.end.len(),
220            lvalue_len: destructure.len(),
221            unchecked: false,
222        };
223        executable.push_command(span.copy_with_extra(command));
224        let start_register = self.register_count;
225        self.register_count += destructure.start.len() + destructure.end.len() + 1;
226
227        for (i, lvalue) in (start_register..).zip(&destructure.start) {
228            self.assign(executable, lvalue, i)?;
229        }
230
231        let start_register = start_register + destructure.start.len();
232        if let Some(middle) = &destructure.middle {
233            if let Some(lvalue) = middle.extra.to_lvalue() {
234                self.assign(executable, &lvalue, start_register)?;
235            }
236        }
237
238        let start_register = start_register + 1;
239        for (i, lvalue) in (start_register..).zip(&destructure.end) {
240            self.assign(executable, lvalue, i)?;
241        }
242
243        Ok(())
244    }
245
246    fn destructure_object<'a, T, Ty>(
247        &mut self,
248        executable: &mut Executable<'a, T>,
249        destructure: &ObjectDestructure<'a, Ty>,
250        span: Spanned<'a>,
251        rhs_register: usize,
252    ) -> Result<(), Error<'a>> {
253        for field in &destructure.fields {
254            let field_name = FieldName::Name((*field.field_name.fragment()).to_owned());
255            let field_access = CompiledExpr::FieldAccess {
256                receiver: span.copy_with_extra(Atom::Register(rhs_register)).into(),
257                field: field_name,
258            };
259            let register = self.push_assignment(executable, field_access, &field.field_name);
260            if let Some(binding) = &field.binding {
261                self.assign(executable, binding, register)?;
262            } else {
263                self.insert_var(executable, field.field_name, register);
264            }
265        }
266        Ok(())
267    }
268}
269
270/// Compiler extensions defined for some AST nodes, most notably, `Block`.
271///
272/// # Examples
273///
274/// ```
275/// use arithmetic_parser::grammars::{F32Grammar, Parse, Untyped};
276/// use arithmetic_eval::CompilerExt;
277/// # use hashbrown::HashSet;
278/// # use core::iter::FromIterator;
279///
280/// # fn main() -> anyhow::Result<()> {
281/// let block = "x = sin(0.5) / PI; y = x * E; (x, y)";
282/// let block = Untyped::<F32Grammar>::parse_statements(block)?;
283/// let undefined_vars = block.undefined_variables()?;
284/// assert_eq!(
285///     undefined_vars.keys().copied().collect::<HashSet<_>>(),
286///     HashSet::from_iter(vec!["sin", "PI", "E"])
287/// );
288/// assert_eq!(undefined_vars["PI"].location_offset(), 15);
289/// # Ok(())
290/// # }
291/// ```
292pub trait CompilerExt<'a> {
293    /// Returns variables not defined within the AST node, together with the span of their first
294    /// occurrence.
295    ///
296    /// # Errors
297    ///
298    /// - Returns an error if the AST is intrinsically malformed. This may be the case if it
299    ///   contains destructuring with the same variable on left-hand side,
300    ///   such as `(x, x) = ...`.
301    ///
302    /// The fact that an error is *not* returned does not guarantee that the AST node will evaluate
303    /// successfully if all variables are assigned.
304    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>>;
305}
306
307impl<'a, T: Grammar<'a>> CompilerExt<'a> for Block<'a, T> {
308    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
309        CompilerExtTarget::Block(self).get_undefined_variables()
310    }
311}
312
313impl<'a, T: Grammar<'a>> CompilerExt<'a> for FnDefinition<'a, T> {
314    fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
315        CompilerExtTarget::FnDefinition(self).get_undefined_variables()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::{Value, WildcardId};
323
324    use arithmetic_parser::grammars::{F32Grammar, Parse, ParseLiteral, Typed, Untyped};
325    use arithmetic_parser::{Expr, NomResult};
326
327    #[test]
328    fn compilation_basics() {
329        let block = "x = 3; 1 + { y = 2; y * x } == 7";
330        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
331        let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
332        let value = module.run().unwrap();
333        assert_eq!(value, Value::Bool(true));
334    }
335
336    #[test]
337    fn compiled_function() {
338        let block = "add = |x, y| x + y; add(2, 3) == 5";
339        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
340        let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
341        assert_eq!(module.run().unwrap(), Value::Bool(true));
342    }
343
344    #[test]
345    fn compiled_function_with_capture() {
346        let block = "A = 2; add = |x, y| x + y / A; add(2, 3) == 3.5";
347        let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
348        let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
349        assert_eq!(module.run().unwrap(), Value::Bool(true));
350    }
351
352    #[test]
353    fn variable_extraction() {
354        let def = "|a, b| ({ x = a * b + y; x - 2 }, a / b)";
355        let def = Untyped::<F32Grammar>::parse_statements(def)
356            .unwrap()
357            .return_value
358            .unwrap();
359        let def = match def.extra {
360            Expr::FnDefinition(def) => def,
361            other => panic!("Unexpected function parsing result: {:?}", other),
362        };
363
364        let captures = def.undefined_variables().unwrap();
365        assert_eq!(captures["y"].location_offset(), 22);
366        assert!(!captures.contains_key("x"));
367    }
368
369    #[test]
370    fn variable_extraction_with_scoping() {
371        let def = "|a, b| ({ x = a * b + y; x - 2 }, a / x)";
372        let def = Untyped::<F32Grammar>::parse_statements(def)
373            .unwrap()
374            .return_value
375            .unwrap();
376        let def = match def.extra {
377            Expr::FnDefinition(def) => def,
378            other => panic!("Unexpected function parsing result: {:?}", other),
379        };
380
381        let captures = def.undefined_variables().unwrap();
382        assert_eq!(captures["y"].location_offset(), 22);
383        assert_eq!(captures["x"].location_offset(), 38);
384    }
385
386    #[test]
387    fn extracting_captures() {
388        let program = "y = 5 * x; y - 3 + x";
389        let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
390        let (registers, import_spans) =
391            Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
392
393        assert_eq!(registers.register_count(), 1);
394        assert_eq!(*registers.get_var("x").unwrap(), Value::void());
395        assert_eq!(import_spans.len(), 1);
396        assert_eq!(import_spans["x"], Spanned::from_str(program, 8..9));
397    }
398
399    #[test]
400    fn extracting_captures_with_inner_fns() {
401        let program = r#"
402            y = 5 * x;          // x is a capture
403            fun = |z| {         // z is not a capture
404                z * x + y * PI  // y is not a capture for the entire module, PI is
405            };
406        "#;
407        let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
408
409        let (registers, import_spans) =
410            Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
411        assert_eq!(registers.register_count(), 2);
412        assert!(registers.variables_map().contains_key("x"));
413        assert!(registers.variables_map().contains_key("PI"));
414        assert_eq!(import_spans["x"].location_line(), 2); // should be the first mention
415    }
416
417    #[test]
418    fn type_casts_are_ignored() {
419        struct TypedGrammar;
420
421        impl ParseLiteral for TypedGrammar {
422            type Lit = f32;
423
424            fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
425                F32Grammar::parse_literal(input)
426            }
427        }
428
429        impl Grammar<'_> for TypedGrammar {
430            type Type = ();
431
432            fn parse_type(input: InputSpan<'_>) -> NomResult<'_, Self::Type> {
433                use nom::{bytes::complete::tag, combinator::map};
434                map(tag("Num"), drop)(input)
435            }
436        }
437
438        let block = "x = 3 as Num; 1 + { y = 2; y * x as Num } == 7";
439        let block = Typed::<TypedGrammar>::parse_statements(block).unwrap();
440        let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
441        let value = module.run().unwrap();
442        assert_eq!(value, Value::Bool(true));
443    }
444}