1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{
6    ClosureFunction, FunctionKind, InternalFunction, MethodRegistry,
7};
8use crate::lexer::{
9    ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator,
10};
11use crate::parser::Node;
12use rust_decimal::prelude::ToPrimitive;
13use rust_decimal::Decimal;
14use rust_decimal_macros::dec;
15use std::sync::Arc;
16
17#[derive(Debug)]
20pub struct Compiler {
21    bytecode: Vec<Opcode>, }
23
24impl Compiler {
25    pub fn new() -> Self {
27        Self { bytecode: Default::default() }
28    }
29
30    pub fn compile(
38        &mut self,
39        root: &Node,
40    ) -> CompilerResult<&[Opcode]> {
41        self.bytecode.clear();
42
43        CompilerInner::new(&mut self.bytecode, root).compile()?;
44        Ok(self.bytecode.as_slice())
45    }
46
47    pub fn get_bytecode(&self) -> &[Opcode] {
49        self.bytecode.as_slice()
50    }
51}
52
53#[derive(Debug)]
56struct CompilerInner<'arena, 'bytecode_ref> {
57    root: &'arena Node<'arena>,               bytecode: &'bytecode_ref mut Vec<Opcode>, }
60
61impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
62    pub fn new(
64        bytecode: &'bytecode_ref mut Vec<Opcode>,
65        root: &'arena Node<'arena>,
66    ) -> Self {
67        Self { root, bytecode }
68    }
69
70    pub fn compile(&mut self) -> CompilerResult<()> {
72        self.compile_node(self.root)?;
73        Ok(())
74    }
75
76    fn emit(
84        &mut self,
85        op: Opcode,
86    ) -> usize {
87        self.bytecode.push(op);
88        self.bytecode.len()
89    }
90
91    fn emit_loop<F>(
96        &mut self,
97        body: F,
98    ) -> CompilerResult<()>
99    where
100        F: FnOnce(&mut Self) -> CompilerResult<()>,
101    {
102        let begin = self.bytecode.len();
103        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0)); body(self)?; self.emit(Opcode::IncrementIt); let e = self.emit(Opcode::Jump(
109            Jump::Backward,
110            self.calc_backward_jump(begin) as u32,
111        )); self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
115        Ok(())
116    }
117
118    fn emit_cond<F>(
123        &mut self,
124        mut body: F,
125    ) where
126        F: FnMut(&mut Self),
127    {
128        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop); body(self); let jmp = self.emit(Opcode::Jump(Jump::Forward, 0)); self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
135        let e = self.emit(Opcode::Pop); self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
137    }
138
139    fn replace(
145        &mut self,
146        at: usize,
147        op: Opcode,
148    ) {
149        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
150    }
151
152    fn calc_backward_jump(
160        &self,
161        to: usize,
162    ) -> usize {
163        self.bytecode.len() + 1 - to
164    }
165
166    fn compile_argument<T: ToString>(
173        &mut self,
174        function_kind: T,
175        arguments: &[&'arena Node<'arena>],
176        index: usize,
177    ) -> CompilerResult<usize> {
178        let arg = arguments.get(index).ok_or_else(|| {
179            CompilerError::ArgumentNotFound {
180                index,
181                function: function_kind.to_string(),
182            }
183        })?;
184
185        self.compile_node(arg)
186    }
187
188    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
197    fn compile_member_fast(
198        &mut self,
199        node: &'arena Node<'arena>,
200    ) -> Option<Vec<FetchFastTarget>> {
201        match node {
202            Node::Root => Some(vec![FetchFastTarget::Root]),
203            Node::Identifier(v) => Some(vec![
204                FetchFastTarget::Root,
205                FetchFastTarget::String(Arc::from(*v)),
206            ]),
207            Node::Member { node, property } => {
208                let mut path = self.compile_member_fast(node)?;
209                match property {
210                    Node::String(v) => {
211                        path.push(FetchFastTarget::String(Arc::from(*v)));
212                        Some(path)
213                    },
214                    Node::Number(v) => {
215                        if let Some(idx) = v.to_u32() {
216                            path.push(FetchFastTarget::Number(idx));
217                            Some(path)
218                        } else {
219                            None
220                        }
221                    },
222                    _ => None,
223                }
224            },
225            _ => None,
226        }
227    }
228
229    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
238    fn compile_node(
239        &mut self,
240        node: &'arena Node<'arena>,
241    ) -> CompilerResult<usize> {
242        match node {
243            Node::Null => Ok(self.emit(Opcode::PushNull)),
245            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
246            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
247            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
248            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
249            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
250
251            Node::Array(v) => {
253                v.iter().try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
254                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
255                Ok(self.emit(Opcode::Array))
256            },
257
258            Node::Object(v) => {
260                v.iter().try_for_each(|&(key, value)| {
261                    self.compile_node(key).map(|_| ())?;
262                    self.emit(Opcode::CallFunction {
264                        arg_count: 1,
265                        kind: FunctionKind::Internal(InternalFunction::String),
266                    });
267                    self.compile_node(value).map(|_| ())?;
268                    Ok(())
269                })?;
270
271                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
272                Ok(self.emit(Opcode::Object))
273            },
274
275            Node::Identifier(v) => {
277                Ok(self.emit(Opcode::FetchEnv(Arc::from(*v))))
278            },
279
280            Node::Closure(v) => self.compile_node(v),
282            Node::Parenthesized(v) => self.compile_node(v),
283
284            Node::Member { node: n, property: p } => {
286                if let Some(path) = self.compile_member_fast(node) {
287                    Ok(self.emit(Opcode::FetchFast(path)))
288                } else {
289                    self.compile_node(n)?;
290                    self.compile_node(p)?;
291                    Ok(self.emit(Opcode::Fetch))
292                }
293            },
294
295            Node::TemplateString(parts) => {
297                parts.iter().try_for_each(|&n| {
298                    self.compile_node(n).map(|_| ())?;
299                    self.emit(Opcode::CallFunction {
301                        arg_count: 1,
302                        kind: FunctionKind::Internal(InternalFunction::String),
303                    });
304                    Ok(())
305                })?;
306
307                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
308                self.emit(Opcode::Array);
309                self.emit(Opcode::PushString(Arc::from("")));
310                Ok(self.emit(Opcode::Join))
311            },
312
313            Node::Slice { node, to, from } => {
315                self.compile_node(node)?;
316                if let Some(t) = to {
317                    self.compile_node(t)?;
318                } else {
319                    self.emit(Opcode::Len);
321                    self.emit(Opcode::PushNumber(dec!(1)));
322                    self.emit(Opcode::Subtract);
323                }
324
325                if let Some(f) = from {
326                    self.compile_node(f)?;
327                } else {
328                    self.emit(Opcode::PushNumber(dec!(0)));
330                }
331
332                Ok(self.emit(Opcode::Slice))
333            },
334
335            Node::Interval { left, right, left_bracket, right_bracket } => {
337                self.compile_node(left)?;
338                self.compile_node(right)?;
339                Ok(self.emit(Opcode::Interval {
340                    left_bracket: *left_bracket,
341                    right_bracket: *right_bracket,
342                }))
343            },
344
345            Node::Conditional { condition, on_true, on_false } => {
347                self.compile_node(condition)?;
348                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop); self.compile_node(on_true)?; let end = self.emit(Opcode::Jump(Jump::Forward, 0)); self.replace(
356                    otherwise,
357                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
358                );
359                self.emit(Opcode::Pop); let b = self.compile_node(on_false)?; self.replace(
362                    end,
363                    Opcode::Jump(Jump::Forward, (b - end) as u32),
364                );
365
366                Ok(b)
367            },
368
369            Node::Unary { node, operator } => {
371                let curr = self.compile_node(node)?;
372                match *operator {
373                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr), Operator::Arithmetic(ArithmeticOperator::Subtract) => {
375                        Ok(self.emit(Opcode::Negate)) },
377                    Operator::Logical(LogicalOperator::Not) => {
378                        Ok(self.emit(Opcode::Not))
379                    }, _ => Err(CompilerError::UnknownUnaryOperator {
381                        operator: operator.to_string(),
382                    }),
383                }
384            },
385
386            Node::Binary { left, right, operator } => match *operator {
388                Operator::Comparison(ComparisonOperator::Equal) => {
390                    self.compile_node(left)?;
391                    self.compile_node(right)?;
392                    Ok(self.emit(Opcode::Equal))
393                },
394
395                Operator::Comparison(ComparisonOperator::NotEqual) => {
397                    self.compile_node(left)?;
398                    self.compile_node(right)?;
399                    self.emit(Opcode::Equal);
400                    Ok(self.emit(Opcode::Not))
401                },
402
403                Operator::Logical(LogicalOperator::Or) => {
405                    self.compile_node(left)?;
406                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0)); self.emit(Opcode::Pop);
408                    let r = self.compile_node(right)?;
409                    self.replace(
410                        end,
411                        Opcode::Jump(Jump::IfTrue, (r - end) as u32),
412                    );
413                    Ok(r)
414                },
415
416                Operator::Logical(LogicalOperator::And) => {
418                    self.compile_node(left)?;
419                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0)); self.emit(Opcode::Pop);
421                    let r = self.compile_node(right)?;
422                    self.replace(
423                        end,
424                        Opcode::Jump(Jump::IfFalse, (r - end) as u32),
425                    );
426                    Ok(r)
427                },
428
429                Operator::Logical(LogicalOperator::NullishCoalescing) => {
431                    self.compile_node(left)?;
432                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0)); self.emit(Opcode::Pop);
434                    let r = self.compile_node(right)?;
435                    self.replace(
436                        end,
437                        Opcode::Jump(Jump::IfNotNull, (r - end) as u32),
438                    );
439                    Ok(r)
440                },
441
442                Operator::Comparison(ComparisonOperator::In) => {
444                    self.compile_node(left)?;
445                    self.compile_node(right)?;
446                    Ok(self.emit(Opcode::In))
447                },
448
449                Operator::Comparison(ComparisonOperator::NotIn) => {
451                    self.compile_node(left)?;
452                    self.compile_node(right)?;
453                    self.emit(Opcode::In);
454                    Ok(self.emit(Opcode::Not))
455                },
456
457                Operator::Comparison(ComparisonOperator::LessThan) => {
459                    self.compile_node(left)?;
460                    self.compile_node(right)?;
461                    Ok(self.emit(Opcode::Compare(Compare::Less)))
462                },
463                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
464                    self.compile_node(left)?;
465                    self.compile_node(right)?;
466                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
467                },
468                Operator::Comparison(ComparisonOperator::GreaterThan) => {
469                    self.compile_node(left)?;
470                    self.compile_node(right)?;
471                    Ok(self.emit(Opcode::Compare(Compare::More)))
472                },
473                Operator::Comparison(
474                    ComparisonOperator::GreaterThanOrEqual,
475                ) => {
476                    self.compile_node(left)?;
477                    self.compile_node(right)?;
478                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
479                },
480
481                Operator::Arithmetic(ArithmeticOperator::Add) => {
483                    self.compile_node(left)?;
484                    self.compile_node(right)?;
485                    Ok(self.emit(Opcode::Add))
486                },
487                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
488                    self.compile_node(left)?;
489                    self.compile_node(right)?;
490                    Ok(self.emit(Opcode::Subtract))
491                },
492                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
493                    self.compile_node(left)?;
494                    self.compile_node(right)?;
495                    Ok(self.emit(Opcode::Multiply))
496                },
497                Operator::Arithmetic(ArithmeticOperator::Divide) => {
498                    self.compile_node(left)?;
499                    self.compile_node(right)?;
500                    Ok(self.emit(Opcode::Divide))
501                },
502                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
503                    self.compile_node(left)?;
504                    self.compile_node(right)?;
505                    Ok(self.emit(Opcode::Modulo))
506                },
507                Operator::Arithmetic(ArithmeticOperator::Power) => {
508                    self.compile_node(left)?;
509                    self.compile_node(right)?;
510                    Ok(self.emit(Opcode::Exponent))
511                },
512                _ => Err(CompilerError::UnknownBinaryOperator {
513                    operator: operator.to_string(),
514                }),
515            },
516
517            Node::FunctionCall { kind, arguments } => match kind {
519                FunctionKind::Internal(_)
520                | FunctionKind::Deprecated(_)
521                | FunctionKind::Custom(_) => {
522                    let function = FunctionRegistry::get_definition(kind)
523                        .ok_or_else(|| CompilerError::UnknownFunction {
524                            name: kind.to_string(),
525                        })?;
526
527                    let min_params = function.required_parameters();
529                    let max_params =
530                        min_params + function.optional_parameters();
531                    if arguments.len() < min_params
532                        || arguments.len() > max_params
533                    {
534                        return Err(CompilerError::InvalidFunctionCall {
535                            name: kind.to_string(),
536                            message: "无效的参数数量".to_string(),
537                        });
538                    }
539
540                    for i in 0..arguments.len() {
542                        self.compile_argument(kind, arguments, i)?;
543                    }
544
545                    Ok(self.emit(Opcode::CallFunction {
546                        kind: kind.clone(),
547                        arg_count: arguments.len() as u32,
548                    }))
549                },
550
551                FunctionKind::Closure(c) => match c {
553                    ClosureFunction::All => {
555                        self.compile_argument(kind, arguments, 0)?; self.emit(Opcode::Begin);
557                        let mut loop_break: usize = 0;
558                        self.emit_loop(|c| {
559                            c.compile_argument(kind, arguments, 1)?; loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0)); c.emit(Opcode::Pop);
562                            Ok(())
563                        })?;
564                        let e = self.emit(Opcode::PushBool(true)); self.replace(
566                            loop_break,
567                            Opcode::Jump(
568                                Jump::IfFalse,
569                                (e - loop_break) as u32,
570                            ),
571                        );
572                        Ok(self.emit(Opcode::End))
573                    },
574
575                    ClosureFunction::None => {
577                        self.compile_argument(kind, arguments, 0)?;
578                        self.emit(Opcode::Begin);
579                        let mut loop_break: usize = 0;
580                        self.emit_loop(|c| {
581                            c.compile_argument(kind, arguments, 1)?;
582                            c.emit(Opcode::Not); loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
584                            c.emit(Opcode::Pop);
585                            Ok(())
586                        })?;
587                        let e = self.emit(Opcode::PushBool(true));
588                        self.replace(
589                            loop_break,
590                            Opcode::Jump(
591                                Jump::IfFalse,
592                                (e - loop_break) as u32,
593                            ),
594                        );
595                        Ok(self.emit(Opcode::End))
596                    },
597
598                    ClosureFunction::Some => {
600                        self.compile_argument(kind, arguments, 0)?;
601                        self.emit(Opcode::Begin);
602                        let mut loop_break: usize = 0;
603                        self.emit_loop(|c| {
604                            c.compile_argument(kind, arguments, 1)?;
605                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0)); c.emit(Opcode::Pop);
607                            Ok(())
608                        })?;
609                        let e = self.emit(Opcode::PushBool(false)); self.replace(
611                            loop_break,
612                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
613                        );
614                        Ok(self.emit(Opcode::End))
615                    },
616
617                    ClosureFunction::One => {
619                        self.compile_argument(kind, arguments, 0)?;
620                        self.emit(Opcode::Begin);
621                        self.emit_loop(|c| {
622                            c.compile_argument(kind, arguments, 1)?;
623                            c.emit_cond(|c| {
624                                c.emit(Opcode::IncrementCount); });
626                            Ok(())
627                        })?;
628                        self.emit(Opcode::GetCount);
629                        self.emit(Opcode::PushNumber(dec!(1)));
630                        self.emit(Opcode::Equal); Ok(self.emit(Opcode::End))
632                    },
633
634                    ClosureFunction::Filter => {
636                        self.compile_argument(kind, arguments, 0)?;
637                        self.emit(Opcode::Begin);
638                        self.emit_loop(|c| {
639                            c.compile_argument(kind, arguments, 1)?;
640                            c.emit_cond(|c| {
641                                c.emit(Opcode::IncrementCount);
642                                c.emit(Opcode::Pointer); });
644                            Ok(())
645                        })?;
646                        self.emit(Opcode::GetCount);
647                        self.emit(Opcode::End);
648                        Ok(self.emit(Opcode::Array))
649                    },
650
651                    ClosureFunction::Map => {
653                        self.compile_argument(kind, arguments, 0)?;
654                        self.emit(Opcode::Begin);
655                        self.emit_loop(|c| {
656                            c.compile_argument(kind, arguments, 1)?; Ok(())
658                        })?;
659                        self.emit(Opcode::GetLen);
660                        self.emit(Opcode::End);
661                        Ok(self.emit(Opcode::Array))
662                    },
663
664                    ClosureFunction::FlatMap => {
666                        self.compile_argument(kind, arguments, 0)?;
667                        self.emit(Opcode::Begin);
668                        self.emit_loop(|c| {
669                            c.compile_argument(kind, arguments, 1)?;
670                            Ok(())
671                        })?;
672                        self.emit(Opcode::GetLen);
673                        self.emit(Opcode::End);
674                        self.emit(Opcode::Array);
675                        Ok(self.emit(Opcode::Flatten)) },
677
678                    ClosureFunction::Count => {
680                        self.compile_argument(kind, arguments, 0)?;
681                        self.emit(Opcode::Begin);
682                        self.emit_loop(|c| {
683                            c.compile_argument(kind, arguments, 1)?;
684                            c.emit_cond(|c| {
685                                c.emit(Opcode::IncrementCount);
686                            });
687                            Ok(())
688                        })?;
689                        self.emit(Opcode::GetCount);
690                        Ok(self.emit(Opcode::End))
691                    },
692                },
693            },
694
695            Node::MethodCall { kind, this, arguments } => {
697                let method =
698                    MethodRegistry::get_definition(kind).ok_or_else(|| {
699                        CompilerError::UnknownFunction {
700                            name: kind.to_string(),
701                        }
702                    })?;
703
704                self.compile_node(this)?; let min_params = method.required_parameters() - 1;
708                let max_params = min_params + method.optional_parameters();
709                if arguments.len() < min_params || arguments.len() > max_params
710                {
711                    return Err(CompilerError::InvalidMethodCall {
712                        name: kind.to_string(),
713                        message: "Invalid number of arguments".to_string(),
714                    });
715                }
716
717                for i in 0..arguments.len() {
719                    self.compile_argument(kind, arguments, i)?;
720                }
721
722                Ok(self.emit(Opcode::CallMethod {
723                    kind: kind.clone(),
724                    arg_count: arguments.len() as u32,
725                }))
726            },
727
728            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
730        }
731    }
732}