Skip to main content

cairo_lang_syntax_codegen/
generator.rs

1use std::fs;
2use std::path::PathBuf;
3
4use genco::prelude::*;
5use xshell::Shell;
6
7use crate::cairo_spec::get_spec;
8use crate::spec::{Member, Node, NodeKind, Variant, Variants};
9
10pub fn project_root() -> PathBuf {
11    // This is the directory of Cargo.toml of the syntax_codegen crate.
12    let dir = env!("CARGO_MANIFEST_DIR");
13    // Pop the "/crates/cairo-lang-syntax-codegen" suffix.
14    let res = PathBuf::from(dir).parent().unwrap().parent().unwrap().to_owned();
15    assert!(res.join("Cargo.toml").exists(), "Could not find project root directory.");
16    res
17}
18
19pub fn ensure_file_content(filename: PathBuf, content: String) {
20    if let Ok(old_contents) = fs::read_to_string(&filename)
21        && old_contents == content
22    {
23        return;
24    }
25
26    fs::write(&filename, content).unwrap();
27}
28
29pub fn get_codes() -> Vec<(String, String)> {
30    vec![
31        (
32            "crates/cairo-lang-syntax/src/node/ast.rs".into(),
33            reformat_rust_code(generate_ast_code().to_string().unwrap()),
34        ),
35        (
36            "crates/cairo-lang-syntax/src/node/kind.rs".into(),
37            reformat_rust_code(generate_kinds_code().to_string().unwrap()),
38        ),
39        (
40            "crates/cairo-lang-syntax/src/node/key_fields.rs".into(),
41            reformat_rust_code(generate_key_fields_code().to_string().unwrap()),
42        ),
43    ]
44}
45
46pub fn reformat_rust_code(text: String) -> String {
47    // Since rustfmt is used with nightly features, it takes 2 runs to reach a fixed point.
48    reformat_rust_code_inner(reformat_rust_code_inner(text))
49}
50pub fn reformat_rust_code_inner(text: String) -> String {
51    let sh = Shell::new().unwrap();
52    let cmd = sh.cmd("rustfmt").env("RUSTUP_TOOLCHAIN", "nightly-2026-04-09");
53    let cmd_with_args = cmd.arg("--config-path").arg(project_root().join("rustfmt.toml"));
54    let mut stdout = cmd_with_args.stdin(text).read().unwrap();
55    if !stdout.ends_with('\n') {
56        stdout.push('\n');
57    }
58    stdout
59}
60
61fn generate_kinds_code() -> rust::Tokens {
62    let spec = get_spec();
63    let mut tokens = quote! {
64        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
65        use core::fmt;
66        use serde::{Deserialize, Serialize};
67    };
68
69    // Definition of SyntaxKind.
70    let kinds = name_tokens(&spec, |k| !matches!(k, NodeKind::Enum { .. }));
71    let token_kinds = name_tokens(&spec, |k| matches!(k, NodeKind::Token { .. }));
72    let keyword_token_kinds =
73        name_tokens(&spec, |k| matches!(k, NodeKind::Token { is_keyword } if *is_keyword));
74    let terminal_kinds = name_tokens(&spec, |k| matches!(k, NodeKind::Terminal { .. }));
75    let keyword_terminal_kinds =
76        name_tokens(&spec, |k| matches!(k, NodeKind::Terminal { is_keyword, .. } if *is_keyword));
77    let missing_kinds = spec.iter().filter_map(|n| {
78        if let NodeKind::Enum { missing_variant, .. } = &n.kind {
79            missing_variant.as_ref().map(|v| v.kind.as_str())
80        } else {
81            None
82        }
83    });
84
85    tokens.extend(quote! {
86        #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, salsa::Update, cairo_lang_proc_macros::HeapSize)]
87        pub enum SyntaxKind {
88            $(for t in kinds => $t,)
89        }
90        impl SyntaxKind {
91            pub fn is_token(&self) -> bool {
92                matches!(
93                    *self,
94                    $(for t in token_kinds join ( | ) => SyntaxKind::$t)
95                )
96            }
97            pub fn is_terminal(&self) -> bool {
98                matches!(
99                    *self,
100                    $(for t in terminal_kinds join ( | ) => SyntaxKind::$t)
101                )
102            }
103            pub fn is_keyword_token(&self) -> bool {
104                matches!(
105                    *self,
106                    $(for t in keyword_token_kinds join ( | ) => SyntaxKind::$t)
107                )
108            }
109            pub fn is_keyword_terminal(&self) -> bool {
110                matches!(
111                    *self,
112                    $(for t in keyword_terminal_kinds join ( | ) => SyntaxKind::$t)
113                )
114            }
115            pub fn is_missing(&self) -> bool {
116                matches!(
117                    *self,
118                    $(for t in missing_kinds join ( | ) => SyntaxKind::$t)
119                )
120            }
121        }
122        impl fmt::Display for SyntaxKind {
123            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124                write!(f, "{self:?}")
125            }
126        }
127    });
128    tokens
129}
130
131/// Returns an iterator to the names of the tokens matching `predicate`.
132fn name_tokens(spec: &[Node], predicate: impl Fn(&NodeKind) -> bool) -> impl Iterator<Item = &str> {
133    spec.iter().filter(move |n| predicate(&n.kind)).map(|n| n.name.as_str())
134}
135
136fn generate_key_fields_code() -> rust::Tokens {
137    let spec = get_spec();
138    let mut arms = rust::Tokens::new();
139
140    for Node { name, kind } in spec {
141        match kind {
142            NodeKind::Struct { members } | NodeKind::Terminal { members, .. } => {
143                let mut fields = rust::Tokens::new();
144                let mut key_fields_range = 0..0;
145                for (i, member) in members.into_iter().enumerate() {
146                    let field_name = member.name;
147                    if member.key {
148                        if key_fields_range.is_empty() {
149                            key_fields_range = i..(i + 1);
150                        } else {
151                            assert_eq!(key_fields_range.end, i, "Key fields must be contiguous.");
152                            key_fields_range.end = i + 1;
153                        }
154                        if !fields.is_empty() {
155                            fields.extend(quote! { $(", ") });
156                        }
157                        fields.extend(quote!($field_name));
158                    }
159                }
160                if !fields.is_empty() {
161                    arms.extend(quote! {
162                        $("\n// Key fields:") $fields.$("\n")
163                    });
164                }
165                let key_fields_range =
166                    format!("{}..{}", key_fields_range.start, key_fields_range.end);
167                arms.extend(quote! {
168                    SyntaxKind::$name => $key_fields_range,
169                });
170            }
171            NodeKind::List { .. } | NodeKind::SeparatedList { .. } | NodeKind::Token { .. } => {
172                arms.extend(quote! {
173                    SyntaxKind::$name => 0..0,
174                });
175            }
176            NodeKind::Enum { .. } => {}
177        }
178    }
179    let tokens = quote! {
180        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
181        use super::kind::SyntaxKind;
182        $("/// Gets the vector of children ids that are the indexing key for this SyntaxKind.")
183        $("///")
184        $("/// Each SyntaxKind has some children that are defined in the spec to be its indexing key")
185        $("/// for its stable pointer. See [super::stable_ptr].")
186        pub fn key_fields_range(kind: SyntaxKind) -> core::ops::Range<usize> {
187            match kind {
188                $arms
189            }
190        }
191    };
192    tokens
193}
194
195fn generate_ast_code() -> rust::Tokens {
196    let spec = get_spec();
197    let mut tokens = quote! {
198        $("// Autogenerated file. To regenerate, please run `cargo run --bin generate-syntax`.")
199        #![allow(clippy::match_single_binding)]
200        #![allow(clippy::too_many_arguments)]
201        #![allow(dead_code)]
202        #![allow(unused_variables)]
203        use std::ops::Deref;
204
205        use cairo_lang_filesystem::span::TextWidth;
206        use cairo_lang_filesystem::ids::SmolStrId;
207        use cairo_lang_utils::{extract_matches, Intern};
208        use cairo_lang_proc_macros::HeapSize;
209
210        use salsa::Database;
211
212        use super::element_list::ElementList;
213        use super::green::GreenNodeDetails;
214        use super::kind::SyntaxKind;
215        use super::{
216            GreenId, GreenNode, SyntaxNode, SyntaxStablePtrId, Terminal, Token, TypedStablePtr,
217            TypedSyntaxNode,
218        };
219        #[path = "ast_ext.rs"]
220        mod ast_ext;
221    };
222    let all_token_variants: Vec<_> = spec
223        .iter()
224        .filter(|node| matches!(node.kind, NodeKind::Terminal { .. }))
225        .map(|node| Variant { name: node.name.clone(), kind: node.name.clone() })
226        .collect();
227    for Node { name, kind } in spec {
228        tokens.extend(match kind {
229            NodeKind::Enum { variants, missing_variant } => {
230                let variants_list = match variants {
231                    Variants::List(variants) => variants,
232                    Variants::AllTokens => all_token_variants.clone(),
233                };
234                gen_enum_code(name, variants_list, missing_variant)
235            }
236            NodeKind::Struct { members } => gen_struct_code(name, members, false),
237            NodeKind::Terminal { members, .. } => gen_struct_code(name, members, true),
238            NodeKind::Token { .. } => gen_token_code(name),
239            NodeKind::List { element_type } => gen_list_code(name, element_type),
240            NodeKind::SeparatedList { element_type, separator_type } => {
241                gen_separated_list_code(name, element_type, separator_type)
242            }
243        });
244    }
245    tokens
246}
247
248fn gen_list_code(name: String, element_type: String) -> rust::Tokens {
249    // TODO(spapini): Change Deref to Borrow.
250    let ptr_name = format!("{name}Ptr");
251    let green_name = format!("{name}Green");
252    let element_green_name = format!("{element_type}Green");
253    let common_code = gen_common_list_code(&name, &green_name, &ptr_name);
254    quote! {
255        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
256        pub struct $(&name)<'db>(ElementList<'db, $(&element_type)<'db>, 1>);
257        impl<'db> Deref for $(&name)<'db>{
258            type Target = ElementList<'db, $(&element_type)<'db>, 1>;
259            fn deref(&self) -> &Self::Target {
260                &self.0
261            }
262        }
263        impl<'db> $(&name)<'db>{
264            pub fn new_green(
265                db: &'db dyn Database, children: &[$(&element_green_name)<'db>]
266            ) -> $(&green_name)<'db> {
267                let width = children.iter().map(|id|
268                    id.0.long(db).width(db)).sum();
269                $(&green_name)(GreenNode {
270                    kind: SyntaxKind::$(&name),
271                    details: GreenNodeDetails::Node {
272                        children: children.iter().map(|x| x.0).collect(),
273                        width,
274                    },
275                }.intern(db))
276            }
277        }
278        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
279        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
280        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
281            type SyntaxNode = $(&name)<'db>;
282            fn untyped(self) -> SyntaxStablePtrId<'db> {
283                self.0
284            }
285            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
286                $(&name)::from_syntax_node(db, self.0.lookup(db))
287            }
288        }
289        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
290            fn from(ptr: $(&ptr_name)<'db>) -> Self {
291                ptr.untyped()
292            }
293        }
294        $common_code
295    }
296}
297
298fn gen_separated_list_code(
299    name: String,
300    element_type: String,
301    separator_type: String,
302) -> rust::Tokens {
303    // TODO(spapini): Change Deref to Borrow.
304    let ptr_name = format!("{name}Ptr");
305    let green_name = format!("{name}Green");
306    let element_or_separator_green_name = format!("{name}ElementOrSeparatorGreen");
307    let element_green_name = format!("{element_type}Green");
308    let separator_green_name = format!("{separator_type}Green");
309    let common_code = gen_common_list_code(&name, &green_name, &ptr_name);
310    quote! {
311        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
312        pub struct $(&name)<'db>(ElementList<'db, $(&element_type)<'db>, 2>);
313        impl<'db> Deref for $(&name)<'db>{
314            type Target = ElementList<'db, $(&element_type)<'db>, 2>;
315            fn deref(&self) -> &Self::Target {
316                &self.0
317            }
318        }
319        impl<'db> $(&name)<'db>{
320            pub fn new_green(
321                db: &'db dyn Database, children: &[$(&element_or_separator_green_name)<'db>]
322            ) -> $(&green_name)<'db> {
323                let width = children.iter().map(|id|
324                    id.id().long(db).width(db)).sum();
325                $(&green_name)(GreenNode {
326                    kind: SyntaxKind::$(&name),
327                    details: GreenNodeDetails::Node {
328                        children: children.iter().map(|x| x.id()).collect(),
329                        width,
330                    },
331                }.intern(db))
332            }
333        }
334        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
335        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
336        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
337            type SyntaxNode = $(&name)<'db>;
338            fn untyped(self) -> SyntaxStablePtrId<'db> {
339                self.0
340            }
341            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
342                $(&name)::from_syntax_node(db, self.0.lookup(db))
343            }
344        }
345        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
346            fn from(ptr: $(&ptr_name)<'db>) -> Self {
347                ptr.untyped()
348            }
349        }
350        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
351        pub enum $(&element_or_separator_green_name)<'db> {
352            Separator($(&separator_green_name)<'db>),
353            Element($(&element_green_name)<'db>),
354        }
355        impl<'db> From<$(&separator_green_name)<'db>> for $(&element_or_separator_green_name)<'db> {
356            fn from(value: $(&separator_green_name)<'db>) -> Self {
357                $(&element_or_separator_green_name)::Separator(value)
358            }
359        }
360        impl<'db> From<$(&element_green_name)<'db>> for $(&element_or_separator_green_name)<'db> {
361            fn from(value: $(&element_green_name)<'db>) -> Self {
362                $(&element_or_separator_green_name)::Element(value)
363            }
364        }
365        impl<'db> $(&element_or_separator_green_name)<'db> {
366            fn id(&self) -> GreenId<'db> {
367                match self {
368                    $(&element_or_separator_green_name)::Separator(green) => green.0,
369                    $(&element_or_separator_green_name)::Element(green) => green.0,
370                }
371            }
372        }
373        $common_code
374    }
375}
376
377fn gen_common_list_code(name: &str, green_name: &str, ptr_name: &str) -> rust::Tokens {
378    quote! {
379        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
380        pub struct $green_name<'db>(pub GreenId<'db>);
381        impl<'db> TypedSyntaxNode<'db> for $name<'db> {
382            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$name);
383            type StablePtr = $ptr_name<'db>;
384            type Green = $green_name<'db>;
385            fn missing(db: &'db dyn Database) -> Self::Green {
386                $green_name(
387                    GreenNode {
388                        kind: SyntaxKind::$name,
389                        details: GreenNodeDetails::Node { children: [].into(), width: TextWidth::default() },
390                    }.intern(db)
391                )
392            }
393            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
394                Self(ElementList::new(node))
395            }
396            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
397                if node.kind(db) == SyntaxKind::$name {
398                    Some(Self(ElementList::new(node)))
399                } else {
400                    None
401                }
402            }
403            fn as_syntax_node(&self) -> SyntaxNode<'db> {
404                self.node
405            }
406            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
407                $ptr_name(self.node.stable_ptr(db))
408            }
409        }
410    }
411}
412
413fn gen_enum_code(
414    name: String,
415    variants: Vec<Variant>,
416    missing_variant: Option<Variant>,
417) -> rust::Tokens {
418    let ptr_name = format!("{name}Ptr");
419    let green_name = format!("{name}Green");
420    let mut enum_body = quote! {};
421    let mut from_node_body = quote! {};
422    let mut cast_body = quote! {};
423    let mut ptr_conversions = quote! {};
424    let mut green_conversions = quote! {};
425    for variant in &variants {
426        let n = &variant.name;
427        let k = &variant.kind;
428
429        enum_body.extend(quote! {
430            $n($k<'db>),
431        });
432        from_node_body.extend(quote! {
433            SyntaxKind::$k => $(&name)::$n($k::from_syntax_node(db, node)),
434        });
435        cast_body.extend(quote! {
436            SyntaxKind::$k => Some($(&name)::$n($k::from_syntax_node(db, node))),
437        });
438        let variant_ptr = format!("{k}Ptr");
439        ptr_conversions.extend(quote! {
440            impl<'db> From<$(&variant_ptr)<'db>> for $(&ptr_name)<'db> {
441                fn from(value: $(&variant_ptr)<'db>) -> Self {
442                    Self(value.0)
443                }
444            }
445        });
446        let variant_green = format!("{k}Green");
447        green_conversions.extend(quote! {
448            impl<'db> From<$(&variant_green)<'db>> for $(&green_name)<'db> {
449                fn from(value: $(&variant_green)<'db>) -> Self {
450                    Self(value.0)
451                }
452            }
453        });
454    }
455    let missing_body = match missing_variant {
456        Some(missing) => quote! {
457            $(&green_name)($(missing.kind)::missing(db).0)
458        },
459        None => quote! {
460            panic!("No missing variant.");
461        },
462    };
463    quote! {
464        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
465        pub enum $(&name)<'db>{
466            $enum_body
467        }
468        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
469        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
470        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
471            type SyntaxNode = $(&name)<'db>;
472            fn untyped(self) -> SyntaxStablePtrId<'db> {
473                self.0
474            }
475            fn lookup(&self, db: &'db dyn Database) -> Self::SyntaxNode {
476                $(&name)::from_syntax_node(db, self.0.lookup(db))
477            }
478        }
479        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
480            fn from(ptr: $(&ptr_name)<'db>) -> Self {
481                ptr.untyped()
482            }
483        }
484        $ptr_conversions
485        $green_conversions
486        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
487        pub struct $(&green_name)<'db>(pub GreenId<'db>);
488        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db>{
489            const OPTIONAL_KIND: Option<SyntaxKind> = None;
490            type StablePtr = $(&ptr_name)<'db>;
491            type Green = $(&green_name)<'db>;
492            fn missing(db: &'db dyn Database) -> Self::Green {
493                $missing_body
494            }
495            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
496                let kind = node.kind(db);
497                match kind{
498                    $from_node_body
499                    _ => panic!(
500                        "Unexpected syntax kind {:?} when constructing {}.",
501                        kind,
502                        $[str]($[const](&name))),
503                }
504            }
505            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
506                let kind = node.kind(db);
507                match kind {
508                    $cast_body
509                    _ => None,
510                }
511            }
512            fn as_syntax_node(&self) -> SyntaxNode<'db> {
513                match self {
514                    $(for v in &variants => $(&name)::$(&v.name)(x) => x.as_syntax_node(),)
515                }
516            }
517            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
518                $(&ptr_name)(self.as_syntax_node().stable_ptr(db))
519            }
520        }
521        impl<'db> $(&name)<'db> {
522            $("/// Checks if a kind of a variant of [")$(&name)$("].")
523            pub fn is_variant(kind: SyntaxKind) -> bool {
524                matches!(kind, $(for v in &variants join (|) => SyntaxKind::$(&v.kind)))
525            }
526        }
527    }
528}
529
530fn gen_token_code(name: String) -> rust::Tokens {
531    let green_name = format!("{name}Green");
532    let ptr_name = format!("{name}Ptr");
533
534    quote! {
535        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
536        pub struct $(&name)<'db> {
537            node: SyntaxNode<'db>,
538        }
539        impl<'db> Token<'db> for $(&name)<'db> {
540            fn new_green(db: &'db dyn Database, text: SmolStrId<'db>) -> Self::Green {
541                $(&green_name)(GreenNode {
542                    kind: SyntaxKind::$(&name),
543                    details: GreenNodeDetails::Token(text),
544                }.intern(db))
545            }
546            fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
547                *extract_matches!(&self.node.green_node(db).details,
548                    GreenNodeDetails::Token)
549            }
550        }
551        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
552        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
553        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
554            type SyntaxNode = $(&name)<'db>;
555            fn untyped(self) -> SyntaxStablePtrId<'db> {
556                self.0
557            }
558            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
559                $(&name)::from_syntax_node(db, self.0.lookup(db))
560            }
561        }
562        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
563            fn from(ptr: $(&ptr_name)<'db>) -> Self {
564                ptr.untyped()
565            }
566        }
567        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
568        pub struct $(&green_name)<'db>(pub GreenId<'db>);
569        impl<'db> $(&green_name)<'db> {
570            pub fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
571                *extract_matches!(&self.0.long(db).details, GreenNodeDetails::Token)
572            }
573        }
574        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db>{
575            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$(&name));
576            type StablePtr = $(&ptr_name)<'db>;
577            type Green = $(&green_name)<'db>;
578            fn missing(db: &'db dyn Database) -> Self::Green {
579                $(&green_name)(GreenNode {
580                    kind: SyntaxKind::TokenMissing,
581                    details: GreenNodeDetails::Token(SmolStrId::from(db, "")),
582                }.intern(db))
583            }
584            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
585                match node.green_node(db).details {
586                    GreenNodeDetails::Token(_) => Self { node },
587                    GreenNodeDetails::Node { .. } => panic!(
588                        "Expected a token {:?}, not an internal node",
589                        SyntaxKind::$(&name)
590                    ),
591                }
592            }
593            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
594                match node.green_node(db).details {
595                    GreenNodeDetails::Token(_) => Some(Self { node }),
596                    GreenNodeDetails::Node { .. } => None,
597                }
598            }
599            fn as_syntax_node(&self) -> SyntaxNode<'db> {
600                self.node
601            }
602            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
603                $(&ptr_name)(self.node.stable_ptr(db))
604            }
605        }
606    }
607}
608
609fn gen_struct_code(name: String, members: Vec<Member>, is_terminal: bool) -> rust::Tokens {
610    let green_name = format!("{name}Green");
611    let mut body = rust::Tokens::new();
612    let mut field_indices = quote! {};
613    let mut args = quote! {};
614    let mut params = quote! {};
615    let mut args_for_missing = quote! {};
616    let mut ptr_getters = quote! {};
617    let mut key_field_index: usize = 0;
618    for (i, Member { name, kind, key }) in members.iter().enumerate() {
619        let index_name = format!("INDEX_{}", name.to_uppercase());
620        field_indices.extend(quote! {
621            pub const $index_name : usize = $i;
622        });
623        let key_name_green = format!("{name}_green");
624        args.extend(quote! {$name.0,});
625        // TODO(spapini): Validate that children SyntaxKinds are as expected.
626
627        let child_green = format!("{kind}Green");
628        params.extend(quote! {$name: $(&child_green)<'db>,});
629        body.extend(quote! {
630            pub fn $name(&self, db: &'db dyn Database) -> $kind<'db> {
631                $kind::from_syntax_node(db, self.node.get_children(db)[$i])
632            }
633        });
634        args_for_missing.extend(quote! {$kind::missing(db).0,});
635
636        if *key {
637            ptr_getters.extend(quote! {
638                pub fn $(&key_name_green)(self, db: &'db dyn Database) -> $(&child_green)<'db> {
639                    $(&child_green)(self.0.0.key_fields(db)[$key_field_index])
640                }
641            });
642            key_field_index += 1;
643        }
644    }
645    let ptr_name = format!("{name}Ptr");
646    let new_green_impl = if is_terminal {
647        let token_name = name.replace("Terminal", "Token");
648        quote! {
649            impl<'db> Terminal<'db> for $(&name)<'db> {
650                const KIND: SyntaxKind = SyntaxKind::$(&name);
651                type TokenType = $(&token_name)<'db>;
652                fn new_green(
653                    db: &'db dyn Database,
654                    leading_trivia: TriviaGreen<'db>,
655                    token: <<$(&name)<'db> as Terminal<'db>>::TokenType as TypedSyntaxNode<'db>>::Green,
656                    trailing_trivia: TriviaGreen<'db>
657                ) -> Self::Green {
658                    let children = [$args];
659                    let width =
660                        children.into_iter().map(|id: GreenId<'_>| id.long(db).width(db)).sum();
661                    $(&green_name)(GreenNode {
662                        kind: SyntaxKind::$(&name),
663                        details: GreenNodeDetails::Node { children: children.into(), width },
664                    }.intern(db))
665                }
666                fn text(&self, db: &'db dyn Database) -> SmolStrId<'db> {
667                    let GreenNodeDetails::Node{children,..} = &self.node.green_node(db).details else {
668                        unreachable!("Expected a node, not a token");
669                    };
670                    *extract_matches!(&children[1].long(db).details, GreenNodeDetails::Token)
671                }
672            }
673        }
674    } else {
675        quote! {
676            impl<'db> $(&name)<'db> {
677                $field_indices
678                pub fn new_green(db: &'db dyn Database, $params) -> $(&green_name)<'db> {
679                    let children = [$args];
680                    let width =
681                        children.into_iter().map(|id: GreenId<'_>| id.long(db).width(db)).sum();
682                    $(&green_name)(GreenNode {
683                        kind: SyntaxKind::$(&name),
684                        details: GreenNodeDetails::Node { children: children.into(), width },
685                    }.intern(db))
686                }
687            }
688        }
689    };
690    quote! {
691        #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
692        pub struct $(&name)<'db> {
693            node: SyntaxNode<'db>,
694        }
695        $new_green_impl
696        impl<'db> $(&name)<'db> {
697            $body
698        }
699        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update, HeapSize)]
700        pub struct $(&ptr_name)<'db>(pub SyntaxStablePtrId<'db>);
701        impl<'db> $(&ptr_name)<'db> {
702            $ptr_getters
703        }
704        impl<'db> TypedStablePtr<'db> for $(&ptr_name)<'db> {
705            type SyntaxNode = $(&name)<'db>;
706            fn untyped(self) -> SyntaxStablePtrId<'db> {
707                self.0
708            }
709            fn lookup(&self, db: &'db dyn Database) -> $(&name)<'db> {
710                $(&name)::from_syntax_node(db, self.0.lookup(db))
711            }
712        }
713        impl<'db> From<$(&ptr_name)<'db>> for SyntaxStablePtrId<'db> {
714            fn from(ptr: $(&ptr_name)<'db>) -> Self {
715                ptr.untyped()
716            }
717        }
718        #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, salsa::Update)]
719        pub struct $(&green_name)<'db>(pub GreenId<'db>);
720        impl<'db> TypedSyntaxNode<'db> for $(&name)<'db> {
721            const OPTIONAL_KIND: Option<SyntaxKind> = Some(SyntaxKind::$(&name));
722            type StablePtr = $(&ptr_name)<'db>;
723            type Green = $(&green_name)<'db>;
724            fn missing(db: &'db dyn Database) -> Self::Green {
725                // Note: A missing syntax element should result in an internal green node
726                // of width 0, with as much structure as possible.
727                $(&green_name)(GreenNode {
728                    kind: SyntaxKind::$(&name),
729                    details: GreenNodeDetails::Node {
730                        children: [$args_for_missing].into(),
731                        width: TextWidth::default(),
732                    },
733                }.intern(db))
734            }
735            fn from_syntax_node(db: &'db dyn Database, node: SyntaxNode<'db>) -> Self {
736                let kind = node.kind(db);
737                assert_eq!(kind, SyntaxKind::$(&name), "Unexpected SyntaxKind {:?}. Expected {:?}.", kind, SyntaxKind::$(&name));
738                Self { node }
739            }
740            fn cast(db: &'db dyn Database, node: SyntaxNode<'db>) -> Option<Self> {
741                let kind = node.kind(db);
742                if kind == SyntaxKind::$(&name) {
743                    Some(Self::from_syntax_node(db, node))
744                } else {
745                    None
746                }
747            }
748            fn as_syntax_node(&self) -> SyntaxNode<'db> {
749                self.node
750            }
751            fn stable_ptr(&self, db: &'db dyn Database) -> Self::StablePtr {
752                $(&ptr_name)(self.node.stable_ptr(db))
753            }
754        }
755    }
756}