moduforge_rules_expression/parser/
ast.rs

1use crate::functions::{FunctionKind, MethodKind};
2use crate::lexer::{Bracket, Operator};
3use rust_decimal::Decimal;
4use std::cell::Cell;
5use strum_macros::IntoStaticStr;
6use thiserror::Error;
7
8/// 抽象语法树节点枚举
9/// 定义了表达式中所有可能的节点类型
10#[derive(Debug, PartialEq, Clone, IntoStaticStr)]
11pub enum Node<'a> {
12    /// 空值节点
13    Null,
14    /// 布尔值节点
15    Bool(bool),
16    /// 数字节点(使用高精度十进制)
17    Number(Decimal),
18    /// 字符串节点
19    String(&'a str),
20    /// 模板字符串节点(包含多个子节点的数组)
21    TemplateString(&'a [&'a Node<'a>]),
22    /// 指针节点(回调引用 #)
23    Pointer,
24    /// 数组节点
25    Array(&'a [&'a Node<'a>]),
26    /// 对象节点(键值对数组)
27    Object(&'a [(&'a Node<'a>, &'a Node<'a>)]),
28    /// 标识符节点
29    Identifier(&'a str),
30    /// 闭包节点
31    Closure(&'a Node<'a>),
32    /// 括号表达式节点
33    Parenthesized(&'a Node<'a>),
34    /// 根节点($ 引用)
35    Root,
36    /// 成员访问节点(对象.属性 或 对象["属性"])
37    Member {
38        node: &'a Node<'a>,     // 被访问的对象
39        property: &'a Node<'a>, // 属性名
40    },
41    /// 切片节点(数组[from:to])
42    Slice {
43        node: &'a Node<'a>,         // 被切片的对象
44        from: Option<&'a Node<'a>>, // 开始位置(可选)
45        to: Option<&'a Node<'a>>,   // 结束位置(可选)
46    },
47    /// 区间节点([a, b] 或 (a, b) 等)
48    Interval {
49        left: &'a Node<'a>,     // 左边界
50        right: &'a Node<'a>,    // 右边界
51        left_bracket: Bracket,  // 左括号类型
52        right_bracket: Bracket, // 右括号类型
53    },
54    /// 条件表达式节点(三元操作符 condition ? true_expr : false_expr)
55    Conditional {
56        condition: &'a Node<'a>, // 条件表达式
57        on_true: &'a Node<'a>,   // 条件为真时的表达式
58        on_false: &'a Node<'a>,  // 条件为假时的表达式
59    },
60    /// 一元操作节点(如 -x, !x, +x)
61    Unary {
62        node: &'a Node<'a>, // 操作数
63        operator: Operator, // 操作符
64    },
65    /// 二元操作节点(如 x + y, x == y)
66    Binary {
67        left: &'a Node<'a>,  // 左操作数
68        operator: Operator,  // 操作符
69        right: &'a Node<'a>, // 右操作数
70    },
71    /// 函数调用节点
72    FunctionCall {
73        kind: FunctionKind,            // 函数类型
74        arguments: &'a [&'a Node<'a>], // 参数列表
75    },
76    /// 方法调用节点
77    MethodCall {
78        kind: MethodKind,              // 方法类型
79        this: &'a Node<'a>,            // 调用对象(this)
80        arguments: &'a [&'a Node<'a>], // 参数列表
81    },
82    /// 错误节点(包含解析错误信息)
83    Error {
84        node: Option<&'a Node<'a>>, // 可选的关联节点
85        error: AstNodeError<'a>,    // 错误信息
86    },
87}
88
89impl<'a> Node<'a> {
90    /// 遍历AST节点
91    /// 对每个节点(包括子节点)执行指定的函数
92    pub fn walk<F>(
93        &self,
94        mut func: F,
95    ) where
96        F: FnMut(&Self) + Clone,
97    {
98        // 先对当前节点执行函数
99        {
100            func(self);
101        };
102
103        // 然后递归遍历子节点
104        match self {
105            // 叶子节点:无子节点
106            Node::Null => {},
107            Node::Bool(_) => {},
108            Node::Number(_) => {},
109            Node::String(_) => {},
110            Node::Pointer => {},
111            Node::Identifier(_) => {},
112            Node::Root => {},
113
114            // 错误节点:可能包含一个子节点
115            Node::Error { node, .. } => {
116                if let Some(n) = node {
117                    n.walk(func.clone())
118                }
119            },
120
121            // 包含多个子节点的节点
122            Node::TemplateString(parts) => {
123                parts.iter().for_each(|n| n.walk(func.clone()))
124            },
125            Node::Array(parts) => {
126                parts.iter().for_each(|n| n.walk(func.clone()))
127            },
128            Node::Object(obj) => obj.iter().for_each(|(k, v)| {
129                k.walk(func.clone());
130                v.walk(func.clone());
131            }),
132
133            // 包含单个子节点的节点
134            Node::Closure(closure) => closure.walk(func.clone()),
135            Node::Parenthesized(c) => c.walk(func.clone()),
136
137            // 包含两个子节点的节点
138            Node::Member { node, property } => {
139                node.walk(func.clone());
140                property.walk(func.clone());
141            },
142            Node::Slice { node, to, from } => {
143                node.walk(func.clone());
144                if let Some(to) = to {
145                    to.walk(func.clone());
146                }
147                if let Some(from) = from {
148                    from.walk(func.clone());
149                }
150            },
151            Node::Interval { left, right, .. } => {
152                left.walk(func.clone());
153                right.walk(func.clone());
154            },
155
156            // 一元操作节点
157            Node::Unary { node, .. } => {
158                node.walk(func);
159            },
160
161            // 二元操作节点
162            Node::Binary { left, right, .. } => {
163                left.walk(func.clone());
164                right.walk(func.clone());
165            },
166
167            // 函数调用节点
168            Node::FunctionCall { arguments, .. } => {
169                arguments.iter().for_each(|n| n.walk(func.clone()));
170            },
171
172            // 方法调用节点
173            Node::MethodCall { this, arguments, .. } => {
174                this.walk(func.clone());
175                arguments.iter().for_each(|n| n.walk(func.clone()));
176            },
177
178            // 条件表达式节点
179            Node::Conditional { on_true, condition, on_false } => {
180                condition.walk(func.clone());
181                on_true.walk(func.clone());
182                on_false.walk(func.clone());
183            },
184        };
185    }
186
187    /// 查找AST中的第一个错误
188    /// 返回第一个遇到的错误节点中的错误信息
189    pub fn first_error(&self) -> Option<AstNodeError> {
190        let error_cell = Cell::new(None);
191        self.walk(|n| {
192            if let Node::Error { error, .. } = n {
193                error_cell.set(Some(error.clone()))
194            }
195        });
196
197        error_cell.into_inner()
198    }
199
200    /// 检查AST是否包含错误节点
201    pub fn has_error(&self) -> bool {
202        self.first_error().is_some()
203    }
204
205    /// 获取节点的位置范围
206    /// 只有错误节点才有位置信息
207    pub(crate) fn span(&self) -> Option<(u32, u32)> {
208        match self {
209            Node::Error { error, .. } => match error {
210                AstNodeError::UnknownBuiltIn { span, .. } => Some(span.clone()),
211                AstNodeError::UnknownMethod { span, .. } => Some(span.clone()),
212                AstNodeError::UnexpectedIdentifier { span, .. } => {
213                    Some(span.clone())
214                },
215                AstNodeError::UnexpectedToken { span, .. } => {
216                    Some(span.clone())
217                },
218                AstNodeError::InvalidNumber { span, .. } => Some(span.clone()),
219                AstNodeError::InvalidBoolean { span, .. } => Some(span.clone()),
220                AstNodeError::InvalidProperty { span, .. } => {
221                    Some(span.clone())
222                },
223                AstNodeError::MissingToken { position, .. } => {
224                    Some((*position as u32, *position as u32))
225                },
226                AstNodeError::Custom { span, .. } => Some(span.clone()),
227            },
228            _ => None,
229        }
230    }
231}
232
233/// AST节点错误枚举
234/// 定义了AST节点中可能出现的各种错误类型
235#[derive(Debug, PartialEq, Eq, Clone, Error)]
236pub enum AstNodeError<'a> {
237    /// 未知内置函数错误
238    #[error("Unknown function `{name}` at ({}, {})", span.0, span.1)]
239    UnknownBuiltIn { name: &'a str, span: (u32, u32) },
240
241    /// 未知方法错误
242    #[error("Unknown method `{name}` at ({}, {})", span.0, span.1)]
243    UnknownMethod { name: &'a str, span: (u32, u32) },
244
245    /// 意外标识符错误
246    #[error("Unexpected identifier: {received} at ({}, {}); Expected {expected}.", span.0, span.1)]
247    UnexpectedIdentifier {
248        received: &'a str, // 实际收到的标识符
249        expected: &'a str, // 期望的标识符
250        span: (u32, u32),  // 位置范围
251    },
252
253    /// 意外令牌错误
254    #[error("Unexpected token: {received} at ({}, {}); Expected {expected}.", span.0, span.1)]
255    UnexpectedToken {
256        received: &'a str, // 实际收到的令牌
257        expected: &'a str, // 期望的令牌
258        span: (u32, u32),  // 位置范围
259    },
260
261    /// 无效数字错误
262    #[error("Invalid number: {number} at ({}, {})", span.0, span.1)]
263    InvalidNumber { number: &'a str, span: (u32, u32) },
264
265    /// 无效布尔值错误
266    #[error("Invalid boolean: {boolean} at ({}, {})", span.0, span.1)]
267    InvalidBoolean { boolean: &'a str, span: (u32, u32) },
268
269    /// 无效属性错误
270    #[error("Invalid property: {property} at ({}, {})", span.0, span.1)]
271    InvalidProperty { property: &'a str, span: (u32, u32) },
272
273    /// 缺少期望令牌错误
274    #[error("Missing expected token: {expected} at {position}")]
275    MissingToken { expected: &'a str, position: usize },
276
277    /// 自定义错误
278    #[error("{message} at ({}, {})", span.0, span.1)]
279    Custom { message: &'a str, span: (u32, u32) },
280}