kcl_lib/walk/
ast_node.rs

1use crate::{
2    parsing::ast::types::{self, NodeRef},
3    source_range::SourceRange,
4};
5
6/// The "Node" type wraps all the AST elements we're able to find in a KCL
7/// file. Tokens we walk through will be one of these.
8#[derive(Copy, Clone, Debug)]
9pub enum Node<'a> {
10    Program(NodeRef<'a, types::Program>),
11
12    ImportStatement(NodeRef<'a, types::ImportStatement>),
13    ExpressionStatement(NodeRef<'a, types::ExpressionStatement>),
14    VariableDeclaration(NodeRef<'a, types::VariableDeclaration>),
15    ReturnStatement(NodeRef<'a, types::ReturnStatement>),
16
17    VariableDeclarator(NodeRef<'a, types::VariableDeclarator>),
18
19    Literal(NodeRef<'a, types::Literal>),
20    TagDeclarator(NodeRef<'a, types::TagDeclarator>),
21    Identifier(NodeRef<'a, types::Identifier>),
22    BinaryExpression(NodeRef<'a, types::BinaryExpression>),
23    FunctionExpression(NodeRef<'a, types::FunctionExpression>),
24    CallExpression(NodeRef<'a, types::CallExpression>),
25    CallExpressionKw(NodeRef<'a, types::CallExpressionKw>),
26    PipeExpression(NodeRef<'a, types::PipeExpression>),
27    PipeSubstitution(NodeRef<'a, types::PipeSubstitution>),
28    ArrayExpression(NodeRef<'a, types::ArrayExpression>),
29    ArrayRangeExpression(NodeRef<'a, types::ArrayRangeExpression>),
30    ObjectExpression(NodeRef<'a, types::ObjectExpression>),
31    MemberExpression(NodeRef<'a, types::MemberExpression>),
32    UnaryExpression(NodeRef<'a, types::UnaryExpression>),
33    IfExpression(NodeRef<'a, types::IfExpression>),
34    ElseIf(&'a types::ElseIf),
35    LabelledExpression(NodeRef<'a, types::LabelledExpression>),
36    Ascription(NodeRef<'a, types::Ascription>),
37
38    Parameter(&'a types::Parameter),
39
40    ObjectProperty(NodeRef<'a, types::ObjectProperty>),
41
42    KclNone(&'a types::KclNone),
43}
44
45impl Node<'_> {
46    /// Return the digest of the [Node], pulling the underlying Digest from
47    /// the AST types.
48    ///
49    /// The Digest type may change over time.
50    pub fn digest(&self) -> Option<[u8; 32]> {
51        match self {
52            Node::Program(n) => n.digest,
53            Node::ImportStatement(n) => n.digest,
54            Node::ExpressionStatement(n) => n.digest,
55            Node::VariableDeclaration(n) => n.digest,
56            Node::ReturnStatement(n) => n.digest,
57            Node::VariableDeclarator(n) => n.digest,
58            Node::Literal(n) => n.digest,
59            Node::TagDeclarator(n) => n.digest,
60            Node::Identifier(n) => n.digest,
61            Node::BinaryExpression(n) => n.digest,
62            Node::FunctionExpression(n) => n.digest,
63            Node::CallExpression(n) => n.digest,
64            Node::CallExpressionKw(n) => n.digest,
65            Node::PipeExpression(n) => n.digest,
66            Node::PipeSubstitution(n) => n.digest,
67            Node::ArrayExpression(n) => n.digest,
68            Node::ArrayRangeExpression(n) => n.digest,
69            Node::ObjectExpression(n) => n.digest,
70            Node::MemberExpression(n) => n.digest,
71            Node::UnaryExpression(n) => n.digest,
72            Node::Parameter(p) => p.digest,
73            Node::ObjectProperty(n) => n.digest,
74            Node::IfExpression(n) => n.digest,
75            Node::ElseIf(n) => n.digest,
76            Node::KclNone(n) => n.digest,
77            Node::LabelledExpression(n) => n.digest,
78            Node::Ascription(n) => n.digest,
79        }
80    }
81
82    /// Check to see if this [Node] points to the same underlying specific
83    /// borrowed object as another [Node]. This is not the same as `Eq` or
84    /// even `PartialEq` -- anything that is `true` here is absolutely `Eq`,
85    /// but it's possible this node is `Eq` to another with this being `false`.
86    ///
87    /// This merely indicates that this [Node] specifically is the exact same
88    /// borrowed object as [Node].
89    pub fn ptr_eq(&self, other: Node) -> bool {
90        unsafe { std::ptr::eq(self.ptr(), other.ptr()) }
91    }
92
93    unsafe fn ptr(&self) -> *const () {
94        match self {
95            Node::Program(n) => *n as *const _ as *const (),
96            Node::ImportStatement(n) => *n as *const _ as *const (),
97            Node::ExpressionStatement(n) => *n as *const _ as *const (),
98            Node::VariableDeclaration(n) => *n as *const _ as *const (),
99            Node::ReturnStatement(n) => *n as *const _ as *const (),
100            Node::VariableDeclarator(n) => *n as *const _ as *const (),
101            Node::Literal(n) => *n as *const _ as *const (),
102            Node::TagDeclarator(n) => *n as *const _ as *const (),
103            Node::Identifier(n) => *n as *const _ as *const (),
104            Node::BinaryExpression(n) => *n as *const _ as *const (),
105            Node::FunctionExpression(n) => *n as *const _ as *const (),
106            Node::CallExpression(n) => *n as *const _ as *const (),
107            Node::CallExpressionKw(n) => *n as *const _ as *const (),
108            Node::PipeExpression(n) => *n as *const _ as *const (),
109            Node::PipeSubstitution(n) => *n as *const _ as *const (),
110            Node::ArrayExpression(n) => *n as *const _ as *const (),
111            Node::ArrayRangeExpression(n) => *n as *const _ as *const (),
112            Node::ObjectExpression(n) => *n as *const _ as *const (),
113            Node::MemberExpression(n) => *n as *const _ as *const (),
114            Node::UnaryExpression(n) => *n as *const _ as *const (),
115            Node::Parameter(p) => *p as *const _ as *const (),
116            Node::ObjectProperty(n) => *n as *const _ as *const (),
117            Node::IfExpression(n) => *n as *const _ as *const (),
118            Node::ElseIf(n) => *n as *const _ as *const (),
119            Node::KclNone(n) => *n as *const _ as *const (),
120            Node::LabelledExpression(n) => *n as *const _ as *const (),
121            Node::Ascription(n) => *n as *const _ as *const (),
122        }
123    }
124}
125
126/// Returned during source_range conversion.
127#[derive(Debug)]
128pub enum AstNodeError {
129    /// Returned if we try and [SourceRange] a [types::KclNone].
130    NoSourceForAKclNone,
131}
132
133impl TryFrom<&Node<'_>> for SourceRange {
134    type Error = AstNodeError;
135
136    fn try_from(node: &Node) -> Result<Self, Self::Error> {
137        Ok(match node {
138            Node::Program(n) => SourceRange::from(*n),
139            Node::ImportStatement(n) => SourceRange::from(*n),
140            Node::ExpressionStatement(n) => SourceRange::from(*n),
141            Node::VariableDeclaration(n) => SourceRange::from(*n),
142            Node::ReturnStatement(n) => SourceRange::from(*n),
143            Node::VariableDeclarator(n) => SourceRange::from(*n),
144            Node::Literal(n) => SourceRange::from(*n),
145            Node::TagDeclarator(n) => SourceRange::from(*n),
146            Node::Identifier(n) => SourceRange::from(*n),
147            Node::BinaryExpression(n) => SourceRange::from(*n),
148            Node::FunctionExpression(n) => SourceRange::from(*n),
149            Node::CallExpression(n) => SourceRange::from(*n),
150            Node::CallExpressionKw(n) => SourceRange::from(*n),
151            Node::PipeExpression(n) => SourceRange::from(*n),
152            Node::PipeSubstitution(n) => SourceRange::from(*n),
153            Node::ArrayExpression(n) => SourceRange::from(*n),
154            Node::ArrayRangeExpression(n) => SourceRange::from(*n),
155            Node::ObjectExpression(n) => SourceRange::from(*n),
156            Node::MemberExpression(n) => SourceRange::from(*n),
157            Node::UnaryExpression(n) => SourceRange::from(*n),
158            Node::Parameter(p) => SourceRange::from(&p.identifier),
159            Node::ObjectProperty(n) => SourceRange::from(*n),
160            Node::IfExpression(n) => SourceRange::from(*n),
161            Node::LabelledExpression(n) => SourceRange::from(*n),
162            Node::Ascription(n) => SourceRange::from(*n),
163
164            // This is broken too
165            Node::ElseIf(n) => SourceRange::new(n.cond.start(), n.cond.end(), n.cond.module_id()),
166
167            // The KclNone type here isn't an actual node, so it has no
168            // start/end information.
169            Node::KclNone(_) => return Err(Self::Error::NoSourceForAKclNone),
170        })
171    }
172}
173
174impl<'tree> From<&'tree types::BodyItem> for Node<'tree> {
175    fn from(node: &'tree types::BodyItem) -> Self {
176        match node {
177            types::BodyItem::ImportStatement(v) => v.as_ref().into(),
178            types::BodyItem::ExpressionStatement(v) => v.into(),
179            types::BodyItem::VariableDeclaration(v) => v.as_ref().into(),
180            types::BodyItem::ReturnStatement(v) => v.into(),
181        }
182    }
183}
184
185impl<'tree> From<&'tree types::Expr> for Node<'tree> {
186    fn from(node: &'tree types::Expr) -> Self {
187        match node {
188            types::Expr::Literal(lit) => lit.as_ref().into(),
189            types::Expr::TagDeclarator(tag) => tag.as_ref().into(),
190            types::Expr::Identifier(id) => id.as_ref().into(),
191            types::Expr::BinaryExpression(be) => be.as_ref().into(),
192            types::Expr::FunctionExpression(fe) => fe.as_ref().into(),
193            types::Expr::CallExpression(ce) => ce.as_ref().into(),
194            types::Expr::CallExpressionKw(ce) => ce.as_ref().into(),
195            types::Expr::PipeExpression(pe) => pe.as_ref().into(),
196            types::Expr::PipeSubstitution(ps) => ps.as_ref().into(),
197            types::Expr::ArrayExpression(ae) => ae.as_ref().into(),
198            types::Expr::ArrayRangeExpression(are) => are.as_ref().into(),
199            types::Expr::ObjectExpression(oe) => oe.as_ref().into(),
200            types::Expr::MemberExpression(me) => me.as_ref().into(),
201            types::Expr::UnaryExpression(ue) => ue.as_ref().into(),
202            types::Expr::IfExpression(e) => e.as_ref().into(),
203            types::Expr::LabelledExpression(e) => e.as_ref().into(),
204            types::Expr::AscribedExpression(e) => e.as_ref().into(),
205            types::Expr::None(n) => n.into(),
206        }
207    }
208}
209
210impl<'tree> From<&'tree types::BinaryPart> for Node<'tree> {
211    fn from(node: &'tree types::BinaryPart) -> Self {
212        match node {
213            types::BinaryPart::Literal(lit) => lit.as_ref().into(),
214            types::BinaryPart::Identifier(id) => id.as_ref().into(),
215            types::BinaryPart::BinaryExpression(be) => be.as_ref().into(),
216            types::BinaryPart::CallExpression(ce) => ce.as_ref().into(),
217            types::BinaryPart::CallExpressionKw(ce) => ce.as_ref().into(),
218            types::BinaryPart::UnaryExpression(ue) => ue.as_ref().into(),
219            types::BinaryPart::MemberExpression(me) => me.as_ref().into(),
220            types::BinaryPart::IfExpression(e) => e.as_ref().into(),
221        }
222    }
223}
224
225impl<'tree> From<&'tree types::MemberObject> for Node<'tree> {
226    fn from(node: &'tree types::MemberObject) -> Self {
227        match node {
228            types::MemberObject::MemberExpression(me) => me.as_ref().into(),
229            types::MemberObject::Identifier(id) => id.as_ref().into(),
230        }
231    }
232}
233
234impl<'tree> From<&'tree types::LiteralIdentifier> for Node<'tree> {
235    fn from(node: &'tree types::LiteralIdentifier) -> Self {
236        match node {
237            types::LiteralIdentifier::Identifier(id) => id.as_ref().into(),
238            types::LiteralIdentifier::Literal(lit) => lit.as_ref().into(),
239        }
240    }
241}
242
243macro_rules! impl_from {
244    ($node:ident, $t: ident) => {
245        impl<'a> From<NodeRef<'a, types::$t>> for Node<'a> {
246            fn from(v: NodeRef<'a, types::$t>) -> Self {
247                Node::$t(v)
248            }
249        }
250    };
251}
252
253macro_rules! impl_from_ref {
254    ($node:ident, $t: ident) => {
255        impl<'a> From<&'a types::$t> for Node<'a> {
256            fn from(v: &'a types::$t) -> Self {
257                Node::$t(v)
258            }
259        }
260    };
261}
262
263impl_from!(Node, Program);
264impl_from!(Node, ImportStatement);
265impl_from!(Node, ExpressionStatement);
266impl_from!(Node, VariableDeclaration);
267impl_from!(Node, ReturnStatement);
268impl_from!(Node, VariableDeclarator);
269impl_from!(Node, Literal);
270impl_from!(Node, TagDeclarator);
271impl_from!(Node, Identifier);
272impl_from!(Node, BinaryExpression);
273impl_from!(Node, FunctionExpression);
274impl_from!(Node, CallExpression);
275impl_from!(Node, CallExpressionKw);
276impl_from!(Node, PipeExpression);
277impl_from!(Node, PipeSubstitution);
278impl_from!(Node, ArrayExpression);
279impl_from!(Node, ArrayRangeExpression);
280impl_from!(Node, ObjectExpression);
281impl_from!(Node, MemberExpression);
282impl_from!(Node, UnaryExpression);
283impl_from!(Node, ObjectProperty);
284impl_from_ref!(Node, Parameter);
285impl_from!(Node, IfExpression);
286impl_from!(Node, ElseIf);
287impl_from!(Node, LabelledExpression);
288impl_from!(Node, Ascription);
289impl_from!(Node, KclNone);
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    macro_rules! kcl {
296        ( $kcl:expr ) => {{
297            $crate::parsing::top_level_parse($kcl).unwrap()
298        }};
299    }
300
301    #[test]
302    fn check_ptr_eq() {
303        let program = kcl!(
304            "
305const foo = 1
306const bar = foo + 1
307
308fn myfn = () => {
309    const foo = 2
310    sin(foo)
311}
312"
313        );
314
315        let foo: Node = (&program.body[0]).into();
316        assert!(foo.ptr_eq((&program.body[0]).into()));
317        assert!(!foo.ptr_eq((&program.body[1]).into()));
318    }
319}