1use crate::{
2 parsing::ast::types::{self, NodeRef},
3 source_range::SourceRange,
4};
5
6#[derive(Copy, Clone, Debug)]
9pub enum Node<'a> {
10 Program(NodeRef<'a, types::Program>),
11
12 ImportStatement(NodeRef<'a, types::ImportStatement>),
13 ExpressionStatement(NodeRef<'a, types::ExpressionStatement>),
14 VariableDeclaration(NodeRef<'a, types::VariableDeclaration>),
15 TypeDeclaration(NodeRef<'a, types::TypeDeclaration>),
16 ReturnStatement(NodeRef<'a, types::ReturnStatement>),
17
18 VariableDeclarator(NodeRef<'a, types::VariableDeclarator>),
19
20 Literal(NodeRef<'a, types::Literal>),
21 TagDeclarator(NodeRef<'a, types::TagDeclarator>),
22 Identifier(NodeRef<'a, types::Identifier>),
23 Name(NodeRef<'a, types::Name>),
24 BinaryExpression(NodeRef<'a, types::BinaryExpression>),
25 FunctionExpression(NodeRef<'a, types::FunctionExpression>),
26 CallExpressionKw(NodeRef<'a, types::CallExpressionKw>),
27 PipeExpression(NodeRef<'a, types::PipeExpression>),
28 PipeSubstitution(NodeRef<'a, types::PipeSubstitution>),
29 ArrayExpression(NodeRef<'a, types::ArrayExpression>),
30 ArrayRangeExpression(NodeRef<'a, types::ArrayRangeExpression>),
31 ObjectExpression(NodeRef<'a, types::ObjectExpression>),
32 MemberExpression(NodeRef<'a, types::MemberExpression>),
33 UnaryExpression(NodeRef<'a, types::UnaryExpression>),
34 IfExpression(NodeRef<'a, types::IfExpression>),
35 ElseIf(&'a types::ElseIf),
36 LabelledExpression(NodeRef<'a, types::LabelledExpression>),
37 AscribedExpression(NodeRef<'a, types::AscribedExpression>),
38
39 Parameter(&'a types::Parameter),
40
41 ObjectProperty(NodeRef<'a, types::ObjectProperty>),
42
43 KclNone(&'a types::KclNone),
44}
45
46impl Node<'_> {
47 pub fn digest(&self) -> Option<[u8; 32]> {
52 match self {
53 Node::Program(n) => n.digest,
54 Node::ImportStatement(n) => n.digest,
55 Node::ExpressionStatement(n) => n.digest,
56 Node::VariableDeclaration(n) => n.digest,
57 Node::TypeDeclaration(n) => n.digest,
58 Node::ReturnStatement(n) => n.digest,
59 Node::VariableDeclarator(n) => n.digest,
60 Node::Literal(n) => n.digest,
61 Node::TagDeclarator(n) => n.digest,
62 Node::Identifier(n) => n.digest,
63 Node::Name(n) => n.digest,
64 Node::BinaryExpression(n) => n.digest,
65 Node::FunctionExpression(n) => n.digest,
66 Node::CallExpressionKw(n) => n.digest,
67 Node::PipeExpression(n) => n.digest,
68 Node::PipeSubstitution(n) => n.digest,
69 Node::ArrayExpression(n) => n.digest,
70 Node::ArrayRangeExpression(n) => n.digest,
71 Node::ObjectExpression(n) => n.digest,
72 Node::MemberExpression(n) => n.digest,
73 Node::UnaryExpression(n) => n.digest,
74 Node::Parameter(p) => p.digest,
75 Node::ObjectProperty(n) => n.digest,
76 Node::IfExpression(n) => n.digest,
77 Node::ElseIf(n) => n.digest,
78 Node::KclNone(n) => n.digest,
79 Node::LabelledExpression(n) => n.digest,
80 Node::AscribedExpression(n) => n.digest,
81 }
82 }
83
84 pub fn ptr_eq(&self, other: Node) -> bool {
92 unsafe { std::ptr::eq(self.ptr(), other.ptr()) }
93 }
94
95 unsafe fn ptr(&self) -> *const () {
96 match self {
97 Node::Program(n) => *n as *const _ as *const (),
98 Node::ImportStatement(n) => *n as *const _ as *const (),
99 Node::ExpressionStatement(n) => *n as *const _ as *const (),
100 Node::VariableDeclaration(n) => *n as *const _ as *const (),
101 Node::TypeDeclaration(n) => *n as *const _ as *const (),
102 Node::ReturnStatement(n) => *n as *const _ as *const (),
103 Node::VariableDeclarator(n) => *n as *const _ as *const (),
104 Node::Literal(n) => *n as *const _ as *const (),
105 Node::TagDeclarator(n) => *n as *const _ as *const (),
106 Node::Identifier(n) => *n as *const _ as *const (),
107 Node::Name(n) => *n as *const _ as *const (),
108 Node::BinaryExpression(n) => *n as *const _ as *const (),
109 Node::FunctionExpression(n) => *n as *const _ as *const (),
110 Node::CallExpressionKw(n) => *n as *const _ as *const (),
111 Node::PipeExpression(n) => *n as *const _ as *const (),
112 Node::PipeSubstitution(n) => *n as *const _ as *const (),
113 Node::ArrayExpression(n) => *n as *const _ as *const (),
114 Node::ArrayRangeExpression(n) => *n as *const _ as *const (),
115 Node::ObjectExpression(n) => *n as *const _ as *const (),
116 Node::MemberExpression(n) => *n as *const _ as *const (),
117 Node::UnaryExpression(n) => *n as *const _ as *const (),
118 Node::Parameter(p) => *p as *const _ as *const (),
119 Node::ObjectProperty(n) => *n as *const _ as *const (),
120 Node::IfExpression(n) => *n as *const _ as *const (),
121 Node::ElseIf(n) => *n as *const _ as *const (),
122 Node::KclNone(n) => *n as *const _ as *const (),
123 Node::LabelledExpression(n) => *n as *const _ as *const (),
124 Node::AscribedExpression(n) => *n as *const _ as *const (),
125 }
126 }
127}
128
129#[derive(Debug)]
131pub enum AstNodeError {
132 NoSourceForAKclNone,
134}
135
136impl TryFrom<&Node<'_>> for SourceRange {
137 type Error = AstNodeError;
138
139 fn try_from(node: &Node) -> Result<Self, Self::Error> {
140 Ok(match node {
141 Node::Program(n) => SourceRange::from(*n),
142 Node::ImportStatement(n) => SourceRange::from(*n),
143 Node::ExpressionStatement(n) => SourceRange::from(*n),
144 Node::VariableDeclaration(n) => SourceRange::from(*n),
145 Node::TypeDeclaration(n) => SourceRange::from(*n),
146 Node::ReturnStatement(n) => SourceRange::from(*n),
147 Node::VariableDeclarator(n) => SourceRange::from(*n),
148 Node::Literal(n) => SourceRange::from(*n),
149 Node::TagDeclarator(n) => SourceRange::from(*n),
150 Node::Identifier(n) => SourceRange::from(*n),
151 Node::Name(n) => SourceRange::from(*n),
152 Node::BinaryExpression(n) => SourceRange::from(*n),
153 Node::FunctionExpression(n) => SourceRange::from(*n),
154 Node::CallExpressionKw(n) => SourceRange::from(*n),
155 Node::PipeExpression(n) => SourceRange::from(*n),
156 Node::PipeSubstitution(n) => SourceRange::from(*n),
157 Node::ArrayExpression(n) => SourceRange::from(*n),
158 Node::ArrayRangeExpression(n) => SourceRange::from(*n),
159 Node::ObjectExpression(n) => SourceRange::from(*n),
160 Node::MemberExpression(n) => SourceRange::from(*n),
161 Node::UnaryExpression(n) => SourceRange::from(*n),
162 Node::Parameter(p) => SourceRange::from(&p.identifier),
163 Node::ObjectProperty(n) => SourceRange::from(*n),
164 Node::IfExpression(n) => SourceRange::from(*n),
165 Node::LabelledExpression(n) => SourceRange::from(*n),
166 Node::AscribedExpression(n) => SourceRange::from(*n),
167
168 Node::ElseIf(n) => SourceRange::new(n.cond.start(), n.cond.end(), n.cond.module_id()),
170
171 Node::KclNone(_) => return Err(Self::Error::NoSourceForAKclNone),
174 })
175 }
176}
177
178impl<'tree> From<&'tree types::BodyItem> for Node<'tree> {
179 fn from(node: &'tree types::BodyItem) -> Self {
180 match node {
181 types::BodyItem::ImportStatement(v) => v.as_ref().into(),
182 types::BodyItem::ExpressionStatement(v) => v.into(),
183 types::BodyItem::VariableDeclaration(v) => v.as_ref().into(),
184 types::BodyItem::TypeDeclaration(v) => v.as_ref().into(),
185 types::BodyItem::ReturnStatement(v) => v.into(),
186 }
187 }
188}
189
190impl<'tree> From<&'tree types::Expr> for Node<'tree> {
191 fn from(node: &'tree types::Expr) -> Self {
192 match node {
193 types::Expr::Literal(lit) => lit.as_ref().into(),
194 types::Expr::TagDeclarator(tag) => tag.as_ref().into(),
195 types::Expr::Name(id) => id.as_ref().into(),
196 types::Expr::BinaryExpression(be) => be.as_ref().into(),
197 types::Expr::FunctionExpression(fe) => fe.as_ref().into(),
198 types::Expr::CallExpressionKw(ce) => ce.as_ref().into(),
199 types::Expr::PipeExpression(pe) => pe.as_ref().into(),
200 types::Expr::PipeSubstitution(ps) => ps.as_ref().into(),
201 types::Expr::ArrayExpression(ae) => ae.as_ref().into(),
202 types::Expr::ArrayRangeExpression(are) => are.as_ref().into(),
203 types::Expr::ObjectExpression(oe) => oe.as_ref().into(),
204 types::Expr::MemberExpression(me) => me.as_ref().into(),
205 types::Expr::UnaryExpression(ue) => ue.as_ref().into(),
206 types::Expr::IfExpression(e) => e.as_ref().into(),
207 types::Expr::LabelledExpression(e) => e.as_ref().into(),
208 types::Expr::AscribedExpression(e) => e.as_ref().into(),
209 types::Expr::None(n) => n.into(),
210 }
211 }
212}
213
214impl<'tree> From<&'tree types::BinaryPart> for Node<'tree> {
215 fn from(node: &'tree types::BinaryPart) -> Self {
216 match node {
217 types::BinaryPart::Literal(lit) => lit.as_ref().into(),
218 types::BinaryPart::Name(id) => id.as_ref().into(),
219 types::BinaryPart::BinaryExpression(be) => be.as_ref().into(),
220 types::BinaryPart::CallExpressionKw(ce) => ce.as_ref().into(),
221 types::BinaryPart::UnaryExpression(ue) => ue.as_ref().into(),
222 types::BinaryPart::MemberExpression(me) => me.as_ref().into(),
223 types::BinaryPart::IfExpression(e) => e.as_ref().into(),
224 types::BinaryPart::AscribedExpression(e) => e.as_ref().into(),
225 }
226 }
227}
228
229impl<'tree> From<&'tree types::MemberObject> for Node<'tree> {
230 fn from(node: &'tree types::MemberObject) -> Self {
231 match node {
232 types::MemberObject::MemberExpression(me) => me.as_ref().into(),
233 types::MemberObject::Identifier(id) => id.as_ref().into(),
234 }
235 }
236}
237
238impl<'tree> From<&'tree types::LiteralIdentifier> for Node<'tree> {
239 fn from(node: &'tree types::LiteralIdentifier) -> Self {
240 match node {
241 types::LiteralIdentifier::Identifier(id) => id.as_ref().into(),
242 types::LiteralIdentifier::Literal(lit) => lit.as_ref().into(),
243 }
244 }
245}
246
247macro_rules! impl_from {
248 ($node:ident, $t: ident) => {
249 impl<'a> From<NodeRef<'a, types::$t>> for Node<'a> {
250 fn from(v: NodeRef<'a, types::$t>) -> Self {
251 Node::$t(v)
252 }
253 }
254 };
255}
256
257macro_rules! impl_from_ref {
258 ($node:ident, $t: ident) => {
259 impl<'a> From<&'a types::$t> for Node<'a> {
260 fn from(v: &'a types::$t) -> Self {
261 Node::$t(v)
262 }
263 }
264 };
265}
266
267impl_from!(Node, Program);
268impl_from!(Node, ImportStatement);
269impl_from!(Node, ExpressionStatement);
270impl_from!(Node, VariableDeclaration);
271impl_from!(Node, TypeDeclaration);
272impl_from!(Node, ReturnStatement);
273impl_from!(Node, VariableDeclarator);
274impl_from!(Node, Literal);
275impl_from!(Node, TagDeclarator);
276impl_from!(Node, Identifier);
277impl_from!(Node, Name);
278impl_from!(Node, BinaryExpression);
279impl_from!(Node, FunctionExpression);
280impl_from!(Node, CallExpressionKw);
281impl_from!(Node, PipeExpression);
282impl_from!(Node, PipeSubstitution);
283impl_from!(Node, ArrayExpression);
284impl_from!(Node, ArrayRangeExpression);
285impl_from!(Node, ObjectExpression);
286impl_from!(Node, MemberExpression);
287impl_from!(Node, UnaryExpression);
288impl_from!(Node, ObjectProperty);
289impl_from_ref!(Node, Parameter);
290impl_from!(Node, IfExpression);
291impl_from!(Node, ElseIf);
292impl_from!(Node, LabelledExpression);
293impl_from!(Node, AscribedExpression);
294impl_from!(Node, KclNone);
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 macro_rules! kcl {
301 ( $kcl:expr ) => {{
302 $crate::parsing::top_level_parse($kcl).unwrap()
303 }};
304 }
305
306 #[test]
307 fn check_ptr_eq() {
308 let program = kcl!(
309 "
310foo = 1
311bar = foo + 1
312
313fn myfn() {
314 foo = 2
315 sin(foo)
316}
317"
318 );
319
320 let foo: Node = (&program.body[0]).into();
321 assert!(foo.ptr_eq((&program.body[0]).into()));
322 assert!(!foo.ptr_eq((&program.body[1]).into()));
323 }
324}