plotnik_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{LitStr, parse_macro_input};
5use tree_sitter::Language;
6
7use plotnik_core::NodeTypes;
8
9/// Generate a StaticNodeTypes constant for a language.
10///
11/// Usage: `generate_node_types!("javascript")`
12///
13/// This reads the node-types.json at compile time and uses the tree-sitter
14/// Language to resolve node/field names to IDs, producing efficient lookup tables.
15/// The output is fully statically allocated - no runtime initialization needed.
16#[proc_macro]
17pub fn generate_node_types(input: TokenStream) -> TokenStream {
18    let lang_key = parse_macro_input!(input as LitStr).value();
19
20    let env_var = format!("PLOTNIK_NODE_TYPES_{}", lang_key.to_uppercase());
21
22    let json_path = std::env::var(&env_var).unwrap_or_else(|_| {
23        panic!(
24            "Environment variable {} not set. Is build.rs configured correctly?",
25            env_var
26        )
27    });
28
29    let json_content = std::fs::read_to_string(&json_path)
30        .unwrap_or_else(|e| panic!("Failed to read {}: {}", json_path, e));
31
32    let raw_nodes: Vec<plotnik_core::RawNode> = serde_json::from_str(&json_content)
33        .unwrap_or_else(|e| panic!("Failed to parse {}: {}", json_path, e));
34
35    let ts_lang = get_language_for_key(&lang_key);
36
37    let const_name = syn::Ident::new(
38        &format!("{}_NODE_TYPES", lang_key.to_uppercase()),
39        Span::call_site(),
40    );
41
42    let generated = generate_static_node_types_code(&raw_nodes, &ts_lang, &lang_key, &const_name);
43
44    generated.into()
45}
46
47fn get_language_for_key(key: &str) -> Language {
48    match key.to_lowercase().as_str() {
49        #[cfg(feature = "bash")]
50        "bash" => tree_sitter_bash::LANGUAGE.into(),
51        #[cfg(feature = "c")]
52        "c" => tree_sitter_c::LANGUAGE.into(),
53        #[cfg(feature = "cpp")]
54        "cpp" => tree_sitter_cpp::LANGUAGE.into(),
55        #[cfg(feature = "csharp")]
56        "csharp" => tree_sitter_c_sharp::LANGUAGE.into(),
57        #[cfg(feature = "css")]
58        "css" => tree_sitter_css::LANGUAGE.into(),
59        #[cfg(feature = "elixir")]
60        "elixir" => tree_sitter_elixir::LANGUAGE.into(),
61        #[cfg(feature = "go")]
62        "go" => tree_sitter_go::LANGUAGE.into(),
63        #[cfg(feature = "haskell")]
64        "haskell" => tree_sitter_haskell::LANGUAGE.into(),
65        #[cfg(feature = "hcl")]
66        "hcl" => tree_sitter_hcl::LANGUAGE.into(),
67        #[cfg(feature = "html")]
68        "html" => tree_sitter_html::LANGUAGE.into(),
69        #[cfg(feature = "java")]
70        "java" => tree_sitter_java::LANGUAGE.into(),
71        #[cfg(feature = "javascript")]
72        "javascript" => tree_sitter_javascript::LANGUAGE.into(),
73        #[cfg(feature = "json")]
74        "json" => tree_sitter_json::LANGUAGE.into(),
75        #[cfg(feature = "kotlin")]
76        "kotlin" => tree_sitter_kotlin::LANGUAGE.into(),
77        #[cfg(feature = "lua")]
78        "lua" => tree_sitter_lua::LANGUAGE.into(),
79        #[cfg(feature = "nix")]
80        "nix" => tree_sitter_nix::LANGUAGE.into(),
81        #[cfg(feature = "php")]
82        "php" => tree_sitter_php::LANGUAGE_PHP.into(),
83        #[cfg(feature = "python")]
84        "python" => tree_sitter_python::LANGUAGE.into(),
85        #[cfg(feature = "ruby")]
86        "ruby" => tree_sitter_ruby::LANGUAGE.into(),
87        #[cfg(feature = "rust")]
88        "rust" => tree_sitter_rust::LANGUAGE.into(),
89        #[cfg(feature = "scala")]
90        "scala" => tree_sitter_scala::LANGUAGE.into(),
91        #[cfg(feature = "solidity")]
92        "solidity" => tree_sitter_solidity::LANGUAGE.into(),
93        #[cfg(feature = "swift")]
94        "swift" => tree_sitter_swift::LANGUAGE.into(),
95        #[cfg(feature = "typescript")]
96        "typescript" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
97        #[cfg(feature = "typescript")]
98        "typescript_tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
99        #[cfg(feature = "yaml")]
100        "yaml" => tree_sitter_yaml::LANGUAGE.into(),
101        _ => panic!("Unknown or disabled language key: {}", key),
102    }
103}
104
105struct FieldCodeGen {
106    array_defs: Vec<proc_macro2::TokenStream>,
107    entries: Vec<proc_macro2::TokenStream>,
108}
109
110fn generate_field_code(
111    prefix: &str,
112    node_id: u16,
113    field_id: &std::num::NonZeroU16,
114    field_info: &plotnik_core::FieldInfo,
115) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
116    let valid_types = field_info.valid_types.to_vec();
117    let valid_types_name = syn::Ident::new(
118        &format!("{}_N{}_F{}_TYPES", prefix, node_id, field_id),
119        Span::call_site(),
120    );
121
122    let multiple = field_info.cardinality.multiple;
123    let required = field_info.cardinality.required;
124    let types_len = valid_types.len();
125
126    let array_def = quote! {
127        static #valid_types_name: [u16; #types_len] = [#(#valid_types),*];
128    };
129
130    let field_id_raw = field_id.get();
131    let entry = quote! {
132        (std::num::NonZeroU16::new(#field_id_raw).unwrap(), plotnik_core::StaticFieldInfo {
133            cardinality: plotnik_core::Cardinality {
134                multiple: #multiple,
135                required: #required,
136            },
137            valid_types: &#valid_types_name,
138        })
139    };
140
141    (array_def, entry)
142}
143
144fn generate_fields_for_node(
145    prefix: &str,
146    node_id: u16,
147    fields: &std::collections::HashMap<std::num::NonZeroU16, plotnik_core::FieldInfo>,
148) -> FieldCodeGen {
149    let mut sorted_fields: Vec<_> = fields.iter().collect();
150    sorted_fields.sort_by_key(|(fid, _)| *fid);
151
152    let mut array_defs = Vec::new();
153    let mut entries = Vec::new();
154
155    for (field_id, field_info) in sorted_fields {
156        let (array_def, entry) = generate_field_code(prefix, node_id, field_id, field_info);
157        array_defs.push(array_def);
158        entries.push(entry);
159    }
160
161    FieldCodeGen {
162        array_defs,
163        entries,
164    }
165}
166
167fn generate_children_code(
168    prefix: &str,
169    node_id: u16,
170    children: &plotnik_core::ChildrenInfo,
171    static_defs: &mut Vec<proc_macro2::TokenStream>,
172) -> proc_macro2::TokenStream {
173    let valid_types = children.valid_types.to_vec();
174    let children_types_name = syn::Ident::new(
175        &format!("{}_N{}_CHILDREN_TYPES", prefix, node_id),
176        Span::call_site(),
177    );
178    let types_len = valid_types.len();
179
180    static_defs.push(quote! {
181        static #children_types_name: [u16; #types_len] = [#(#valid_types),*];
182    });
183
184    let multiple = children.cardinality.multiple;
185    let required = children.cardinality.required;
186
187    quote! {
188        Some(plotnik_core::StaticChildrenInfo {
189            cardinality: plotnik_core::Cardinality {
190                multiple: #multiple,
191                required: #required,
192            },
193            valid_types: &#children_types_name,
194        })
195    }
196}
197
198fn generate_static_node_types_code(
199    raw_nodes: &[plotnik_core::RawNode],
200    ts_lang: &Language,
201    lang_key: &str,
202    const_name: &syn::Ident,
203) -> proc_macro2::TokenStream {
204    let node_types = plotnik_core::DynamicNodeTypes::build(
205        raw_nodes,
206        |name, named| {
207            let id = ts_lang.id_for_node_kind(name, named);
208            if id == 0 && named { None } else { Some(id) }
209        },
210        |name| ts_lang.field_id_for_name(name),
211    );
212
213    let prefix = lang_key.to_uppercase();
214    let mut static_defs = Vec::new();
215    let mut node_entries = Vec::new();
216
217    let extras = node_types.sorted_extras();
218    let root = node_types.root();
219    let sorted_node_ids = node_types.sorted_node_ids();
220
221    for &node_id in &sorted_node_ids {
222        let info = node_types.get(node_id).unwrap();
223
224        let field_gen = generate_fields_for_node(&prefix, node_id, &info.fields);
225        static_defs.extend(field_gen.array_defs);
226
227        let fields_ref = if field_gen.entries.is_empty() {
228            quote! { &[] }
229        } else {
230            let fields_array_name = syn::Ident::new(
231                &format!("{}_N{}_FIELDS", prefix, node_id),
232                Span::call_site(),
233            );
234            let fields_len = field_gen.entries.len();
235            let field_entries = &field_gen.entries;
236
237            static_defs.push(quote! {
238                static #fields_array_name: [(std::num::NonZeroU16, plotnik_core::StaticFieldInfo); #fields_len] = [
239                    #(#field_entries),*
240                ];
241            });
242
243            quote! { &#fields_array_name }
244        };
245
246        let children_code = match &info.children {
247            Some(children) => generate_children_code(&prefix, node_id, children, &mut static_defs),
248            None => quote! { None },
249        };
250
251        let name = &info.name;
252        let named = info.named;
253
254        node_entries.push(quote! {
255            (#node_id, plotnik_core::StaticNodeTypeInfo {
256                name: #name,
257                named: #named,
258                fields: #fields_ref,
259                children: #children_code,
260            })
261        });
262    }
263
264    let nodes_array_name = syn::Ident::new(&format!("{}_NODES", prefix), Span::call_site());
265    let nodes_len = sorted_node_ids.len();
266
267    let extras_array_name = syn::Ident::new(&format!("{}_EXTRAS", prefix), Span::call_site());
268    let extras_len = extras.len();
269
270    let root_code = match root {
271        Some(id) => quote! { Some(#id) },
272        None => quote! { None },
273    };
274
275    quote! {
276        #(#static_defs)*
277
278        static #nodes_array_name: [(u16, plotnik_core::StaticNodeTypeInfo); #nodes_len] = [
279            #(#node_entries),*
280        ];
281
282        static #extras_array_name: [u16; #extras_len] = [#(#extras),*];
283
284        pub static #const_name: plotnik_core::StaticNodeTypes = plotnik_core::StaticNodeTypes::new(
285            &#nodes_array_name,
286            &#extras_array_name,
287            #root_code,
288        );
289    }
290}