plotnik_langs/
lib.rs

1use std::sync::Arc;
2
3use tree_sitter::Language;
4
5pub use plotnik_core::{Cardinality, NodeFieldId, NodeTypeId, NodeTypes, StaticNodeTypes};
6
7pub mod builtin;
8pub mod dynamic;
9
10pub use builtin::*;
11
12/// User-facing language type. Works with any language (static or dynamic).
13pub type Lang = Arc<dyn LangImpl>;
14
15/// Trait providing a unified facade for tree-sitter's Language API
16/// combined with our node type constraints.
17pub trait LangImpl: Send + Sync {
18    fn name(&self) -> &str;
19
20    /// Parse source code into a tree-sitter tree.
21    fn parse(&self, source: &str) -> tree_sitter::Tree;
22
23    fn resolve_named_node(&self, kind: &str) -> Option<NodeTypeId>;
24    fn resolve_anonymous_node(&self, kind: &str) -> Option<NodeTypeId>;
25    fn resolve_field(&self, name: &str) -> Option<NodeFieldId>;
26
27    // Enumeration methods for suggestions
28    fn all_named_node_kinds(&self) -> Vec<&'static str>;
29    fn all_field_names(&self) -> Vec<&'static str>;
30    fn node_type_name(&self, node_type_id: NodeTypeId) -> Option<&'static str>;
31    fn field_name(&self, field_id: NodeFieldId) -> Option<&'static str>;
32    fn fields_for_node_type(&self, node_type_id: NodeTypeId) -> Vec<&'static str>;
33
34    fn is_supertype(&self, node_type_id: NodeTypeId) -> bool;
35    fn subtypes(&self, supertype: NodeTypeId) -> &[u16];
36
37    fn root(&self) -> Option<NodeTypeId>;
38    fn is_extra(&self, node_type_id: NodeTypeId) -> bool;
39
40    fn has_field(&self, node_type_id: NodeTypeId, node_field_id: NodeFieldId) -> bool;
41    fn field_cardinality(
42        &self,
43        node_type_id: NodeTypeId,
44        node_field_id: NodeFieldId,
45    ) -> Option<Cardinality>;
46    fn valid_field_types(
47        &self,
48        node_type_id: NodeTypeId,
49        node_field_id: NodeFieldId,
50    ) -> &[NodeTypeId];
51    fn is_valid_field_type(
52        &self,
53        node_type_id: NodeTypeId,
54        node_field_id: NodeFieldId,
55        child: NodeTypeId,
56    ) -> bool;
57
58    fn children_cardinality(&self, node_type_id: NodeTypeId) -> Option<Cardinality>;
59    fn valid_child_types(&self, node_type_id: NodeTypeId) -> &[NodeTypeId];
60    fn is_valid_child_type(&self, node_type_id: NodeTypeId, child: NodeTypeId) -> bool;
61}
62
63/// Generic language implementation parameterized by node types.
64///
65/// This struct provides a single implementation of `LangImpl` that works with
66/// any `NodeTypes` implementation (static or dynamic).
67#[derive(Debug)]
68pub struct LangInner<N: NodeTypes> {
69    name: String,
70    ts_lang: Language,
71    node_types: N,
72}
73
74impl LangInner<&'static StaticNodeTypes> {
75    pub fn new_static(name: &str, ts_lang: Language, node_types: &'static StaticNodeTypes) -> Self {
76        Self {
77            name: name.to_owned(),
78            ts_lang,
79            node_types,
80        }
81    }
82
83    pub fn node_types(&self) -> &'static StaticNodeTypes {
84        self.node_types
85    }
86}
87
88impl<N: NodeTypes + Send + Sync> LangImpl for LangInner<N> {
89    fn name(&self) -> &str {
90        &self.name
91    }
92
93    fn parse(&self, source: &str) -> tree_sitter::Tree {
94        let mut parser = tree_sitter::Parser::new();
95        parser
96            .set_language(&self.ts_lang)
97            .expect("failed to set language");
98        parser.parse(source, None).expect("failed to parse source")
99    }
100
101    fn resolve_named_node(&self, kind: &str) -> Option<NodeTypeId> {
102        let id = self.ts_lang.id_for_node_kind(kind, true);
103        // For named nodes, 0 always means "not found"
104        (id != 0).then_some(id)
105    }
106
107    fn resolve_anonymous_node(&self, kind: &str) -> Option<NodeTypeId> {
108        let id = self.ts_lang.id_for_node_kind(kind, false);
109        // Tree-sitter returns 0 for both "not found" AND the valid anonymous "end" node.
110        // We disambiguate via reverse lookup.
111        if id != 0 {
112            return Some(id);
113        }
114        (self.ts_lang.node_kind_for_id(0) == Some(kind)).then_some(0)
115    }
116
117    fn resolve_field(&self, name: &str) -> Option<NodeFieldId> {
118        self.ts_lang.field_id_for_name(name)
119    }
120
121    fn all_named_node_kinds(&self) -> Vec<&'static str> {
122        let count = self.ts_lang.node_kind_count();
123        (0..count as u16)
124            .filter(|&id| self.ts_lang.node_kind_is_named(id))
125            .filter_map(|id| self.ts_lang.node_kind_for_id(id))
126            .collect()
127    }
128
129    fn all_field_names(&self) -> Vec<&'static str> {
130        let count = self.ts_lang.field_count();
131        (1..=count as u16)
132            .filter_map(|id| self.ts_lang.field_name_for_id(id))
133            .collect()
134    }
135
136    fn node_type_name(&self, node_type_id: NodeTypeId) -> Option<&'static str> {
137        self.ts_lang.node_kind_for_id(node_type_id)
138    }
139
140    fn field_name(&self, field_id: NodeFieldId) -> Option<&'static str> {
141        self.ts_lang.field_name_for_id(field_id.get())
142    }
143
144    fn fields_for_node_type(&self, node_type_id: NodeTypeId) -> Vec<&'static str> {
145        let count = self.ts_lang.field_count();
146        (1..=count as u16)
147            .filter_map(|id| {
148                let field_id = std::num::NonZeroU16::new(id)?;
149                if self.node_types.has_field(node_type_id, field_id) {
150                    self.ts_lang.field_name_for_id(id)
151                } else {
152                    None
153                }
154            })
155            .collect()
156    }
157
158    fn is_supertype(&self, node_type_id: NodeTypeId) -> bool {
159        self.ts_lang.node_kind_is_supertype(node_type_id)
160    }
161
162    fn subtypes(&self, supertype: NodeTypeId) -> &[u16] {
163        self.ts_lang.subtypes_for_supertype(supertype)
164    }
165
166    fn root(&self) -> Option<NodeTypeId> {
167        self.node_types.root()
168    }
169
170    fn is_extra(&self, node_type_id: NodeTypeId) -> bool {
171        self.node_types.is_extra(node_type_id)
172    }
173
174    fn has_field(&self, node_type_id: NodeTypeId, node_field_id: NodeFieldId) -> bool {
175        self.node_types.has_field(node_type_id, node_field_id)
176    }
177
178    fn field_cardinality(
179        &self,
180        node_type_id: NodeTypeId,
181        node_field_id: NodeFieldId,
182    ) -> Option<Cardinality> {
183        self.node_types
184            .field_cardinality(node_type_id, node_field_id)
185    }
186
187    fn valid_field_types(
188        &self,
189        node_type_id: NodeTypeId,
190        node_field_id: NodeFieldId,
191    ) -> &[NodeTypeId] {
192        self.node_types
193            .valid_field_types(node_type_id, node_field_id)
194    }
195
196    fn is_valid_field_type(
197        &self,
198        node_type_id: NodeTypeId,
199        node_field_id: NodeFieldId,
200        child: NodeTypeId,
201    ) -> bool {
202        self.node_types
203            .is_valid_field_type(node_type_id, node_field_id, child)
204    }
205
206    fn children_cardinality(&self, node_type_id: NodeTypeId) -> Option<Cardinality> {
207        self.node_types.children_cardinality(node_type_id)
208    }
209
210    fn valid_child_types(&self, node_type_id: NodeTypeId) -> &[NodeTypeId] {
211        self.node_types.valid_child_types(node_type_id)
212    }
213
214    fn is_valid_child_type(&self, node_type_id: NodeTypeId, child: NodeTypeId) -> bool {
215        self.node_types.is_valid_child_type(node_type_id, child)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    #[cfg(feature = "javascript")]
225    fn lang_from_name() {
226        assert_eq!(from_name("js").unwrap().name(), "javascript");
227        assert_eq!(from_name("JavaScript").unwrap().name(), "javascript");
228        assert!(from_name("unknown").is_none());
229    }
230
231    #[test]
232    #[cfg(feature = "go")]
233    fn lang_from_name_golang() {
234        assert_eq!(from_name("go").unwrap().name(), "go");
235        assert_eq!(from_name("golang").unwrap().name(), "go");
236        assert_eq!(from_name("GOLANG").unwrap().name(), "go");
237    }
238
239    #[test]
240    #[cfg(feature = "javascript")]
241    fn lang_from_extension() {
242        assert_eq!(from_ext("js").unwrap().name(), "javascript");
243        assert_eq!(from_ext("mjs").unwrap().name(), "javascript");
244    }
245
246    #[test]
247    #[cfg(feature = "typescript")]
248    fn typescript_and_tsx() {
249        assert_eq!(typescript().name(), "typescript");
250        assert_eq!(tsx().name(), "tsx");
251        assert_eq!(from_ext("ts").unwrap().name(), "typescript");
252        assert_eq!(from_ext("tsx").unwrap().name(), "tsx");
253    }
254
255    #[test]
256    fn all_returns_enabled_langs() {
257        let langs = all();
258        assert!(!langs.is_empty());
259        for lang in &langs {
260            assert!(!lang.name().is_empty());
261        }
262    }
263
264    #[test]
265    #[cfg(feature = "javascript")]
266    fn resolve_node_and_field() {
267        let lang = javascript();
268
269        let func_id = lang.resolve_named_node("function_declaration");
270        assert!(func_id.is_some());
271
272        let unknown = lang.resolve_named_node("nonexistent_node_type");
273        assert!(unknown.is_none());
274
275        let name_field = lang.resolve_field("name");
276        assert!(name_field.is_some());
277
278        let unknown_field = lang.resolve_field("nonexistent_field");
279        assert!(unknown_field.is_none());
280    }
281
282    #[test]
283    #[cfg(feature = "javascript")]
284    fn supertype_via_lang_trait() {
285        let lang = javascript();
286
287        let expr_id = lang.resolve_named_node("expression").unwrap();
288        assert!(lang.is_supertype(expr_id));
289
290        let subtypes = lang.subtypes(expr_id);
291        assert!(!subtypes.is_empty());
292
293        let func_id = lang.resolve_named_node("function_declaration").unwrap();
294        assert!(!lang.is_supertype(func_id));
295    }
296
297    #[test]
298    #[cfg(feature = "javascript")]
299    fn field_validation_via_trait() {
300        let lang = javascript();
301
302        let func_id = lang.resolve_named_node("function_declaration").unwrap();
303        let name_field = lang.resolve_field("name").unwrap();
304        let body_field = lang.resolve_field("body").unwrap();
305
306        assert!(lang.has_field(func_id, name_field));
307        assert!(lang.has_field(func_id, body_field));
308
309        let identifier_id = lang.resolve_named_node("identifier").unwrap();
310        assert!(lang.is_valid_field_type(func_id, name_field, identifier_id));
311
312        let statement_block_id = lang.resolve_named_node("statement_block").unwrap();
313        assert!(lang.is_valid_field_type(func_id, body_field, statement_block_id));
314    }
315
316    #[test]
317    #[cfg(feature = "javascript")]
318    fn root_via_trait() {
319        let lang = javascript();
320        let root_id = lang.root();
321        assert!(root_id.is_some());
322
323        let program_id = lang.resolve_named_node("program");
324        assert_eq!(root_id, program_id);
325    }
326
327    #[test]
328    #[cfg(feature = "javascript")]
329    fn unresolved_returns_none() {
330        let lang = javascript();
331
332        assert!(lang.resolve_named_node("nonexistent_node_type").is_none());
333        assert!(lang.resolve_field("nonexistent_field").is_none());
334    }
335
336    #[test]
337    #[cfg(feature = "rust")]
338    fn rust_lang_works() {
339        let lang = rust();
340        let func_id = lang.resolve_named_node("function_item");
341        assert!(func_id.is_some());
342    }
343
344    #[test]
345    #[cfg(feature = "javascript")]
346    fn tree_sitter_id_zero_disambiguation() {
347        let lang = javascript();
348
349        // For named nodes: 0 unambiguously means "not found"
350        assert!(lang.resolve_named_node("fake_named").is_none());
351
352        // For anonymous nodes: we disambiguate via reverse lookup
353        let end_resolved = lang.resolve_anonymous_node("end");
354        let fake_resolved = lang.resolve_anonymous_node("totally_fake_node");
355
356        assert!(end_resolved.is_some(), "Valid 'end' node should resolve");
357        assert_eq!(end_resolved, Some(0), "'end' should have ID 0");
358
359        assert!(fake_resolved.is_none(), "Non-existent node should be None");
360
361        // Our wrapper preserves field cleanliness
362        assert!(lang.resolve_field("name").is_some());
363        assert!(lang.resolve_field("fake_field").is_none());
364    }
365}