1use super::{CodeElement, SourceLocation};
8use crate::error::AqlError;
9use crate::types::{AttrName, CodeElementName, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use std::cell::RefCell;
12use std::path::Path;
13
14pub struct RustResolver;
16
17impl super::CodeResolver for RustResolver {
18 fn resolve(&self, file_path: &Path) -> Result<CodeElement, AqlError> {
19 let source =
20 std::fs::read_to_string(file_path).map_err(|e| format!("Failed to read file: {e}"))?;
21 let root = parse_rust_source(&source, file_path)?;
22 Ok(root)
23 }
24
25 fn extensions(&self) -> &[&str] {
26 &[".rs"]
27 }
28
29 fn code_tags(&self) -> &[&str] {
30 &[
31 "function", "struct", "enum", "trait", "impl", "module", "const", "static", "type",
32 "macro",
33 ]
34 }
35}
36
37thread_local! {
39 static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
40}
41
42fn with_rust_parser<F, R>(f: F) -> Result<R, String>
43where
44 F: FnOnce(&mut tree_sitter::Parser) -> Result<R, String>,
45{
46 RUST_PARSER.with(|cell| {
47 let mut opt = cell.borrow_mut();
48 let parser = opt.get_or_insert_with(|| {
49 let mut p = tree_sitter::Parser::new();
50 p.set_language(&tree_sitter_rust::LANGUAGE.into())
51 .expect("Failed to set Rust language for tree-sitter");
52 p
53 });
54 f(parser)
55 })
56}
57
58fn parse_rust_source(source: &str, file_path: &Path) -> Result<CodeElement, String> {
60 let tree = with_rust_parser(|parser| {
61 parser
62 .parse(source, None)
63 .ok_or_else(|| "Failed to parse source".to_string())
64 })?;
65
66 let root_node = tree.root_node();
67 let src = source.as_bytes();
68 let file_str = file_path.to_string_lossy().to_string();
69
70 let mut children = Vec::new();
71 let mut cursor = root_node.walk();
72 for child in root_node.named_children(&mut cursor) {
73 if let Some(element) = extract_element(&child, src, &file_str) {
74 children.push(element);
75 }
76 }
77
78 let filename = file_path
79 .file_name()
80 .map(|f| f.to_string_lossy().to_string())
81 .unwrap_or_else(|| file_str.clone());
82
83 Ok(CodeElement {
84 tag: TagName::from("module"),
85 name: CodeElementName::from(filename),
86 attrs: FxHashMap::default(),
87 children,
88 source: SourceLocation {
89 file: RelativePath::from(file_str),
90 line: 1,
91 column: 0,
92 end_line: Some(root_node.end_position().row + 1),
93 end_column: Some(root_node.end_position().column),
94 start_byte: root_node.start_byte(),
95 end_byte: root_node.end_byte(),
96 },
97 })
98}
99
100fn extract_element(node: &tree_sitter::Node, src: &[u8], file: &str) -> Option<CodeElement> {
102 match node.kind() {
103 "function_item" => Some(extract_function(node, src, file)),
104 "struct_item" => Some(extract_named_element("struct", node, src, file)),
105 "enum_item" => Some(extract_named_element("enum", node, src, file)),
106 "trait_item" => Some(extract_trait(node, src, file)),
107 "impl_item" => Some(extract_impl(node, src, file)),
108 "mod_item" => Some(extract_named_element("module", node, src, file)),
109 "const_item" => Some(extract_named_element("const", node, src, file)),
110 "static_item" => Some(extract_static(node, src, file)),
111 "type_item" => Some(extract_named_element("type", node, src, file)),
112 "macro_definition" => Some(extract_named_element("macro", node, src, file)),
113 _ => None,
114 }
115}
116
117fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
118 node.utf8_text(src).unwrap_or("")
119}
120
121fn get_name(node: &tree_sitter::Node, src: &[u8]) -> CodeElementName {
122 CodeElementName::from(
123 node.child_by_field_name("name")
124 .map(|n| node_text(&n, src).to_string())
125 .unwrap_or_default(),
126 )
127}
128
129fn get_visibility(node: &tree_sitter::Node, src: &[u8]) -> Option<String> {
130 let mut cursor = node.walk();
131 for child in node.named_children(&mut cursor) {
132 if child.kind() == "visibility_modifier" {
133 return Some(node_text(&child, src).to_string());
134 }
135 }
136 None
137}
138
139struct Modifiers {
141 is_async: bool,
142 is_unsafe: bool,
143 is_const: bool,
144}
145
146fn extract_modifiers(node: &tree_sitter::Node, src: &[u8]) -> Modifiers {
148 let mut mods = Modifiers {
149 is_async: false,
150 is_unsafe: false,
151 is_const: false,
152 };
153
154 let mut cursor = node.walk();
155 for child in node.children(&mut cursor) {
156 let text = node_text(&child, src);
157 match text {
158 "async" => mods.is_async = true,
159 "unsafe" => mods.is_unsafe = true,
160 "const" => mods.is_const = true,
161 _ => {}
162 }
163 if child.is_named() {
165 let mut inner_cursor = child.walk();
166 for inner in child.children(&mut inner_cursor) {
167 let inner_text = node_text(&inner, src);
168 match inner_text {
169 "async" => mods.is_async = true,
170 "unsafe" => mods.is_unsafe = true,
171 "const" => mods.is_const = true,
172 _ => {}
173 }
174 }
175 }
176 }
177
178 mods
179}
180
181fn make_source_location(node: &tree_sitter::Node, file: &str) -> SourceLocation {
182 let start = node.start_position();
183 let end = node.end_position();
184 SourceLocation {
185 file: RelativePath::from(file),
186 line: start.row + 1,
187 column: start.column,
188 end_line: Some(end.row + 1),
189 end_column: Some(end.column),
190 start_byte: node.start_byte(),
191 end_byte: node.end_byte(),
192 }
193}
194
195fn extract_named_element(
198 tag: &str,
199 node: &tree_sitter::Node,
200 src: &[u8],
201 file: &str,
202) -> CodeElement {
203 let name = get_name(node, src);
204 let mut attrs = FxHashMap::default();
205 attrs.insert(
206 AttrName::from("name"),
207 serde_json::Value::String(name.to_string()),
208 );
209 if let Some(vis) = get_visibility(node, src) {
210 attrs.insert(AttrName::from("visibility"), serde_json::Value::String(vis));
211 }
212 CodeElement {
213 tag: TagName::from(tag),
214 name,
215 attrs,
216 children: vec![],
217 source: make_source_location(node, file),
218 }
219}
220
221fn extract_function(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
222 let mut el = extract_named_element("function", node, src, file);
223 let mods = extract_modifiers(node, src);
224 if mods.is_async {
225 el.attrs
226 .insert(AttrName::from("async"), serde_json::Value::Bool(true));
227 }
228 if mods.is_unsafe {
229 el.attrs
230 .insert(AttrName::from("unsafe"), serde_json::Value::Bool(true));
231 }
232 if mods.is_const {
233 el.attrs
234 .insert(AttrName::from("const"), serde_json::Value::Bool(true));
235 }
236 el
237}
238
239fn extract_trait(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
240 let mut el = extract_named_element("trait", node, src, file);
241 if extract_modifiers(node, src).is_unsafe {
242 el.attrs
243 .insert(AttrName::from("unsafe"), serde_json::Value::Bool(true));
244 }
245 el
246}
247
248fn extract_impl(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
249 let type_name = node
250 .child_by_field_name("type")
251 .map(|n| node_text(&n, src).to_string())
252 .unwrap_or_default();
253
254 let trait_name = node
255 .child_by_field_name("trait")
256 .map(|n| node_text(&n, src).to_string());
257
258 let mut attrs = FxHashMap::default();
259 attrs.insert(
260 AttrName::from("type"),
261 serde_json::Value::String(type_name.clone()),
262 );
263 if let Some(ref t) = trait_name {
264 attrs.insert(
265 AttrName::from("trait"),
266 serde_json::Value::String(t.clone()),
267 );
268 }
269
270 let mut children = Vec::new();
272 if let Some(body) = node.child_by_field_name("body") {
273 let mut cursor = body.walk();
274 for child in body.named_children(&mut cursor) {
275 if child.kind() == "function_item" {
276 children.push(extract_function(&child, src, file));
277 }
278 }
279 }
280
281 let name = if let Some(ref t) = trait_name {
282 format!("{t} for {type_name}")
283 } else {
284 type_name
285 };
286
287 CodeElement {
288 tag: TagName::from("impl"),
289 name: CodeElementName::from(name),
290 attrs,
291 children,
292 source: make_source_location(node, file),
293 }
294}
295
296fn extract_static(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
297 let mut el = extract_named_element("static", node, src, file);
298 let mut cursor = node.walk();
299 for child in node.named_children(&mut cursor) {
300 if child.kind() == "mutable_specifier" {
301 el.attrs
302 .insert(AttrName::from("mutable"), serde_json::Value::Bool(true));
303 break;
304 }
305 }
306 el
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 fn parse_snippet(source: &str) -> CodeElement {
314 parse_rust_source(source, Path::new("test.rs")).unwrap()
315 }
316
317 #[test]
318 fn parses_simple_named_elements() {
319 let struct_root = parse_snippet("struct Bar { x: i32 }");
321 let pub_struct_root = parse_snippet("pub struct Baz;");
322 let enum_root = parse_snippet("pub enum Color { Red, Green, Blue }");
323 let const_root = parse_snippet("pub const MAX: usize = 100;");
324 let type_root = parse_snippet("pub type Result<T> = std::result::Result<T, Error>;");
325 let macro_root = parse_snippet("macro_rules! my_macro { () => {} }");
326 let module_root = parse_snippet("pub mod inner {}");
327
328 let s = &struct_root.children[0];
330 let ps = &pub_struct_root.children[0];
331 let e = &enum_root.children[0];
332 let c = &const_root.children[0];
333 let t = &type_root.children[0];
334 let m = ¯o_root.children[0];
335 let md = &module_root.children[0];
336
337 assert_eq!(struct_root.children.len(), 1);
339 assert_eq!(s.tag, "struct");
340 assert_eq!(s.name, "Bar");
341
342 assert_eq!(ps.tag, "struct");
343 assert_eq!(ps.name, "Baz");
344 assert_eq!(
345 ps.attrs.get("visibility"),
346 Some(&serde_json::Value::String("pub".to_string()))
347 );
348
349 assert_eq!(e.tag, "enum");
350 assert_eq!(e.name, "Color");
351 assert_eq!(
352 e.attrs.get("visibility"),
353 Some(&serde_json::Value::String("pub".to_string()))
354 );
355
356 assert_eq!(c.tag, "const");
357 assert_eq!(c.name, "MAX");
358 assert_eq!(
359 c.attrs.get("visibility"),
360 Some(&serde_json::Value::String("pub".to_string()))
361 );
362
363 assert_eq!(t.tag, "type");
364 assert_eq!(t.name, "Result");
365
366 assert_eq!(m.tag, "macro");
367 assert_eq!(m.name, "my_macro");
368
369 assert_eq!(md.tag, "module");
370 assert_eq!(md.name, "inner");
371 assert_eq!(
372 md.attrs.get("visibility"),
373 Some(&serde_json::Value::String("pub".to_string()))
374 );
375 }
376
377 #[test]
378 fn parses_function_variants() {
379 let pub_async_root = parse_snippet("pub async fn foo() {}");
381 let unsafe_root = parse_snippet("unsafe fn danger() {}");
382
383 let func = &pub_async_root.children[0];
385 let uf = &unsafe_root.children[0];
386
387 assert_eq!(pub_async_root.tag, "module");
389 assert_eq!(pub_async_root.children.len(), 1);
390 assert_eq!(func.tag, "function");
391 assert_eq!(func.name, "foo");
392 assert_eq!(
393 func.attrs.get("async"),
394 Some(&serde_json::Value::Bool(true))
395 );
396 assert_eq!(
397 func.attrs.get("visibility"),
398 Some(&serde_json::Value::String("pub".to_string()))
399 );
400
401 assert_eq!(uf.tag, "function");
402 assert_eq!(uf.name, "danger");
403 assert_eq!(uf.attrs.get("unsafe"), Some(&serde_json::Value::Bool(true)));
404 }
405
406 #[test]
407 fn parses_compound_elements() {
408 let impl_root = parse_snippet("impl Foo { fn bar() {} fn baz(&self) {} }");
410 let trait_impl_root = parse_snippet("impl Display for Foo { fn fmt(&self) {} }");
411 let trait_root = parse_snippet("pub trait Resolver { fn resolve(&self); }");
412 let static_root = parse_snippet("static mut COUNTER: u32 = 0;");
413
414 let imp = &impl_root.children[0];
416 let ti = &trait_impl_root.children[0];
417 let tr = &trait_root.children[0];
418 let st = &static_root.children[0];
419
420 assert_eq!(impl_root.children.len(), 1);
422 assert_eq!(imp.tag, "impl");
423 assert_eq!(
424 imp.attrs.get("type"),
425 Some(&serde_json::Value::String("Foo".to_string()))
426 );
427 assert_eq!(imp.children.len(), 2);
428 assert_eq!(imp.children[0].tag, "function");
429 assert_eq!(imp.children[0].name, "bar");
430 assert_eq!(imp.children[1].name, "baz");
431
432 assert_eq!(ti.tag, "impl");
433 assert_eq!(
434 ti.attrs.get("trait"),
435 Some(&serde_json::Value::String("Display".to_string()))
436 );
437 assert_eq!(
438 ti.attrs.get("type"),
439 Some(&serde_json::Value::String("Foo".to_string()))
440 );
441
442 assert_eq!(tr.tag, "trait");
443 assert_eq!(tr.name, "Resolver");
444
445 assert_eq!(st.tag, "static");
446 assert_eq!(st.name, "COUNTER");
447 assert_eq!(
448 st.attrs.get("mutable"),
449 Some(&serde_json::Value::Bool(true))
450 );
451 }
452}