kcl_lib/walk/
ast_node.rs

1use crate::{
2    parsing::ast::types::{self, NodeRef},
3    source_range::SourceRange,
4};
5
6/// The "Node" type wraps all the AST elements we're able to find in a KCL
7/// file. Tokens we walk through will be one of these.
8#[derive(Copy, Clone, Debug)]
9pub enum Node<'a> {
10    Program(NodeRef<'a, types::Program>),
11
12    ImportStatement(NodeRef<'a, types::ImportStatement>),
13    ExpressionStatement(NodeRef<'a, types::ExpressionStatement>),
14    VariableDeclaration(NodeRef<'a, types::VariableDeclaration>),
15    TypeDeclaration(NodeRef<'a, types::TypeDeclaration>),
16    ReturnStatement(NodeRef<'a, types::ReturnStatement>),
17
18    VariableDeclarator(NodeRef<'a, types::VariableDeclarator>),
19
20    Literal(NodeRef<'a, types::Literal>),
21    TagDeclarator(NodeRef<'a, types::TagDeclarator>),
22    Identifier(NodeRef<'a, types::Identifier>),
23    Name(NodeRef<'a, types::Name>),
24    BinaryExpression(NodeRef<'a, types::BinaryExpression>),
25    FunctionExpression(NodeRef<'a, types::FunctionExpression>),
26    CallExpressionKw(NodeRef<'a, types::CallExpressionKw>),
27    PipeExpression(NodeRef<'a, types::PipeExpression>),
28    PipeSubstitution(NodeRef<'a, types::PipeSubstitution>),
29    ArrayExpression(NodeRef<'a, types::ArrayExpression>),
30    ArrayRangeExpression(NodeRef<'a, types::ArrayRangeExpression>),
31    ObjectExpression(NodeRef<'a, types::ObjectExpression>),
32    MemberExpression(NodeRef<'a, types::MemberExpression>),
33    UnaryExpression(NodeRef<'a, types::UnaryExpression>),
34    IfExpression(NodeRef<'a, types::IfExpression>),
35    ElseIf(&'a types::ElseIf),
36    LabelledExpression(NodeRef<'a, types::LabelledExpression>),
37    AscribedExpression(NodeRef<'a, types::AscribedExpression>),
38
39    Parameter(&'a types::Parameter),
40
41    ObjectProperty(NodeRef<'a, types::ObjectProperty>),
42
43    KclNone(&'a types::KclNone),
44}
45
46impl Node<'_> {
47    /// Return the digest of the [Node], pulling the underlying Digest from
48    /// the AST types.
49    ///
50    /// The Digest type may change over time.
51    pub fn digest(&self) -> Option<[u8; 32]> {
52        match self {
53            Node::Program(n) => n.digest,
54            Node::ImportStatement(n) => n.digest,
55            Node::ExpressionStatement(n) => n.digest,
56            Node::VariableDeclaration(n) => n.digest,
57            Node::TypeDeclaration(n) => n.digest,
58            Node::ReturnStatement(n) => n.digest,
59            Node::VariableDeclarator(n) => n.digest,
60            Node::Literal(n) => n.digest,
61            Node::TagDeclarator(n) => n.digest,
62            Node::Identifier(n) => n.digest,
63            Node::Name(n) => n.digest,
64            Node::BinaryExpression(n) => n.digest,
65            Node::FunctionExpression(n) => n.digest,
66            Node::CallExpressionKw(n) => n.digest,
67            Node::PipeExpression(n) => n.digest,
68            Node::PipeSubstitution(n) => n.digest,
69            Node::ArrayExpression(n) => n.digest,
70            Node::ArrayRangeExpression(n) => n.digest,
71            Node::ObjectExpression(n) => n.digest,
72            Node::MemberExpression(n) => n.digest,
73            Node::UnaryExpression(n) => n.digest,
74            Node::Parameter(p) => p.digest,
75            Node::ObjectProperty(n) => n.digest,
76            Node::IfExpression(n) => n.digest,
77            Node::ElseIf(n) => n.digest,
78            Node::KclNone(n) => n.digest,
79            Node::LabelledExpression(n) => n.digest,
80            Node::AscribedExpression(n) => n.digest,
81        }
82    }
83
84    /// Check to see if this [Node] points to the same underlying specific
85    /// borrowed object as another [Node]. This is not the same as `Eq` or
86    /// even `PartialEq` -- anything that is `true` here is absolutely `Eq`,
87    /// but it's possible this node is `Eq` to another with this being `false`.
88    ///
89    /// This merely indicates that this [Node] specifically is the exact same
90    /// borrowed object as [Node].
91    pub fn ptr_eq(&self, other: Node) -> bool {
92        unsafe { std::ptr::eq(self.ptr(), other.ptr()) }
93    }
94
95    unsafe fn ptr(&self) -> *const () {
96        match self {
97            Node::Program(n) => *n as *const _ as *const (),
98            Node::ImportStatement(n) => *n as *const _ as *const (),
99            Node::ExpressionStatement(n) => *n as *const _ as *const (),
100            Node::VariableDeclaration(n) => *n as *const _ as *const (),
101            Node::TypeDeclaration(n) => *n as *const _ as *const (),
102            Node::ReturnStatement(n) => *n as *const _ as *const (),
103            Node::VariableDeclarator(n) => *n as *const _ as *const (),
104            Node::Literal(n) => *n as *const _ as *const (),
105            Node::TagDeclarator(n) => *n as *const _ as *const (),
106            Node::Identifier(n) => *n as *const _ as *const (),
107            Node::Name(n) => *n as *const _ as *const (),
108            Node::BinaryExpression(n) => *n as *const _ as *const (),
109            Node::FunctionExpression(n) => *n as *const _ as *const (),
110            Node::CallExpressionKw(n) => *n as *const _ as *const (),
111            Node::PipeExpression(n) => *n as *const _ as *const (),
112            Node::PipeSubstitution(n) => *n as *const _ as *const (),
113            Node::ArrayExpression(n) => *n as *const _ as *const (),
114            Node::ArrayRangeExpression(n) => *n as *const _ as *const (),
115            Node::ObjectExpression(n) => *n as *const _ as *const (),
116            Node::MemberExpression(n) => *n as *const _ as *const (),
117            Node::UnaryExpression(n) => *n as *const _ as *const (),
118            Node::Parameter(p) => *p as *const _ as *const (),
119            Node::ObjectProperty(n) => *n as *const _ as *const (),
120            Node::IfExpression(n) => *n as *const _ as *const (),
121            Node::ElseIf(n) => *n as *const _ as *const (),
122            Node::KclNone(n) => *n as *const _ as *const (),
123            Node::LabelledExpression(n) => *n as *const _ as *const (),
124            Node::AscribedExpression(n) => *n as *const _ as *const (),
125        }
126    }
127}
128
129/// Returned during source_range conversion.
130#[derive(Debug)]
131pub enum AstNodeError {
132    /// Returned if we try and [SourceRange] a [types::KclNone].
133    NoSourceForAKclNone,
134}
135
136impl TryFrom<&Node<'_>> for SourceRange {
137    type Error = AstNodeError;
138
139    fn try_from(node: &Node) -> Result<Self, Self::Error> {
140        Ok(match node {
141            Node::Program(n) => SourceRange::from(*n),
142            Node::ImportStatement(n) => SourceRange::from(*n),
143            Node::ExpressionStatement(n) => SourceRange::from(*n),
144            Node::VariableDeclaration(n) => SourceRange::from(*n),
145            Node::TypeDeclaration(n) => SourceRange::from(*n),
146            Node::ReturnStatement(n) => SourceRange::from(*n),
147            Node::VariableDeclarator(n) => SourceRange::from(*n),
148            Node::Literal(n) => SourceRange::from(*n),
149            Node::TagDeclarator(n) => SourceRange::from(*n),
150            Node::Identifier(n) => SourceRange::from(*n),
151            Node::Name(n) => SourceRange::from(*n),
152            Node::BinaryExpression(n) => SourceRange::from(*n),
153            Node::FunctionExpression(n) => SourceRange::from(*n),
154            Node::CallExpressionKw(n) => SourceRange::from(*n),
155            Node::PipeExpression(n) => SourceRange::from(*n),
156            Node::PipeSubstitution(n) => SourceRange::from(*n),
157            Node::ArrayExpression(n) => SourceRange::from(*n),
158            Node::ArrayRangeExpression(n) => SourceRange::from(*n),
159            Node::ObjectExpression(n) => SourceRange::from(*n),
160            Node::MemberExpression(n) => SourceRange::from(*n),
161            Node::UnaryExpression(n) => SourceRange::from(*n),
162            Node::Parameter(p) => SourceRange::from(&p.identifier),
163            Node::ObjectProperty(n) => SourceRange::from(*n),
164            Node::IfExpression(n) => SourceRange::from(*n),
165            Node::LabelledExpression(n) => SourceRange::from(*n),
166            Node::AscribedExpression(n) => SourceRange::from(*n),
167
168            // This is broken too
169            Node::ElseIf(n) => SourceRange::new(n.cond.start(), n.cond.end(), n.cond.module_id()),
170
171            // The KclNone type here isn't an actual node, so it has no
172            // start/end information.
173            Node::KclNone(_) => return Err(Self::Error::NoSourceForAKclNone),
174        })
175    }
176}
177
178impl<'tree> From<&'tree types::BodyItem> for Node<'tree> {
179    fn from(node: &'tree types::BodyItem) -> Self {
180        match node {
181            types::BodyItem::ImportStatement(v) => v.as_ref().into(),
182            types::BodyItem::ExpressionStatement(v) => v.into(),
183            types::BodyItem::VariableDeclaration(v) => v.as_ref().into(),
184            types::BodyItem::TypeDeclaration(v) => v.as_ref().into(),
185            types::BodyItem::ReturnStatement(v) => v.into(),
186        }
187    }
188}
189
190impl<'tree> From<&'tree types::Expr> for Node<'tree> {
191    fn from(node: &'tree types::Expr) -> Self {
192        match node {
193            types::Expr::Literal(lit) => lit.as_ref().into(),
194            types::Expr::TagDeclarator(tag) => tag.as_ref().into(),
195            types::Expr::Name(id) => id.as_ref().into(),
196            types::Expr::BinaryExpression(be) => be.as_ref().into(),
197            types::Expr::FunctionExpression(fe) => fe.as_ref().into(),
198            types::Expr::CallExpressionKw(ce) => ce.as_ref().into(),
199            types::Expr::PipeExpression(pe) => pe.as_ref().into(),
200            types::Expr::PipeSubstitution(ps) => ps.as_ref().into(),
201            types::Expr::ArrayExpression(ae) => ae.as_ref().into(),
202            types::Expr::ArrayRangeExpression(are) => are.as_ref().into(),
203            types::Expr::ObjectExpression(oe) => oe.as_ref().into(),
204            types::Expr::MemberExpression(me) => me.as_ref().into(),
205            types::Expr::UnaryExpression(ue) => ue.as_ref().into(),
206            types::Expr::IfExpression(e) => e.as_ref().into(),
207            types::Expr::LabelledExpression(e) => e.as_ref().into(),
208            types::Expr::AscribedExpression(e) => e.as_ref().into(),
209            types::Expr::None(n) => n.into(),
210        }
211    }
212}
213
214impl<'tree> From<&'tree types::BinaryPart> for Node<'tree> {
215    fn from(node: &'tree types::BinaryPart) -> Self {
216        match node {
217            types::BinaryPart::Literal(lit) => lit.as_ref().into(),
218            types::BinaryPart::Name(id) => id.as_ref().into(),
219            types::BinaryPart::BinaryExpression(be) => be.as_ref().into(),
220            types::BinaryPart::CallExpressionKw(ce) => ce.as_ref().into(),
221            types::BinaryPart::UnaryExpression(ue) => ue.as_ref().into(),
222            types::BinaryPart::MemberExpression(me) => me.as_ref().into(),
223            types::BinaryPart::IfExpression(e) => e.as_ref().into(),
224            types::BinaryPart::AscribedExpression(e) => e.as_ref().into(),
225        }
226    }
227}
228
229impl<'tree> From<&'tree types::MemberObject> for Node<'tree> {
230    fn from(node: &'tree types::MemberObject) -> Self {
231        match node {
232            types::MemberObject::MemberExpression(me) => me.as_ref().into(),
233            types::MemberObject::Identifier(id) => id.as_ref().into(),
234        }
235    }
236}
237
238impl<'tree> From<&'tree types::LiteralIdentifier> for Node<'tree> {
239    fn from(node: &'tree types::LiteralIdentifier) -> Self {
240        match node {
241            types::LiteralIdentifier::Identifier(id) => id.as_ref().into(),
242            types::LiteralIdentifier::Literal(lit) => lit.as_ref().into(),
243        }
244    }
245}
246
247macro_rules! impl_from {
248    ($node:ident, $t: ident) => {
249        impl<'a> From<NodeRef<'a, types::$t>> for Node<'a> {
250            fn from(v: NodeRef<'a, types::$t>) -> Self {
251                Node::$t(v)
252            }
253        }
254    };
255}
256
257macro_rules! impl_from_ref {
258    ($node:ident, $t: ident) => {
259        impl<'a> From<&'a types::$t> for Node<'a> {
260            fn from(v: &'a types::$t) -> Self {
261                Node::$t(v)
262            }
263        }
264    };
265}
266
267impl_from!(Node, Program);
268impl_from!(Node, ImportStatement);
269impl_from!(Node, ExpressionStatement);
270impl_from!(Node, VariableDeclaration);
271impl_from!(Node, TypeDeclaration);
272impl_from!(Node, ReturnStatement);
273impl_from!(Node, VariableDeclarator);
274impl_from!(Node, Literal);
275impl_from!(Node, TagDeclarator);
276impl_from!(Node, Identifier);
277impl_from!(Node, Name);
278impl_from!(Node, BinaryExpression);
279impl_from!(Node, FunctionExpression);
280impl_from!(Node, CallExpressionKw);
281impl_from!(Node, PipeExpression);
282impl_from!(Node, PipeSubstitution);
283impl_from!(Node, ArrayExpression);
284impl_from!(Node, ArrayRangeExpression);
285impl_from!(Node, ObjectExpression);
286impl_from!(Node, MemberExpression);
287impl_from!(Node, UnaryExpression);
288impl_from!(Node, ObjectProperty);
289impl_from_ref!(Node, Parameter);
290impl_from!(Node, IfExpression);
291impl_from!(Node, ElseIf);
292impl_from!(Node, LabelledExpression);
293impl_from!(Node, AscribedExpression);
294impl_from!(Node, KclNone);
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    macro_rules! kcl {
301        ( $kcl:expr ) => {{
302            $crate::parsing::top_level_parse($kcl).unwrap()
303        }};
304    }
305
306    #[test]
307    fn check_ptr_eq() {
308        let program = kcl!(
309            "
310foo = 1
311bar = foo + 1
312
313fn myfn() {
314    foo = 2
315    sin(foo)
316}
317"
318        );
319
320        let foo: Node = (&program.body[0]).into();
321        assert!(foo.ptr_eq((&program.body[0]).into()));
322        assert!(!foo.ptr_eq((&program.body[1]).into()));
323    }
324}