1use super::BuiltinExtractor;
9use crate::store::Annotation;
10use crate::types::{AttrName, Binding, RelativePath, TagName};
11use rustc_hash::FxHashMap;
12use serde_json::Value as JsonValue;
13use std::cell::RefCell;
14
15pub struct RustStructureExtractor;
17
18impl BuiltinExtractor for RustStructureExtractor {
19 fn name(&self) -> &str {
20 "rust-structure"
21 }
22
23 fn extensions(&self) -> &[&str] {
24 &[".rs"]
25 }
26
27 fn extract(&self, source: &str, file: &RelativePath) -> Vec<Annotation> {
28 let tree = match parse_rust(source) {
29 Some(t) => t,
30 None => return vec![],
31 };
32 let mut annotations = Vec::new();
33 let root = tree.root_node();
34 let src = source.as_bytes();
35 let mut cursor = root.walk();
36 for child in root.named_children(&mut cursor) {
37 if let Some(ann) = extract_element(&child, src, file) {
38 annotations.push(ann);
39 }
40 }
41 annotations
42 }
43}
44
45fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
50 node.utf8_text(src).unwrap_or("")
51}
52
53fn get_name(node: &tree_sitter::Node, src: &[u8]) -> String {
54 node.child_by_field_name("name")
55 .map(|n| node_text(&n, src).to_string())
56 .unwrap_or_default()
57}
58
59fn get_visibility(node: &tree_sitter::Node, src: &[u8]) -> Option<String> {
60 let mut cursor = node.walk();
61 for child in node.named_children(&mut cursor) {
62 if child.kind() == "visibility_modifier" {
63 return Some(node_text(&child, src).to_string());
64 }
65 }
66 None
67}
68
69struct Modifiers {
71 is_async: bool,
72 is_unsafe: bool,
73 is_const: bool,
74}
75
76fn extract_modifiers(node: &tree_sitter::Node, src: &[u8]) -> Modifiers {
78 let mut mods = Modifiers {
79 is_async: false,
80 is_unsafe: false,
81 is_const: false,
82 };
83
84 let mut cursor = node.walk();
85 for child in node.children(&mut cursor) {
86 let text = node_text(&child, src);
87 match text {
88 "async" => mods.is_async = true,
89 "unsafe" => mods.is_unsafe = true,
90 "const" => mods.is_const = true,
91 _ => {}
92 }
93 if child.is_named() {
94 let mut inner_cursor = child.walk();
95 for inner in child.children(&mut inner_cursor) {
96 let inner_text = node_text(&inner, src);
97 match inner_text {
98 "async" => mods.is_async = true,
99 "unsafe" => mods.is_unsafe = true,
100 "const" => mods.is_const = true,
101 _ => {}
102 }
103 }
104 }
105 }
106
107 mods
108}
109
110fn make_annotation(
111 tag: &str,
112 binding: String,
113 attrs: FxHashMap<AttrName, JsonValue>,
114 file: &RelativePath,
115 children: Vec<Annotation>,
116) -> Annotation {
117 Annotation {
118 tag: TagName::from(tag),
119 attrs,
120 binding: Binding::from(binding),
121 file: file.clone(),
122 children,
123 }
124}
125
126fn extract_element(
132 node: &tree_sitter::Node,
133 src: &[u8],
134 file: &RelativePath,
135) -> Option<Annotation> {
136 match node.kind() {
137 "function_item" => extract_function(node, src, file),
138 "struct_item" => extract_named_element("struct", node, src, file),
139 "enum_item" => extract_enum(node, src, file),
140 "trait_item" => extract_trait(node, src, file),
141 "impl_item" => extract_impl(node, src, file),
142 "mod_item" => extract_named_element("module", node, src, file),
143 "const_item" => extract_named_element("const", node, src, file),
144 "static_item" => extract_static(node, src, file),
145 "type_item" => extract_named_element("type", node, src, file),
146 "macro_definition" => extract_named_element("macro", node, src, file),
147 _ => None,
148 }
149}
150
151fn extract_named_element(
153 tag: &str,
154 node: &tree_sitter::Node,
155 src: &[u8],
156 file: &RelativePath,
157) -> Option<Annotation> {
158 let name = get_name(node, src);
159 if name.is_empty() {
160 return None;
161 }
162 let mut attrs = FxHashMap::default();
163 attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
164 if let Some(vis) = get_visibility(node, src) {
165 attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
166 }
167 Some(make_annotation(tag, name, attrs, file, vec![]))
168}
169
170fn extract_function(
171 node: &tree_sitter::Node,
172 src: &[u8],
173 file: &RelativePath,
174) -> Option<Annotation> {
175 let name = get_name(node, src);
176 if name.is_empty() {
177 return None;
178 }
179 let mut attrs = FxHashMap::default();
180 attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
181 if let Some(vis) = get_visibility(node, src) {
182 attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
183 }
184 let mods = extract_modifiers(node, src);
185 if mods.is_async {
186 attrs.insert(AttrName::from("async"), JsonValue::Bool(true));
187 }
188 if mods.is_unsafe {
189 attrs.insert(AttrName::from("unsafe"), JsonValue::Bool(true));
190 }
191 if mods.is_const {
192 attrs.insert(AttrName::from("const"), JsonValue::Bool(true));
193 }
194 Some(make_annotation("function", name, attrs, file, vec![]))
195}
196
197fn extract_trait(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
198 let name = get_name(node, src);
199 if name.is_empty() {
200 return None;
201 }
202 let mut attrs = FxHashMap::default();
203 attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
204 if let Some(vis) = get_visibility(node, src) {
205 attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
206 }
207 if extract_modifiers(node, src).is_unsafe {
208 attrs.insert(AttrName::from("unsafe"), JsonValue::Bool(true));
209 }
210
211 let mut children = Vec::new();
213 if let Some(body) = node.child_by_field_name("body") {
214 let mut cursor = body.walk();
215 for child in body.named_children(&mut cursor) {
216 if child.kind() == "function_item" || child.kind() == "function_signature_item" {
217 let method_name = get_name(&child, src);
218 if !method_name.is_empty() {
219 let mut method_attrs = FxHashMap::default();
220 method_attrs.insert(
221 AttrName::from("name"),
222 JsonValue::String(method_name.clone()),
223 );
224 children.push(make_annotation(
225 "method",
226 method_name,
227 method_attrs,
228 file,
229 vec![],
230 ));
231 }
232 }
233 }
234 }
235
236 Some(make_annotation("trait", name, attrs, file, children))
237}
238
239fn extract_impl(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
240 let type_name = node
241 .child_by_field_name("type")
242 .map(|n| node_text(&n, src).to_string())
243 .unwrap_or_default();
244
245 if type_name.is_empty() {
246 return None;
247 }
248
249 let trait_name = node
250 .child_by_field_name("trait")
251 .map(|n| node_text(&n, src).to_string());
252
253 let mut attrs = FxHashMap::default();
254 attrs.insert(AttrName::from("type"), JsonValue::String(type_name.clone()));
255 if let Some(ref t) = trait_name {
256 attrs.insert(AttrName::from("trait"), JsonValue::String(t.clone()));
257 }
258
259 let mut children = Vec::new();
261 if let Some(body) = node.child_by_field_name("body") {
262 let mut cursor = body.walk();
263 for child in body.named_children(&mut cursor) {
264 if child.kind() == "function_item" {
265 if let Some(ann) = extract_function(&child, src, file) {
266 let method = Annotation {
268 tag: TagName::from("method"),
269 ..ann
270 };
271 children.push(method);
272 }
273 }
274 }
275 }
276
277 let binding = if let Some(ref t) = trait_name {
278 format!("{t} for {type_name}")
279 } else {
280 type_name
281 };
282
283 Some(make_annotation("impl", binding, attrs, file, children))
284}
285
286fn extract_enum(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
287 let name = get_name(node, src);
288 if name.is_empty() {
289 return None;
290 }
291 let mut attrs = FxHashMap::default();
292 attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
293 if let Some(vis) = get_visibility(node, src) {
294 attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
295 }
296
297 let mut children = Vec::new();
299 if let Some(body) = node.child_by_field_name("body") {
300 let mut cursor = body.walk();
301 for child in body.named_children(&mut cursor) {
302 if child.kind() == "enum_variant" {
303 let variant_name = get_name(&child, src);
304 if !variant_name.is_empty() {
305 let mut variant_attrs = FxHashMap::default();
306 variant_attrs.insert(
307 AttrName::from("name"),
308 JsonValue::String(variant_name.clone()),
309 );
310 children.push(make_annotation(
311 "variant",
312 variant_name,
313 variant_attrs,
314 file,
315 vec![],
316 ));
317 }
318 }
319 }
320 }
321
322 Some(make_annotation("enum", name, attrs, file, children))
323}
324
325fn extract_static(node: &tree_sitter::Node, src: &[u8], file: &RelativePath) -> Option<Annotation> {
326 let name = get_name(node, src);
327 if name.is_empty() {
328 return None;
329 }
330 let mut attrs = FxHashMap::default();
331 attrs.insert(AttrName::from("name"), JsonValue::String(name.clone()));
332 if let Some(vis) = get_visibility(node, src) {
333 attrs.insert(AttrName::from("visibility"), JsonValue::String(vis));
334 }
335 let mut cursor = node.walk();
336 for child in node.named_children(&mut cursor) {
337 if child.kind() == "mutable_specifier" {
338 attrs.insert(AttrName::from("mutable"), JsonValue::Bool(true));
339 break;
340 }
341 }
342 Some(make_annotation("static", name, attrs, file, vec![]))
343}
344
345thread_local! {
350 static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
351}
352
353fn parse_rust(source: &str) -> Option<tree_sitter::Tree> {
354 RUST_PARSER.with(|cell| {
355 let mut opt = cell.borrow_mut();
356 let parser = opt.get_or_insert_with(|| {
357 let mut p = tree_sitter::Parser::new();
358 p.set_language(&tree_sitter_rust::LANGUAGE.into())
359 .expect("Failed to set Rust language");
360 p
361 });
362 parser.parse(source, None)
363 })
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 fn run(source: &str) -> Vec<Annotation> {
371 let file = RelativePath::from("src/lib.rs");
372 RustStructureExtractor.extract(source, &file)
373 }
374
375 #[test]
376 fn extracts_functions() {
377 let source = "pub async fn foo() {}\nunsafe fn danger() {}";
379
380 let anns = run(source);
382
383 assert_eq!(anns.len(), 2, "should find 2 functions");
385 assert_eq!(anns[0].tag.as_ref(), "function", "should be function");
386 assert_eq!(anns[0].binding.as_ref(), "foo", "function name");
387 assert_eq!(
388 anns[0].attrs.get(&AttrName::from("async")),
389 Some(&JsonValue::Bool(true)),
390 "foo should be async"
391 );
392 assert_eq!(
393 anns[0].attrs.get(&AttrName::from("visibility")),
394 Some(&JsonValue::String("pub".to_string())),
395 "foo should be pub"
396 );
397 assert_eq!(
398 anns[1].attrs.get(&AttrName::from("unsafe")),
399 Some(&JsonValue::Bool(true)),
400 "danger should be unsafe"
401 );
402 }
403
404 #[test]
405 fn extracts_structs_and_enums() {
406 let source = "pub struct Foo { x: i32 }\npub enum Color { Red, Green, Blue }";
408
409 let anns = run(source);
411
412 assert_eq!(anns.len(), 2, "should find struct + enum");
414 assert_eq!(anns[0].tag.as_ref(), "struct", "should be struct");
415 assert_eq!(anns[0].binding.as_ref(), "Foo", "struct name");
416 assert_eq!(anns[1].tag.as_ref(), "enum", "should be enum");
417 assert_eq!(anns[1].binding.as_ref(), "Color", "enum name");
418 assert_eq!(anns[1].children.len(), 3, "should have 3 variants");
419 }
420
421 #[test]
422 fn extracts_impl_with_methods() {
423 let source = "impl Foo { fn bar() {} fn baz(&self) {} }";
425
426 let anns = run(source);
428
429 assert_eq!(anns.len(), 1, "should find 1 impl");
431 assert_eq!(anns[0].tag.as_ref(), "impl", "should be impl");
432 assert_eq!(anns[0].binding.as_ref(), "Foo", "impl type");
433 assert_eq!(anns[0].children.len(), 2, "should have 2 methods");
434 assert_eq!(
435 anns[0].children[0].tag.as_ref(),
436 "method",
437 "child should be method"
438 );
439 assert_eq!(anns[0].children[0].binding.as_ref(), "bar", "method name");
440 }
441
442 #[test]
443 fn extracts_trait_impl() {
444 let source = "impl Display for Foo { fn fmt(&self) {} }";
446
447 let anns = run(source);
449
450 assert_eq!(anns.len(), 1, "should find 1 impl");
452 assert_eq!(
453 anns[0].attrs.get(&AttrName::from("trait")),
454 Some(&JsonValue::String("Display".to_string())),
455 "should have trait attr"
456 );
457 assert_eq!(anns[0].binding.as_ref(), "Display for Foo", "impl binding");
458 }
459
460 #[test]
461 fn extracts_traits() {
462 let source = "pub trait Resolver { fn resolve(&self); }";
464
465 let anns = run(source);
467
468 assert_eq!(anns.len(), 1, "should find 1 trait");
470 assert_eq!(anns[0].tag.as_ref(), "trait", "should be trait");
471 assert_eq!(anns[0].binding.as_ref(), "Resolver", "trait name");
472 }
473
474 #[test]
475 fn extracts_statics() {
476 let source = "static mut COUNTER: u32 = 0;";
478
479 let anns = run(source);
481
482 assert_eq!(anns.len(), 1, "should find 1 static");
484 assert_eq!(anns[0].tag.as_ref(), "static", "should be static");
485 assert_eq!(
486 anns[0].attrs.get(&AttrName::from("mutable")),
487 Some(&JsonValue::Bool(true)),
488 "should be mutable"
489 );
490 }
491
492 #[test]
493 fn extracts_other_items() {
494 let source = "pub const MAX: usize = 100;\npub type Result<T> = std::result::Result<T, Error>;\nmacro_rules! my_macro { () => {} }\npub mod inner {}";
496
497 let anns = run(source);
499
500 assert_eq!(anns.len(), 4, "should find const + type + macro + mod");
502 assert_eq!(anns[0].tag.as_ref(), "const", "should be const");
503 assert_eq!(anns[0].binding.as_ref(), "MAX", "const name");
504 assert_eq!(anns[1].tag.as_ref(), "type", "should be type");
505 assert_eq!(anns[2].tag.as_ref(), "macro", "should be macro");
506 assert_eq!(anns[3].tag.as_ref(), "module", "should be module");
507 }
508}