plotnik_macros/
lib.rs

1use arborium_tree_sitter as tree_sitter;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::quote;
5use syn::{LitStr, parse_macro_input};
6use tree_sitter::Language;
7
8use plotnik_core::NodeTypes;
9
10/// Generate a StaticNodeTypes constant for a language.
11///
12/// Usage: `generate_node_types!("javascript")`
13///
14/// This reads the node-types.json at compile time and uses the tree-sitter
15/// Language to resolve node/field names to IDs, producing efficient lookup tables.
16/// The output is fully statically allocated - no runtime initialization needed.
17#[proc_macro]
18pub fn generate_node_types(input: TokenStream) -> TokenStream {
19    let lang_key = parse_macro_input!(input as LitStr).value();
20
21    let env_var = format!("PLOTNIK_NODE_TYPES_{}", lang_key.to_uppercase());
22
23    let json_path = std::env::var(&env_var).unwrap_or_else(|_| {
24        panic!(
25            "Environment variable {} not set. Is build.rs configured correctly?",
26            env_var
27        )
28    });
29
30    let json_content = std::fs::read_to_string(&json_path)
31        .unwrap_or_else(|e| panic!("Failed to read {}: {}", json_path, e));
32
33    let raw_nodes: Vec<plotnik_core::RawNode> = serde_json::from_str(&json_content)
34        .unwrap_or_else(|e| panic!("Failed to parse {}: {}", json_path, e));
35
36    let ts_lang = get_language_for_key(&lang_key);
37
38    let const_name = syn::Ident::new(
39        &format!("{}_NODE_TYPES", lang_key.to_uppercase()),
40        Span::call_site(),
41    );
42
43    let generated = generate_static_node_types_code(&raw_nodes, &ts_lang, &lang_key, &const_name);
44
45    generated.into()
46}
47
48fn get_language_for_key(key: &str) -> Language {
49    match key.to_lowercase().as_str() {
50        #[cfg(feature = "lang-ada")]
51        "ada" => arborium_ada::language().into(),
52        #[cfg(feature = "lang-agda")]
53        "agda" => arborium_agda::language().into(),
54        #[cfg(feature = "lang-asciidoc")]
55        "asciidoc" => arborium_asciidoc::language().into(),
56        #[cfg(feature = "lang-asm")]
57        "asm" => arborium_asm::language().into(),
58        #[cfg(feature = "lang-awk")]
59        "awk" => arborium_awk::language().into(),
60        #[cfg(feature = "lang-bash")]
61        "bash" => arborium_bash::language().into(),
62        #[cfg(feature = "lang-batch")]
63        "batch" => arborium_batch::language().into(),
64        #[cfg(feature = "lang-c")]
65        "c" => arborium_c::language().into(),
66        #[cfg(feature = "lang-c-sharp")]
67        "c_sharp" => arborium_c_sharp::language().into(),
68        #[cfg(feature = "lang-caddy")]
69        "caddy" => arborium_caddy::language().into(),
70        #[cfg(feature = "lang-capnp")]
71        "capnp" => arborium_capnp::language().into(),
72        #[cfg(feature = "lang-clojure")]
73        "clojure" => arborium_clojure::language().into(),
74        #[cfg(feature = "lang-cmake")]
75        "cmake" => arborium_cmake::language().into(),
76        #[cfg(feature = "lang-commonlisp")]
77        "commonlisp" => arborium_commonlisp::language().into(),
78        #[cfg(feature = "lang-cpp")]
79        "cpp" => arborium_cpp::language().into(),
80        #[cfg(feature = "lang-css")]
81        "css" => arborium_css::language().into(),
82        #[cfg(feature = "lang-d")]
83        "d" => arborium_d::language().into(),
84        #[cfg(feature = "lang-dart")]
85        "dart" => arborium_dart::language().into(),
86        #[cfg(feature = "lang-devicetree")]
87        "devicetree" => arborium_devicetree::language().into(),
88        #[cfg(feature = "lang-diff")]
89        "diff" => arborium_diff::language().into(),
90        #[cfg(feature = "lang-dockerfile")]
91        "dockerfile" => arborium_dockerfile::language().into(),
92        #[cfg(feature = "lang-dot")]
93        "dot" => arborium_dot::language().into(),
94        #[cfg(feature = "lang-elisp")]
95        "elisp" => arborium_elisp::language().into(),
96        #[cfg(feature = "lang-elixir")]
97        "elixir" => arborium_elixir::language().into(),
98        #[cfg(feature = "lang-elm")]
99        "elm" => arborium_elm::language().into(),
100        #[cfg(feature = "lang-erlang")]
101        "erlang" => arborium_erlang::language().into(),
102        #[cfg(feature = "lang-fish")]
103        "fish" => arborium_fish::language().into(),
104        #[cfg(feature = "lang-fsharp")]
105        "fsharp" => arborium_fsharp::language().into(),
106        #[cfg(feature = "lang-gleam")]
107        "gleam" => arborium_gleam::language().into(),
108        #[cfg(feature = "lang-glsl")]
109        "glsl" => arborium_glsl::language().into(),
110        #[cfg(feature = "lang-go")]
111        "go" => arborium_go::language().into(),
112        #[cfg(feature = "lang-graphql")]
113        "graphql" => arborium_graphql::language().into(),
114        #[cfg(feature = "lang-groovy")]
115        "groovy" => arborium_groovy::language().into(),
116        #[cfg(feature = "lang-haskell")]
117        "haskell" => arborium_haskell::language().into(),
118        #[cfg(feature = "lang-hcl")]
119        "hcl" => arborium_hcl::language().into(),
120        #[cfg(feature = "lang-hlsl")]
121        "hlsl" => arborium_hlsl::language().into(),
122        #[cfg(feature = "lang-html")]
123        "html" => arborium_html::language().into(),
124        #[cfg(feature = "lang-idris")]
125        "idris" => arborium_idris::language().into(),
126        #[cfg(feature = "lang-ini")]
127        "ini" => arborium_ini::language().into(),
128        #[cfg(feature = "lang-java")]
129        "java" => arborium_java::language().into(),
130        #[cfg(feature = "lang-javascript")]
131        "javascript" => arborium_javascript::language().into(),
132        #[cfg(feature = "lang-jinja2")]
133        "jinja2" => arborium_jinja2::language().into(),
134        #[cfg(feature = "lang-jq")]
135        "jq" => arborium_jq::language().into(),
136        #[cfg(feature = "lang-json")]
137        "json" => arborium_json::language().into(),
138        #[cfg(feature = "lang-julia")]
139        "julia" => arborium_julia::language().into(),
140        #[cfg(feature = "lang-kdl")]
141        "kdl" => arborium_kdl::language().into(),
142        #[cfg(feature = "lang-kotlin")]
143        "kotlin" => arborium_kotlin::language().into(),
144        #[cfg(feature = "lang-lean")]
145        "lean" => arborium_lean::language().into(),
146        #[cfg(feature = "lang-lua")]
147        "lua" => arborium_lua::language().into(),
148        #[cfg(feature = "lang-markdown")]
149        "markdown" => arborium_markdown::language().into(),
150        #[cfg(feature = "lang-matlab")]
151        "matlab" => arborium_matlab::language().into(),
152        #[cfg(feature = "lang-meson")]
153        "meson" => arborium_meson::language().into(),
154        #[cfg(feature = "lang-nginx")]
155        "nginx" => arborium_nginx::language().into(),
156        #[cfg(feature = "lang-ninja")]
157        "ninja" => arborium_ninja::language().into(),
158        #[cfg(feature = "lang-nix")]
159        "nix" => arborium_nix::language().into(),
160        #[cfg(feature = "lang-objc")]
161        "objc" => arborium_objc::language().into(),
162        #[cfg(feature = "lang-ocaml")]
163        "ocaml" => arborium_ocaml::language().into(),
164        #[cfg(feature = "lang-perl")]
165        "perl" => arborium_perl::language().into(),
166        #[cfg(feature = "lang-php")]
167        "php" => arborium_php::language().into(),
168        #[cfg(feature = "lang-postscript")]
169        "postscript" => arborium_postscript::language().into(),
170        #[cfg(feature = "lang-powershell")]
171        "powershell" => arborium_powershell::language().into(),
172        #[cfg(feature = "lang-prolog")]
173        "prolog" => arborium_prolog::language().into(),
174        #[cfg(feature = "lang-python")]
175        "python" => arborium_python::language().into(),
176        #[cfg(feature = "lang-query")]
177        "query" => arborium_query::language().into(),
178        #[cfg(feature = "lang-r")]
179        "r" => arborium_r::language().into(),
180        #[cfg(feature = "lang-rescript")]
181        "rescript" => arborium_rescript::language().into(),
182        #[cfg(feature = "lang-ron")]
183        "ron" => arborium_ron::language().into(),
184        #[cfg(feature = "lang-ruby")]
185        "ruby" => arborium_ruby::language().into(),
186        #[cfg(feature = "lang-rust")]
187        "rust" => arborium_rust::language().into(),
188        #[cfg(feature = "lang-scala")]
189        "scala" => arborium_scala::language().into(),
190        #[cfg(feature = "lang-scheme")]
191        "scheme" => arborium_scheme::language().into(),
192        #[cfg(feature = "lang-scss")]
193        "scss" => arborium_scss::language().into(),
194        #[cfg(feature = "lang-sparql")]
195        "sparql" => arborium_sparql::language().into(),
196        #[cfg(feature = "lang-sql")]
197        "sql" => arborium_sql::language().into(),
198        #[cfg(feature = "lang-ssh-config")]
199        "ssh_config" => arborium_ssh_config::language().into(),
200        #[cfg(feature = "lang-starlark")]
201        "starlark" => arborium_starlark::language().into(),
202        #[cfg(feature = "lang-svelte")]
203        "svelte" => arborium_svelte::language().into(),
204        #[cfg(feature = "lang-swift")]
205        "swift" => arborium_swift::language().into(),
206        #[cfg(feature = "lang-textproto")]
207        "textproto" => arborium_textproto::language().into(),
208        #[cfg(feature = "lang-thrift")]
209        "thrift" => arborium_thrift::language().into(),
210        #[cfg(feature = "lang-tlaplus")]
211        "tlaplus" => arborium_tlaplus::language().into(),
212        #[cfg(feature = "lang-toml")]
213        "toml" => arborium_toml::language().into(),
214        #[cfg(feature = "lang-tsx")]
215        "tsx" => arborium_tsx::language().into(),
216        #[cfg(feature = "lang-typescript")]
217        "typescript" => arborium_typescript::language().into(),
218        #[cfg(feature = "lang-typst")]
219        "typst" => arborium_typst::language().into(),
220        #[cfg(feature = "lang-uiua")]
221        "uiua" => arborium_uiua::language().into(),
222        #[cfg(feature = "lang-vb")]
223        "vb" => arborium_vb::language().into(),
224        #[cfg(feature = "lang-verilog")]
225        "verilog" => arborium_verilog::language().into(),
226        #[cfg(feature = "lang-vhdl")]
227        "vhdl" => arborium_vhdl::language().into(),
228        #[cfg(feature = "lang-vim")]
229        "vim" => arborium_vim::language().into(),
230        #[cfg(feature = "lang-vue")]
231        "vue" => arborium_vue::language().into(),
232        #[cfg(feature = "lang-wit")]
233        "wit" => arborium_wit::language().into(),
234        #[cfg(feature = "lang-x86asm")]
235        "x86asm" => arborium_x86asm::language().into(),
236        #[cfg(feature = "lang-xml")]
237        "xml" => arborium_xml::language().into(),
238        #[cfg(feature = "lang-yaml")]
239        "yaml" => arborium_yaml::language().into(),
240        #[cfg(feature = "lang-yuri")]
241        "yuri" => arborium_yuri::language().into(),
242        #[cfg(feature = "lang-zig")]
243        "zig" => arborium_zig::language().into(),
244        #[cfg(feature = "lang-zsh")]
245        "zsh" => arborium_zsh::language().into(),
246        _ => panic!("Unknown or disabled language key: {}", key),
247    }
248}
249
250struct FieldCodeGen {
251    array_defs: Vec<proc_macro2::TokenStream>,
252    entries: Vec<proc_macro2::TokenStream>,
253}
254
255fn generate_field_code(
256    prefix: &str,
257    node_id: std::num::NonZeroU16,
258    field_id: &std::num::NonZeroU16,
259    field_info: &plotnik_core::FieldInfo,
260) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
261    let valid_types_raw: Vec<u16> = field_info.valid_types.iter().map(|id| id.get()).collect();
262    let valid_types_name = syn::Ident::new(
263        &format!("{}_N{}_F{}_TYPES", prefix, node_id.get(), field_id),
264        Span::call_site(),
265    );
266
267    let multiple = field_info.cardinality.multiple;
268    let required = field_info.cardinality.required;
269    let types_len = valid_types_raw.len();
270
271    let array_def = quote! {
272        static #valid_types_name: [std::num::NonZeroU16; #types_len] = [
273            #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),*
274        ];
275    };
276
277    let field_id_raw = field_id.get();
278    let entry = quote! {
279        (std::num::NonZeroU16::new(#field_id_raw).unwrap(), plotnik_core::StaticFieldInfo {
280            cardinality: plotnik_core::Cardinality {
281                multiple: #multiple,
282                required: #required,
283            },
284            valid_types: &#valid_types_name,
285        })
286    };
287
288    (array_def, entry)
289}
290
291fn generate_fields_for_node(
292    prefix: &str,
293    node_id: std::num::NonZeroU16,
294    fields: &std::collections::HashMap<std::num::NonZeroU16, plotnik_core::FieldInfo>,
295) -> FieldCodeGen {
296    let mut sorted_fields: Vec<_> = fields.iter().collect();
297    sorted_fields.sort_by_key(|(fid, _)| *fid);
298
299    let mut array_defs = Vec::new();
300    let mut entries = Vec::new();
301
302    for (field_id, field_info) in sorted_fields {
303        let (array_def, entry) = generate_field_code(prefix, node_id, field_id, field_info);
304        array_defs.push(array_def);
305        entries.push(entry);
306    }
307
308    FieldCodeGen {
309        array_defs,
310        entries,
311    }
312}
313
314fn generate_children_code(
315    prefix: &str,
316    node_id: std::num::NonZeroU16,
317    children: &plotnik_core::ChildrenInfo,
318    static_defs: &mut Vec<proc_macro2::TokenStream>,
319) -> proc_macro2::TokenStream {
320    let valid_types_raw: Vec<u16> = children.valid_types.iter().map(|id| id.get()).collect();
321    let children_types_name = syn::Ident::new(
322        &format!("{}_N{}_CHILDREN_TYPES", prefix, node_id.get()),
323        Span::call_site(),
324    );
325    let types_len = valid_types_raw.len();
326
327    static_defs.push(quote! {
328        static #children_types_name: [std::num::NonZeroU16; #types_len] = [
329            #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),*
330        ];
331    });
332
333    let multiple = children.cardinality.multiple;
334    let required = children.cardinality.required;
335
336    quote! {
337        Some(plotnik_core::StaticChildrenInfo {
338            cardinality: plotnik_core::Cardinality {
339                multiple: #multiple,
340                required: #required,
341            },
342            valid_types: &#children_types_name,
343        })
344    }
345}
346
347fn generate_static_node_types_code(
348    raw_nodes: &[plotnik_core::RawNode],
349    ts_lang: &Language,
350    lang_key: &str,
351    const_name: &syn::Ident,
352) -> proc_macro2::TokenStream {
353    let node_types = plotnik_core::DynamicNodeTypes::build(
354        raw_nodes,
355        |name, named| {
356            let id = ts_lang.id_for_node_kind(name, named);
357            std::num::NonZeroU16::new(id)
358        },
359        |name| ts_lang.field_id_for_name(name),
360    );
361
362    let prefix = lang_key.to_uppercase();
363    let mut static_defs = Vec::new();
364    let mut node_entries = Vec::new();
365
366    let extras_raw: Vec<u16> = node_types
367        .sorted_extras()
368        .iter()
369        .map(|id| id.get())
370        .collect();
371    let root = node_types.root();
372    let sorted_node_ids = node_types.sorted_node_ids();
373
374    for &node_id in &sorted_node_ids {
375        let info = node_types.get(node_id).unwrap();
376
377        let node_id_raw = node_id.get();
378        let field_gen = generate_fields_for_node(&prefix, node_id, &info.fields);
379        static_defs.extend(field_gen.array_defs);
380
381        let fields_ref = if field_gen.entries.is_empty() {
382            quote! { &[] }
383        } else {
384            let fields_array_name = syn::Ident::new(
385                &format!("{}_N{}_FIELDS", prefix, node_id_raw),
386                Span::call_site(),
387            );
388            let fields_len = field_gen.entries.len();
389            let field_entries = &field_gen.entries;
390
391            static_defs.push(quote! {
392                static #fields_array_name: [(std::num::NonZeroU16, plotnik_core::StaticFieldInfo); #fields_len] = [
393                    #(#field_entries),*
394                ];
395            });
396
397            quote! { &#fields_array_name }
398        };
399
400        let children_code = match &info.children {
401            Some(children) => generate_children_code(&prefix, node_id, children, &mut static_defs),
402            None => quote! { None },
403        };
404
405        let name = &info.name;
406        let named = info.named;
407
408        node_entries.push(quote! {
409            (std::num::NonZeroU16::new(#node_id_raw).unwrap(), plotnik_core::StaticNodeTypeInfo {
410                name: #name,
411                named: #named,
412                fields: #fields_ref,
413                children: #children_code,
414            })
415        });
416    }
417
418    let nodes_array_name = syn::Ident::new(&format!("{}_NODES", prefix), Span::call_site());
419    let nodes_len = sorted_node_ids.len();
420
421    let extras_array_name = syn::Ident::new(&format!("{}_EXTRAS", prefix), Span::call_site());
422    let extras_len = extras_raw.len();
423
424    let root_code = match root {
425        Some(id) => {
426            let id_raw = id.get();
427            quote! { Some(std::num::NonZeroU16::new(#id_raw).unwrap()) }
428        }
429        None => quote! { None },
430    };
431
432    quote! {
433        #(#static_defs)*
434
435        static #nodes_array_name: [(std::num::NonZeroU16, plotnik_core::StaticNodeTypeInfo); #nodes_len] = [
436            #(#node_entries),*
437        ];
438
439        static #extras_array_name: [std::num::NonZeroU16; #extras_len] = [
440            #(std::num::NonZeroU16::new(#extras_raw).unwrap()),*
441        ];
442
443        pub static #const_name: plotnik_core::StaticNodeTypes = plotnik_core::StaticNodeTypes::new(
444            &#nodes_array_name,
445            &#extras_array_name,
446            #root_code,
447        );
448    }
449}