1use arborium_tree_sitter as tree_sitter;
2use proc_macro::TokenStream;
3use proc_macro2::Span;
4use quote::quote;
5use syn::{LitStr, parse_macro_input};
6use tree_sitter::Language;
7
8use plotnik_core::NodeTypes;
9
10#[proc_macro]
18pub fn generate_node_types(input: TokenStream) -> TokenStream {
19 let lang_key = parse_macro_input!(input as LitStr).value();
20
21 let env_var = format!("PLOTNIK_NODE_TYPES_{}", lang_key.to_uppercase());
22
23 let json_path = std::env::var(&env_var).unwrap_or_else(|_| {
24 panic!(
25 "Environment variable {} not set. Is build.rs configured correctly?",
26 env_var
27 )
28 });
29
30 let json_content = std::fs::read_to_string(&json_path)
31 .unwrap_or_else(|e| panic!("Failed to read {}: {}", json_path, e));
32
33 let raw_nodes: Vec<plotnik_core::RawNode> = serde_json::from_str(&json_content)
34 .unwrap_or_else(|e| panic!("Failed to parse {}: {}", json_path, e));
35
36 let ts_lang = get_language_for_key(&lang_key);
37
38 let const_name = syn::Ident::new(
39 &format!("{}_NODE_TYPES", lang_key.to_uppercase()),
40 Span::call_site(),
41 );
42
43 let generated = generate_static_node_types_code(&raw_nodes, &ts_lang, &lang_key, &const_name);
44
45 generated.into()
46}
47
48fn get_language_for_key(key: &str) -> Language {
49 match key.to_lowercase().as_str() {
50 #[cfg(feature = "lang-ada")]
51 "ada" => arborium_ada::language().into(),
52 #[cfg(feature = "lang-agda")]
53 "agda" => arborium_agda::language().into(),
54 #[cfg(feature = "lang-asciidoc")]
55 "asciidoc" => arborium_asciidoc::language().into(),
56 #[cfg(feature = "lang-asm")]
57 "asm" => arborium_asm::language().into(),
58 #[cfg(feature = "lang-awk")]
59 "awk" => arborium_awk::language().into(),
60 #[cfg(feature = "lang-bash")]
61 "bash" => arborium_bash::language().into(),
62 #[cfg(feature = "lang-batch")]
63 "batch" => arborium_batch::language().into(),
64 #[cfg(feature = "lang-c")]
65 "c" => arborium_c::language().into(),
66 #[cfg(feature = "lang-c-sharp")]
67 "c_sharp" => arborium_c_sharp::language().into(),
68 #[cfg(feature = "lang-caddy")]
69 "caddy" => arborium_caddy::language().into(),
70 #[cfg(feature = "lang-capnp")]
71 "capnp" => arborium_capnp::language().into(),
72 #[cfg(feature = "lang-clojure")]
73 "clojure" => arborium_clojure::language().into(),
74 #[cfg(feature = "lang-cmake")]
75 "cmake" => arborium_cmake::language().into(),
76 #[cfg(feature = "lang-commonlisp")]
77 "commonlisp" => arborium_commonlisp::language().into(),
78 #[cfg(feature = "lang-cpp")]
79 "cpp" => arborium_cpp::language().into(),
80 #[cfg(feature = "lang-css")]
81 "css" => arborium_css::language().into(),
82 #[cfg(feature = "lang-d")]
83 "d" => arborium_d::language().into(),
84 #[cfg(feature = "lang-dart")]
85 "dart" => arborium_dart::language().into(),
86 #[cfg(feature = "lang-devicetree")]
87 "devicetree" => arborium_devicetree::language().into(),
88 #[cfg(feature = "lang-diff")]
89 "diff" => arborium_diff::language().into(),
90 #[cfg(feature = "lang-dockerfile")]
91 "dockerfile" => arborium_dockerfile::language().into(),
92 #[cfg(feature = "lang-dot")]
93 "dot" => arborium_dot::language().into(),
94 #[cfg(feature = "lang-elisp")]
95 "elisp" => arborium_elisp::language().into(),
96 #[cfg(feature = "lang-elixir")]
97 "elixir" => arborium_elixir::language().into(),
98 #[cfg(feature = "lang-elm")]
99 "elm" => arborium_elm::language().into(),
100 #[cfg(feature = "lang-erlang")]
101 "erlang" => arborium_erlang::language().into(),
102 #[cfg(feature = "lang-fish")]
103 "fish" => arborium_fish::language().into(),
104 #[cfg(feature = "lang-fsharp")]
105 "fsharp" => arborium_fsharp::language().into(),
106 #[cfg(feature = "lang-gleam")]
107 "gleam" => arborium_gleam::language().into(),
108 #[cfg(feature = "lang-glsl")]
109 "glsl" => arborium_glsl::language().into(),
110 #[cfg(feature = "lang-go")]
111 "go" => arborium_go::language().into(),
112 #[cfg(feature = "lang-graphql")]
113 "graphql" => arborium_graphql::language().into(),
114 #[cfg(feature = "lang-groovy")]
115 "groovy" => arborium_groovy::language().into(),
116 #[cfg(feature = "lang-haskell")]
117 "haskell" => arborium_haskell::language().into(),
118 #[cfg(feature = "lang-hcl")]
119 "hcl" => arborium_hcl::language().into(),
120 #[cfg(feature = "lang-hlsl")]
121 "hlsl" => arborium_hlsl::language().into(),
122 #[cfg(feature = "lang-html")]
123 "html" => arborium_html::language().into(),
124 #[cfg(feature = "lang-idris")]
125 "idris" => arborium_idris::language().into(),
126 #[cfg(feature = "lang-ini")]
127 "ini" => arborium_ini::language().into(),
128 #[cfg(feature = "lang-java")]
129 "java" => arborium_java::language().into(),
130 #[cfg(feature = "lang-javascript")]
131 "javascript" => arborium_javascript::language().into(),
132 #[cfg(feature = "lang-jinja2")]
133 "jinja2" => arborium_jinja2::language().into(),
134 #[cfg(feature = "lang-jq")]
135 "jq" => arborium_jq::language().into(),
136 #[cfg(feature = "lang-json")]
137 "json" => arborium_json::language().into(),
138 #[cfg(feature = "lang-julia")]
139 "julia" => arborium_julia::language().into(),
140 #[cfg(feature = "lang-kdl")]
141 "kdl" => arborium_kdl::language().into(),
142 #[cfg(feature = "lang-kotlin")]
143 "kotlin" => arborium_kotlin::language().into(),
144 #[cfg(feature = "lang-lean")]
145 "lean" => arborium_lean::language().into(),
146 #[cfg(feature = "lang-lua")]
147 "lua" => arborium_lua::language().into(),
148 #[cfg(feature = "lang-markdown")]
149 "markdown" => arborium_markdown::language().into(),
150 #[cfg(feature = "lang-matlab")]
151 "matlab" => arborium_matlab::language().into(),
152 #[cfg(feature = "lang-meson")]
153 "meson" => arborium_meson::language().into(),
154 #[cfg(feature = "lang-nginx")]
155 "nginx" => arborium_nginx::language().into(),
156 #[cfg(feature = "lang-ninja")]
157 "ninja" => arborium_ninja::language().into(),
158 #[cfg(feature = "lang-nix")]
159 "nix" => arborium_nix::language().into(),
160 #[cfg(feature = "lang-objc")]
161 "objc" => arborium_objc::language().into(),
162 #[cfg(feature = "lang-ocaml")]
163 "ocaml" => arborium_ocaml::language().into(),
164 #[cfg(feature = "lang-perl")]
165 "perl" => arborium_perl::language().into(),
166 #[cfg(feature = "lang-php")]
167 "php" => arborium_php::language().into(),
168 #[cfg(feature = "lang-postscript")]
169 "postscript" => arborium_postscript::language().into(),
170 #[cfg(feature = "lang-powershell")]
171 "powershell" => arborium_powershell::language().into(),
172 #[cfg(feature = "lang-prolog")]
173 "prolog" => arborium_prolog::language().into(),
174 #[cfg(feature = "lang-python")]
175 "python" => arborium_python::language().into(),
176 #[cfg(feature = "lang-query")]
177 "query" => arborium_query::language().into(),
178 #[cfg(feature = "lang-r")]
179 "r" => arborium_r::language().into(),
180 #[cfg(feature = "lang-rescript")]
181 "rescript" => arborium_rescript::language().into(),
182 #[cfg(feature = "lang-ron")]
183 "ron" => arborium_ron::language().into(),
184 #[cfg(feature = "lang-ruby")]
185 "ruby" => arborium_ruby::language().into(),
186 #[cfg(feature = "lang-rust")]
187 "rust" => arborium_rust::language().into(),
188 #[cfg(feature = "lang-scala")]
189 "scala" => arborium_scala::language().into(),
190 #[cfg(feature = "lang-scheme")]
191 "scheme" => arborium_scheme::language().into(),
192 #[cfg(feature = "lang-scss")]
193 "scss" => arborium_scss::language().into(),
194 #[cfg(feature = "lang-sparql")]
195 "sparql" => arborium_sparql::language().into(),
196 #[cfg(feature = "lang-sql")]
197 "sql" => arborium_sql::language().into(),
198 #[cfg(feature = "lang-ssh-config")]
199 "ssh_config" => arborium_ssh_config::language().into(),
200 #[cfg(feature = "lang-starlark")]
201 "starlark" => arborium_starlark::language().into(),
202 #[cfg(feature = "lang-svelte")]
203 "svelte" => arborium_svelte::language().into(),
204 #[cfg(feature = "lang-swift")]
205 "swift" => arborium_swift::language().into(),
206 #[cfg(feature = "lang-textproto")]
207 "textproto" => arborium_textproto::language().into(),
208 #[cfg(feature = "lang-thrift")]
209 "thrift" => arborium_thrift::language().into(),
210 #[cfg(feature = "lang-tlaplus")]
211 "tlaplus" => arborium_tlaplus::language().into(),
212 #[cfg(feature = "lang-toml")]
213 "toml" => arborium_toml::language().into(),
214 #[cfg(feature = "lang-tsx")]
215 "tsx" => arborium_tsx::language().into(),
216 #[cfg(feature = "lang-typescript")]
217 "typescript" => arborium_typescript::language().into(),
218 #[cfg(feature = "lang-typst")]
219 "typst" => arborium_typst::language().into(),
220 #[cfg(feature = "lang-uiua")]
221 "uiua" => arborium_uiua::language().into(),
222 #[cfg(feature = "lang-vb")]
223 "vb" => arborium_vb::language().into(),
224 #[cfg(feature = "lang-verilog")]
225 "verilog" => arborium_verilog::language().into(),
226 #[cfg(feature = "lang-vhdl")]
227 "vhdl" => arborium_vhdl::language().into(),
228 #[cfg(feature = "lang-vim")]
229 "vim" => arborium_vim::language().into(),
230 #[cfg(feature = "lang-vue")]
231 "vue" => arborium_vue::language().into(),
232 #[cfg(feature = "lang-wit")]
233 "wit" => arborium_wit::language().into(),
234 #[cfg(feature = "lang-x86asm")]
235 "x86asm" => arborium_x86asm::language().into(),
236 #[cfg(feature = "lang-xml")]
237 "xml" => arborium_xml::language().into(),
238 #[cfg(feature = "lang-yaml")]
239 "yaml" => arborium_yaml::language().into(),
240 #[cfg(feature = "lang-yuri")]
241 "yuri" => arborium_yuri::language().into(),
242 #[cfg(feature = "lang-zig")]
243 "zig" => arborium_zig::language().into(),
244 #[cfg(feature = "lang-zsh")]
245 "zsh" => arborium_zsh::language().into(),
246 _ => panic!("Unknown or disabled language key: {}", key),
247 }
248}
249
250struct FieldCodeGen {
251 array_defs: Vec<proc_macro2::TokenStream>,
252 entries: Vec<proc_macro2::TokenStream>,
253}
254
255fn generate_field_code(
256 prefix: &str,
257 node_id: std::num::NonZeroU16,
258 field_id: &std::num::NonZeroU16,
259 field_info: &plotnik_core::FieldInfo,
260) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
261 let valid_types_raw: Vec<u16> = field_info.valid_types.iter().map(|id| id.get()).collect();
262 let valid_types_name = syn::Ident::new(
263 &format!("{}_N{}_F{}_TYPES", prefix, node_id.get(), field_id),
264 Span::call_site(),
265 );
266
267 let multiple = field_info.cardinality.multiple;
268 let required = field_info.cardinality.required;
269 let types_len = valid_types_raw.len();
270
271 let array_def = quote! {
272 static #valid_types_name: [std::num::NonZeroU16; #types_len] = [
273 #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),*
274 ];
275 };
276
277 let field_id_raw = field_id.get();
278 let entry = quote! {
279 (std::num::NonZeroU16::new(#field_id_raw).unwrap(), plotnik_core::StaticFieldInfo {
280 cardinality: plotnik_core::Cardinality {
281 multiple: #multiple,
282 required: #required,
283 },
284 valid_types: &#valid_types_name,
285 })
286 };
287
288 (array_def, entry)
289}
290
291fn generate_fields_for_node(
292 prefix: &str,
293 node_id: std::num::NonZeroU16,
294 fields: &std::collections::HashMap<std::num::NonZeroU16, plotnik_core::FieldInfo>,
295) -> FieldCodeGen {
296 let mut sorted_fields: Vec<_> = fields.iter().collect();
297 sorted_fields.sort_by_key(|(fid, _)| *fid);
298
299 let mut array_defs = Vec::new();
300 let mut entries = Vec::new();
301
302 for (field_id, field_info) in sorted_fields {
303 let (array_def, entry) = generate_field_code(prefix, node_id, field_id, field_info);
304 array_defs.push(array_def);
305 entries.push(entry);
306 }
307
308 FieldCodeGen {
309 array_defs,
310 entries,
311 }
312}
313
314fn generate_children_code(
315 prefix: &str,
316 node_id: std::num::NonZeroU16,
317 children: &plotnik_core::ChildrenInfo,
318 static_defs: &mut Vec<proc_macro2::TokenStream>,
319) -> proc_macro2::TokenStream {
320 let valid_types_raw: Vec<u16> = children.valid_types.iter().map(|id| id.get()).collect();
321 let children_types_name = syn::Ident::new(
322 &format!("{}_N{}_CHILDREN_TYPES", prefix, node_id.get()),
323 Span::call_site(),
324 );
325 let types_len = valid_types_raw.len();
326
327 static_defs.push(quote! {
328 static #children_types_name: [std::num::NonZeroU16; #types_len] = [
329 #(std::num::NonZeroU16::new(#valid_types_raw).unwrap()),*
330 ];
331 });
332
333 let multiple = children.cardinality.multiple;
334 let required = children.cardinality.required;
335
336 quote! {
337 Some(plotnik_core::StaticChildrenInfo {
338 cardinality: plotnik_core::Cardinality {
339 multiple: #multiple,
340 required: #required,
341 },
342 valid_types: &#children_types_name,
343 })
344 }
345}
346
347fn generate_static_node_types_code(
348 raw_nodes: &[plotnik_core::RawNode],
349 ts_lang: &Language,
350 lang_key: &str,
351 const_name: &syn::Ident,
352) -> proc_macro2::TokenStream {
353 let node_types = plotnik_core::DynamicNodeTypes::build(
354 raw_nodes,
355 |name, named| {
356 let id = ts_lang.id_for_node_kind(name, named);
357 std::num::NonZeroU16::new(id)
358 },
359 |name| ts_lang.field_id_for_name(name),
360 );
361
362 let prefix = lang_key.to_uppercase();
363 let mut static_defs = Vec::new();
364 let mut node_entries = Vec::new();
365
366 let extras_raw: Vec<u16> = node_types
367 .sorted_extras()
368 .iter()
369 .map(|id| id.get())
370 .collect();
371 let root = node_types.root();
372 let sorted_node_ids = node_types.sorted_node_ids();
373
374 for &node_id in &sorted_node_ids {
375 let info = node_types.get(node_id).unwrap();
376
377 let node_id_raw = node_id.get();
378 let field_gen = generate_fields_for_node(&prefix, node_id, &info.fields);
379 static_defs.extend(field_gen.array_defs);
380
381 let fields_ref = if field_gen.entries.is_empty() {
382 quote! { &[] }
383 } else {
384 let fields_array_name = syn::Ident::new(
385 &format!("{}_N{}_FIELDS", prefix, node_id_raw),
386 Span::call_site(),
387 );
388 let fields_len = field_gen.entries.len();
389 let field_entries = &field_gen.entries;
390
391 static_defs.push(quote! {
392 static #fields_array_name: [(std::num::NonZeroU16, plotnik_core::StaticFieldInfo); #fields_len] = [
393 #(#field_entries),*
394 ];
395 });
396
397 quote! { &#fields_array_name }
398 };
399
400 let children_code = match &info.children {
401 Some(children) => generate_children_code(&prefix, node_id, children, &mut static_defs),
402 None => quote! { None },
403 };
404
405 let name = &info.name;
406 let named = info.named;
407
408 node_entries.push(quote! {
409 (std::num::NonZeroU16::new(#node_id_raw).unwrap(), plotnik_core::StaticNodeTypeInfo {
410 name: #name,
411 named: #named,
412 fields: #fields_ref,
413 children: #children_code,
414 })
415 });
416 }
417
418 let nodes_array_name = syn::Ident::new(&format!("{}_NODES", prefix), Span::call_site());
419 let nodes_len = sorted_node_ids.len();
420
421 let extras_array_name = syn::Ident::new(&format!("{}_EXTRAS", prefix), Span::call_site());
422 let extras_len = extras_raw.len();
423
424 let root_code = match root {
425 Some(id) => {
426 let id_raw = id.get();
427 quote! { Some(std::num::NonZeroU16::new(#id_raw).unwrap()) }
428 }
429 None => quote! { None },
430 };
431
432 quote! {
433 #(#static_defs)*
434
435 static #nodes_array_name: [(std::num::NonZeroU16, plotnik_core::StaticNodeTypeInfo); #nodes_len] = [
436 #(#node_entries),*
437 ];
438
439 static #extras_array_name: [std::num::NonZeroU16; #extras_len] = [
440 #(std::num::NonZeroU16::new(#extras_raw).unwrap()),*
441 ];
442
443 pub static #const_name: plotnik_core::StaticNodeTypes = plotnik_core::StaticNodeTypes::new(
444 &#nodes_array_name,
445 &#extras_array_name,
446 #root_code,
447 );
448 }
449}