1use hashbrown::HashMap;
4
5use crate::{
6 alloc::{Box, String, ToOwned},
7 executable::{Atom, Command, CompiledExpr, Executable, ExecutableModule, FieldName, Registers},
8 Error, ErrorKind, ModuleId, Value,
9};
10use arithmetic_parser::{
11 grammars::Grammar, BinaryOp, Block, Destructure, FnDefinition, InputSpan, Lvalue,
12 ObjectDestructure, Spanned, SpannedLvalue, UnaryOp,
13};
14
15mod captures;
16mod expr;
17
18use self::captures::{CapturesExtractor, CompilerExtTarget};
19
20pub(crate) type ImportSpans<'a> = HashMap<String, Spanned<'a>>;
21
22#[derive(Debug)]
23pub(crate) struct Compiler {
24 vars_to_registers: HashMap<String, usize>,
26 scope_depth: usize,
27 register_count: usize,
28 module_id: Box<dyn ModuleId>,
29}
30
31impl Compiler {
32 fn new(module_id: Box<dyn ModuleId>) -> Self {
33 Self {
34 vars_to_registers: HashMap::new(),
35 scope_depth: 0,
36 register_count: 0,
37 module_id,
38 }
39 }
40
41 fn from_env<T>(module_id: Box<dyn ModuleId>, env: &Registers<'_, T>) -> Self {
42 Self {
43 vars_to_registers: env.variables_map().clone(),
44 register_count: env.register_count(),
45 scope_depth: 0,
46 module_id,
47 }
48 }
49
50 fn backup(&mut self) -> Self {
52 Self {
53 vars_to_registers: self.vars_to_registers.clone(),
54 scope_depth: self.scope_depth,
55 register_count: self.register_count,
56 module_id: self.module_id.clone_boxed(),
57 }
58 }
59
60 fn create_error<'a, T>(&self, span: &Spanned<'a, T>, err: ErrorKind) -> Error<'a> {
61 Error::new(self.module_id.as_ref(), span, err)
62 }
63
64 fn check_unary_op<'a>(&self, op: &Spanned<'a, UnaryOp>) -> Result<UnaryOp, Error<'a>> {
65 match op.extra {
66 UnaryOp::Neg | UnaryOp::Not => Ok(op.extra),
67 _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
68 }
69 }
70
71 fn check_binary_op<'a>(&self, op: &Spanned<'a, BinaryOp>) -> Result<BinaryOp, Error<'a>> {
72 match op.extra {
73 BinaryOp::Add
74 | BinaryOp::Sub
75 | BinaryOp::Mul
76 | BinaryOp::Div
77 | BinaryOp::Power
78 | BinaryOp::And
79 | BinaryOp::Or
80 | BinaryOp::Eq
81 | BinaryOp::NotEq
82 | BinaryOp::Gt
83 | BinaryOp::Ge
84 | BinaryOp::Lt
85 | BinaryOp::Le => Ok(op.extra),
86
87 _ => Err(self.create_error(op, ErrorKind::unsupported(op.extra))),
88 }
89 }
90
91 fn get_var(&self, name: &str) -> usize {
92 *self
93 .vars_to_registers
94 .get(name)
95 .expect("Captures must created during module compilation")
96 }
97
98 fn push_assignment<'a, T, U>(
99 &mut self,
100 executable: &mut Executable<'a, T>,
101 rhs: CompiledExpr<'a, T>,
102 rhs_span: &Spanned<'a, U>,
103 ) -> usize {
104 let register = self.register_count;
105 let command = Command::Push(rhs);
106 executable.push_command(rhs_span.copy_with_extra(command));
107 self.register_count += 1;
108 register
109 }
110
111 pub fn compile_module<'a, Id: ModuleId, T: Grammar<'a>>(
112 module_id: Id,
113 block: &Block<'a, T>,
114 ) -> Result<(ExecutableModule<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
115 let module_id = Box::new(module_id) as Box<dyn ModuleId>;
116 let (captures, import_spans) = Self::extract_captures(module_id.clone_boxed(), block)?;
117 let mut compiler = Self::from_env(module_id.clone_boxed(), &captures);
118
119 let mut executable = Executable::new(module_id);
120 let empty_span = InputSpan::new("");
121 let last_atom = compiler
122 .compile_block_inner(&mut executable, block)?
123 .map_or(Atom::Void, |spanned| spanned.extra);
124 compiler.push_assignment(
126 &mut executable,
127 CompiledExpr::Atom(last_atom),
128 &empty_span.into(),
129 );
130
131 executable.finalize_block(compiler.register_count);
132 let module = ExecutableModule::from_parts(executable, captures);
133 Ok((module, import_spans))
134 }
135
136 fn extract_captures<'a, T: Grammar<'a>>(
137 module_id: Box<dyn ModuleId>,
138 block: &Block<'a, T>,
139 ) -> Result<(Registers<'a, T::Lit>, ImportSpans<'a>), Error<'a>> {
140 let mut extractor = CapturesExtractor::new(module_id);
141 extractor.eval_block(&block)?;
142
143 let mut captures = Registers::new();
144 for &var_name in extractor.captures.keys() {
145 captures.insert_var(var_name, Value::void());
146 }
147
148 let import_spans = extractor
149 .captures
150 .into_iter()
151 .map(|(var_name, var_span)| (var_name.to_owned(), var_span))
152 .collect();
153
154 Ok((captures, import_spans))
155 }
156
157 fn assign<'a, T, Ty>(
158 &mut self,
159 executable: &mut Executable<'a, T>,
160 lhs: &SpannedLvalue<'a, Ty>,
161 rhs_register: usize,
162 ) -> Result<(), Error<'a>> {
163 match &lhs.extra {
164 Lvalue::Variable { .. } => {
165 self.insert_var(executable, lhs.with_no_extra(), rhs_register);
166 }
167
168 Lvalue::Tuple(destructure) => {
169 let span = lhs.with_no_extra();
170 self.destructure(executable, destructure, span, rhs_register)?;
171 }
172
173 Lvalue::Object(destructure) => {
174 let span = lhs.with_no_extra();
175 self.destructure_object(executable, destructure, span, rhs_register)?;
176 }
177
178 _ => {
179 let err = ErrorKind::unsupported(lhs.extra.ty());
180 return Err(self.create_error(lhs, err));
181 }
182 }
183
184 Ok(())
185 }
186
187 fn insert_var<'a, T>(
188 &mut self,
189 executable: &mut Executable<'a, T>,
190 var_span: Spanned<'a>,
191 register: usize,
192 ) {
193 let var_name = *var_span.fragment();
194 if var_name != "_" {
195 self.vars_to_registers.insert(var_name.to_owned(), register);
196
197 if self.scope_depth == 0 {
200 let command = Command::Annotate {
201 register,
202 name: var_name.to_owned(),
203 };
204 executable.push_command(var_span.copy_with_extra(command));
205 }
206 }
207 }
208
209 fn destructure<'a, T, Ty>(
210 &mut self,
211 executable: &mut Executable<'a, T>,
212 destructure: &Destructure<'a, Ty>,
213 span: Spanned<'a>,
214 rhs_register: usize,
215 ) -> Result<(), Error<'a>> {
216 let command = Command::Destructure {
217 source: rhs_register,
218 start_len: destructure.start.len(),
219 end_len: destructure.end.len(),
220 lvalue_len: destructure.len(),
221 unchecked: false,
222 };
223 executable.push_command(span.copy_with_extra(command));
224 let start_register = self.register_count;
225 self.register_count += destructure.start.len() + destructure.end.len() + 1;
226
227 for (i, lvalue) in (start_register..).zip(&destructure.start) {
228 self.assign(executable, lvalue, i)?;
229 }
230
231 let start_register = start_register + destructure.start.len();
232 if let Some(middle) = &destructure.middle {
233 if let Some(lvalue) = middle.extra.to_lvalue() {
234 self.assign(executable, &lvalue, start_register)?;
235 }
236 }
237
238 let start_register = start_register + 1;
239 for (i, lvalue) in (start_register..).zip(&destructure.end) {
240 self.assign(executable, lvalue, i)?;
241 }
242
243 Ok(())
244 }
245
246 fn destructure_object<'a, T, Ty>(
247 &mut self,
248 executable: &mut Executable<'a, T>,
249 destructure: &ObjectDestructure<'a, Ty>,
250 span: Spanned<'a>,
251 rhs_register: usize,
252 ) -> Result<(), Error<'a>> {
253 for field in &destructure.fields {
254 let field_name = FieldName::Name((*field.field_name.fragment()).to_owned());
255 let field_access = CompiledExpr::FieldAccess {
256 receiver: span.copy_with_extra(Atom::Register(rhs_register)).into(),
257 field: field_name,
258 };
259 let register = self.push_assignment(executable, field_access, &field.field_name);
260 if let Some(binding) = &field.binding {
261 self.assign(executable, binding, register)?;
262 } else {
263 self.insert_var(executable, field.field_name, register);
264 }
265 }
266 Ok(())
267 }
268}
269
270pub trait CompilerExt<'a> {
293 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>>;
305}
306
307impl<'a, T: Grammar<'a>> CompilerExt<'a> for Block<'a, T> {
308 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
309 CompilerExtTarget::Block(self).get_undefined_variables()
310 }
311}
312
313impl<'a, T: Grammar<'a>> CompilerExt<'a> for FnDefinition<'a, T> {
314 fn undefined_variables(&self) -> Result<HashMap<&'a str, Spanned<'a>>, Error<'a>> {
315 CompilerExtTarget::FnDefinition(self).get_undefined_variables()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::{Value, WildcardId};
323
324 use arithmetic_parser::grammars::{F32Grammar, Parse, ParseLiteral, Typed, Untyped};
325 use arithmetic_parser::{Expr, NomResult};
326
327 #[test]
328 fn compilation_basics() {
329 let block = "x = 3; 1 + { y = 2; y * x } == 7";
330 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
331 let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
332 let value = module.run().unwrap();
333 assert_eq!(value, Value::Bool(true));
334 }
335
336 #[test]
337 fn compiled_function() {
338 let block = "add = |x, y| x + y; add(2, 3) == 5";
339 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
340 let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
341 assert_eq!(module.run().unwrap(), Value::Bool(true));
342 }
343
344 #[test]
345 fn compiled_function_with_capture() {
346 let block = "A = 2; add = |x, y| x + y / A; add(2, 3) == 3.5";
347 let block = Untyped::<F32Grammar>::parse_statements(block).unwrap();
348 let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
349 assert_eq!(module.run().unwrap(), Value::Bool(true));
350 }
351
352 #[test]
353 fn variable_extraction() {
354 let def = "|a, b| ({ x = a * b + y; x - 2 }, a / b)";
355 let def = Untyped::<F32Grammar>::parse_statements(def)
356 .unwrap()
357 .return_value
358 .unwrap();
359 let def = match def.extra {
360 Expr::FnDefinition(def) => def,
361 other => panic!("Unexpected function parsing result: {:?}", other),
362 };
363
364 let captures = def.undefined_variables().unwrap();
365 assert_eq!(captures["y"].location_offset(), 22);
366 assert!(!captures.contains_key("x"));
367 }
368
369 #[test]
370 fn variable_extraction_with_scoping() {
371 let def = "|a, b| ({ x = a * b + y; x - 2 }, a / x)";
372 let def = Untyped::<F32Grammar>::parse_statements(def)
373 .unwrap()
374 .return_value
375 .unwrap();
376 let def = match def.extra {
377 Expr::FnDefinition(def) => def,
378 other => panic!("Unexpected function parsing result: {:?}", other),
379 };
380
381 let captures = def.undefined_variables().unwrap();
382 assert_eq!(captures["y"].location_offset(), 22);
383 assert_eq!(captures["x"].location_offset(), 38);
384 }
385
386 #[test]
387 fn extracting_captures() {
388 let program = "y = 5 * x; y - 3 + x";
389 let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
390 let (registers, import_spans) =
391 Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
392
393 assert_eq!(registers.register_count(), 1);
394 assert_eq!(*registers.get_var("x").unwrap(), Value::void());
395 assert_eq!(import_spans.len(), 1);
396 assert_eq!(import_spans["x"], Spanned::from_str(program, 8..9));
397 }
398
399 #[test]
400 fn extracting_captures_with_inner_fns() {
401 let program = r#"
402 y = 5 * x; // x is a capture
403 fun = |z| { // z is not a capture
404 z * x + y * PI // y is not a capture for the entire module, PI is
405 };
406 "#;
407 let module = Untyped::<F32Grammar>::parse_statements(program).unwrap();
408
409 let (registers, import_spans) =
410 Compiler::extract_captures(Box::new(WildcardId), &module).unwrap();
411 assert_eq!(registers.register_count(), 2);
412 assert!(registers.variables_map().contains_key("x"));
413 assert!(registers.variables_map().contains_key("PI"));
414 assert_eq!(import_spans["x"].location_line(), 2); }
416
417 #[test]
418 fn type_casts_are_ignored() {
419 struct TypedGrammar;
420
421 impl ParseLiteral for TypedGrammar {
422 type Lit = f32;
423
424 fn parse_literal(input: InputSpan<'_>) -> NomResult<'_, Self::Lit> {
425 F32Grammar::parse_literal(input)
426 }
427 }
428
429 impl Grammar<'_> for TypedGrammar {
430 type Type = ();
431
432 fn parse_type(input: InputSpan<'_>) -> NomResult<'_, Self::Type> {
433 use nom::{bytes::complete::tag, combinator::map};
434 map(tag("Num"), drop)(input)
435 }
436 }
437
438 let block = "x = 3 as Num; 1 + { y = 2; y * x as Num } == 7";
439 let block = Typed::<TypedGrammar>::parse_statements(block).unwrap();
440 let (module, _) = Compiler::compile_module(WildcardId, &block).unwrap();
441 let value = module.run().unwrap();
442 assert_eq!(value, Value::Bool(true));
443 }
444}