mf_expression/compiler/
compiler.rs

1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{
6    ClosureFunction, FunctionKind, InternalFunction, MethodRegistry,
7};
8use crate::lexer::{
9    ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator,
10};
11use crate::parser::Node;
12use rust_decimal::prelude::ToPrimitive;
13use rust_decimal::Decimal;
14use rust_decimal_macros::dec;
15use std::sync::Arc;
16
17#[derive(Debug)]
18pub struct Compiler {
19    bytecode: Vec<Opcode>,
20}
21
22impl Compiler {
23    pub fn new() -> Self {
24        Self { bytecode: Default::default() }
25    }
26
27    pub fn compile(
28        &mut self,
29        root: &Node,
30    ) -> CompilerResult<&[Opcode]> {
31        self.bytecode.clear();
32
33        CompilerInner::new(&mut self.bytecode, root).compile()?;
34        Ok(self.bytecode.as_slice())
35    }
36
37    pub fn get_bytecode(&self) -> &[Opcode] {
38        self.bytecode.as_slice()
39    }
40}
41
42#[derive(Debug)]
43struct CompilerInner<'arena, 'bytecode_ref> {
44    root: &'arena Node<'arena>,
45    bytecode: &'bytecode_ref mut Vec<Opcode>,
46}
47
48impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
49    pub fn new(
50        bytecode: &'bytecode_ref mut Vec<Opcode>,
51        root: &'arena Node<'arena>,
52    ) -> Self {
53        Self { root, bytecode }
54    }
55
56    pub fn compile(&mut self) -> CompilerResult<()> {
57        self.compile_node(self.root)?;
58        Ok(())
59    }
60
61    fn emit(
62        &mut self,
63        op: Opcode,
64    ) -> usize {
65        self.bytecode.push(op);
66        self.bytecode.len()
67    }
68
69    fn emit_loop<F>(
70        &mut self,
71        body: F,
72    ) -> CompilerResult<()>
73    where
74        F: FnOnce(&mut Self) -> CompilerResult<()>,
75    {
76        let begin = self.bytecode.len();
77        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
78
79        body(self)?;
80
81        self.emit(Opcode::IncrementIt);
82        let e = self.emit(Opcode::Jump(
83            Jump::Backward,
84            self.calc_backward_jump(begin) as u32,
85        ));
86        self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
87        Ok(())
88    }
89
90    fn emit_cond<F>(
91        &mut self,
92        mut body: F,
93    ) where
94        F: FnMut(&mut Self),
95    {
96        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
97        self.emit(Opcode::Pop);
98
99        body(self);
100
101        let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
102        self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
103        let e = self.emit(Opcode::Pop);
104        self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
105    }
106
107    fn replace(
108        &mut self,
109        at: usize,
110        op: Opcode,
111    ) {
112        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
113    }
114
115    fn calc_backward_jump(
116        &self,
117        to: usize,
118    ) -> usize {
119        self.bytecode.len() + 1 - to
120    }
121
122    fn compile_argument<T: ToString>(
123        &mut self,
124        function_kind: T,
125        arguments: &[&'arena Node<'arena>],
126        index: usize,
127    ) -> CompilerResult<usize> {
128        let arg = arguments.get(index).ok_or_else(|| {
129            CompilerError::ArgumentNotFound {
130                index,
131                function: function_kind.to_string(),
132            }
133        })?;
134
135        self.compile_node(arg)
136    }
137
138    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
139    fn compile_member_fast(
140        &mut self,
141        node: &'arena Node<'arena>,
142    ) -> Option<Vec<FetchFastTarget>> {
143        match node {
144            Node::Root => Some(vec![FetchFastTarget::Root]),
145            Node::Identifier(v) => Some(vec![
146                FetchFastTarget::Begin,
147                FetchFastTarget::String(Arc::from(*v)),
148            ]),
149            Node::Member { node, property } => {
150                let mut path = self.compile_member_fast(node)?;
151                match property {
152                    Node::String(v) => {
153                        path.push(FetchFastTarget::String(Arc::from(*v)));
154                        Some(path)
155                    },
156                    Node::Number(v) => {
157                        if let Some(idx) = v.to_u32() {
158                            path.push(FetchFastTarget::Number(idx));
159                            Some(path)
160                        } else {
161                            None
162                        }
163                    },
164                    _ => None,
165                }
166            },
167            _ => None,
168        }
169    }
170
171    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
172    fn compile_node(
173        &mut self,
174        node: &'arena Node<'arena>,
175    ) -> CompilerResult<usize> {
176        match node {
177            Node::Null => Ok(self.emit(Opcode::PushNull)),
178            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
179            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
180            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
181            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
182            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
183            Node::Array(v) => {
184                v.iter().try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
185                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
186                Ok(self.emit(Opcode::Array))
187            },
188            Node::Object(v) => {
189                v.iter().try_for_each(|&(key, value)| {
190                    self.compile_node(key).map(|_| ())?;
191                    self.emit(Opcode::CallFunction {
192                        arg_count: 1,
193                        kind: FunctionKind::Internal(InternalFunction::String),
194                    });
195                    self.compile_node(value).map(|_| ())?;
196                    Ok(())
197                })?;
198
199                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
200                Ok(self.emit(Opcode::Object))
201            },
202            Node::Assignments { list, output } => {
203                self.emit(Opcode::AssignedObjectBegin);
204                list.iter().try_for_each(|&(key, value)| {
205                    self.compile_node(key).map(|_| ())?;
206                    self.compile_node(value).map(|_| ())?;
207                    self.emit(Opcode::AssignedObjectStep);
208
209                    Ok(())
210                })?;
211
212                if let Some(output) = output {
213                    self.compile_node(output).map(|_| ())?;
214                }
215
216                Ok(self.emit(Opcode::AssignedObjectEnd {
217                    with_return: output.is_some(),
218                }))
219            },
220            Node::Identifier(v) => {
221                Ok(self.emit(Opcode::FetchEnv(Arc::from(*v))))
222            },
223            Node::Closure(v) => self.compile_node(v),
224            Node::Parenthesized(v) => self.compile_node(v),
225            Node::Member { node: n, property: p } => {
226                if let Some(path) = self.compile_member_fast(node) {
227                    Ok(self.emit(Opcode::FetchFast(path)))
228                } else {
229                    self.compile_node(n)?;
230                    self.compile_node(p)?;
231                    Ok(self.emit(Opcode::Fetch))
232                }
233            },
234            Node::TemplateString(parts) => {
235                parts.iter().try_for_each(|&n| {
236                    self.compile_node(n).map(|_| ())?;
237                    self.emit(Opcode::CallFunction {
238                        arg_count: 1,
239                        kind: FunctionKind::Internal(InternalFunction::String),
240                    });
241                    Ok(())
242                })?;
243
244                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
245                self.emit(Opcode::Array);
246                self.emit(Opcode::PushString(Arc::from("")));
247                Ok(self.emit(Opcode::Join))
248            },
249            Node::Slice { node, to, from } => {
250                self.compile_node(node)?;
251                if let Some(t) = to {
252                    self.compile_node(t)?;
253                } else {
254                    self.emit(Opcode::Len);
255                    self.emit(Opcode::PushNumber(dec!(1)));
256                    self.emit(Opcode::Subtract);
257                }
258
259                if let Some(f) = from {
260                    self.compile_node(f)?;
261                } else {
262                    self.emit(Opcode::PushNumber(dec!(0)));
263                }
264
265                Ok(self.emit(Opcode::Slice))
266            },
267            Node::Interval { left, right, left_bracket, right_bracket } => {
268                self.compile_node(left)?;
269                self.compile_node(right)?;
270                Ok(self.emit(Opcode::Interval {
271                    left_bracket: *left_bracket,
272                    right_bracket: *right_bracket,
273                }))
274            },
275            Node::Conditional { condition, on_true, on_false } => {
276                self.compile_node(condition)?;
277                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
278
279                self.emit(Opcode::Pop);
280                self.compile_node(on_true)?;
281                let end = self.emit(Opcode::Jump(Jump::Forward, 0));
282
283                self.replace(
284                    otherwise,
285                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
286                );
287                self.emit(Opcode::Pop);
288                let b = self.compile_node(on_false)?;
289                self.replace(
290                    end,
291                    Opcode::Jump(Jump::Forward, (b - end) as u32),
292                );
293
294                Ok(b)
295            },
296            Node::Unary { node, operator } => {
297                let curr = self.compile_node(node)?;
298                match *operator {
299                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
300                    Operator::Arithmetic(ArithmeticOperator::Subtract) => {
301                        Ok(self.emit(Opcode::Negate))
302                    },
303                    Operator::Logical(LogicalOperator::Not) => {
304                        Ok(self.emit(Opcode::Not))
305                    },
306                    _ => Err(CompilerError::UnknownUnaryOperator {
307                        operator: operator.to_string(),
308                    }),
309                }
310            },
311            Node::Binary { left, right, operator } => match *operator {
312                Operator::Comparison(ComparisonOperator::Equal) => {
313                    self.compile_node(left)?;
314                    self.compile_node(right)?;
315
316                    Ok(self.emit(Opcode::Equal))
317                },
318                Operator::Comparison(ComparisonOperator::NotEqual) => {
319                    self.compile_node(left)?;
320                    self.compile_node(right)?;
321
322                    self.emit(Opcode::Equal);
323                    Ok(self.emit(Opcode::Not))
324                },
325                Operator::Logical(LogicalOperator::Or) => {
326                    self.compile_node(left)?;
327                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
328                    self.emit(Opcode::Pop);
329                    let r = self.compile_node(right)?;
330                    self.replace(
331                        end,
332                        Opcode::Jump(Jump::IfTrue, (r - end) as u32),
333                    );
334
335                    Ok(r)
336                },
337                Operator::Logical(LogicalOperator::And) => {
338                    self.compile_node(left)?;
339                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
340                    self.emit(Opcode::Pop);
341                    let r = self.compile_node(right)?;
342                    self.replace(
343                        end,
344                        Opcode::Jump(Jump::IfFalse, (r - end) as u32),
345                    );
346
347                    Ok(r)
348                },
349                Operator::Logical(LogicalOperator::NullishCoalescing) => {
350                    self.compile_node(left)?;
351                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
352                    self.emit(Opcode::Pop);
353                    let r = self.compile_node(right)?;
354                    self.replace(
355                        end,
356                        Opcode::Jump(Jump::IfNotNull, (r - end) as u32),
357                    );
358
359                    Ok(r)
360                },
361                Operator::Comparison(ComparisonOperator::In) => {
362                    self.compile_node(left)?;
363                    self.compile_node(right)?;
364                    Ok(self.emit(Opcode::In))
365                },
366                Operator::Comparison(ComparisonOperator::NotIn) => {
367                    self.compile_node(left)?;
368                    self.compile_node(right)?;
369                    self.emit(Opcode::In);
370                    Ok(self.emit(Opcode::Not))
371                },
372                Operator::Comparison(ComparisonOperator::LessThan) => {
373                    self.compile_node(left)?;
374                    self.compile_node(right)?;
375                    Ok(self.emit(Opcode::Compare(Compare::Less)))
376                },
377                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
378                    self.compile_node(left)?;
379                    self.compile_node(right)?;
380                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
381                },
382                Operator::Comparison(ComparisonOperator::GreaterThan) => {
383                    self.compile_node(left)?;
384                    self.compile_node(right)?;
385                    Ok(self.emit(Opcode::Compare(Compare::More)))
386                },
387                Operator::Comparison(
388                    ComparisonOperator::GreaterThanOrEqual,
389                ) => {
390                    self.compile_node(left)?;
391                    self.compile_node(right)?;
392                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
393                },
394                Operator::Arithmetic(ArithmeticOperator::Add) => {
395                    self.compile_node(left)?;
396                    self.compile_node(right)?;
397                    Ok(self.emit(Opcode::Add))
398                },
399                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
400                    self.compile_node(left)?;
401                    self.compile_node(right)?;
402                    Ok(self.emit(Opcode::Subtract))
403                },
404                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
405                    self.compile_node(left)?;
406                    self.compile_node(right)?;
407                    Ok(self.emit(Opcode::Multiply))
408                },
409                Operator::Arithmetic(ArithmeticOperator::Divide) => {
410                    self.compile_node(left)?;
411                    self.compile_node(right)?;
412                    Ok(self.emit(Opcode::Divide))
413                },
414                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
415                    self.compile_node(left)?;
416                    self.compile_node(right)?;
417                    Ok(self.emit(Opcode::Modulo))
418                },
419                Operator::Arithmetic(ArithmeticOperator::Power) => {
420                    self.compile_node(left)?;
421                    self.compile_node(right)?;
422                    Ok(self.emit(Opcode::Exponent))
423                },
424                _ => Err(CompilerError::UnknownBinaryOperator {
425                    operator: operator.to_string(),
426                }),
427            },
428            Node::FunctionCall { kind, arguments } => match kind {
429                FunctionKind::Internal(_)
430                | FunctionKind::Deprecated(_)
431                | FunctionKind::Mf(_) => {
432                    let function = FunctionRegistry::get_definition(kind)
433                        .ok_or_else(|| CompilerError::UnknownFunction {
434                            name: kind.to_string(),
435                        })?;
436
437                    let min_params = function.required_parameters();
438                    let max_params =
439                        min_params + function.optional_parameters();
440                    if arguments.len() < min_params
441                        || arguments.len() > max_params
442                    {
443                        return Err(CompilerError::InvalidFunctionCall {
444                            name: kind.to_string(),
445                            message: "Invalid number of arguments".to_string(),
446                        });
447                    }
448
449                    for i in 0..arguments.len() {
450                        self.compile_argument(kind, arguments, i)?;
451                    }
452
453                    Ok(self.emit(Opcode::CallFunction {
454                        kind: kind.clone(),
455                        arg_count: arguments.len() as u32,
456                    }))
457                },
458                FunctionKind::Closure(c) => match c {
459                    ClosureFunction::All => {
460                        self.compile_argument(kind, arguments, 0)?;
461                        self.emit(Opcode::Begin);
462                        let mut loop_break: usize = 0;
463                        self.emit_loop(|c| {
464                            c.compile_argument(kind, arguments, 1)?;
465                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
466                            c.emit(Opcode::Pop);
467                            Ok(())
468                        })?;
469                        let e = self.emit(Opcode::PushBool(true));
470                        self.replace(
471                            loop_break,
472                            Opcode::Jump(
473                                Jump::IfFalse,
474                                (e - loop_break) as u32,
475                            ),
476                        );
477                        Ok(self.emit(Opcode::End))
478                    },
479                    ClosureFunction::None => {
480                        self.compile_argument(kind, arguments, 0)?;
481                        self.emit(Opcode::Begin);
482                        let mut loop_break: usize = 0;
483                        self.emit_loop(|c| {
484                            c.compile_argument(kind, arguments, 1)?;
485                            c.emit(Opcode::Not);
486                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
487                            c.emit(Opcode::Pop);
488                            Ok(())
489                        })?;
490                        let e = self.emit(Opcode::PushBool(true));
491                        self.replace(
492                            loop_break,
493                            Opcode::Jump(
494                                Jump::IfFalse,
495                                (e - loop_break) as u32,
496                            ),
497                        );
498                        Ok(self.emit(Opcode::End))
499                    },
500                    ClosureFunction::Some => {
501                        self.compile_argument(kind, arguments, 0)?;
502                        self.emit(Opcode::Begin);
503                        let mut loop_break: usize = 0;
504                        self.emit_loop(|c| {
505                            c.compile_argument(kind, arguments, 1)?;
506                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
507                            c.emit(Opcode::Pop);
508                            Ok(())
509                        })?;
510                        let e = self.emit(Opcode::PushBool(false));
511                        self.replace(
512                            loop_break,
513                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
514                        );
515                        Ok(self.emit(Opcode::End))
516                    },
517                    ClosureFunction::One => {
518                        self.compile_argument(kind, arguments, 0)?;
519                        self.emit(Opcode::Begin);
520                        self.emit_loop(|c| {
521                            c.compile_argument(kind, arguments, 1)?;
522                            c.emit_cond(|c| {
523                                c.emit(Opcode::IncrementCount);
524                            });
525                            Ok(())
526                        })?;
527                        self.emit(Opcode::GetCount);
528                        self.emit(Opcode::PushNumber(dec!(1)));
529                        self.emit(Opcode::Equal);
530                        Ok(self.emit(Opcode::End))
531                    },
532                    ClosureFunction::Filter => {
533                        self.compile_argument(kind, arguments, 0)?;
534                        self.emit(Opcode::Begin);
535                        self.emit_loop(|c| {
536                            c.compile_argument(kind, arguments, 1)?;
537                            c.emit_cond(|c| {
538                                c.emit(Opcode::IncrementCount);
539                                c.emit(Opcode::Pointer);
540                            });
541                            Ok(())
542                        })?;
543                        self.emit(Opcode::GetCount);
544                        self.emit(Opcode::End);
545                        Ok(self.emit(Opcode::Array))
546                    },
547                    ClosureFunction::Map => {
548                        self.compile_argument(kind, arguments, 0)?;
549                        self.emit(Opcode::Begin);
550                        self.emit_loop(|c| {
551                            c.compile_argument(kind, arguments, 1)?;
552                            Ok(())
553                        })?;
554                        self.emit(Opcode::GetLen);
555                        self.emit(Opcode::End);
556                        Ok(self.emit(Opcode::Array))
557                    },
558                    ClosureFunction::FlatMap => {
559                        self.compile_argument(kind, arguments, 0)?;
560                        self.emit(Opcode::Begin);
561                        self.emit_loop(|c| {
562                            c.compile_argument(kind, arguments, 1)?;
563                            Ok(())
564                        })?;
565                        self.emit(Opcode::GetLen);
566                        self.emit(Opcode::End);
567                        self.emit(Opcode::Array);
568                        Ok(self.emit(Opcode::Flatten))
569                    },
570                    ClosureFunction::Count => {
571                        self.compile_argument(kind, arguments, 0)?;
572                        self.emit(Opcode::Begin);
573                        self.emit_loop(|c| {
574                            c.compile_argument(kind, arguments, 1)?;
575                            c.emit_cond(|c| {
576                                c.emit(Opcode::IncrementCount);
577                            });
578                            Ok(())
579                        })?;
580                        self.emit(Opcode::GetCount);
581                        Ok(self.emit(Opcode::End))
582                    },
583                },
584            },
585            Node::MethodCall { kind, this, arguments } => {
586                let method =
587                    MethodRegistry::get_definition(kind).ok_or_else(|| {
588                        CompilerError::UnknownFunction {
589                            name: kind.to_string(),
590                        }
591                    })?;
592
593                self.compile_node(this)?;
594
595                let min_params = method.required_parameters() - 1;
596                let max_params = min_params + method.optional_parameters();
597                if arguments.len() < min_params || arguments.len() > max_params
598                {
599                    return Err(CompilerError::InvalidMethodCall {
600                        name: kind.to_string(),
601                        message: "Invalid number of arguments".to_string(),
602                    });
603                }
604
605                for i in 0..arguments.len() {
606                    self.compile_argument(kind, arguments, i)?;
607                }
608
609                Ok(self.emit(Opcode::CallMethod {
610                    kind: kind.clone(),
611                    arg_count: arguments.len() as u32,
612                }))
613            },
614            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
615        }
616    }
617}