Skip to main content

gdscript_syntax/
ast.rs

1//! WS4 — the typed AST.
2//!
3//! A thin, zero-cost typed *view* over the lossless CST (the rust-analyzer model). An
4//! AST node just wraps a red [`GdNode`] of the matching [`SyntaxKind`]; accessors are
5//! filtered child lookups. No data is copied. Because [`GdNode`] is a *resolved* node
6//! (it carries the interner), text accessors are clean — `Name::text()` needs no extra
7//! resolver argument.
8
9use cstree::util::NodeOrToken;
10
11use crate::SyntaxKind;
12use crate::syntax_kind::{GdNode, GdToken};
13
14/// A typed node: a checked view over a red node of one [`SyntaxKind`].
15pub trait AstNode: Sized {
16    /// Whether a node of `kind` can be viewed as `Self`.
17    fn can_cast(kind: SyntaxKind) -> bool;
18    /// View `node` as `Self`, if its kind matches.
19    fn cast(node: GdNode) -> Option<Self>;
20    /// The underlying red node.
21    fn syntax(&self) -> &GdNode;
22}
23
24/// Generate a single-kind AST wrapper struct + its [`AstNode`] impl.
25macro_rules! ast_node {
26    ($(#[$meta:meta])* $name:ident) => {
27        $(#[$meta])*
28        #[derive(Debug, Clone)]
29        pub struct $name(GdNode);
30
31        impl AstNode for $name {
32            fn can_cast(kind: SyntaxKind) -> bool {
33                kind == SyntaxKind::$name
34            }
35            fn cast(node: GdNode) -> Option<Self> {
36                if node.kind() == SyntaxKind::$name {
37                    Some(Self(node))
38                } else {
39                    None
40                }
41            }
42            fn syntax(&self) -> &GdNode {
43                &self.0
44            }
45        }
46    };
47}
48
49ast_node!(
50    /// The whole file.
51    SourceFile
52);
53ast_node!(ClassNameDecl);
54ast_node!(ExtendsClause);
55ast_node!(Annotation);
56ast_node!(FuncDecl);
57ast_node!(VarDecl);
58ast_node!(ConstDecl);
59ast_node!(EnumDecl);
60ast_node!(EnumVariant);
61ast_node!(SignalDecl);
62ast_node!(InnerClassDecl);
63ast_node!(ClassBody);
64ast_node!(ParamList);
65ast_node!(Param);
66ast_node!(Block);
67ast_node!(TypeRef);
68ast_node!(
69    /// A declaration's name (wraps the declared identifier).
70    Name
71);
72
73// ---- generic navigation helpers -------------------------------------------------
74
75/// The first child node castable to `N`.
76fn child<N: AstNode>(node: &GdNode) -> Option<N> {
77    node.children().find_map(|c| N::cast(c.clone()))
78}
79
80/// All child nodes castable to `N`.
81fn children<N: AstNode>(node: &GdNode) -> impl Iterator<Item = N> + '_ {
82    node.children().filter_map(|c| N::cast(c.clone()))
83}
84
85/// Whether `node` has a direct child token of `kind`.
86fn has_token(node: &GdNode, kind: SyntaxKind) -> bool {
87    node.children_with_tokens()
88        .filter_map(NodeOrToken::into_token)
89        .any(|t| t.kind() == kind)
90}
91
92/// The text of the first direct child token of `kind`.
93fn token_text(node: &GdNode, kind: SyntaxKind) -> Option<String> {
94    node.children_with_tokens()
95        .filter_map(NodeOrToken::into_token)
96        .find(|t| t.kind() == kind)
97        .map(|t| t.text().to_owned())
98}
99
100// ---- accessors ------------------------------------------------------------------
101
102impl Name {
103    /// The identifier text.
104    #[must_use]
105    pub fn text(&self) -> Option<String> {
106        token_text(&self.0, SyntaxKind::Ident)
107    }
108}
109
110impl SourceFile {
111    /// The top-level declarations, in source order.
112    pub fn decls(&self) -> impl Iterator<Item = Decl> + '_ {
113        self.0.children().filter_map(|c| Decl::cast(c.clone()))
114    }
115}
116
117impl FuncDecl {
118    /// The function name.
119    #[must_use]
120    pub fn name(&self) -> Option<Name> {
121        child(&self.0)
122    }
123    /// The parameter list.
124    #[must_use]
125    pub fn param_list(&self) -> Option<ParamList> {
126        child(&self.0)
127    }
128    /// The body block.
129    #[must_use]
130    pub fn body(&self) -> Option<Block> {
131        child(&self.0)
132    }
133    /// The declared return type, if any (the `TypeRef` after `->`).
134    #[must_use]
135    pub fn return_type(&self) -> Option<TypeRef> {
136        child(&self.0)
137    }
138    /// Whether this is a `static func`.
139    #[must_use]
140    pub fn is_static(&self) -> bool {
141        has_token(&self.0, SyntaxKind::StaticKw)
142    }
143}
144
145impl ParamList {
146    /// The parameters (excludes vararg rest params).
147    pub fn params(&self) -> impl Iterator<Item = Param> + '_ {
148        children(&self.0)
149    }
150}
151
152impl Param {
153    /// The parameter name.
154    #[must_use]
155    pub fn name(&self) -> Option<Name> {
156        child(&self.0)
157    }
158    /// The declared type, if any.
159    #[must_use]
160    pub fn type_ref(&self) -> Option<TypeRef> {
161        child(&self.0)
162    }
163}
164
165impl VarDecl {
166    /// The variable name.
167    #[must_use]
168    pub fn name(&self) -> Option<Name> {
169        child(&self.0)
170    }
171    /// The declared type, if any.
172    #[must_use]
173    pub fn type_ref(&self) -> Option<TypeRef> {
174        child(&self.0)
175    }
176    /// Whether this is a `static var`.
177    #[must_use]
178    pub fn is_static(&self) -> bool {
179        has_token(&self.0, SyntaxKind::StaticKw)
180    }
181}
182
183impl ConstDecl {
184    /// The constant name.
185    #[must_use]
186    pub fn name(&self) -> Option<Name> {
187        child(&self.0)
188    }
189}
190
191impl EnumDecl {
192    /// The enum's name, if it is a named enum.
193    #[must_use]
194    pub fn name(&self) -> Option<Name> {
195        child(&self.0)
196    }
197    /// The enum variants.
198    pub fn variants(&self) -> impl Iterator<Item = EnumVariant> + '_ {
199        children(&self.0)
200    }
201}
202
203impl EnumVariant {
204    /// The variant name.
205    #[must_use]
206    pub fn text(&self) -> Option<String> {
207        token_text(&self.0, SyntaxKind::Ident)
208    }
209}
210
211impl SignalDecl {
212    /// The signal name.
213    #[must_use]
214    pub fn name(&self) -> Option<Name> {
215        child(&self.0)
216    }
217    /// The typed parameter list, if any.
218    #[must_use]
219    pub fn param_list(&self) -> Option<ParamList> {
220        child(&self.0)
221    }
222}
223
224impl ClassNameDecl {
225    /// The registered global class name.
226    #[must_use]
227    pub fn name(&self) -> Option<Name> {
228        child(&self.0)
229    }
230}
231
232impl InnerClassDecl {
233    /// The inner class name.
234    #[must_use]
235    pub fn name(&self) -> Option<Name> {
236        child(&self.0)
237    }
238    /// The class body (its members), if present.
239    #[must_use]
240    pub fn body(&self) -> Option<ClassBody> {
241        child(&self.0)
242    }
243}
244
245impl ClassBody {
246    /// The member declarations.
247    pub fn decls(&self) -> impl Iterator<Item = Decl> + '_ {
248        self.0.children().filter_map(|c| Decl::cast(c.clone()))
249    }
250}
251
252impl Annotation {
253    /// The annotation name (the identifier after `@`).
254    #[must_use]
255    pub fn name(&self) -> Option<String> {
256        token_text(&self.0, SyntaxKind::Ident)
257    }
258}
259
260impl TypeRef {
261    /// The leading type identifier (e.g. `int`, `Array`).
262    #[must_use]
263    pub fn text(&self) -> Option<String> {
264        self.0
265            .children_with_tokens()
266            .filter_map(NodeOrToken::into_token)
267            .find(|t| matches!(t.kind(), SyntaxKind::Ident | SyntaxKind::VoidKw))
268            .map(|t| t.text().to_owned())
269    }
270}
271
272/// Any top-level or class-body declaration — the unit `document_symbols` iterates.
273#[derive(Debug, Clone)]
274pub enum Decl {
275    /// `class_name X`
276    ClassName(ClassNameDecl),
277    /// `func f(...)`
278    Func(FuncDecl),
279    /// `var x`
280    Var(VarDecl),
281    /// `const X`
282    Const(ConstDecl),
283    /// `enum E { ... }`
284    Enum(EnumDecl),
285    /// `signal s`
286    Signal(SignalDecl),
287    /// `class Inner: ...`
288    Class(InnerClassDecl),
289}
290
291impl Decl {
292    /// View a node as a declaration, if it is one.
293    #[must_use]
294    pub fn cast(node: GdNode) -> Option<Self> {
295        match node.kind() {
296            SyntaxKind::ClassNameDecl => ClassNameDecl::cast(node).map(Self::ClassName),
297            SyntaxKind::FuncDecl => FuncDecl::cast(node).map(Self::Func),
298            SyntaxKind::VarDecl => VarDecl::cast(node).map(Self::Var),
299            SyntaxKind::ConstDecl => ConstDecl::cast(node).map(Self::Const),
300            SyntaxKind::EnumDecl => EnumDecl::cast(node).map(Self::Enum),
301            SyntaxKind::SignalDecl => SignalDecl::cast(node).map(Self::Signal),
302            SyntaxKind::InnerClassDecl => InnerClassDecl::cast(node).map(Self::Class),
303            _ => None,
304        }
305    }
306
307    /// The declaration's underlying node.
308    #[must_use]
309    pub fn syntax(&self) -> &GdNode {
310        match self {
311            Self::ClassName(d) => d.syntax(),
312            Self::Func(d) => d.syntax(),
313            Self::Var(d) => d.syntax(),
314            Self::Const(d) => d.syntax(),
315            Self::Enum(d) => d.syntax(),
316            Self::Signal(d) => d.syntax(),
317            Self::Class(d) => d.syntax(),
318        }
319    }
320
321    /// The declaration's name, if it has one.
322    #[must_use]
323    pub fn name(&self) -> Option<String> {
324        let name = match self {
325            Self::ClassName(d) => d.name(),
326            Self::Func(d) => d.name(),
327            Self::Var(d) => d.name(),
328            Self::Const(d) => d.name(),
329            Self::Enum(d) => d.name(),
330            Self::Signal(d) => d.name(),
331            Self::Class(d) => d.name(),
332        };
333        name.and_then(|n| n.text())
334    }
335}
336
337/// A pre-order walk over every node in the tree (depth-first), for visitors that need
338/// to inspect all declarations/blocks (e.g. folding ranges).
339#[must_use]
340pub fn descendants(root: &GdNode) -> Vec<GdNode> {
341    let mut out = vec![root.clone()];
342    for child in root.children() {
343        out.extend(descendants(child));
344    }
345    out
346}
347
348/// The token (if any) at `offset`, right-biased — the completion-context probe.
349#[must_use]
350pub fn token_at(root: &GdNode, offset: text_size::TextSize) -> Option<GdToken> {
351    root.token_at_offset(offset).right_biased()
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::parse;
358
359    #[test]
360    fn func_accessors() {
361        let parse = parse("static func add(a: int, b := 1) -> int:\n\treturn a + b\n");
362        let file = SourceFile::cast(parse.syntax_node()).unwrap();
363        let func = file
364            .decls()
365            .find_map(|d| match d {
366                Decl::Func(f) => Some(f),
367                _ => None,
368            })
369            .unwrap();
370        assert!(func.is_static());
371        assert_eq!(func.name().and_then(|n| n.text()).as_deref(), Some("add"));
372        assert_eq!(
373            func.return_type().and_then(|t| t.text()).as_deref(),
374            Some("int")
375        );
376        let params: Vec<_> = func
377            .param_list()
378            .unwrap()
379            .params()
380            .filter_map(|p| p.name().and_then(|n| n.text()))
381            .collect();
382        assert_eq!(params, vec!["a", "b"]);
383        assert!(func.body().is_some());
384    }
385
386    #[test]
387    fn declarations_are_enumerated() {
388        let parse = parse(
389            "class_name Foo\nconst K = 1\nvar x: int\nsignal s\nenum E { A, B }\nfunc f():\n\tpass\nclass Inner:\n\tvar y = 2\n",
390        );
391        let file = SourceFile::cast(parse.syntax_node()).unwrap();
392        let names: Vec<_> = file.decls().map(|d| d.name().unwrap_or_default()).collect();
393        assert_eq!(names, vec!["Foo", "K", "x", "s", "E", "f", "Inner"]);
394    }
395
396    #[test]
397    fn enum_variants_and_inner_class_members() {
398        let parse =
399            parse("enum E { A, B = 5, C }\nclass Inner:\n\tvar a = 1\n\tfunc m():\n\t\tpass\n");
400        let file = SourceFile::cast(parse.syntax_node()).unwrap();
401
402        let en = file
403            .decls()
404            .find_map(|d| match d {
405                Decl::Enum(e) => Some(e),
406                _ => None,
407            })
408            .unwrap();
409        let variants: Vec<_> = en.variants().filter_map(|v| v.text()).collect();
410        assert_eq!(variants, vec!["A", "B", "C"]);
411
412        let inner = file
413            .decls()
414            .find_map(|d| match d {
415                Decl::Class(c) => Some(c),
416                _ => None,
417            })
418            .unwrap();
419        let member_names: Vec<_> = inner
420            .body()
421            .unwrap()
422            .decls()
423            .map(|d| d.name().unwrap_or_default())
424            .collect();
425        assert_eq!(member_names, vec!["a", "m"]);
426    }
427
428    #[test]
429    fn token_at_offset_finds_identifier() {
430        let src = "var hello = 1\n";
431        let parse = parse(src);
432        let node = parse.syntax_node();
433        // offset 5 is inside "hello" (chars 4..9).
434        let tok = node
435            .token_at_offset(text_size::TextSize::new(5))
436            .right_biased()
437            .unwrap();
438        assert_eq!(tok.kind(), SyntaxKind::Ident);
439        assert_eq!(tok.text(), "hello");
440    }
441}