arithmetic_eval/compiler/
captures.rs

1//! Captures extractor.
2
3use hashbrown::HashMap;
4
5use core::iter;
6
7use crate::{
8    alloc::{vec, Box, Vec},
9    error::{AuxErrorInfo, RepeatedAssignmentContext},
10    Error, ErrorKind, ModuleId, WildcardId,
11};
12use arithmetic_parser::{
13    grammars::Grammar, Block, Destructure, Expr, FnDefinition, Lvalue, Spanned, SpannedExpr,
14    SpannedLvalue, SpannedStatement, Statement,
15};
16
17/// Helper context for symbolic execution of a function body or a block in order to determine
18/// variables captured by it.
19#[derive(Debug)]
20pub(super) struct CapturesExtractor<'a> {
21    module_id: Box<dyn ModuleId>,
22    local_vars: Vec<HashMap<&'a str, Spanned<'a>>>,
23    pub captures: HashMap<&'a str, Spanned<'a>>,
24}
25
26impl<'a> CapturesExtractor<'a> {
27    pub fn new(module_id: Box<dyn ModuleId>) -> Self {
28        Self {
29            module_id,
30            local_vars: vec![],
31            captures: HashMap::new(),
32        }
33    }
34
35    /// Collects variables captured by the function into a single `Scope`.
36    pub fn eval_function<T: Grammar<'a>>(
37        &mut self,
38        definition: &FnDefinition<'a, T>,
39    ) -> Result<(), Error<'a>> {
40        let mut fn_local_vars = HashMap::new();
41        extract_vars(
42            self.module_id.as_ref(),
43            &mut fn_local_vars,
44            &definition.args.extra,
45            RepeatedAssignmentContext::FnArgs,
46        )?;
47        self.eval_block_inner(&definition.body, fn_local_vars)
48    }
49
50    fn has_var(&self, var_name: &str) -> bool {
51        self.local_vars.iter().any(|set| set.contains_key(var_name))
52    }
53
54    /// Processes a local variable in the rvalue position.
55    fn eval_local_var<T>(&mut self, var_span: &Spanned<'a, T>) {
56        let var_name = *var_span.fragment();
57        if !self.has_var(var_name) && !self.captures.contains_key(var_name) {
58            self.captures.insert(var_name, var_span.with_no_extra());
59        }
60    }
61
62    fn create_error<T>(&self, span: &Spanned<'a, T>, err: ErrorKind) -> Error<'a> {
63        Error::new(self.module_id.as_ref(), span, err)
64    }
65
66    /// Evaluates an expression with the function validation semantics, i.e., to determine
67    /// function captures.
68    fn eval<T: Grammar<'a>>(&mut self, expr: &SpannedExpr<'a, T>) -> Result<(), Error<'a>> {
69        match &expr.extra {
70            Expr::Variable => {
71                self.eval_local_var(expr);
72            }
73
74            Expr::Literal(_) => { /* no action */ }
75
76            Expr::Tuple(fragments) => {
77                for fragment in fragments {
78                    self.eval(fragment)?;
79                }
80            }
81            Expr::Unary { inner, .. } => {
82                self.eval(inner)?;
83            }
84            Expr::Binary { lhs, rhs, .. } => {
85                self.eval(lhs)?;
86                self.eval(rhs)?;
87            }
88
89            Expr::Function { args, name } => {
90                for arg in args {
91                    self.eval(arg)?;
92                }
93                self.eval(name)?;
94            }
95
96            Expr::FieldAccess { receiver, .. } => {
97                self.eval(receiver)?;
98            }
99
100            Expr::Method {
101                args,
102                receiver,
103                name,
104            } => {
105                self.eval(receiver)?;
106                for arg in args {
107                    self.eval(arg)?;
108                }
109
110                self.eval_local_var(name);
111            }
112
113            Expr::Block(block) => {
114                self.eval_block_inner(block, HashMap::new())?;
115            }
116            Expr::Object(object) => {
117                // Check that all field names are unique.
118                let mut object_fields = HashMap::new();
119                for (name, _) in &object.fields {
120                    let field_str = *name.fragment();
121                    if let Some(prev_span) = object_fields.insert(field_str, *name) {
122                        let err = ErrorKind::RepeatedField;
123                        return Err(Error::new(self.module_id.as_ref(), name, err)
124                            .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
125                    }
126                }
127
128                for (name, field_expr) in &object.fields {
129                    if let Some(field_expr) = field_expr {
130                        self.eval(field_expr)?;
131                    } else {
132                        self.eval_local_var(name);
133                    }
134                }
135            }
136            Expr::TypeCast { value, .. } => {
137                self.eval(value)?;
138            }
139
140            Expr::FnDefinition(def) => {
141                self.eval_function(def)?;
142            }
143
144            _ => {
145                let err = ErrorKind::unsupported(expr.extra.ty());
146                return Err(self.create_error(expr, err));
147            }
148        }
149
150        Ok(())
151    }
152
153    /// Evaluates a statement using the provided context.
154    fn eval_statement<T: Grammar<'a>>(
155        &mut self,
156        statement: &SpannedStatement<'a, T>,
157    ) -> Result<(), Error<'a>> {
158        match &statement.extra {
159            Statement::Expr(expr) => self.eval(expr),
160
161            Statement::Assignment { lhs, rhs } => {
162                self.eval(rhs)?;
163                let mut new_vars = HashMap::new();
164                extract_vars_iter(
165                    self.module_id.as_ref(),
166                    &mut new_vars,
167                    iter::once(lhs),
168                    RepeatedAssignmentContext::Assignment,
169                )?;
170                self.local_vars.last_mut().unwrap().extend(&new_vars);
171                Ok(())
172            }
173
174            _ => {
175                let err = ErrorKind::unsupported(statement.extra.ty());
176                Err(self.create_error(statement, err))
177            }
178        }
179    }
180
181    fn eval_block_inner<T: Grammar<'a>>(
182        &mut self,
183        block: &Block<'a, T>,
184        local_vars: HashMap<&'a str, Spanned<'a>>,
185    ) -> Result<(), Error<'a>> {
186        self.local_vars.push(local_vars);
187        for statement in &block.statements {
188            self.eval_statement(statement)?;
189        }
190        if let Some(ref return_expr) = block.return_value {
191            self.eval(return_expr)?;
192        }
193        self.local_vars.pop();
194        Ok(())
195    }
196
197    pub fn eval_block<T: Grammar<'a>>(&mut self, block: &Block<'a, T>) -> Result<(), Error<'a>> {
198        self.eval_block_inner(block, HashMap::new())
199    }
200}
201
202fn extract_vars<'a, T>(
203    module_id: &dyn ModuleId,
204    vars: &mut HashMap<&'a str, Spanned<'a>>,
205    lvalues: &Destructure<'a, T>,
206    context: RepeatedAssignmentContext,
207) -> Result<(), Error<'a>> {
208    let middle = lvalues
209        .middle
210        .as_ref()
211        .and_then(|rest| rest.extra.to_lvalue());
212    let all_lvalues = lvalues
213        .start
214        .iter()
215        .chain(middle.as_ref())
216        .chain(&lvalues.end);
217    extract_vars_iter(module_id, vars, all_lvalues, context)
218}
219
220fn add_var<'a>(
221    module_id: &dyn ModuleId,
222    vars: &mut HashMap<&'a str, Spanned<'a>>,
223    var_span: Spanned<'a>,
224    context: RepeatedAssignmentContext,
225) -> Result<(), Error<'a>> {
226    let var_name = *var_span.fragment();
227    if var_name != "_" {
228        if let Some(prev_span) = vars.insert(var_name, var_span) {
229            let err = ErrorKind::RepeatedAssignment { context };
230            return Err(Error::new(module_id, &var_span, err)
231                .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
232        }
233    }
234    Ok(())
235}
236
237pub(super) fn extract_vars_iter<'it, 'a: 'it, T: 'it>(
238    module_id: &dyn ModuleId,
239    vars: &mut HashMap<&'a str, Spanned<'a>>,
240    lvalues: impl Iterator<Item = &'it SpannedLvalue<'a, T>>,
241    context: RepeatedAssignmentContext,
242) -> Result<(), Error<'a>> {
243    for lvalue in lvalues {
244        match &lvalue.extra {
245            Lvalue::Variable { .. } => {
246                add_var(module_id, vars, lvalue.with_no_extra(), context)?;
247            }
248
249            Lvalue::Tuple(tuple) => {
250                extract_vars(module_id, vars, tuple, context)?;
251            }
252
253            Lvalue::Object(object) => {
254                let mut object_fields = HashMap::new();
255                for field in &object.fields {
256                    let field_str = *field.field_name.fragment();
257                    if let Some(prev_span) = object_fields.insert(field_str, field.field_name) {
258                        let err = ErrorKind::RepeatedField;
259                        return Err(Error::new(module_id, &field.field_name, err)
260                            .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
261                    }
262
263                    if let Some(binding) = &field.binding {
264                        extract_vars_iter(module_id, vars, iter::once(binding), context)?;
265                    } else {
266                        add_var(module_id, vars, field.field_name, context)?;
267                    }
268                }
269            }
270
271            _ => {
272                let err = ErrorKind::unsupported(lvalue.extra.ty());
273                return Err(Error::new(module_id, lvalue, err));
274            }
275        }
276    }
277    Ok(())
278}
279
280/// Helper enum for `CompilerExt` implementations that allows to reduce code duplication.
281#[derive(Debug)]
282pub(super) enum CompilerExtTarget<'r, 'a, T: Grammar<'a>> {
283    Block(&'r Block<'a, T>),
284    FnDefinition(&'r FnDefinition<'a, T>),
285}
286
287impl<'a, T: Grammar<'a>> CompilerExtTarget<'_, 'a, T> {
288    pub fn get_undefined_variables(self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
289        let mut extractor = CapturesExtractor::new(Box::new(WildcardId));
290
291        match self {
292            Self::Block(block) => extractor.eval_block_inner(block, HashMap::new())?,
293            Self::FnDefinition(definition) => extractor.eval_function(definition)?,
294        }
295
296        Ok(extractor.captures)
297    }
298}