arithmetic_eval/compiler/
captures.rs1use hashbrown::HashMap;
4
5use core::iter;
6
7use crate::{
8 alloc::{vec, Box, Vec},
9 error::{AuxErrorInfo, RepeatedAssignmentContext},
10 Error, ErrorKind, ModuleId, WildcardId,
11};
12use arithmetic_parser::{
13 grammars::Grammar, Block, Destructure, Expr, FnDefinition, Lvalue, Spanned, SpannedExpr,
14 SpannedLvalue, SpannedStatement, Statement,
15};
16
17#[derive(Debug)]
20pub(super) struct CapturesExtractor<'a> {
21 module_id: Box<dyn ModuleId>,
22 local_vars: Vec<HashMap<&'a str, Spanned<'a>>>,
23 pub captures: HashMap<&'a str, Spanned<'a>>,
24}
25
26impl<'a> CapturesExtractor<'a> {
27 pub fn new(module_id: Box<dyn ModuleId>) -> Self {
28 Self {
29 module_id,
30 local_vars: vec![],
31 captures: HashMap::new(),
32 }
33 }
34
35 pub fn eval_function<T: Grammar<'a>>(
37 &mut self,
38 definition: &FnDefinition<'a, T>,
39 ) -> Result<(), Error<'a>> {
40 let mut fn_local_vars = HashMap::new();
41 extract_vars(
42 self.module_id.as_ref(),
43 &mut fn_local_vars,
44 &definition.args.extra,
45 RepeatedAssignmentContext::FnArgs,
46 )?;
47 self.eval_block_inner(&definition.body, fn_local_vars)
48 }
49
50 fn has_var(&self, var_name: &str) -> bool {
51 self.local_vars.iter().any(|set| set.contains_key(var_name))
52 }
53
54 fn eval_local_var<T>(&mut self, var_span: &Spanned<'a, T>) {
56 let var_name = *var_span.fragment();
57 if !self.has_var(var_name) && !self.captures.contains_key(var_name) {
58 self.captures.insert(var_name, var_span.with_no_extra());
59 }
60 }
61
62 fn create_error<T>(&self, span: &Spanned<'a, T>, err: ErrorKind) -> Error<'a> {
63 Error::new(self.module_id.as_ref(), span, err)
64 }
65
66 fn eval<T: Grammar<'a>>(&mut self, expr: &SpannedExpr<'a, T>) -> Result<(), Error<'a>> {
69 match &expr.extra {
70 Expr::Variable => {
71 self.eval_local_var(expr);
72 }
73
74 Expr::Literal(_) => { }
75
76 Expr::Tuple(fragments) => {
77 for fragment in fragments {
78 self.eval(fragment)?;
79 }
80 }
81 Expr::Unary { inner, .. } => {
82 self.eval(inner)?;
83 }
84 Expr::Binary { lhs, rhs, .. } => {
85 self.eval(lhs)?;
86 self.eval(rhs)?;
87 }
88
89 Expr::Function { args, name } => {
90 for arg in args {
91 self.eval(arg)?;
92 }
93 self.eval(name)?;
94 }
95
96 Expr::FieldAccess { receiver, .. } => {
97 self.eval(receiver)?;
98 }
99
100 Expr::Method {
101 args,
102 receiver,
103 name,
104 } => {
105 self.eval(receiver)?;
106 for arg in args {
107 self.eval(arg)?;
108 }
109
110 self.eval_local_var(name);
111 }
112
113 Expr::Block(block) => {
114 self.eval_block_inner(block, HashMap::new())?;
115 }
116 Expr::Object(object) => {
117 let mut object_fields = HashMap::new();
119 for (name, _) in &object.fields {
120 let field_str = *name.fragment();
121 if let Some(prev_span) = object_fields.insert(field_str, *name) {
122 let err = ErrorKind::RepeatedField;
123 return Err(Error::new(self.module_id.as_ref(), name, err)
124 .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
125 }
126 }
127
128 for (name, field_expr) in &object.fields {
129 if let Some(field_expr) = field_expr {
130 self.eval(field_expr)?;
131 } else {
132 self.eval_local_var(name);
133 }
134 }
135 }
136 Expr::TypeCast { value, .. } => {
137 self.eval(value)?;
138 }
139
140 Expr::FnDefinition(def) => {
141 self.eval_function(def)?;
142 }
143
144 _ => {
145 let err = ErrorKind::unsupported(expr.extra.ty());
146 return Err(self.create_error(expr, err));
147 }
148 }
149
150 Ok(())
151 }
152
153 fn eval_statement<T: Grammar<'a>>(
155 &mut self,
156 statement: &SpannedStatement<'a, T>,
157 ) -> Result<(), Error<'a>> {
158 match &statement.extra {
159 Statement::Expr(expr) => self.eval(expr),
160
161 Statement::Assignment { lhs, rhs } => {
162 self.eval(rhs)?;
163 let mut new_vars = HashMap::new();
164 extract_vars_iter(
165 self.module_id.as_ref(),
166 &mut new_vars,
167 iter::once(lhs),
168 RepeatedAssignmentContext::Assignment,
169 )?;
170 self.local_vars.last_mut().unwrap().extend(&new_vars);
171 Ok(())
172 }
173
174 _ => {
175 let err = ErrorKind::unsupported(statement.extra.ty());
176 Err(self.create_error(statement, err))
177 }
178 }
179 }
180
181 fn eval_block_inner<T: Grammar<'a>>(
182 &mut self,
183 block: &Block<'a, T>,
184 local_vars: HashMap<&'a str, Spanned<'a>>,
185 ) -> Result<(), Error<'a>> {
186 self.local_vars.push(local_vars);
187 for statement in &block.statements {
188 self.eval_statement(statement)?;
189 }
190 if let Some(ref return_expr) = block.return_value {
191 self.eval(return_expr)?;
192 }
193 self.local_vars.pop();
194 Ok(())
195 }
196
197 pub fn eval_block<T: Grammar<'a>>(&mut self, block: &Block<'a, T>) -> Result<(), Error<'a>> {
198 self.eval_block_inner(block, HashMap::new())
199 }
200}
201
202fn extract_vars<'a, T>(
203 module_id: &dyn ModuleId,
204 vars: &mut HashMap<&'a str, Spanned<'a>>,
205 lvalues: &Destructure<'a, T>,
206 context: RepeatedAssignmentContext,
207) -> Result<(), Error<'a>> {
208 let middle = lvalues
209 .middle
210 .as_ref()
211 .and_then(|rest| rest.extra.to_lvalue());
212 let all_lvalues = lvalues
213 .start
214 .iter()
215 .chain(middle.as_ref())
216 .chain(&lvalues.end);
217 extract_vars_iter(module_id, vars, all_lvalues, context)
218}
219
220fn add_var<'a>(
221 module_id: &dyn ModuleId,
222 vars: &mut HashMap<&'a str, Spanned<'a>>,
223 var_span: Spanned<'a>,
224 context: RepeatedAssignmentContext,
225) -> Result<(), Error<'a>> {
226 let var_name = *var_span.fragment();
227 if var_name != "_" {
228 if let Some(prev_span) = vars.insert(var_name, var_span) {
229 let err = ErrorKind::RepeatedAssignment { context };
230 return Err(Error::new(module_id, &var_span, err)
231 .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
232 }
233 }
234 Ok(())
235}
236
237pub(super) fn extract_vars_iter<'it, 'a: 'it, T: 'it>(
238 module_id: &dyn ModuleId,
239 vars: &mut HashMap<&'a str, Spanned<'a>>,
240 lvalues: impl Iterator<Item = &'it SpannedLvalue<'a, T>>,
241 context: RepeatedAssignmentContext,
242) -> Result<(), Error<'a>> {
243 for lvalue in lvalues {
244 match &lvalue.extra {
245 Lvalue::Variable { .. } => {
246 add_var(module_id, vars, lvalue.with_no_extra(), context)?;
247 }
248
249 Lvalue::Tuple(tuple) => {
250 extract_vars(module_id, vars, tuple, context)?;
251 }
252
253 Lvalue::Object(object) => {
254 let mut object_fields = HashMap::new();
255 for field in &object.fields {
256 let field_str = *field.field_name.fragment();
257 if let Some(prev_span) = object_fields.insert(field_str, field.field_name) {
258 let err = ErrorKind::RepeatedField;
259 return Err(Error::new(module_id, &field.field_name, err)
260 .with_span(&prev_span.into(), AuxErrorInfo::PrevAssignment));
261 }
262
263 if let Some(binding) = &field.binding {
264 extract_vars_iter(module_id, vars, iter::once(binding), context)?;
265 } else {
266 add_var(module_id, vars, field.field_name, context)?;
267 }
268 }
269 }
270
271 _ => {
272 let err = ErrorKind::unsupported(lvalue.extra.ty());
273 return Err(Error::new(module_id, lvalue, err));
274 }
275 }
276 }
277 Ok(())
278}
279
280#[derive(Debug)]
282pub(super) enum CompilerExtTarget<'r, 'a, T: Grammar<'a>> {
283 Block(&'r Block<'a, T>),
284 FnDefinition(&'r FnDefinition<'a, T>),
285}
286
287impl<'a, T: Grammar<'a>> CompilerExtTarget<'_, 'a, T> {
288 pub fn get_undefined_variables(self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
289 let mut extractor = CapturesExtractor::new(Box::new(WildcardId));
290
291 match self {
292 Self::Block(block) => extractor.eval_block_inner(block, HashMap::new())?,
293 Self::FnDefinition(definition) => extractor.eval_function(definition)?,
294 }
295
296 Ok(extractor.captures)
297 }
298}