zen_expression/compiler/
compiler.rs

1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15    bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19    pub fn new() -> Self {
20        Self {
21            bytecode: Default::default(),
22        }
23    }
24
25    pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26        self.bytecode.clear();
27
28        CompilerInner::new(&mut self.bytecode, root).compile()?;
29        Ok(self.bytecode.as_slice())
30    }
31
32    pub fn get_bytecode(&self) -> &[Opcode] {
33        self.bytecode.as_slice()
34    }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39    root: &'arena Node<'arena>,
40    bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44    pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45        Self { root, bytecode }
46    }
47
48    pub fn compile(&mut self) -> CompilerResult<()> {
49        self.compile_node(self.root)?;
50        Ok(())
51    }
52
53    fn emit(&mut self, op: Opcode) -> usize {
54        self.bytecode.push(op);
55        self.bytecode.len()
56    }
57
58    fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59    where
60        F: FnOnce(&mut Self) -> CompilerResult<()>,
61    {
62        let begin = self.bytecode.len();
63        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65        body(self)?;
66
67        self.emit(Opcode::IncrementIt);
68        let e = self.emit(Opcode::Jump(
69            Jump::Backward,
70            self.calc_backward_jump(begin) as u32,
71        ));
72        self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73        Ok(())
74    }
75
76    fn emit_cond<F>(&mut self, mut body: F)
77    where
78        F: FnMut(&mut Self),
79    {
80        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81        self.emit(Opcode::Pop);
82
83        body(self);
84
85        let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86        self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87        let e = self.emit(Opcode::Pop);
88        self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89    }
90
91    fn replace(&mut self, at: usize, op: Opcode) {
92        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93    }
94
95    fn calc_backward_jump(&self, to: usize) -> usize {
96        self.bytecode.len() + 1 - to
97    }
98
99    fn compile_argument<T: ToString>(
100        &mut self,
101        function_kind: T,
102        arguments: &[&'arena Node<'arena>],
103        index: usize,
104    ) -> CompilerResult<usize> {
105        let arg = arguments
106            .get(index)
107            .ok_or_else(|| CompilerError::ArgumentNotFound {
108                index,
109                function: function_kind.to_string(),
110            })?;
111
112        self.compile_node(arg)
113    }
114
115    fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
116        match node {
117            Node::Root => Some(vec![FetchFastTarget::Root]),
118            Node::Identifier(v) => Some(vec![
119                FetchFastTarget::Root,
120                FetchFastTarget::String(Arc::from(*v)),
121            ]),
122            Node::Member { node, property } => {
123                let mut path = self.compile_member_fast(node)?;
124                match property {
125                    Node::String(v) => {
126                        path.push(FetchFastTarget::String(Arc::from(*v)));
127                        Some(path)
128                    }
129                    Node::Number(v) => {
130                        if let Some(idx) = v.to_u32() {
131                            path.push(FetchFastTarget::Number(idx));
132                            Some(path)
133                        } else {
134                            None
135                        }
136                    }
137                    _ => None,
138                }
139            }
140            _ => None,
141        }
142    }
143
144    fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
145        match node {
146            Node::Null => Ok(self.emit(Opcode::PushNull)),
147            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
148            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
149            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
150            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
151            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
152            Node::Array(v) => {
153                v.iter()
154                    .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
155                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
156                Ok(self.emit(Opcode::Array))
157            }
158            Node::Object(v) => {
159                v.iter().try_for_each(|&(key, value)| {
160                    self.compile_node(key).map(|_| ())?;
161                    self.emit(Opcode::CallFunction {
162                        arg_count: 1,
163                        kind: FunctionKind::Internal(InternalFunction::String),
164                    });
165                    self.compile_node(value).map(|_| ())?;
166                    Ok(())
167                })?;
168
169                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
170                Ok(self.emit(Opcode::Object))
171            }
172            Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
173            Node::Closure(v) => self.compile_node(v),
174            Node::Parenthesized(v) => self.compile_node(v),
175            Node::Member {
176                node: n,
177                property: p,
178            } => {
179                if let Some(path) = self.compile_member_fast(node) {
180                    Ok(self.emit(Opcode::FetchFast(path)))
181                } else {
182                    self.compile_node(n)?;
183                    self.compile_node(p)?;
184                    Ok(self.emit(Opcode::Fetch))
185                }
186            }
187            Node::TemplateString(parts) => {
188                parts.iter().try_for_each(|&n| {
189                    self.compile_node(n).map(|_| ())?;
190                    self.emit(Opcode::CallFunction {
191                        arg_count: 1,
192                        kind: FunctionKind::Internal(InternalFunction::String),
193                    });
194                    Ok(())
195                })?;
196
197                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
198                self.emit(Opcode::Array);
199                self.emit(Opcode::PushString(Arc::from("")));
200                Ok(self.emit(Opcode::Join))
201            }
202            Node::Slice { node, to, from } => {
203                self.compile_node(node)?;
204                if let Some(t) = to {
205                    self.compile_node(t)?;
206                } else {
207                    self.emit(Opcode::Len);
208                    self.emit(Opcode::PushNumber(dec!(1)));
209                    self.emit(Opcode::Subtract);
210                }
211
212                if let Some(f) = from {
213                    self.compile_node(f)?;
214                } else {
215                    self.emit(Opcode::PushNumber(dec!(0)));
216                }
217
218                Ok(self.emit(Opcode::Slice))
219            }
220            Node::Interval {
221                left,
222                right,
223                left_bracket,
224                right_bracket,
225            } => {
226                self.compile_node(left)?;
227                self.compile_node(right)?;
228                Ok(self.emit(Opcode::Interval {
229                    left_bracket: *left_bracket,
230                    right_bracket: *right_bracket,
231                }))
232            }
233            Node::Conditional {
234                condition,
235                on_true,
236                on_false,
237            } => {
238                self.compile_node(condition)?;
239                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
240
241                self.emit(Opcode::Pop);
242                self.compile_node(on_true)?;
243                let end = self.emit(Opcode::Jump(Jump::Forward, 0));
244
245                self.replace(
246                    otherwise,
247                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
248                );
249                self.emit(Opcode::Pop);
250                let b = self.compile_node(on_false)?;
251                self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
252
253                Ok(b)
254            }
255            Node::Unary { node, operator } => {
256                let curr = self.compile_node(node)?;
257                match *operator {
258                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
259                    Operator::Arithmetic(ArithmeticOperator::Subtract) => {
260                        Ok(self.emit(Opcode::Negate))
261                    }
262                    Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
263                    _ => Err(CompilerError::UnknownUnaryOperator {
264                        operator: operator.to_string(),
265                    }),
266                }
267            }
268            Node::Binary {
269                left,
270                right,
271                operator,
272            } => match *operator {
273                Operator::Comparison(ComparisonOperator::Equal) => {
274                    self.compile_node(left)?;
275                    self.compile_node(right)?;
276
277                    Ok(self.emit(Opcode::Equal))
278                }
279                Operator::Comparison(ComparisonOperator::NotEqual) => {
280                    self.compile_node(left)?;
281                    self.compile_node(right)?;
282
283                    self.emit(Opcode::Equal);
284                    Ok(self.emit(Opcode::Not))
285                }
286                Operator::Logical(LogicalOperator::Or) => {
287                    self.compile_node(left)?;
288                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
289                    self.emit(Opcode::Pop);
290                    let r = self.compile_node(right)?;
291                    self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
292
293                    Ok(r)
294                }
295                Operator::Logical(LogicalOperator::And) => {
296                    self.compile_node(left)?;
297                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
298                    self.emit(Opcode::Pop);
299                    let r = self.compile_node(right)?;
300                    self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
301
302                    Ok(r)
303                }
304                Operator::Logical(LogicalOperator::NullishCoalescing) => {
305                    self.compile_node(left)?;
306                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
307                    self.emit(Opcode::Pop);
308                    let r = self.compile_node(right)?;
309                    self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
310
311                    Ok(r)
312                }
313                Operator::Comparison(ComparisonOperator::In) => {
314                    self.compile_node(left)?;
315                    self.compile_node(right)?;
316                    Ok(self.emit(Opcode::In))
317                }
318                Operator::Comparison(ComparisonOperator::NotIn) => {
319                    self.compile_node(left)?;
320                    self.compile_node(right)?;
321                    self.emit(Opcode::In);
322                    Ok(self.emit(Opcode::Not))
323                }
324                Operator::Comparison(ComparisonOperator::LessThan) => {
325                    self.compile_node(left)?;
326                    self.compile_node(right)?;
327                    Ok(self.emit(Opcode::Compare(Compare::Less)))
328                }
329                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
330                    self.compile_node(left)?;
331                    self.compile_node(right)?;
332                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
333                }
334                Operator::Comparison(ComparisonOperator::GreaterThan) => {
335                    self.compile_node(left)?;
336                    self.compile_node(right)?;
337                    Ok(self.emit(Opcode::Compare(Compare::More)))
338                }
339                Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
340                    self.compile_node(left)?;
341                    self.compile_node(right)?;
342                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
343                }
344                Operator::Arithmetic(ArithmeticOperator::Add) => {
345                    self.compile_node(left)?;
346                    self.compile_node(right)?;
347                    Ok(self.emit(Opcode::Add))
348                }
349                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
350                    self.compile_node(left)?;
351                    self.compile_node(right)?;
352                    Ok(self.emit(Opcode::Subtract))
353                }
354                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
355                    self.compile_node(left)?;
356                    self.compile_node(right)?;
357                    Ok(self.emit(Opcode::Multiply))
358                }
359                Operator::Arithmetic(ArithmeticOperator::Divide) => {
360                    self.compile_node(left)?;
361                    self.compile_node(right)?;
362                    Ok(self.emit(Opcode::Divide))
363                }
364                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
365                    self.compile_node(left)?;
366                    self.compile_node(right)?;
367                    Ok(self.emit(Opcode::Modulo))
368                }
369                Operator::Arithmetic(ArithmeticOperator::Power) => {
370                    self.compile_node(left)?;
371                    self.compile_node(right)?;
372                    Ok(self.emit(Opcode::Exponent))
373                }
374                _ => Err(CompilerError::UnknownBinaryOperator {
375                    operator: operator.to_string(),
376                }),
377            },
378            Node::FunctionCall { kind, arguments } => match kind {
379                FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
380                    let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
381                        CompilerError::UnknownFunction {
382                            name: kind.to_string(),
383                        }
384                    })?;
385
386                    let min_params = function.required_parameters();
387                    let max_params = min_params + function.optional_parameters();
388                    if arguments.len() < min_params || arguments.len() > max_params {
389                        return Err(CompilerError::InvalidFunctionCall {
390                            name: kind.to_string(),
391                            message: "Invalid number of arguments".to_string(),
392                        });
393                    }
394
395                    for i in 0..arguments.len() {
396                        self.compile_argument(kind, arguments, i)?;
397                    }
398
399                    Ok(self.emit(Opcode::CallFunction {
400                        kind: kind.clone(),
401                        arg_count: arguments.len() as u32,
402                    }))
403                }
404                FunctionKind::Closure(c) => match c {
405                    ClosureFunction::All => {
406                        self.compile_argument(kind, arguments, 0)?;
407                        self.emit(Opcode::Begin);
408                        let mut loop_break: usize = 0;
409                        self.emit_loop(|c| {
410                            c.compile_argument(kind, arguments, 1)?;
411                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
412                            c.emit(Opcode::Pop);
413                            Ok(())
414                        })?;
415                        let e = self.emit(Opcode::PushBool(true));
416                        self.replace(
417                            loop_break,
418                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
419                        );
420                        Ok(self.emit(Opcode::End))
421                    }
422                    ClosureFunction::None => {
423                        self.compile_argument(kind, arguments, 0)?;
424                        self.emit(Opcode::Begin);
425                        let mut loop_break: usize = 0;
426                        self.emit_loop(|c| {
427                            c.compile_argument(kind, arguments, 1)?;
428                            c.emit(Opcode::Not);
429                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
430                            c.emit(Opcode::Pop);
431                            Ok(())
432                        })?;
433                        let e = self.emit(Opcode::PushBool(true));
434                        self.replace(
435                            loop_break,
436                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
437                        );
438                        Ok(self.emit(Opcode::End))
439                    }
440                    ClosureFunction::Some => {
441                        self.compile_argument(kind, arguments, 0)?;
442                        self.emit(Opcode::Begin);
443                        let mut loop_break: usize = 0;
444                        self.emit_loop(|c| {
445                            c.compile_argument(kind, arguments, 1)?;
446                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
447                            c.emit(Opcode::Pop);
448                            Ok(())
449                        })?;
450                        let e = self.emit(Opcode::PushBool(false));
451                        self.replace(
452                            loop_break,
453                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
454                        );
455                        Ok(self.emit(Opcode::End))
456                    }
457                    ClosureFunction::One => {
458                        self.compile_argument(kind, arguments, 0)?;
459                        self.emit(Opcode::Begin);
460                        self.emit_loop(|c| {
461                            c.compile_argument(kind, arguments, 1)?;
462                            c.emit_cond(|c| {
463                                c.emit(Opcode::IncrementCount);
464                            });
465                            Ok(())
466                        })?;
467                        self.emit(Opcode::GetCount);
468                        self.emit(Opcode::PushNumber(dec!(1)));
469                        self.emit(Opcode::Equal);
470                        Ok(self.emit(Opcode::End))
471                    }
472                    ClosureFunction::Filter => {
473                        self.compile_argument(kind, arguments, 0)?;
474                        self.emit(Opcode::Begin);
475                        self.emit_loop(|c| {
476                            c.compile_argument(kind, arguments, 1)?;
477                            c.emit_cond(|c| {
478                                c.emit(Opcode::IncrementCount);
479                                c.emit(Opcode::Pointer);
480                            });
481                            Ok(())
482                        })?;
483                        self.emit(Opcode::GetCount);
484                        self.emit(Opcode::End);
485                        Ok(self.emit(Opcode::Array))
486                    }
487                    ClosureFunction::Map => {
488                        self.compile_argument(kind, arguments, 0)?;
489                        self.emit(Opcode::Begin);
490                        self.emit_loop(|c| {
491                            c.compile_argument(kind, arguments, 1)?;
492                            Ok(())
493                        })?;
494                        self.emit(Opcode::GetLen);
495                        self.emit(Opcode::End);
496                        Ok(self.emit(Opcode::Array))
497                    }
498                    ClosureFunction::FlatMap => {
499                        self.compile_argument(kind, arguments, 0)?;
500                        self.emit(Opcode::Begin);
501                        self.emit_loop(|c| {
502                            c.compile_argument(kind, arguments, 1)?;
503                            Ok(())
504                        })?;
505                        self.emit(Opcode::GetLen);
506                        self.emit(Opcode::End);
507                        self.emit(Opcode::Array);
508                        Ok(self.emit(Opcode::Flatten))
509                    }
510                    ClosureFunction::Count => {
511                        self.compile_argument(kind, arguments, 0)?;
512                        self.emit(Opcode::Begin);
513                        self.emit_loop(|c| {
514                            c.compile_argument(kind, arguments, 1)?;
515                            c.emit_cond(|c| {
516                                c.emit(Opcode::IncrementCount);
517                            });
518                            Ok(())
519                        })?;
520                        self.emit(Opcode::GetCount);
521                        Ok(self.emit(Opcode::End))
522                    }
523                },
524            },
525            Node::MethodCall {
526                kind,
527                this,
528                arguments,
529            } => {
530                let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
531                    CompilerError::UnknownFunction {
532                        name: kind.to_string(),
533                    }
534                })?;
535
536                self.compile_node(this)?;
537
538                let min_params = method.required_parameters() - 1;
539                let max_params = min_params + method.optional_parameters();
540                if arguments.len() < min_params || arguments.len() > max_params {
541                    return Err(CompilerError::InvalidMethodCall {
542                        name: kind.to_string(),
543                        message: "Invalid number of arguments".to_string(),
544                    });
545                }
546
547                for i in 0..arguments.len() {
548                    self.compile_argument(kind, arguments, i)?;
549                }
550
551                Ok(self.emit(Opcode::CallMethod {
552                    kind: kind.clone(),
553                    arg_count: arguments.len() as u32,
554                }))
555            }
556            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
557        }
558    }
559}