1use crate::{
2 parsing::ast::types::{self, NodeRef},
3 source_range::SourceRange,
4};
5
6#[derive(Copy, Clone, Debug)]
9pub enum Node<'a> {
10 Program(NodeRef<'a, types::Program>),
11
12 ImportStatement(NodeRef<'a, types::ImportStatement>),
13 ExpressionStatement(NodeRef<'a, types::ExpressionStatement>),
14 VariableDeclaration(NodeRef<'a, types::VariableDeclaration>),
15 ReturnStatement(NodeRef<'a, types::ReturnStatement>),
16
17 VariableDeclarator(NodeRef<'a, types::VariableDeclarator>),
18
19 Literal(NodeRef<'a, types::Literal>),
20 TagDeclarator(NodeRef<'a, types::TagDeclarator>),
21 Identifier(NodeRef<'a, types::Identifier>),
22 BinaryExpression(NodeRef<'a, types::BinaryExpression>),
23 FunctionExpression(NodeRef<'a, types::FunctionExpression>),
24 CallExpression(NodeRef<'a, types::CallExpression>),
25 CallExpressionKw(NodeRef<'a, types::CallExpressionKw>),
26 PipeExpression(NodeRef<'a, types::PipeExpression>),
27 PipeSubstitution(NodeRef<'a, types::PipeSubstitution>),
28 ArrayExpression(NodeRef<'a, types::ArrayExpression>),
29 ArrayRangeExpression(NodeRef<'a, types::ArrayRangeExpression>),
30 ObjectExpression(NodeRef<'a, types::ObjectExpression>),
31 MemberExpression(NodeRef<'a, types::MemberExpression>),
32 UnaryExpression(NodeRef<'a, types::UnaryExpression>),
33 IfExpression(NodeRef<'a, types::IfExpression>),
34 ElseIf(&'a types::ElseIf),
35 LabelledExpression(NodeRef<'a, types::LabelledExpression>),
36 Ascription(NodeRef<'a, types::Ascription>),
37
38 Parameter(&'a types::Parameter),
39
40 ObjectProperty(NodeRef<'a, types::ObjectProperty>),
41
42 KclNone(&'a types::KclNone),
43}
44
45impl Node<'_> {
46 pub fn digest(&self) -> Option<[u8; 32]> {
51 match self {
52 Node::Program(n) => n.digest,
53 Node::ImportStatement(n) => n.digest,
54 Node::ExpressionStatement(n) => n.digest,
55 Node::VariableDeclaration(n) => n.digest,
56 Node::ReturnStatement(n) => n.digest,
57 Node::VariableDeclarator(n) => n.digest,
58 Node::Literal(n) => n.digest,
59 Node::TagDeclarator(n) => n.digest,
60 Node::Identifier(n) => n.digest,
61 Node::BinaryExpression(n) => n.digest,
62 Node::FunctionExpression(n) => n.digest,
63 Node::CallExpression(n) => n.digest,
64 Node::CallExpressionKw(n) => n.digest,
65 Node::PipeExpression(n) => n.digest,
66 Node::PipeSubstitution(n) => n.digest,
67 Node::ArrayExpression(n) => n.digest,
68 Node::ArrayRangeExpression(n) => n.digest,
69 Node::ObjectExpression(n) => n.digest,
70 Node::MemberExpression(n) => n.digest,
71 Node::UnaryExpression(n) => n.digest,
72 Node::Parameter(p) => p.digest,
73 Node::ObjectProperty(n) => n.digest,
74 Node::IfExpression(n) => n.digest,
75 Node::ElseIf(n) => n.digest,
76 Node::KclNone(n) => n.digest,
77 Node::LabelledExpression(n) => n.digest,
78 Node::Ascription(n) => n.digest,
79 }
80 }
81
82 pub fn ptr_eq(&self, other: Node) -> bool {
90 unsafe { std::ptr::eq(self.ptr(), other.ptr()) }
91 }
92
93 unsafe fn ptr(&self) -> *const () {
94 match self {
95 Node::Program(n) => *n as *const _ as *const (),
96 Node::ImportStatement(n) => *n as *const _ as *const (),
97 Node::ExpressionStatement(n) => *n as *const _ as *const (),
98 Node::VariableDeclaration(n) => *n as *const _ as *const (),
99 Node::ReturnStatement(n) => *n as *const _ as *const (),
100 Node::VariableDeclarator(n) => *n as *const _ as *const (),
101 Node::Literal(n) => *n as *const _ as *const (),
102 Node::TagDeclarator(n) => *n as *const _ as *const (),
103 Node::Identifier(n) => *n as *const _ as *const (),
104 Node::BinaryExpression(n) => *n as *const _ as *const (),
105 Node::FunctionExpression(n) => *n as *const _ as *const (),
106 Node::CallExpression(n) => *n as *const _ as *const (),
107 Node::CallExpressionKw(n) => *n as *const _ as *const (),
108 Node::PipeExpression(n) => *n as *const _ as *const (),
109 Node::PipeSubstitution(n) => *n as *const _ as *const (),
110 Node::ArrayExpression(n) => *n as *const _ as *const (),
111 Node::ArrayRangeExpression(n) => *n as *const _ as *const (),
112 Node::ObjectExpression(n) => *n as *const _ as *const (),
113 Node::MemberExpression(n) => *n as *const _ as *const (),
114 Node::UnaryExpression(n) => *n as *const _ as *const (),
115 Node::Parameter(p) => *p as *const _ as *const (),
116 Node::ObjectProperty(n) => *n as *const _ as *const (),
117 Node::IfExpression(n) => *n as *const _ as *const (),
118 Node::ElseIf(n) => *n as *const _ as *const (),
119 Node::KclNone(n) => *n as *const _ as *const (),
120 Node::LabelledExpression(n) => *n as *const _ as *const (),
121 Node::Ascription(n) => *n as *const _ as *const (),
122 }
123 }
124}
125
126#[derive(Debug)]
128pub enum AstNodeError {
129 NoSourceForAKclNone,
131}
132
133impl TryFrom<&Node<'_>> for SourceRange {
134 type Error = AstNodeError;
135
136 fn try_from(node: &Node) -> Result<Self, Self::Error> {
137 Ok(match node {
138 Node::Program(n) => SourceRange::from(*n),
139 Node::ImportStatement(n) => SourceRange::from(*n),
140 Node::ExpressionStatement(n) => SourceRange::from(*n),
141 Node::VariableDeclaration(n) => SourceRange::from(*n),
142 Node::ReturnStatement(n) => SourceRange::from(*n),
143 Node::VariableDeclarator(n) => SourceRange::from(*n),
144 Node::Literal(n) => SourceRange::from(*n),
145 Node::TagDeclarator(n) => SourceRange::from(*n),
146 Node::Identifier(n) => SourceRange::from(*n),
147 Node::BinaryExpression(n) => SourceRange::from(*n),
148 Node::FunctionExpression(n) => SourceRange::from(*n),
149 Node::CallExpression(n) => SourceRange::from(*n),
150 Node::CallExpressionKw(n) => SourceRange::from(*n),
151 Node::PipeExpression(n) => SourceRange::from(*n),
152 Node::PipeSubstitution(n) => SourceRange::from(*n),
153 Node::ArrayExpression(n) => SourceRange::from(*n),
154 Node::ArrayRangeExpression(n) => SourceRange::from(*n),
155 Node::ObjectExpression(n) => SourceRange::from(*n),
156 Node::MemberExpression(n) => SourceRange::from(*n),
157 Node::UnaryExpression(n) => SourceRange::from(*n),
158 Node::Parameter(p) => SourceRange::from(&p.identifier),
159 Node::ObjectProperty(n) => SourceRange::from(*n),
160 Node::IfExpression(n) => SourceRange::from(*n),
161 Node::LabelledExpression(n) => SourceRange::from(*n),
162 Node::Ascription(n) => SourceRange::from(*n),
163
164 Node::ElseIf(n) => SourceRange::new(n.cond.start(), n.cond.end(), n.cond.module_id()),
166
167 Node::KclNone(_) => return Err(Self::Error::NoSourceForAKclNone),
170 })
171 }
172}
173
174impl<'tree> From<&'tree types::BodyItem> for Node<'tree> {
175 fn from(node: &'tree types::BodyItem) -> Self {
176 match node {
177 types::BodyItem::ImportStatement(v) => v.as_ref().into(),
178 types::BodyItem::ExpressionStatement(v) => v.into(),
179 types::BodyItem::VariableDeclaration(v) => v.as_ref().into(),
180 types::BodyItem::ReturnStatement(v) => v.into(),
181 }
182 }
183}
184
185impl<'tree> From<&'tree types::Expr> for Node<'tree> {
186 fn from(node: &'tree types::Expr) -> Self {
187 match node {
188 types::Expr::Literal(lit) => lit.as_ref().into(),
189 types::Expr::TagDeclarator(tag) => tag.as_ref().into(),
190 types::Expr::Identifier(id) => id.as_ref().into(),
191 types::Expr::BinaryExpression(be) => be.as_ref().into(),
192 types::Expr::FunctionExpression(fe) => fe.as_ref().into(),
193 types::Expr::CallExpression(ce) => ce.as_ref().into(),
194 types::Expr::CallExpressionKw(ce) => ce.as_ref().into(),
195 types::Expr::PipeExpression(pe) => pe.as_ref().into(),
196 types::Expr::PipeSubstitution(ps) => ps.as_ref().into(),
197 types::Expr::ArrayExpression(ae) => ae.as_ref().into(),
198 types::Expr::ArrayRangeExpression(are) => are.as_ref().into(),
199 types::Expr::ObjectExpression(oe) => oe.as_ref().into(),
200 types::Expr::MemberExpression(me) => me.as_ref().into(),
201 types::Expr::UnaryExpression(ue) => ue.as_ref().into(),
202 types::Expr::IfExpression(e) => e.as_ref().into(),
203 types::Expr::LabelledExpression(e) => e.as_ref().into(),
204 types::Expr::AscribedExpression(e) => e.as_ref().into(),
205 types::Expr::None(n) => n.into(),
206 }
207 }
208}
209
210impl<'tree> From<&'tree types::BinaryPart> for Node<'tree> {
211 fn from(node: &'tree types::BinaryPart) -> Self {
212 match node {
213 types::BinaryPart::Literal(lit) => lit.as_ref().into(),
214 types::BinaryPart::Identifier(id) => id.as_ref().into(),
215 types::BinaryPart::BinaryExpression(be) => be.as_ref().into(),
216 types::BinaryPart::CallExpression(ce) => ce.as_ref().into(),
217 types::BinaryPart::CallExpressionKw(ce) => ce.as_ref().into(),
218 types::BinaryPart::UnaryExpression(ue) => ue.as_ref().into(),
219 types::BinaryPart::MemberExpression(me) => me.as_ref().into(),
220 types::BinaryPart::IfExpression(e) => e.as_ref().into(),
221 }
222 }
223}
224
225impl<'tree> From<&'tree types::MemberObject> for Node<'tree> {
226 fn from(node: &'tree types::MemberObject) -> Self {
227 match node {
228 types::MemberObject::MemberExpression(me) => me.as_ref().into(),
229 types::MemberObject::Identifier(id) => id.as_ref().into(),
230 }
231 }
232}
233
234impl<'tree> From<&'tree types::LiteralIdentifier> for Node<'tree> {
235 fn from(node: &'tree types::LiteralIdentifier) -> Self {
236 match node {
237 types::LiteralIdentifier::Identifier(id) => id.as_ref().into(),
238 types::LiteralIdentifier::Literal(lit) => lit.as_ref().into(),
239 }
240 }
241}
242
243macro_rules! impl_from {
244 ($node:ident, $t: ident) => {
245 impl<'a> From<NodeRef<'a, types::$t>> for Node<'a> {
246 fn from(v: NodeRef<'a, types::$t>) -> Self {
247 Node::$t(v)
248 }
249 }
250 };
251}
252
253macro_rules! impl_from_ref {
254 ($node:ident, $t: ident) => {
255 impl<'a> From<&'a types::$t> for Node<'a> {
256 fn from(v: &'a types::$t) -> Self {
257 Node::$t(v)
258 }
259 }
260 };
261}
262
263impl_from!(Node, Program);
264impl_from!(Node, ImportStatement);
265impl_from!(Node, ExpressionStatement);
266impl_from!(Node, VariableDeclaration);
267impl_from!(Node, ReturnStatement);
268impl_from!(Node, VariableDeclarator);
269impl_from!(Node, Literal);
270impl_from!(Node, TagDeclarator);
271impl_from!(Node, Identifier);
272impl_from!(Node, BinaryExpression);
273impl_from!(Node, FunctionExpression);
274impl_from!(Node, CallExpression);
275impl_from!(Node, CallExpressionKw);
276impl_from!(Node, PipeExpression);
277impl_from!(Node, PipeSubstitution);
278impl_from!(Node, ArrayExpression);
279impl_from!(Node, ArrayRangeExpression);
280impl_from!(Node, ObjectExpression);
281impl_from!(Node, MemberExpression);
282impl_from!(Node, UnaryExpression);
283impl_from!(Node, ObjectProperty);
284impl_from_ref!(Node, Parameter);
285impl_from!(Node, IfExpression);
286impl_from!(Node, ElseIf);
287impl_from!(Node, LabelledExpression);
288impl_from!(Node, Ascription);
289impl_from!(Node, KclNone);
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 macro_rules! kcl {
296 ( $kcl:expr ) => {{
297 $crate::parsing::top_level_parse($kcl).unwrap()
298 }};
299 }
300
301 #[test]
302 fn check_ptr_eq() {
303 let program = kcl!(
304 "
305const foo = 1
306const bar = foo + 1
307
308fn myfn = () => {
309 const foo = 2
310 sin(foo)
311}
312"
313 );
314
315 let foo: Node = (&program.body[0]).into();
316 assert!(foo.ptr_eq((&program.body[0]).into()));
317 assert!(!foo.ptr_eq((&program.body[1]).into()));
318 }
319}