Skip to main content

cairo_lang_syntax_codegen/
generator.rs

1use std::fs;
2use std::path::PathBuf;
3
4use genco::prelude::*;
5use xshell::Shell;
6
7use crate::cairo_spec::get_spec;
8use crate::spec::{Member, Node, NodeKind, Variant, Variants};
9
10pub fn project_root() -> PathBuf {
11    // This is the directory of Cargo.toml of the syntax_codegen crate.
12    let dir = env!("CARGO_MANIFEST_DIR");
13    // Pop the "/crates/cairo-lang-syntax-codegen" suffix.
14    let res = PathBuf::from(dir).parent().unwrap().parent().unwrap().to_owned();
15    assert!(res.join("Cargo.toml").exists(), "Could not find project root directory.");
16    res
17}
18
19pub fn ensure_file_content(filename: PathBuf, content: String) {
20    if let Ok(old_contents) = fs::read_to_string(&filename)
21        && old_contents == content
22    {
23        return;
24    }
25
26    fs::write(&filename, content).unwrap();
27}
28
29pub fn get_codes() -> Vec<(String, String)> {
30    vec![
31        (
32            "crates/cairo-lang-syntax/src/node/ast.rs".into(),
33            reformat_rust_code(generate_ast_code().to_string().unwrap()),
34        ),
35        (
36            "crates/cairo-lang-syntax/src/node/kind.rs".into(),
37            reformat_rust_code(generate_kinds_code().to_string().unwrap()),
38        ),
39        (
40            "crates/cairo-lang-syntax/src/node/key_fields.rs".into(),
41            reformat_rust_code(generate_key_fields_code().to_string().unwrap()),
42        ),
43    ]
44}
45
46pub fn reformat_rust_code(text: String) -> String {
47    // Since rustfmt is used with nightly features, it takes 2 runs to reach a fixed point.
48    reformat_rust_code_inner(reformat_rust_code_inner(text))
49}
50pub fn reformat_rust_code_inner(text: String) -> String {
51    let sh = Shell::new().unwrap();
52    let cmd = sh.cmd("rustfmt").env("RUSTUP_TOOLCHAIN", "nightly-2025-12-05");
53    let cmd_with_args = cmd.arg("--config-path").arg(project_root().join("rustfmt.toml"));
54    let mut stdout = cmd_with_args.stdin(text).read().unwrap();
55    if !stdout.ends_with('\n') {
56        stdout.push('\n');
57    }
58    stdout
59}
60
61fn generate_kinds_code() -> rust::Tokens {
62    let spec = get_spec();
63    let mut tokens = quote! {
64        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
65        use core::fmt;
66        use serde::{Deserialize, Serialize};
67    };
68
69    // Definition of SyntaxKind.
70    let kinds = name_tokens(&spec, |k| !matches!(k, NodeKind::Enum { .. }));
71    let token_kinds = name_tokens(&spec, |k| matches!(k, NodeKind::Token { .. }));
72    let keyword_token_kinds =
73        name_tokens(&spec, |k| matches!(k, NodeKind::Token { is_keyword } if *is_keyword));
74    let terminal_kinds = name_tokens(&spec, |k| matches!(k, NodeKind::Terminal { .. }));
75    let keyword_terminal_kinds =
76        name_tokens(&spec, |k| matches!(k, NodeKind::Terminal { is_keyword, .. } if *is_keyword));
77    let missing_kinds = spec.iter().filter_map(|n| {
78        if let NodeKind::Enum { missing_variant, .. } = &n.kind {
79            missing_variant.as_ref().map(|v| v.kind.as_str())
80        } else {
81            None
82        }
83    });
84
85    tokens.extend(quote! {
86        #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, salsa::Update, cairo_lang_proc_macros::HeapSize)]
87        pub enum SyntaxKind {
88            $(for t in kinds => $t,)
89        }
90        impl SyntaxKind {
91            pub fn is_token(&self) -> bool {
92                matches!(
93                    *self,
94                    $(for t in token_kinds join ( | ) => SyntaxKind::$t)
95                )
96            }
97            pub fn is_terminal(&self) -> bool {
98                matches!(
99                    *self,
100                    $(for t in terminal_kinds join ( | ) => SyntaxKind::$t)
101                )
102            }
103            pub fn is_keyword_token(&self) -> bool {
104                matches!(
105                    *self,
106                    $(for t in keyword_token_kinds join ( | ) => SyntaxKind::$t)
107                )
108            }
109            pub fn is_keyword_terminal(&self) -> bool {
110                matches!(
111                    *self,
112                    $(for t in keyword_terminal_kinds join ( | ) => SyntaxKind::$t)
113                )
114            }
115            pub fn is_missing(&self) -> bool {
116                matches!(
117                    *self,
118                    $(for t in missing_kinds join ( | ) => SyntaxKind::$t)
119                )
120            }
121        }
122        impl fmt::Display for SyntaxKind {
123            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124                write!(f, "{self:?}")
125            }
126        }
127    });
128    tokens
129}
130
131/// Returns an iterator to the names of the tokens matching `predicate`.
132fn name_tokens(spec: &[Node], predicate: impl Fn(&NodeKind) -> bool) -> impl Iterator<Item = &str> {
133    spec.iter().filter(move |n| predicate(&n.kind)).map(|n| n.name.as_str())
134}
135
136fn generate_key_fields_code() -> rust::Tokens {
137    let spec = get_spec();
138    let mut arms = rust::Tokens::new();
139
140    for Node { name, kind } in spec {
141        match kind {
142            NodeKind::Struct { members } | NodeKind::Terminal { members, .. } => {
143                let mut fields = rust::Tokens::new();
144                let mut key_fields_range = 0..0;
145                for (i, member) in members.into_iter().enumerate() {
146                    let field_name = member.name;
147                    if member.key {
148                        if key_fields_range.is_empty() {
149                            key_fields_range = i..(i + 1);
150                        } else {
151                            assert_eq!(key_fields_range.end, i, "Key fields must be contiguous.");
152                            key_fields_range.end = i + 1;
153                        }
154                        if !fields.is_empty() {
155                            fields.extend(quote! { $(", ") });
156                        }
157                        fields.extend(quote!($field_name));
158                    }
159                }
160                if !fields.is_empty() {
161                    arms.extend(quote! {
162                        $("\n// Key fields:") $fields.$("\n")
163                    });
164                }
165                let key_fields_range =
166                    format!("{}..{}", key_fields_range.start, key_fields_range.end);
167                arms.extend(quote! {
168                    SyntaxKind::$name => $key_fields_range,
169                });
170            }
171            NodeKind::List { .. } | NodeKind::SeparatedList { .. } | NodeKind::Token { .. } => {
172                arms.extend(quote! {
173                    SyntaxKind::$name => 0..0,
174                });
175            }
176            NodeKind::Enum { .. } => {}
177        }
178    }
179    let tokens = quote! {
180        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
181        use super::kind::SyntaxKind;
182        $("/// Gets the vector of children ids that are the indexing key for this SyntaxKind.")
183        $("///")
184        $("/// Each SyntaxKind has some children that are defined in the spec to be its indexing key")
185        $("/// for its stable pointer. See [super::stable_ptr].")
186        pub fn key_fields_range(kind: SyntaxKind) -> core::ops::Range<usize> {
187            match kind {
188                $arms
189            }
190        }
191    };
192    tokens
193}
194
195fn generate_ast_code() -> rust::Tokens {
196    let spec = get_spec();
197    let mut tokens = quote! {
198        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
199        #![allow(clippy::match_single_binding)]
200        #![allow(clippy::too_many_arguments)]
201        #![allow(dead_code)]
202        #![allow(unused_variables)]
203        use std::ops::Deref;
204
205        use cairo_lang_filesystem::span::TextWidth;
206        use cairo_lang_filesystem::ids::SmolStrId;
207        use cairo_lang_utils::{extract_matches, Intern};
208        use cairo_lang_proc_macros::HeapSize;
209
210        use salsa::Database;
211
212        use super::element_list::ElementList;
213        use super::green::GreenNodeDetails;
214        use super::kind::SyntaxKind;
215        use super::{
216            GreenId, GreenNode, SyntaxNode, SyntaxStablePtrId, Terminal, Token, TypedStablePtr,
217            TypedSyntaxNode,
218        };
219        #[path = "ast_ext.rs"]
220        mod ast_ext;
221    };
222    let spec_clone = spec.clone();
223    let all_tokens: Vec<_> =
224        spec_clone.iter().filter(|node| matches!(node.kind, NodeKind::Terminal { .. })).collect();
225    for Node { name, kind } in spec {
226        tokens.extend(match kind {
227            NodeKind::Enum { variants, missing_variant } => {
228                let variants_list = match variants {
229                    Variants::List(variants) => variants,
230                    Variants::AllTokens => all_tokens
231                        .iter()
232                        .map(|node| Variant { name: node.name.clone(), kind: node.name.clone() })
233                        .collect(),
234                };
235                gen_enum_code(name, variants_list, missing_variant)
236            }
237            NodeKind::Struct { members } => gen_struct_code(name, members, false),
238            NodeKind::Terminal { members, .. } => gen_struct_code(name, members, true),
239            NodeKind::Token { .. } => gen_token_code(name),
240            NodeKind::List { element_type } => gen_list_code(name, element_type),
241            NodeKind::SeparatedList { element_type, separator_type } => {
242                gen_separated_list_code(name, element_type, separator_type)
243            }
244        });
245    }
246    tokens
247}
248
249fn gen_list_code(name: String, element_type: String) -> rust::Tokens {
250    // TODO(spapini): Change Deref to Borrow.
251    let ptr_name = format!("{name}Ptr");
252    let green_name = format!("{name}Green");
253    let element_green_name = format!("{element_type}Green");
254    let common_code = gen_common_list_code(&name, &green_name, &ptr_name);
255    quote! {
256        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
257        pub struct $(&name)<'db>(ElementList<'db, $(&element_type)<'db>, 1>);
258        impl<'db> Deref for $(&name)<'db>{
259            type Target = ElementList<'db, $(&element_type)<'db>, 1>;
260            fn deref(&self) -> &Self::Target {
261                &self.0
262            }
263        }
264        impl<'db> $(&name)<'db>{
265            pub fn new_green(
266                db: &'db dyn Database, children: &[$(&element_green_name)<'db>]
267            ) -> $(&green_name)<'db> {
268                let width = children.iter().map(|id|
269                    id.0.long(db).width(db)).sum();
270                $(&green_name)(GreenNode {
271                    kind: SyntaxKind::$(&name),
272                    details: GreenNodeDetails::Node {
273                        children: children.iter().map(|x| x.0).collect(),
274                        width,
275                    },
276                }.intern(db))
277            }
278        }
279        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
280        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
281        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
282            type SyntaxNode = $(&name)<'db>;
283            fn untyped(self) -> SyntaxStablePtrId<'db> {
284                self.0
285            }
286            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
287                $(&name)::from_syntax_node(db, self.0.lookup(db))
288            }
289        }
290        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
291            fn from(ptr: $(&ptr_name)<'db>) -> Self {
292                ptr.untyped()
293            }
294        }
295        $common_code
296    }
297}
298
299fn gen_separated_list_code(
300    name: String,
301    element_type: String,
302    separator_type: String,
303) -> rust::Tokens {
304    // TODO(spapini): Change Deref to Borrow.
305    let ptr_name = format!("{name}Ptr");
306    let green_name = format!("{name}Green");
307    let element_or_separator_green_name = format!("{name}ElementOrSeparatorGreen");
308    let element_green_name = format!("{element_type}Green");
309    let separator_green_name = format!("{separator_type}Green");
310    let common_code = gen_common_list_code(&name, &green_name, &ptr_name);
311    quote! {
312        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
313        pub struct $(&name)<'db>(ElementList<'db, $(&element_type)<'db>, 2>);
314        impl<'db> Deref for $(&name)<'db>{
315            type Target = ElementList<'db, $(&element_type)<'db>, 2>;
316            fn deref(&self) -> &Self::Target {
317                &self.0
318            }
319        }
320        impl<'db> $(&name)<'db>{
321            pub fn new_green(
322                db: &'db dyn Database, children: &[$(&element_or_separator_green_name)<'db>]
323            ) -> $(&green_name)<'db> {
324                let width = children.iter().map(|id|
325                    id.id().long(db).width(db)).sum();
326                $(&green_name)(GreenNode {
327                    kind: SyntaxKind::$(&name),
328                    details: GreenNodeDetails::Node {
329                        children: children.iter().map(|x| x.id()).collect(),
330                        width,
331                    },
332                }.intern(db))
333            }
334        }
335        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
336        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
337        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
338            type SyntaxNode = $(&name)<'db>;
339            fn untyped(self) -> SyntaxStablePtrId<'db> {
340                self.0
341            }
342            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
343                $(&name)::from_syntax_node(db, self.0.lookup(db))
344            }
345        }
346        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
347            fn from(ptr: $(&ptr_name)<'db>) -> Self {
348                ptr.untyped()
349            }
350        }
351        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
352        pub enum $(&element_or_separator_green_name)<'db> {
353            Separator($(&separator_green_name)<'db>),
354            Element($(&element_green_name)<'db>),
355        }
356        impl<'db> From<$(&separator_green_name)<'db>> for $(&element_or_separator_green_name)<'db> {
357            fn from(value: $(&separator_green_name)<'db>) -> Self {
358                $(&element_or_separator_green_name)::Separator(value)
359            }
360        }
361        impl<'db> From<$(&element_green_name)<'db>> for $(&element_or_separator_green_name)<'db> {
362            fn from(value: $(&element_green_name)<'db>) -> Self {
363                $(&element_or_separator_green_name)::Element(value)
364            }
365        }
366        impl<'db> $(&element_or_separator_green_name)<'db> {
367            fn id(&self) -> GreenId<'db> {
368                match self {
369                    $(&element_or_separator_green_name)::Separator(green) => green.0,
370                    $(&element_or_separator_green_name)::Element(green) => green.0,
371                }
372            }
373        }
374        $common_code
375    }
376}
377
378fn gen_common_list_code(name: &str, green_name: &str, ptr_name: &str) -> rust::Tokens {
379    quote! {
380        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
381        pub struct $green_name<'db>(pub GreenId<'db>);
382        impl<'db> TypedSyntaxNode<'db> for $name<'db> {
383            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$name);
384            type StablePtr = $ptr_name<'db>;
385            type Green = $green_name<'db>;
386            fn missing(db: &'db dyn Database) -> Self::Green {
387                $green_name(
388                    GreenNode {
389                        kind: SyntaxKind::$name,
390                        details: GreenNodeDetails::Node { children: [].into(), width: TextWidth::default() },
391                    }.intern(db)
392                )
393            }
394            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
395                Self(ElementList::new(node))
396            }
397            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
398                if node.kind(db) == SyntaxKind::$name {
399                    Some(Self(ElementList::new(node)))
400                } else {
401                    None
402                }
403            }
404            fn as_syntax_node(&self) -> SyntaxNode<'db> {
405                self.node
406            }
407            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
408                $ptr_name(self.node.stable_ptr(db))
409            }
410        }
411    }
412}
413
414fn gen_enum_code(
415    name: String,
416    variants: Vec<Variant>,
417    missing_variant: Option<Variant>,
418) -> rust::Tokens {
419    let ptr_name = format!("{name}Ptr");
420    let green_name = format!("{name}Green");
421    let mut enum_body = quote! {};
422    let mut from_node_body = quote! {};
423    let mut cast_body = quote! {};
424    let mut ptr_conversions = quote! {};
425    let mut green_conversions = quote! {};
426    for variant in &variants {
427        let n = &variant.name;
428        let k = &variant.kind;
429
430        enum_body.extend(quote! {
431            $n($k<'db>),
432        });
433        from_node_body.extend(quote! {
434            SyntaxKind::$k => $(&name)::$n($k::from_syntax_node(db, node)),
435        });
436        cast_body.extend(quote! {
437            SyntaxKind::$k => Some($(&name)::$n($k::from_syntax_node(db, node))),
438        });
439        let variant_ptr = format!("{k}Ptr");
440        ptr_conversions.extend(quote! {
441            impl<'db> From<$(&variant_ptr)<'db>> for $(&ptr_name)<'db> {
442                fn from(value: $(&variant_ptr)<'db>) -> Self {
443                    Self(value.0)
444                }
445            }
446        });
447        let variant_green = format!("{k}Green");
448        green_conversions.extend(quote! {
449            impl<'db> From<$(&variant_green)<'db>> for $(&green_name)<'db> {
450                fn from(value: $(&variant_green)<'db>) -> Self {
451                    Self(value.0)
452                }
453            }
454        });
455    }
456    let missing_body = match missing_variant {
457        Some(missing) => quote! {
458            $(&green_name)($(missing.kind)::missing(db).0)
459        },
460        None => quote! {
461            panic!("No missing variant.");
462        },
463    };
464    quote! {
465        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
466        pub enum $(&name)<'db>{
467            $enum_body
468        }
469        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
470        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
471        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
472            type SyntaxNode = $(&name)<'db>;
473            fn untyped(self) -> SyntaxStablePtrId<'db> {
474                self.0
475            }
476            fn lookup(&self, db: &'db dyn Database) -> Self::SyntaxNode {
477                $(&name)::from_syntax_node(db, self.0.lookup(db))
478            }
479        }
480        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
481            fn from(ptr: $(&ptr_name)<'db>) -> Self {
482                ptr.untyped()
483            }
484        }
485        $ptr_conversions
486        $green_conversions
487        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
488        pub struct $(&green_name)<'db>(pub GreenId<'db>);
489        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db>{
490            const OPTIONAL_KIND: Option<SyntaxKind> = None;
491            type StablePtr = $(&ptr_name)<'db>;
492            type Green = $(&green_name)<'db>;
493            fn missing(db: &'db dyn Database) -> Self::Green {
494                $missing_body
495            }
496            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
497                let kind = node.kind(db);
498                match kind{
499                    $from_node_body
500                    _ => panic!(
501                        "Unexpected syntax kind {:?} when constructing {}.",
502                        kind,
503                        $[str]($[const](&name))),
504                }
505            }
506            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
507                let kind = node.kind(db);
508                match kind {
509                    $cast_body
510                    _ => None,
511                }
512            }
513            fn as_syntax_node(&self) -> SyntaxNode<'db> {
514                match self {
515                    $(for v in &variants => $(&name)::$(&v.name)(x) => x.as_syntax_node(),)
516                }
517            }
518            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
519                $(&ptr_name)(self.as_syntax_node().stable_ptr(db))
520            }
521        }
522        impl<'db> $(&name)<'db> {
523            $("/// Checks if a kind of a variant of [")$(&name)$("].")
524            pub fn is_variant(kind: SyntaxKind) -> bool {
525                matches!(kind, $(for v in &variants join (|) => SyntaxKind::$(&v.kind)))
526            }
527        }
528    }
529}
530
531fn gen_token_code(name: String) -> rust::Tokens {
532    let green_name = format!("{name}Green");
533    let ptr_name = format!("{name}Ptr");
534
535    quote! {
536        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
537        pub struct $(&name)<'db> {
538            node: SyntaxNode<'db>,
539        }
540        impl<'db> Token<'db> for $(&name)<'db> {
541            fn new_green(db: &'db dyn Database, text: SmolStrId<'db>) -> Self::Green {
542                $(&green_name)(GreenNode {
543                    kind: SyntaxKind::$(&name),
544                    details: GreenNodeDetails::Token(text),
545                }.intern(db))
546            }
547            fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
548                *extract_matches!(&self.node.green_node(db).details,
549                    GreenNodeDetails::Token)
550            }
551        }
552        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
553        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
554        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
555            type SyntaxNode = $(&name)<'db>;
556            fn untyped(self) -> SyntaxStablePtrId<'db> {
557                self.0
558            }
559            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
560                $(&name)::from_syntax_node(db, self.0.lookup(db))
561            }
562        }
563        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
564            fn from(ptr: $(&ptr_name)<'db>) -> Self {
565                ptr.untyped()
566            }
567        }
568        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
569        pub struct $(&green_name)<'db>(pub GreenId<'db>);
570        impl<'db> $(&green_name)<'db> {
571            pub fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
572                *extract_matches!(&self.0.long(db).details, GreenNodeDetails::Token)
573            }
574        }
575        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db>{
576            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$(&name));
577            type StablePtr = $(&ptr_name)<'db>;
578            type Green = $(&green_name)<'db>;
579            fn missing(db: &'db dyn Database) -> Self::Green {
580                $(&green_name)(GreenNode {
581                    kind: SyntaxKind::TokenMissing,
582                    details: GreenNodeDetails::Token(SmolStrId::from(db, "")),
583                }.intern(db))
584            }
585            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
586                match node.green_node(db).details {
587                    GreenNodeDetails::Token(_) => Self { node },
588                    GreenNodeDetails::Node { .. } => panic!(
589                        "Expected a token {:?}, not an internal node",
590                        SyntaxKind::$(&name)
591                    ),
592                }
593            }
594            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
595                match node.green_node(db).details {
596                    GreenNodeDetails::Token(_) => Some(Self { node }),
597                    GreenNodeDetails::Node { .. } => None,
598                }
599            }
600            fn as_syntax_node(&self) -> SyntaxNode<'db> {
601                self.node
602            }
603            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
604                $(&ptr_name)(self.node.stable_ptr(db))
605            }
606        }
607    }
608}
609
610fn gen_struct_code(name: String, members: Vec<Member>, is_terminal: bool) -> rust::Tokens {
611    let green_name = format!("{name}Green");
612    let mut body = rust::Tokens::new();
613    let mut field_indices = quote! {};
614    let mut args = quote! {};
615    let mut params = quote! {};
616    let mut args_for_missing = quote! {};
617    let mut ptr_getters = quote! {};
618    let mut key_field_index: usize = 0;
619    for (i, Member { name, kind, key }) in members.iter().enumerate() {
620        let index_name = format!("INDEX_{}", name.to_uppercase());
621        field_indices.extend(quote! {
622            pub const $index_name : usize = $i;
623        });
624        let key_name_green = format!("{name}_green");
625        args.extend(quote! {$name.0,});
626        // TODO(spapini): Validate that children SyntaxKinds are as expected.
627
628        let child_green = format!("{kind}Green");
629        params.extend(quote! {$name: $(&child_green)<'db>,});
630        body.extend(quote! {
631            pub fn $name(&self, db: &'db dyn Database) -> $kind<'db> {
632                $kind::from_syntax_node(db, self.node.get_children(db)[$i])
633            }
634        });
635        args_for_missing.extend(quote! {$kind::missing(db).0,});
636
637        if *key {
638            ptr_getters.extend(quote! {
639                pub fn $(&key_name_green)(self, db: &'db dyn Database) -> $(&child_green)<'db> {
640                    $(&child_green)(self.0.0.key_fields(db)[$key_field_index])
641                }
642            });
643            key_field_index += 1;
644        }
645    }
646    let ptr_name = format!("{name}Ptr");
647    let new_green_impl = if is_terminal {
648        let token_name = name.replace("Terminal", "Token");
649        quote! {
650            impl<'db> Terminal<'db> for $(&name)<'db> {
651                const KIND: SyntaxKind = SyntaxKind::$(&name);
652                type TokenType = $(&token_name)<'db>;
653                fn new_green(
654                    db: &'db dyn Database,
655                    leading_trivia: TriviaGreen<'db>,
656                    token: <<$(&name)<'db> as Terminal<'db>>::TokenType as TypedSyntaxNode<'db>>::Green,
657                    trailing_trivia: TriviaGreen<'db>
658                ) -> Self::Green {
659                    let children = [$args];
660                    let width =
661                        children.into_iter().map(|id: GreenId<'_>| id.long(db).width(db)).sum();
662                    $(&green_name)(GreenNode {
663                        kind: SyntaxKind::$(&name),
664                        details: GreenNodeDetails::Node { children: children.into(), width },
665                    }.intern(db))
666                }
667                fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
668                    let GreenNodeDetails::Node{children,..} = &self.node.green_node(db).details else {
669                        unreachable!("Expected a node, not a token");
670                    };
671                    *extract_matches!(&children[1].long(db).details, GreenNodeDetails::Token)
672                }
673            }
674        }
675    } else {
676        quote! {
677            impl<'db> $(&name)<'db> {
678                $field_indices
679                pub fn new_green(db: &'db dyn Database, $params) -> $(&green_name)<'db> {
680                    let children = [$args];
681                    let width =
682                        children.into_iter().map(|id: GreenId<'_>| id.long(db).width(db)).sum();
683                    $(&green_name)(GreenNode {
684                        kind: SyntaxKind::$(&name),
685                        details: GreenNodeDetails::Node { children: children.into(), width },
686                    }.intern(db))
687                }
688            }
689        }
690    };
691    quote! {
692        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
693        pub struct $(&name)<'db> {
694            node: SyntaxNode<'db>,
695        }
696        $new_green_impl
697        impl<'db> $(&name)<'db> {
698            $body
699        }
700        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
701        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
702        impl<'db> $(&ptr_name)<'db> {
703            $ptr_getters
704        }
705        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
706            type SyntaxNode = $(&name)<'db>;
707            fn untyped(self) -> SyntaxStablePtrId<'db> {
708                self.0
709            }
710            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
711                $(&name)::from_syntax_node(db, self.0.lookup(db))
712            }
713        }
714        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
715            fn from(ptr: $(&ptr_name)<'db>) -> Self {
716                ptr.untyped()
717            }
718        }
719        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
720        pub struct $(&green_name)<'db>(pub GreenId<'db>);
721        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db> {
722            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$(&name));
723            type StablePtr = $(&ptr_name)<'db>;
724            type Green = $(&green_name)<'db>;
725            fn missing(db: &'db dyn Database) -> Self::Green {
726                // Note: A missing syntax element should result in an internal green node
727                // of width 0, with as much structure as possible.
728                $(&green_name)(GreenNode {
729                    kind: SyntaxKind::$(&name),
730                    details: GreenNodeDetails::Node {
731                        children: [$args_for_missing].into(),
732                        width: TextWidth::default(),
733                    },
734                }.intern(db))
735            }
736            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
737                let kind = node.kind(db);
738                assert_eq!(kind, SyntaxKind::$(&name), "Unexpected SyntaxKind {:?}. Expected {:?}.", kind, SyntaxKind::$(&name));
739                Self { node }
740            }
741            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
742                let kind = node.kind(db);
743                if kind == SyntaxKind::$(&name) {
744                    Some(Self::from_syntax_node(db, node))
745                } else {
746                    None
747                }
748            }
749            fn as_syntax_node(&self) -> SyntaxNode<'db> {
750                self.node
751            }
752            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
753                $(&ptr_name)(self.node.stable_ptr(db))
754            }
755        }
756    }
757}