1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15 bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19 pub fn new() -> Self {
20 Self {
21 bytecode: Default::default(),
22 }
23 }
24
25 pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26 self.bytecode.clear();
27
28 CompilerInner::new(&mut self.bytecode, root).compile()?;
29 Ok(self.bytecode.as_slice())
30 }
31
32 pub fn get_bytecode(&self) -> &[Opcode] {
33 self.bytecode.as_slice()
34 }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39 root: &'arena Node<'arena>,
40 bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44 pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45 Self { root, bytecode }
46 }
47
48 pub fn compile(&mut self) -> CompilerResult<()> {
49 self.compile_node(self.root)?;
50 Ok(())
51 }
52
53 fn emit(&mut self, op: Opcode) -> usize {
54 self.bytecode.push(op);
55 self.bytecode.len()
56 }
57
58 fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59 where
60 F: FnOnce(&mut Self) -> CompilerResult<()>,
61 {
62 let begin = self.bytecode.len();
63 let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65 body(self)?;
66
67 self.emit(Opcode::IncrementIt);
68 let e = self.emit(Opcode::Jump(
69 Jump::Backward,
70 self.calc_backward_jump(begin) as u32,
71 ));
72 self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73 Ok(())
74 }
75
76 fn emit_cond<F>(&mut self, mut body: F)
77 where
78 F: FnMut(&mut Self),
79 {
80 let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81 self.emit(Opcode::Pop);
82
83 body(self);
84
85 let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86 self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87 let e = self.emit(Opcode::Pop);
88 self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89 }
90
91 fn replace(&mut self, at: usize, op: Opcode) {
92 let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93 }
94
95 fn calc_backward_jump(&self, to: usize) -> usize {
96 self.bytecode.len() + 1 - to
97 }
98
99 fn compile_argument<T: ToString>(
100 &mut self,
101 function_kind: T,
102 arguments: &[&'arena Node<'arena>],
103 index: usize,
104 ) -> CompilerResult<usize> {
105 let arg = arguments
106 .get(index)
107 .ok_or_else(|| CompilerError::ArgumentNotFound {
108 index,
109 function: function_kind.to_string(),
110 })?;
111
112 self.compile_node(arg)
113 }
114
115 fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
116 match node {
117 Node::Root => Some(vec![FetchFastTarget::Root]),
118 Node::Identifier(v) => Some(vec![
119 FetchFastTarget::Root,
120 FetchFastTarget::String(Arc::from(*v)),
121 ]),
122 Node::Member { node, property } => {
123 let mut path = self.compile_member_fast(node)?;
124 match property {
125 Node::String(v) => {
126 path.push(FetchFastTarget::String(Arc::from(*v)));
127 Some(path)
128 }
129 Node::Number(v) => {
130 if let Some(idx) = v.to_u32() {
131 path.push(FetchFastTarget::Number(idx));
132 Some(path)
133 } else {
134 None
135 }
136 }
137 _ => None,
138 }
139 }
140 _ => None,
141 }
142 }
143
144 fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
145 match node {
146 Node::Null => Ok(self.emit(Opcode::PushNull)),
147 Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
148 Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
149 Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
150 Node::Pointer => Ok(self.emit(Opcode::Pointer)),
151 Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
152 Node::Array(v) => {
153 v.iter()
154 .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
155 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
156 Ok(self.emit(Opcode::Array))
157 }
158 Node::Object(v) => {
159 v.iter().try_for_each(|&(key, value)| {
160 self.compile_node(key).map(|_| ())?;
161 self.emit(Opcode::CallFunction {
162 arg_count: 1,
163 kind: FunctionKind::Internal(InternalFunction::String),
164 });
165 self.compile_node(value).map(|_| ())?;
166 Ok(())
167 })?;
168
169 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
170 Ok(self.emit(Opcode::Object))
171 }
172 Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
173 Node::Closure(v) => self.compile_node(v),
174 Node::Parenthesized(v) => self.compile_node(v),
175 Node::Member {
176 node: n,
177 property: p,
178 } => {
179 if let Some(path) = self.compile_member_fast(node) {
180 Ok(self.emit(Opcode::FetchFast(path)))
181 } else {
182 self.compile_node(n)?;
183 self.compile_node(p)?;
184 Ok(self.emit(Opcode::Fetch))
185 }
186 }
187 Node::TemplateString(parts) => {
188 parts.iter().try_for_each(|&n| {
189 self.compile_node(n).map(|_| ())?;
190 self.emit(Opcode::CallFunction {
191 arg_count: 1,
192 kind: FunctionKind::Internal(InternalFunction::String),
193 });
194 Ok(())
195 })?;
196
197 self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
198 self.emit(Opcode::Array);
199 self.emit(Opcode::PushString(Arc::from("")));
200 Ok(self.emit(Opcode::Join))
201 }
202 Node::Slice { node, to, from } => {
203 self.compile_node(node)?;
204 if let Some(t) = to {
205 self.compile_node(t)?;
206 } else {
207 self.emit(Opcode::Len);
208 self.emit(Opcode::PushNumber(dec!(1)));
209 self.emit(Opcode::Subtract);
210 }
211
212 if let Some(f) = from {
213 self.compile_node(f)?;
214 } else {
215 self.emit(Opcode::PushNumber(dec!(0)));
216 }
217
218 Ok(self.emit(Opcode::Slice))
219 }
220 Node::Interval {
221 left,
222 right,
223 left_bracket,
224 right_bracket,
225 } => {
226 self.compile_node(left)?;
227 self.compile_node(right)?;
228 Ok(self.emit(Opcode::Interval {
229 left_bracket: *left_bracket,
230 right_bracket: *right_bracket,
231 }))
232 }
233 Node::Conditional {
234 condition,
235 on_true,
236 on_false,
237 } => {
238 self.compile_node(condition)?;
239 let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
240
241 self.emit(Opcode::Pop);
242 self.compile_node(on_true)?;
243 let end = self.emit(Opcode::Jump(Jump::Forward, 0));
244
245 self.replace(
246 otherwise,
247 Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
248 );
249 self.emit(Opcode::Pop);
250 let b = self.compile_node(on_false)?;
251 self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
252
253 Ok(b)
254 }
255 Node::Unary { node, operator } => {
256 let curr = self.compile_node(node)?;
257 match *operator {
258 Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
259 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
260 Ok(self.emit(Opcode::Negate))
261 }
262 Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
263 _ => Err(CompilerError::UnknownUnaryOperator {
264 operator: operator.to_string(),
265 }),
266 }
267 }
268 Node::Binary {
269 left,
270 right,
271 operator,
272 } => match *operator {
273 Operator::Comparison(ComparisonOperator::Equal) => {
274 self.compile_node(left)?;
275 self.compile_node(right)?;
276
277 Ok(self.emit(Opcode::Equal))
278 }
279 Operator::Comparison(ComparisonOperator::NotEqual) => {
280 self.compile_node(left)?;
281 self.compile_node(right)?;
282
283 self.emit(Opcode::Equal);
284 Ok(self.emit(Opcode::Not))
285 }
286 Operator::Logical(LogicalOperator::Or) => {
287 self.compile_node(left)?;
288 let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
289 self.emit(Opcode::Pop);
290 let r = self.compile_node(right)?;
291 self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
292
293 Ok(r)
294 }
295 Operator::Logical(LogicalOperator::And) => {
296 self.compile_node(left)?;
297 let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
298 self.emit(Opcode::Pop);
299 let r = self.compile_node(right)?;
300 self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
301
302 Ok(r)
303 }
304 Operator::Logical(LogicalOperator::NullishCoalescing) => {
305 self.compile_node(left)?;
306 let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
307 self.emit(Opcode::Pop);
308 let r = self.compile_node(right)?;
309 self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
310
311 Ok(r)
312 }
313 Operator::Comparison(ComparisonOperator::In) => {
314 self.compile_node(left)?;
315 self.compile_node(right)?;
316 Ok(self.emit(Opcode::In))
317 }
318 Operator::Comparison(ComparisonOperator::NotIn) => {
319 self.compile_node(left)?;
320 self.compile_node(right)?;
321 self.emit(Opcode::In);
322 Ok(self.emit(Opcode::Not))
323 }
324 Operator::Comparison(ComparisonOperator::LessThan) => {
325 self.compile_node(left)?;
326 self.compile_node(right)?;
327 Ok(self.emit(Opcode::Compare(Compare::Less)))
328 }
329 Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
330 self.compile_node(left)?;
331 self.compile_node(right)?;
332 Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
333 }
334 Operator::Comparison(ComparisonOperator::GreaterThan) => {
335 self.compile_node(left)?;
336 self.compile_node(right)?;
337 Ok(self.emit(Opcode::Compare(Compare::More)))
338 }
339 Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
340 self.compile_node(left)?;
341 self.compile_node(right)?;
342 Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
343 }
344 Operator::Arithmetic(ArithmeticOperator::Add) => {
345 self.compile_node(left)?;
346 self.compile_node(right)?;
347 Ok(self.emit(Opcode::Add))
348 }
349 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
350 self.compile_node(left)?;
351 self.compile_node(right)?;
352 Ok(self.emit(Opcode::Subtract))
353 }
354 Operator::Arithmetic(ArithmeticOperator::Multiply) => {
355 self.compile_node(left)?;
356 self.compile_node(right)?;
357 Ok(self.emit(Opcode::Multiply))
358 }
359 Operator::Arithmetic(ArithmeticOperator::Divide) => {
360 self.compile_node(left)?;
361 self.compile_node(right)?;
362 Ok(self.emit(Opcode::Divide))
363 }
364 Operator::Arithmetic(ArithmeticOperator::Modulus) => {
365 self.compile_node(left)?;
366 self.compile_node(right)?;
367 Ok(self.emit(Opcode::Modulo))
368 }
369 Operator::Arithmetic(ArithmeticOperator::Power) => {
370 self.compile_node(left)?;
371 self.compile_node(right)?;
372 Ok(self.emit(Opcode::Exponent))
373 }
374 _ => Err(CompilerError::UnknownBinaryOperator {
375 operator: operator.to_string(),
376 }),
377 },
378 Node::FunctionCall { kind, arguments } => match kind {
379 FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
380 let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
381 CompilerError::UnknownFunction {
382 name: kind.to_string(),
383 }
384 })?;
385
386 let min_params = function.required_parameters();
387 let max_params = min_params + function.optional_parameters();
388 if arguments.len() < min_params || arguments.len() > max_params {
389 return Err(CompilerError::InvalidFunctionCall {
390 name: kind.to_string(),
391 message: "Invalid number of arguments".to_string(),
392 });
393 }
394
395 for i in 0..arguments.len() {
396 self.compile_argument(kind, arguments, i)?;
397 }
398
399 Ok(self.emit(Opcode::CallFunction {
400 kind: kind.clone(),
401 arg_count: arguments.len() as u32,
402 }))
403 }
404 FunctionKind::Closure(c) => match c {
405 ClosureFunction::All => {
406 self.compile_argument(kind, arguments, 0)?;
407 self.emit(Opcode::Begin);
408 let mut loop_break: usize = 0;
409 self.emit_loop(|c| {
410 c.compile_argument(kind, arguments, 1)?;
411 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
412 c.emit(Opcode::Pop);
413 Ok(())
414 })?;
415 let e = self.emit(Opcode::PushBool(true));
416 self.replace(
417 loop_break,
418 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
419 );
420 Ok(self.emit(Opcode::End))
421 }
422 ClosureFunction::None => {
423 self.compile_argument(kind, arguments, 0)?;
424 self.emit(Opcode::Begin);
425 let mut loop_break: usize = 0;
426 self.emit_loop(|c| {
427 c.compile_argument(kind, arguments, 1)?;
428 c.emit(Opcode::Not);
429 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
430 c.emit(Opcode::Pop);
431 Ok(())
432 })?;
433 let e = self.emit(Opcode::PushBool(true));
434 self.replace(
435 loop_break,
436 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
437 );
438 Ok(self.emit(Opcode::End))
439 }
440 ClosureFunction::Some => {
441 self.compile_argument(kind, arguments, 0)?;
442 self.emit(Opcode::Begin);
443 let mut loop_break: usize = 0;
444 self.emit_loop(|c| {
445 c.compile_argument(kind, arguments, 1)?;
446 loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
447 c.emit(Opcode::Pop);
448 Ok(())
449 })?;
450 let e = self.emit(Opcode::PushBool(false));
451 self.replace(
452 loop_break,
453 Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
454 );
455 Ok(self.emit(Opcode::End))
456 }
457 ClosureFunction::One => {
458 self.compile_argument(kind, arguments, 0)?;
459 self.emit(Opcode::Begin);
460 self.emit_loop(|c| {
461 c.compile_argument(kind, arguments, 1)?;
462 c.emit_cond(|c| {
463 c.emit(Opcode::IncrementCount);
464 });
465 Ok(())
466 })?;
467 self.emit(Opcode::GetCount);
468 self.emit(Opcode::PushNumber(dec!(1)));
469 self.emit(Opcode::Equal);
470 Ok(self.emit(Opcode::End))
471 }
472 ClosureFunction::Filter => {
473 self.compile_argument(kind, arguments, 0)?;
474 self.emit(Opcode::Begin);
475 self.emit_loop(|c| {
476 c.compile_argument(kind, arguments, 1)?;
477 c.emit_cond(|c| {
478 c.emit(Opcode::IncrementCount);
479 c.emit(Opcode::Pointer);
480 });
481 Ok(())
482 })?;
483 self.emit(Opcode::GetCount);
484 self.emit(Opcode::End);
485 Ok(self.emit(Opcode::Array))
486 }
487 ClosureFunction::Map => {
488 self.compile_argument(kind, arguments, 0)?;
489 self.emit(Opcode::Begin);
490 self.emit_loop(|c| {
491 c.compile_argument(kind, arguments, 1)?;
492 Ok(())
493 })?;
494 self.emit(Opcode::GetLen);
495 self.emit(Opcode::End);
496 Ok(self.emit(Opcode::Array))
497 }
498 ClosureFunction::FlatMap => {
499 self.compile_argument(kind, arguments, 0)?;
500 self.emit(Opcode::Begin);
501 self.emit_loop(|c| {
502 c.compile_argument(kind, arguments, 1)?;
503 Ok(())
504 })?;
505 self.emit(Opcode::GetLen);
506 self.emit(Opcode::End);
507 self.emit(Opcode::Array);
508 Ok(self.emit(Opcode::Flatten))
509 }
510 ClosureFunction::Count => {
511 self.compile_argument(kind, arguments, 0)?;
512 self.emit(Opcode::Begin);
513 self.emit_loop(|c| {
514 c.compile_argument(kind, arguments, 1)?;
515 c.emit_cond(|c| {
516 c.emit(Opcode::IncrementCount);
517 });
518 Ok(())
519 })?;
520 self.emit(Opcode::GetCount);
521 Ok(self.emit(Opcode::End))
522 }
523 },
524 },
525 Node::MethodCall {
526 kind,
527 this,
528 arguments,
529 } => {
530 let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
531 CompilerError::UnknownFunction {
532 name: kind.to_string(),
533 }
534 })?;
535
536 self.compile_node(this)?;
537
538 let min_params = method.required_parameters() - 1;
539 let max_params = min_params + method.optional_parameters();
540 if arguments.len() < min_params || arguments.len() > max_params {
541 return Err(CompilerError::InvalidMethodCall {
542 name: kind.to_string(),
543 message: "Invalid number of arguments".to_string(),
544 });
545 }
546
547 for i in 0..arguments.len() {
548 self.compile_argument(kind, arguments, i)?;
549 }
550
551 Ok(self.emit(Opcode::CallMethod {
552 kind: kind.clone(),
553 arg_count: arguments.len() as u32,
554 }))
555 }
556 Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
557 }
558 }
559}