1use std::sync::Arc;
2
3use tree_sitter::Language;
4
5pub use plotnik_core::{Cardinality, NodeFieldId, NodeTypeId, NodeTypes, StaticNodeTypes};
6
7pub mod builtin;
8pub mod dynamic;
9
10pub use builtin::*;
11
12pub type Lang = Arc<dyn LangImpl>;
14
15pub trait LangImpl: Send + Sync {
18 fn name(&self) -> &str;
19
20 fn parse(&self, source: &str) -> tree_sitter::Tree;
22
23 fn resolve_named_node(&self, kind: &str) -> Option<NodeTypeId>;
24 fn resolve_anonymous_node(&self, kind: &str) -> Option<NodeTypeId>;
25 fn resolve_field(&self, name: &str) -> Option<NodeFieldId>;
26
27 fn all_named_node_kinds(&self) -> Vec<&'static str>;
29 fn all_field_names(&self) -> Vec<&'static str>;
30 fn node_type_name(&self, node_type_id: NodeTypeId) -> Option<&'static str>;
31 fn field_name(&self, field_id: NodeFieldId) -> Option<&'static str>;
32 fn fields_for_node_type(&self, node_type_id: NodeTypeId) -> Vec<&'static str>;
33
34 fn is_supertype(&self, node_type_id: NodeTypeId) -> bool;
35 fn subtypes(&self, supertype: NodeTypeId) -> &[u16];
36
37 fn root(&self) -> Option<NodeTypeId>;
38 fn is_extra(&self, node_type_id: NodeTypeId) -> bool;
39
40 fn has_field(&self, node_type_id: NodeTypeId, node_field_id: NodeFieldId) -> bool;
41 fn field_cardinality(
42 &self,
43 node_type_id: NodeTypeId,
44 node_field_id: NodeFieldId,
45 ) -> Option<Cardinality>;
46 fn valid_field_types(
47 &self,
48 node_type_id: NodeTypeId,
49 node_field_id: NodeFieldId,
50 ) -> &[NodeTypeId];
51 fn is_valid_field_type(
52 &self,
53 node_type_id: NodeTypeId,
54 node_field_id: NodeFieldId,
55 child: NodeTypeId,
56 ) -> bool;
57
58 fn children_cardinality(&self, node_type_id: NodeTypeId) -> Option<Cardinality>;
59 fn valid_child_types(&self, node_type_id: NodeTypeId) -> &[NodeTypeId];
60 fn is_valid_child_type(&self, node_type_id: NodeTypeId, child: NodeTypeId) -> bool;
61}
62
63#[derive(Debug)]
68pub struct LangInner<N: NodeTypes> {
69 name: String,
70 ts_lang: Language,
71 node_types: N,
72}
73
74impl LangInner<&'static StaticNodeTypes> {
75 pub fn new_static(name: &str, ts_lang: Language, node_types: &'static StaticNodeTypes) -> Self {
76 Self {
77 name: name.to_owned(),
78 ts_lang,
79 node_types,
80 }
81 }
82
83 pub fn node_types(&self) -> &'static StaticNodeTypes {
84 self.node_types
85 }
86}
87
88impl<N: NodeTypes + Send + Sync> LangImpl for LangInner<N> {
89 fn name(&self) -> &str {
90 &self.name
91 }
92
93 fn parse(&self, source: &str) -> tree_sitter::Tree {
94 let mut parser = tree_sitter::Parser::new();
95 parser
96 .set_language(&self.ts_lang)
97 .expect("failed to set language");
98 parser.parse(source, None).expect("failed to parse source")
99 }
100
101 fn resolve_named_node(&self, kind: &str) -> Option<NodeTypeId> {
102 let id = self.ts_lang.id_for_node_kind(kind, true);
103 (id != 0).then_some(id)
105 }
106
107 fn resolve_anonymous_node(&self, kind: &str) -> Option<NodeTypeId> {
108 let id = self.ts_lang.id_for_node_kind(kind, false);
109 if id != 0 {
112 return Some(id);
113 }
114 (self.ts_lang.node_kind_for_id(0) == Some(kind)).then_some(0)
115 }
116
117 fn resolve_field(&self, name: &str) -> Option<NodeFieldId> {
118 self.ts_lang.field_id_for_name(name)
119 }
120
121 fn all_named_node_kinds(&self) -> Vec<&'static str> {
122 let count = self.ts_lang.node_kind_count();
123 (0..count as u16)
124 .filter(|&id| self.ts_lang.node_kind_is_named(id))
125 .filter_map(|id| self.ts_lang.node_kind_for_id(id))
126 .collect()
127 }
128
129 fn all_field_names(&self) -> Vec<&'static str> {
130 let count = self.ts_lang.field_count();
131 (1..=count as u16)
132 .filter_map(|id| self.ts_lang.field_name_for_id(id))
133 .collect()
134 }
135
136 fn node_type_name(&self, node_type_id: NodeTypeId) -> Option<&'static str> {
137 self.ts_lang.node_kind_for_id(node_type_id)
138 }
139
140 fn field_name(&self, field_id: NodeFieldId) -> Option<&'static str> {
141 self.ts_lang.field_name_for_id(field_id.get())
142 }
143
144 fn fields_for_node_type(&self, node_type_id: NodeTypeId) -> Vec<&'static str> {
145 let count = self.ts_lang.field_count();
146 (1..=count as u16)
147 .filter_map(|id| {
148 let field_id = std::num::NonZeroU16::new(id)?;
149 if self.node_types.has_field(node_type_id, field_id) {
150 self.ts_lang.field_name_for_id(id)
151 } else {
152 None
153 }
154 })
155 .collect()
156 }
157
158 fn is_supertype(&self, node_type_id: NodeTypeId) -> bool {
159 self.ts_lang.node_kind_is_supertype(node_type_id)
160 }
161
162 fn subtypes(&self, supertype: NodeTypeId) -> &[u16] {
163 self.ts_lang.subtypes_for_supertype(supertype)
164 }
165
166 fn root(&self) -> Option<NodeTypeId> {
167 self.node_types.root()
168 }
169
170 fn is_extra(&self, node_type_id: NodeTypeId) -> bool {
171 self.node_types.is_extra(node_type_id)
172 }
173
174 fn has_field(&self, node_type_id: NodeTypeId, node_field_id: NodeFieldId) -> bool {
175 self.node_types.has_field(node_type_id, node_field_id)
176 }
177
178 fn field_cardinality(
179 &self,
180 node_type_id: NodeTypeId,
181 node_field_id: NodeFieldId,
182 ) -> Option<Cardinality> {
183 self.node_types
184 .field_cardinality(node_type_id, node_field_id)
185 }
186
187 fn valid_field_types(
188 &self,
189 node_type_id: NodeTypeId,
190 node_field_id: NodeFieldId,
191 ) -> &[NodeTypeId] {
192 self.node_types
193 .valid_field_types(node_type_id, node_field_id)
194 }
195
196 fn is_valid_field_type(
197 &self,
198 node_type_id: NodeTypeId,
199 node_field_id: NodeFieldId,
200 child: NodeTypeId,
201 ) -> bool {
202 self.node_types
203 .is_valid_field_type(node_type_id, node_field_id, child)
204 }
205
206 fn children_cardinality(&self, node_type_id: NodeTypeId) -> Option<Cardinality> {
207 self.node_types.children_cardinality(node_type_id)
208 }
209
210 fn valid_child_types(&self, node_type_id: NodeTypeId) -> &[NodeTypeId] {
211 self.node_types.valid_child_types(node_type_id)
212 }
213
214 fn is_valid_child_type(&self, node_type_id: NodeTypeId, child: NodeTypeId) -> bool {
215 self.node_types.is_valid_child_type(node_type_id, child)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 #[cfg(feature = "javascript")]
225 fn lang_from_name() {
226 assert_eq!(from_name("js").unwrap().name(), "javascript");
227 assert_eq!(from_name("JavaScript").unwrap().name(), "javascript");
228 assert!(from_name("unknown").is_none());
229 }
230
231 #[test]
232 #[cfg(feature = "go")]
233 fn lang_from_name_golang() {
234 assert_eq!(from_name("go").unwrap().name(), "go");
235 assert_eq!(from_name("golang").unwrap().name(), "go");
236 assert_eq!(from_name("GOLANG").unwrap().name(), "go");
237 }
238
239 #[test]
240 #[cfg(feature = "javascript")]
241 fn lang_from_extension() {
242 assert_eq!(from_ext("js").unwrap().name(), "javascript");
243 assert_eq!(from_ext("mjs").unwrap().name(), "javascript");
244 }
245
246 #[test]
247 #[cfg(feature = "typescript")]
248 fn typescript_and_tsx() {
249 assert_eq!(typescript().name(), "typescript");
250 assert_eq!(tsx().name(), "tsx");
251 assert_eq!(from_ext("ts").unwrap().name(), "typescript");
252 assert_eq!(from_ext("tsx").unwrap().name(), "tsx");
253 }
254
255 #[test]
256 fn all_returns_enabled_langs() {
257 let langs = all();
258 assert!(!langs.is_empty());
259 for lang in &langs {
260 assert!(!lang.name().is_empty());
261 }
262 }
263
264 #[test]
265 #[cfg(feature = "javascript")]
266 fn resolve_node_and_field() {
267 let lang = javascript();
268
269 let func_id = lang.resolve_named_node("function_declaration");
270 assert!(func_id.is_some());
271
272 let unknown = lang.resolve_named_node("nonexistent_node_type");
273 assert!(unknown.is_none());
274
275 let name_field = lang.resolve_field("name");
276 assert!(name_field.is_some());
277
278 let unknown_field = lang.resolve_field("nonexistent_field");
279 assert!(unknown_field.is_none());
280 }
281
282 #[test]
283 #[cfg(feature = "javascript")]
284 fn supertype_via_lang_trait() {
285 let lang = javascript();
286
287 let expr_id = lang.resolve_named_node("expression").unwrap();
288 assert!(lang.is_supertype(expr_id));
289
290 let subtypes = lang.subtypes(expr_id);
291 assert!(!subtypes.is_empty());
292
293 let func_id = lang.resolve_named_node("function_declaration").unwrap();
294 assert!(!lang.is_supertype(func_id));
295 }
296
297 #[test]
298 #[cfg(feature = "javascript")]
299 fn field_validation_via_trait() {
300 let lang = javascript();
301
302 let func_id = lang.resolve_named_node("function_declaration").unwrap();
303 let name_field = lang.resolve_field("name").unwrap();
304 let body_field = lang.resolve_field("body").unwrap();
305
306 assert!(lang.has_field(func_id, name_field));
307 assert!(lang.has_field(func_id, body_field));
308
309 let identifier_id = lang.resolve_named_node("identifier").unwrap();
310 assert!(lang.is_valid_field_type(func_id, name_field, identifier_id));
311
312 let statement_block_id = lang.resolve_named_node("statement_block").unwrap();
313 assert!(lang.is_valid_field_type(func_id, body_field, statement_block_id));
314 }
315
316 #[test]
317 #[cfg(feature = "javascript")]
318 fn root_via_trait() {
319 let lang = javascript();
320 let root_id = lang.root();
321 assert!(root_id.is_some());
322
323 let program_id = lang.resolve_named_node("program");
324 assert_eq!(root_id, program_id);
325 }
326
327 #[test]
328 #[cfg(feature = "javascript")]
329 fn unresolved_returns_none() {
330 let lang = javascript();
331
332 assert!(lang.resolve_named_node("nonexistent_node_type").is_none());
333 assert!(lang.resolve_field("nonexistent_field").is_none());
334 }
335
336 #[test]
337 #[cfg(feature = "rust")]
338 fn rust_lang_works() {
339 let lang = rust();
340 let func_id = lang.resolve_named_node("function_item");
341 assert!(func_id.is_some());
342 }
343
344 #[test]
345 #[cfg(feature = "javascript")]
346 fn tree_sitter_id_zero_disambiguation() {
347 let lang = javascript();
348
349 assert!(lang.resolve_named_node("fake_named").is_none());
351
352 let end_resolved = lang.resolve_anonymous_node("end");
354 let fake_resolved = lang.resolve_anonymous_node("totally_fake_node");
355
356 assert!(end_resolved.is_some(), "Valid 'end' node should resolve");
357 assert_eq!(end_resolved, Some(0), "'end' should have ID 0");
358
359 assert!(fake_resolved.is_none(), "Non-existent node should be None");
360
361 assert!(lang.resolve_field("name").is_some());
363 assert!(lang.resolve_field("fake_field").is_none());
364 }
365}