1use crate::error::{ParseError, Result};
11use crate::node::{CodeNode, NodeKind};
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15use tree_sitter::{Language, Parser, Query, QueryCursor, Tree};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct SymbolRelation {
24 pub from_id: String,
26 pub to_name: String,
28 pub kind: RelationType,
30 pub line: u32,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum RelationType {
37 Calls,
39 Imports,
41 Extends,
43 Implements,
45}
46
47#[derive(Debug)]
49pub struct ParseResult {
50 pub symbols: Vec<CodeNode>,
52 pub relations: Vec<SymbolRelation>,
54 pub file_path: String,
56}
57
58pub struct ArborParser {
67 parser: Parser,
69 queries: HashMap<String, CompiledQueries>,
71}
72
73struct CompiledQueries {
75 symbols: Query,
77 imports: Query,
79 calls: Query,
81 language: Language,
83}
84
85impl Default for ArborParser {
86 fn default() -> Self {
87 Self::new().expect("Failed to initialize ArborParser")
88 }
89}
90
91impl ArborParser {
92 pub fn new() -> Result<Self> {
96 let parser = Parser::new();
97 let mut queries = HashMap::new();
98
99 for ext in &["ts", "tsx", "js", "jsx"] {
101 let compiled = Self::compile_typescript_queries()?;
103 queries.insert(ext.to_string(), compiled);
104 }
105
106 let rs_queries = Self::compile_rust_queries()?;
108 queries.insert("rs".to_string(), rs_queries);
109
110 let py_queries = Self::compile_python_queries()?;
112 queries.insert("py".to_string(), py_queries);
113
114 let go_queries = Self::compile_go_queries()?;
117 queries.insert("go".to_string(), go_queries);
118
119 let java_queries = Self::compile_java_queries()?;
121 queries.insert("java".to_string(), java_queries);
122
123 for ext in &["c", "h"] {
125 queries.insert(ext.to_string(), Self::compile_c_queries()?);
126 }
127
128 for ext in &["cpp", "hpp", "cc", "hh", "cxx"] {
130 queries.insert(ext.to_string(), Self::compile_cpp_queries()?);
131 }
132
133 let csharp_queries = Self::compile_csharp_queries()?;
139 queries.insert("cs".to_string(), csharp_queries);
140
141 Ok(Self { parser, queries })
142 }
143
144 pub fn parse_file(&mut self, path: &Path) -> Result<ParseResult> {
156 let source = fs::read_to_string(path).map_err(|e| ParseError::io(path, e))?;
158
159 if source.is_empty() {
160 return Err(ParseError::EmptyFile(path.to_path_buf()));
161 }
162
163 let ext = path
165 .extension()
166 .and_then(|e| e.to_str())
167 .ok_or_else(|| ParseError::UnsupportedLanguage(path.to_path_buf()))?;
168
169 let compiled = self
171 .queries
172 .get(ext)
173 .ok_or_else(|| ParseError::UnsupportedLanguage(path.to_path_buf()))?;
174
175 self.parser
177 .set_language(&compiled.language)
178 .map_err(|e| ParseError::ParserError(format!("Failed to set language: {}", e)))?;
179
180 let tree = self
182 .parser
183 .parse(&source, None)
184 .ok_or_else(|| ParseError::ParserError("Tree-sitter returned no tree".into()))?;
185
186 let file_path = path.to_string_lossy().to_string();
187 let file_name = path
188 .file_name()
189 .and_then(|n| n.to_str())
190 .unwrap_or("unknown");
191
192 let symbols = self.extract_symbols(&tree, &source, &file_path, file_name, compiled);
194
195 let relations = self.extract_relations(&tree, &source, &file_path, &symbols, compiled);
197
198 Ok(ParseResult {
199 symbols,
200 relations,
201 file_path,
202 })
203 }
204
205 pub fn parse_source(
207 &mut self,
208 source: &str,
209 file_path: &str,
210 language: &str,
211 ) -> Result<ParseResult> {
212 if source.is_empty() {
213 return Err(ParseError::EmptyFile(file_path.into()));
214 }
215
216 let compiled = self
217 .queries
218 .get(language)
219 .ok_or_else(|| ParseError::UnsupportedLanguage(file_path.into()))?;
220
221 self.parser
222 .set_language(&compiled.language)
223 .map_err(|e| ParseError::ParserError(format!("Failed to set language: {}", e)))?;
224
225 let tree = self
226 .parser
227 .parse(source, None)
228 .ok_or_else(|| ParseError::ParserError("Tree-sitter returned no tree".into()))?;
229
230 let file_name = Path::new(file_path)
231 .file_name()
232 .and_then(|n| n.to_str())
233 .unwrap_or("unknown");
234
235 let symbols = self.extract_symbols(&tree, source, file_path, file_name, compiled);
236 let relations = self.extract_relations(&tree, source, file_path, &symbols, compiled);
237
238 Ok(ParseResult {
239 symbols,
240 relations,
241 file_path: file_path.to_string(),
242 })
243 }
244
245 fn extract_symbols(
250 &self,
251 tree: &Tree,
252 source: &str,
253 file_path: &str,
254 file_name: &str,
255 compiled: &CompiledQueries,
256 ) -> Vec<CodeNode> {
257 let mut symbols = Vec::new();
258 let mut cursor = QueryCursor::new();
259
260 let matches = cursor.matches(&compiled.symbols, tree.root_node(), source.as_bytes());
261
262 for match_ in matches {
263 let mut name: Option<&str> = None;
265 let mut kind: Option<NodeKind> = None;
266 let mut node = match_.captures.first().map(|c| c.node);
267
268 for capture in match_.captures {
269 let capture_name = compiled.symbols.capture_names()[capture.index as usize];
270 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
271
272 match capture_name {
273 "name" | "function.name" | "class.name" | "interface.name" | "method.name" => {
274 name = Some(text);
275 }
276 "function" | "function_def" => {
277 kind = Some(NodeKind::Function);
278 node = Some(capture.node);
279 }
280 "class" | "class_def" => {
281 kind = Some(NodeKind::Class);
282 node = Some(capture.node);
283 }
284 "interface" | "interface_def" => {
285 kind = Some(NodeKind::Interface);
286 node = Some(capture.node);
287 }
288 "method" | "method_def" => {
289 kind = Some(NodeKind::Method);
290 node = Some(capture.node);
291 }
292 "struct" | "struct_def" => {
293 kind = Some(NodeKind::Struct);
294 node = Some(capture.node);
295 }
296 "enum" | "enum_def" => {
297 kind = Some(NodeKind::Enum);
298 node = Some(capture.node);
299 }
300 "trait" | "trait_def" => {
301 kind = Some(NodeKind::Interface);
302 node = Some(capture.node);
303 }
304 _ => {}
305 }
306 }
307
308 if let (Some(name), Some(kind), Some(node)) = (name, kind, node) {
309 let qualified_name = format!("{}:{}", file_name, name);
311
312 let signature = source
314 .lines()
315 .nth(node.start_position().row)
316 .map(|s| s.trim().to_string());
317
318 let mut symbol = CodeNode::new(name, &qualified_name, kind, file_path)
319 .with_lines(
320 node.start_position().row as u32 + 1,
321 node.end_position().row as u32 + 1,
322 )
323 .with_column(node.start_position().column as u32)
324 .with_bytes(node.start_byte() as u32, node.end_byte() as u32);
325
326 if let Some(sig) = signature {
327 symbol = symbol.with_signature(sig);
328 }
329
330 symbols.push(symbol);
331 }
332 }
333
334 symbols
335 }
336
337 fn extract_relations(
342 &self,
343 tree: &Tree,
344 source: &str,
345 file_path: &str,
346 symbols: &[CodeNode],
347 compiled: &CompiledQueries,
348 ) -> Vec<SymbolRelation> {
349 let mut relations = Vec::new();
350
351 self.extract_imports(tree, source, file_path, &mut relations, compiled);
353
354 self.extract_calls(tree, source, file_path, symbols, &mut relations, compiled);
356
357 relations
358 }
359
360 fn extract_imports(
361 &self,
362 tree: &Tree,
363 source: &str,
364 file_path: &str,
365 relations: &mut Vec<SymbolRelation>,
366 compiled: &CompiledQueries,
367 ) {
368 let mut cursor = QueryCursor::new();
369 let matches = cursor.matches(&compiled.imports, tree.root_node(), source.as_bytes());
370
371 for match_ in matches {
372 let mut module_name: Option<&str> = None;
373 let mut line: u32 = 0;
374
375 for capture in match_.captures {
376 let capture_name = compiled.imports.capture_names()[capture.index as usize];
377 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
378
379 match capture_name {
380 "source" | "module" | "import.source" => {
381 module_name = Some(text.trim_matches(|c| c == '"' || c == '\''));
383 line = capture.node.start_position().row as u32 + 1;
384 }
385 _ => {}
386 }
387 }
388
389 if let Some(module) = module_name {
390 let file_id = format!("{}:__file__", file_path);
392 relations.push(SymbolRelation {
393 from_id: file_id,
394 to_name: module.to_string(),
395 kind: RelationType::Imports,
396 line,
397 });
398 }
399 }
400 }
401
402 fn extract_calls(
403 &self,
404 tree: &Tree,
405 source: &str,
406 file_path: &str,
407 symbols: &[CodeNode],
408 relations: &mut Vec<SymbolRelation>,
409 compiled: &CompiledQueries,
410 ) {
411 let mut cursor = QueryCursor::new();
412 let matches = cursor.matches(&compiled.calls, tree.root_node(), source.as_bytes());
413
414 for match_ in matches {
415 let mut callee_name: Option<&str> = None;
416 let mut call_line: u32 = 0;
417
418 for capture in match_.captures {
419 let capture_name = compiled.calls.capture_names()[capture.index as usize];
420 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
421
422 match capture_name {
423 "callee" | "function" | "call.function" => {
424 if let Some(dot_pos) = text.rfind('.') {
426 callee_name = Some(&text[dot_pos + 1..]);
427 } else {
428 callee_name = Some(text);
429 }
430 call_line = capture.node.start_position().row as u32 + 1;
431 }
432 _ => {}
433 }
434 }
435
436 if let Some(callee) = callee_name {
437 let caller_id = self
439 .find_enclosing_symbol(call_line, symbols)
440 .map(|s| s.id.clone())
441 .unwrap_or_else(|| format!("{}:__file__", file_path));
442
443 relations.push(SymbolRelation {
444 from_id: caller_id,
445 to_name: callee.to_string(),
446 kind: RelationType::Calls,
447 line: call_line,
448 });
449 }
450 }
451 }
452
453 fn find_enclosing_symbol<'a>(
454 &self,
455 line: u32,
456 symbols: &'a [CodeNode],
457 ) -> Option<&'a CodeNode> {
458 symbols
459 .iter()
460 .filter(|s| s.line_start <= line && s.line_end >= line)
461 .min_by_key(|s| s.line_end - s.line_start) }
463
464 fn compile_typescript_queries() -> Result<CompiledQueries> {
469 let language = tree_sitter_typescript::language_typescript();
470
471 let symbols_query = r#"
473 (function_declaration name: (identifier) @name) @function_def
474 (class_declaration name: (type_identifier) @name) @class_def
475 (method_definition name: (property_identifier) @name) @method_def
476 (interface_declaration name: (type_identifier) @name) @interface_def
477 (type_alias_declaration name: (type_identifier) @name) @interface_def
478 "#;
479
480 let imports_query = r#"
482 (import_statement
483 source: (string) @source)
484 "#;
485
486 let calls_query = r#"
488 (call_expression
489 function: (identifier) @callee)
490
491 (call_expression
492 function: (member_expression
493 property: (property_identifier) @callee))
494 "#;
495
496 let symbols = Query::new(&language, symbols_query)
497 .map_err(|e| ParseError::QueryError(e.to_string()))?;
498 let imports = Query::new(&language, imports_query)
499 .map_err(|e| ParseError::QueryError(e.to_string()))?;
500 let calls = Query::new(&language, calls_query)
501 .map_err(|e| ParseError::QueryError(e.to_string()))?;
502
503 Ok(CompiledQueries {
504 symbols,
505 imports,
506 calls,
507 language,
508 })
509 }
510
511 fn compile_rust_queries() -> Result<CompiledQueries> {
512 let language = tree_sitter_rust::language();
513
514 let symbols_query = r#"
516 (function_item name: (identifier) @name) @function_def
517 (struct_item name: (type_identifier) @name) @struct_def
518 (enum_item name: (type_identifier) @name) @enum_def
519 (trait_item name: (type_identifier) @name) @trait_def
520 "#;
521
522 let imports_query = r#"
524 (use_declaration) @source
525 "#;
526
527 let calls_query = r#"
529 (call_expression function: (identifier) @callee)
530 (call_expression function: (field_expression field: (field_identifier) @callee))
531 "#;
532
533 let symbols = Query::new(&language, symbols_query)
534 .map_err(|e| ParseError::QueryError(e.to_string()))?;
535 let imports = Query::new(&language, imports_query)
536 .map_err(|e| ParseError::QueryError(e.to_string()))?;
537 let calls = Query::new(&language, calls_query)
538 .map_err(|e| ParseError::QueryError(e.to_string()))?;
539
540 Ok(CompiledQueries {
541 symbols,
542 imports,
543 calls,
544 language,
545 })
546 }
547
548 fn compile_python_queries() -> Result<CompiledQueries> {
549 let language = tree_sitter_python::language();
550
551 let symbols_query = r#"
553 (function_definition name: (identifier) @name) @function_def
554 (class_definition name: (identifier) @name) @class_def
555 "#;
556
557 let imports_query = r#"
559 (import_statement) @source
560 (import_from_statement) @source
561 "#;
562
563 let calls_query = r#"
565 (call function: (identifier) @callee)
566 (call function: (attribute attribute: (identifier) @callee))
567 "#;
568
569 let symbols = Query::new(&language, symbols_query)
570 .map_err(|e| ParseError::QueryError(e.to_string()))?;
571 let imports = Query::new(&language, imports_query)
572 .map_err(|e| ParseError::QueryError(e.to_string()))?;
573 let calls = Query::new(&language, calls_query)
574 .map_err(|e| ParseError::QueryError(e.to_string()))?;
575
576 Ok(CompiledQueries {
577 symbols,
578 imports,
579 calls,
580 language,
581 })
582 }
583
584 fn compile_go_queries() -> Result<CompiledQueries> {
585 let language = tree_sitter_go::language();
586
587 let symbols_query = r#"
588 (function_declaration name: (identifier) @name) @function_def
589 (method_declaration name: (field_identifier) @name) @method_def
590 (type_declaration (type_spec name: (type_identifier) @name type: (struct_type))) @struct_def
591 (type_declaration (type_spec name: (type_identifier) @name type: (interface_type))) @interface_def
592 "#;
593
594 let imports_query = r#"
595 (import_spec path: (interpreted_string_literal) @source)
596 "#;
597
598 let calls_query = r#"
599 (call_expression function: (identifier) @callee)
600 (call_expression function: (selector_expression field: (field_identifier) @callee))
601 "#;
602
603 let symbols = Query::new(&language, symbols_query)
604 .map_err(|e| ParseError::QueryError(e.to_string()))?;
605 let imports = Query::new(&language, imports_query)
606 .map_err(|e| ParseError::QueryError(e.to_string()))?;
607 let calls = Query::new(&language, calls_query)
608 .map_err(|e| ParseError::QueryError(e.to_string()))?;
609
610 Ok(CompiledQueries {
611 symbols,
612 imports,
613 calls,
614 language,
615 })
616 }
617
618 fn compile_java_queries() -> Result<CompiledQueries> {
619 let language = tree_sitter_java::language();
620
621 let symbols_query = r#"
622 (method_declaration name: (identifier) @name) @method_def
623 (class_declaration name: (identifier) @name) @class_def
624 (interface_declaration name: (identifier) @name) @interface_def
625 (constructor_declaration name: (identifier) @name) @function_def
626 "#;
627
628 let imports_query = r#"
629 (import_declaration) @source
630 "#;
631
632 let calls_query = r#"
633 (method_invocation name: (identifier) @callee)
634 "#;
635
636 let symbols = Query::new(&language, symbols_query)
637 .map_err(|e| ParseError::QueryError(e.to_string()))?;
638 let imports = Query::new(&language, imports_query)
639 .map_err(|e| ParseError::QueryError(e.to_string()))?;
640 let calls = Query::new(&language, calls_query)
641 .map_err(|e| ParseError::QueryError(e.to_string()))?;
642
643 Ok(CompiledQueries {
644 symbols,
645 imports,
646 calls,
647 language,
648 })
649 }
650
651 fn compile_c_queries() -> Result<CompiledQueries> {
652 let language = tree_sitter_c::language();
653
654 let symbols_query = r#"
655 (function_definition declarator: (function_declarator declarator: (identifier) @name)) @function_def
656 (struct_specifier name: (type_identifier) @name) @struct_def
657 (enum_specifier name: (type_identifier) @name) @enum_def
658 "#;
659
660 let imports_query = r#"
661 (preproc_include path: (string_literal) @source)
662 (preproc_include path: (system_lib_string) @source)
663 "#;
664
665 let calls_query = r#"
666 (call_expression function: (identifier) @callee)
667 "#;
668
669 let symbols = Query::new(&language, symbols_query)
670 .map_err(|e| ParseError::QueryError(e.to_string()))?;
671 let imports = Query::new(&language, imports_query)
672 .map_err(|e| ParseError::QueryError(e.to_string()))?;
673 let calls = Query::new(&language, calls_query)
674 .map_err(|e| ParseError::QueryError(e.to_string()))?;
675
676 Ok(CompiledQueries {
677 symbols,
678 imports,
679 calls,
680 language,
681 })
682 }
683
684 fn compile_cpp_queries() -> Result<CompiledQueries> {
685 let language = tree_sitter_cpp::language();
686
687 let symbols_query = r#"
688 (function_definition declarator: (function_declarator declarator: (identifier) @name)) @function_def
689 (function_definition declarator: (function_declarator declarator: (qualified_identifier name: (identifier) @name))) @method_def
690 (class_specifier name: (type_identifier) @name) @class_def
691 (struct_specifier name: (type_identifier) @name) @struct_def
692 "#;
693
694 let imports_query = r#"
695 (preproc_include path: (string_literal) @source)
696 (preproc_include path: (system_lib_string) @source)
697 "#;
698
699 let calls_query = r#"
700 (call_expression function: (identifier) @callee)
701 (call_expression function: (field_expression field: (field_identifier) @callee))
702 "#;
703
704 let symbols = Query::new(&language, symbols_query)
705 .map_err(|e| ParseError::QueryError(e.to_string()))?;
706 let imports = Query::new(&language, imports_query)
707 .map_err(|e| ParseError::QueryError(e.to_string()))?;
708 let calls = Query::new(&language, calls_query)
709 .map_err(|e| ParseError::QueryError(e.to_string()))?;
710
711 Ok(CompiledQueries {
712 symbols,
713 imports,
714 calls,
715 language,
716 })
717 }
718
719 fn compile_csharp_queries() -> Result<CompiledQueries> {
720 let language = tree_sitter_c_sharp::language();
721
722 let symbols_query = r#"
723 (method_declaration name: (identifier) @name) @method_def
724 (class_declaration name: (identifier) @name) @class_def
725 (interface_declaration name: (identifier) @name) @interface_def
726 (struct_declaration name: (identifier) @name) @struct_def
727 (constructor_declaration name: (identifier) @name) @function_def
728 (property_declaration name: (identifier) @name) @method_def
729 "#;
730
731 let imports_query = r#"
732 (using_directive (identifier) @source)
733 (using_directive (qualified_name) @source)
734 "#;
735
736 let calls_query = r#"
737 (invocation_expression function: (identifier) @callee)
738 (invocation_expression function: (member_access_expression name: (identifier) @callee))
739 "#;
740
741 let symbols = Query::new(&language, symbols_query)
742 .map_err(|e| ParseError::QueryError(e.to_string()))?;
743 let imports = Query::new(&language, imports_query)
744 .map_err(|e| ParseError::QueryError(e.to_string()))?;
745 let calls = Query::new(&language, calls_query)
746 .map_err(|e| ParseError::QueryError(e.to_string()))?;
747
748 Ok(CompiledQueries {
749 symbols,
750 imports,
751 calls,
752 language,
753 })
754 }
755
756 }
760
761#[cfg(test)]
766mod tests {
767 use super::*;
768
769 #[test]
770 fn test_parser_initialization() {
771 match ArborParser::new() {
773 Ok(_) => println!("Parser initialized successfully!"),
774 Err(e) => panic!("Parser failed to initialize: {}", e),
775 }
776 }
777
778 #[test]
779 fn test_parse_typescript_symbols() {
780 let mut parser = ArborParser::new().unwrap();
781
782 let source = r#"
783 function greet(name: string): string {
784 return `Hello, ${name}!`;
785 }
786
787 export class UserService {
788 validate(user: User): boolean {
789 return true;
790 }
791 }
792
793 interface User {
794 name: string;
795 email: string;
796 }
797 "#;
798
799 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
800
801 assert!(result.symbols.iter().any(|s| s.name == "greet"));
802 assert!(result.symbols.iter().any(|s| s.name == "UserService"));
803 assert!(result.symbols.iter().any(|s| s.name == "validate"));
804 assert!(result.symbols.iter().any(|s| s.name == "User"));
805 }
806
807 #[test]
808 fn test_parse_typescript_imports() {
809 let mut parser = ArborParser::new().unwrap();
810
811 let source = r#"
812 import { useState } from 'react';
813 import lodash from 'lodash';
814
815 function Component() {
816 const [count, setCount] = useState(0);
817 }
818 "#;
819
820 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
821
822 let imports: Vec<_> = result
823 .relations
824 .iter()
825 .filter(|r| r.kind == RelationType::Imports)
826 .collect();
827
828 assert!(imports.iter().any(|i| i.to_name.contains("react")));
829 assert!(imports.iter().any(|i| i.to_name.contains("lodash")));
830 }
831
832 #[test]
833 fn test_parse_typescript_calls() {
834 let mut parser = ArborParser::new().unwrap();
835
836 let source = r#"
837 function outer() {
838 inner();
839 helper.process();
840 }
841
842 function inner() {
843 console.log("Hello");
844 }
845 "#;
846
847 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
848
849 let calls: Vec<_> = result
850 .relations
851 .iter()
852 .filter(|r| r.kind == RelationType::Calls)
853 .collect();
854
855 assert!(calls.iter().any(|c| c.to_name == "inner"));
856 assert!(calls.iter().any(|c| c.to_name == "process"));
857 assert!(calls.iter().any(|c| c.to_name == "log"));
858 }
859
860 #[test]
861 fn test_parse_rust_symbols() {
862 let mut parser = ArborParser::new().unwrap();
863
864 let source = r#"
865 fn main() {
866 println!("Hello!");
867 }
868
869 pub struct User {
870 name: String,
871 }
872
873 impl User {
874 fn new(name: &str) -> Self {
875 Self { name: name.to_string() }
876 }
877 }
878
879 enum Status {
880 Active,
881 Inactive,
882 }
883 "#;
884
885 let result = parser.parse_source(source, "test.rs", "rs").unwrap();
886
887 assert!(result.symbols.iter().any(|s| s.name == "main"));
888 assert!(result.symbols.iter().any(|s| s.name == "User"));
889 assert!(result.symbols.iter().any(|s| s.name == "new"));
890 assert!(result.symbols.iter().any(|s| s.name == "Status"));
891 }
892
893 #[test]
894 fn test_parse_python_symbols() {
895 let mut parser = ArborParser::new().unwrap();
896
897 let source = r#"
898def greet(name):
899 return f"Hello, {name}!"
900
901class UserService:
902 def validate(self, user):
903 return True
904 "#;
905
906 let result = parser.parse_source(source, "test.py", "py").unwrap();
907
908 assert!(result.symbols.iter().any(|s| s.name == "greet"));
909 assert!(result.symbols.iter().any(|s| s.name == "UserService"));
910 assert!(result.symbols.iter().any(|s| s.name == "validate"));
911 }
912}