1use crate::error::{ParseError, Result};
11use crate::node::{CodeNode, NodeKind};
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15use tree_sitter::{Language, Parser, Query, QueryCursor, Tree};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct SymbolRelation {
24 pub from_id: String,
26 pub to_name: String,
28 pub kind: RelationType,
30 pub line: u32,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum RelationType {
37 Calls,
39 Imports,
41 Extends,
43 Implements,
45}
46
47#[derive(Debug)]
49pub struct ParseResult {
50 pub symbols: Vec<CodeNode>,
52 pub relations: Vec<SymbolRelation>,
54 pub file_path: String,
56}
57
58pub struct ArborParser {
67 parser: Parser,
69 queries: HashMap<String, CompiledQueries>,
71}
72
73struct CompiledQueries {
75 symbols: Query,
77 imports: Query,
79 calls: Query,
81 language: Language,
83}
84
85impl Default for ArborParser {
86 fn default() -> Self {
87 Self::new().expect("Failed to initialize ArborParser")
88 }
89}
90
91impl ArborParser {
92 pub fn new() -> Result<Self> {
96 let parser = Parser::new();
97 let mut queries = HashMap::new();
98
99 for ext in &["ts", "tsx", "js", "jsx"] {
101 let compiled = Self::compile_typescript_queries()?;
103 queries.insert(ext.to_string(), compiled);
104 }
105
106 let rs_queries = Self::compile_rust_queries()?;
108 queries.insert("rs".to_string(), rs_queries);
109
110 let py_queries = Self::compile_python_queries()?;
112 queries.insert("py".to_string(), py_queries);
113
114 Ok(Self { parser, queries })
115 }
116
117 pub fn parse_file(&mut self, path: &Path) -> Result<ParseResult> {
129 let source = fs::read_to_string(path).map_err(|e| ParseError::io(path, e))?;
131
132 if source.is_empty() {
133 return Err(ParseError::EmptyFile(path.to_path_buf()));
134 }
135
136 let ext = path
138 .extension()
139 .and_then(|e| e.to_str())
140 .ok_or_else(|| ParseError::UnsupportedLanguage(path.to_path_buf()))?;
141
142 let compiled = self
144 .queries
145 .get(ext)
146 .ok_or_else(|| ParseError::UnsupportedLanguage(path.to_path_buf()))?;
147
148 self.parser
150 .set_language(compiled.language)
151 .map_err(|e| ParseError::ParserError(format!("Failed to set language: {}", e)))?;
152
153 let tree = self
155 .parser
156 .parse(&source, None)
157 .ok_or_else(|| ParseError::ParserError("Tree-sitter returned no tree".into()))?;
158
159 let file_path = path.to_string_lossy().to_string();
160 let file_name = path
161 .file_name()
162 .and_then(|n| n.to_str())
163 .unwrap_or("unknown");
164
165 let symbols = self.extract_symbols(&tree, &source, &file_path, file_name, compiled);
167
168 let relations = self.extract_relations(&tree, &source, &file_path, &symbols, compiled);
170
171 Ok(ParseResult {
172 symbols,
173 relations,
174 file_path,
175 })
176 }
177
178 pub fn parse_source(
180 &mut self,
181 source: &str,
182 file_path: &str,
183 language: &str,
184 ) -> Result<ParseResult> {
185 if source.is_empty() {
186 return Err(ParseError::EmptyFile(file_path.into()));
187 }
188
189 let compiled = self
190 .queries
191 .get(language)
192 .ok_or_else(|| ParseError::UnsupportedLanguage(file_path.into()))?;
193
194 self.parser
195 .set_language(compiled.language)
196 .map_err(|e| ParseError::ParserError(format!("Failed to set language: {}", e)))?;
197
198 let tree = self
199 .parser
200 .parse(source, None)
201 .ok_or_else(|| ParseError::ParserError("Tree-sitter returned no tree".into()))?;
202
203 let file_name = Path::new(file_path)
204 .file_name()
205 .and_then(|n| n.to_str())
206 .unwrap_or("unknown");
207
208 let symbols = self.extract_symbols(&tree, source, file_path, file_name, compiled);
209 let relations = self.extract_relations(&tree, source, file_path, &symbols, compiled);
210
211 Ok(ParseResult {
212 symbols,
213 relations,
214 file_path: file_path.to_string(),
215 })
216 }
217
218 fn extract_symbols(
223 &self,
224 tree: &Tree,
225 source: &str,
226 file_path: &str,
227 file_name: &str,
228 compiled: &CompiledQueries,
229 ) -> Vec<CodeNode> {
230 let mut symbols = Vec::new();
231 let mut cursor = QueryCursor::new();
232
233 let matches = cursor.matches(&compiled.symbols, tree.root_node(), source.as_bytes());
234
235 for match_ in matches {
236 let mut name: Option<&str> = None;
238 let mut kind: Option<NodeKind> = None;
239 let mut node = match_.captures.first().map(|c| c.node);
240
241 for capture in match_.captures {
242 let capture_name = &compiled.symbols.capture_names()[capture.index as usize];
243 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
244
245 match capture_name.as_str() {
246 "name" | "function.name" | "class.name" | "interface.name" | "method.name" => {
247 name = Some(text);
248 }
249 "function" | "function_def" => {
250 kind = Some(NodeKind::Function);
251 node = Some(capture.node);
252 }
253 "class" | "class_def" => {
254 kind = Some(NodeKind::Class);
255 node = Some(capture.node);
256 }
257 "interface" | "interface_def" => {
258 kind = Some(NodeKind::Interface);
259 node = Some(capture.node);
260 }
261 "method" | "method_def" => {
262 kind = Some(NodeKind::Method);
263 node = Some(capture.node);
264 }
265 "struct" | "struct_def" => {
266 kind = Some(NodeKind::Struct);
267 node = Some(capture.node);
268 }
269 "enum" | "enum_def" => {
270 kind = Some(NodeKind::Enum);
271 node = Some(capture.node);
272 }
273 "trait" | "trait_def" => {
274 kind = Some(NodeKind::Interface);
275 node = Some(capture.node);
276 }
277 _ => {}
278 }
279 }
280
281 if let (Some(name), Some(kind), Some(node)) = (name, kind, node) {
282 let qualified_name = format!("{}:{}", file_name, name);
284
285 let signature = source
287 .lines()
288 .nth(node.start_position().row)
289 .map(|s| s.trim().to_string());
290
291 let mut symbol = CodeNode::new(name, &qualified_name, kind, file_path)
292 .with_lines(
293 node.start_position().row as u32 + 1,
294 node.end_position().row as u32 + 1,
295 )
296 .with_column(node.start_position().column as u32)
297 .with_bytes(node.start_byte() as u32, node.end_byte() as u32);
298
299 if let Some(sig) = signature {
300 symbol = symbol.with_signature(sig);
301 }
302
303 symbols.push(symbol);
304 }
305 }
306
307 symbols
308 }
309
310 fn extract_relations(
315 &self,
316 tree: &Tree,
317 source: &str,
318 file_path: &str,
319 symbols: &[CodeNode],
320 compiled: &CompiledQueries,
321 ) -> Vec<SymbolRelation> {
322 let mut relations = Vec::new();
323
324 self.extract_imports(tree, source, file_path, &mut relations, compiled);
326
327 self.extract_calls(tree, source, file_path, symbols, &mut relations, compiled);
329
330 relations
331 }
332
333 fn extract_imports(
334 &self,
335 tree: &Tree,
336 source: &str,
337 file_path: &str,
338 relations: &mut Vec<SymbolRelation>,
339 compiled: &CompiledQueries,
340 ) {
341 let mut cursor = QueryCursor::new();
342 let matches = cursor.matches(&compiled.imports, tree.root_node(), source.as_bytes());
343
344 for match_ in matches {
345 let mut module_name: Option<&str> = None;
346 let mut line: u32 = 0;
347
348 for capture in match_.captures {
349 let capture_name = &compiled.imports.capture_names()[capture.index as usize];
350 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
351
352 match capture_name.as_str() {
353 "source" | "module" | "import.source" => {
354 module_name = Some(text.trim_matches(|c| c == '"' || c == '\''));
356 line = capture.node.start_position().row as u32 + 1;
357 }
358 _ => {}
359 }
360 }
361
362 if let Some(module) = module_name {
363 let file_id = format!("{}:__file__", file_path);
365 relations.push(SymbolRelation {
366 from_id: file_id,
367 to_name: module.to_string(),
368 kind: RelationType::Imports,
369 line,
370 });
371 }
372 }
373 }
374
375 fn extract_calls(
376 &self,
377 tree: &Tree,
378 source: &str,
379 file_path: &str,
380 symbols: &[CodeNode],
381 relations: &mut Vec<SymbolRelation>,
382 compiled: &CompiledQueries,
383 ) {
384 let mut cursor = QueryCursor::new();
385 let matches = cursor.matches(&compiled.calls, tree.root_node(), source.as_bytes());
386
387 for match_ in matches {
388 let mut callee_name: Option<&str> = None;
389 let mut call_line: u32 = 0;
390
391 for capture in match_.captures {
392 let capture_name = &compiled.calls.capture_names()[capture.index as usize];
393 let text = capture.node.utf8_text(source.as_bytes()).unwrap_or("");
394
395 match capture_name.as_str() {
396 "callee" | "function" | "call.function" => {
397 if let Some(dot_pos) = text.rfind('.') {
399 callee_name = Some(&text[dot_pos + 1..]);
400 } else {
401 callee_name = Some(text);
402 }
403 call_line = capture.node.start_position().row as u32 + 1;
404 }
405 _ => {}
406 }
407 }
408
409 if let Some(callee) = callee_name {
410 let caller_id = self
412 .find_enclosing_symbol(call_line, symbols)
413 .map(|s| s.id.clone())
414 .unwrap_or_else(|| format!("{}:__file__", file_path));
415
416 relations.push(SymbolRelation {
417 from_id: caller_id,
418 to_name: callee.to_string(),
419 kind: RelationType::Calls,
420 line: call_line,
421 });
422 }
423 }
424 }
425
426 fn find_enclosing_symbol<'a>(
427 &self,
428 line: u32,
429 symbols: &'a [CodeNode],
430 ) -> Option<&'a CodeNode> {
431 symbols
432 .iter()
433 .filter(|s| s.line_start <= line && s.line_end >= line)
434 .min_by_key(|s| s.line_end - s.line_start) }
436
437 fn compile_typescript_queries() -> Result<CompiledQueries> {
442 let language = tree_sitter_typescript::language_typescript();
443
444 let symbols_query = r#"
446 (function_declaration name: (identifier) @name) @function_def
447 (class_declaration name: (type_identifier) @name) @class_def
448 (method_definition name: (property_identifier) @name) @method_def
449 (interface_declaration name: (type_identifier) @name) @interface_def
450 (type_alias_declaration name: (type_identifier) @name) @interface_def
451 "#;
452
453 let imports_query = r#"
455 (import_statement
456 source: (string) @source)
457 "#;
458
459 let calls_query = r#"
461 (call_expression
462 function: (identifier) @callee)
463
464 (call_expression
465 function: (member_expression
466 property: (property_identifier) @callee))
467 "#;
468
469 let symbols =
470 Query::new(language, symbols_query).map_err(|e| ParseError::QueryError(e.message))?;
471 let imports =
472 Query::new(language, imports_query).map_err(|e| ParseError::QueryError(e.message))?;
473 let calls =
474 Query::new(language, calls_query).map_err(|e| ParseError::QueryError(e.message))?;
475
476 Ok(CompiledQueries {
477 symbols,
478 imports,
479 calls,
480 language,
481 })
482 }
483
484 fn compile_rust_queries() -> Result<CompiledQueries> {
485 let language = tree_sitter_rust::language();
486
487 let symbols_query = r#"
489 (function_item name: (identifier) @name) @function_def
490 (struct_item name: (type_identifier) @name) @struct_def
491 (enum_item name: (type_identifier) @name) @enum_def
492 (trait_item name: (type_identifier) @name) @trait_def
493 "#;
494
495 let imports_query = r#"
497 (use_declaration) @source
498 "#;
499
500 let calls_query = r#"
502 (call_expression function: (identifier) @callee)
503 (call_expression function: (field_expression field: (field_identifier) @callee))
504 "#;
505
506 let symbols =
507 Query::new(language, symbols_query).map_err(|e| ParseError::QueryError(e.message))?;
508 let imports =
509 Query::new(language, imports_query).map_err(|e| ParseError::QueryError(e.message))?;
510 let calls =
511 Query::new(language, calls_query).map_err(|e| ParseError::QueryError(e.message))?;
512
513 Ok(CompiledQueries {
514 symbols,
515 imports,
516 calls,
517 language,
518 })
519 }
520
521 fn compile_python_queries() -> Result<CompiledQueries> {
522 let language = tree_sitter_python::language();
523
524 let symbols_query = r#"
526 (function_definition name: (identifier) @name) @function_def
527 (class_definition name: (identifier) @name) @class_def
528 "#;
529
530 let imports_query = r#"
532 (import_statement) @source
533 (import_from_statement) @source
534 "#;
535
536 let calls_query = r#"
538 (call function: (identifier) @callee)
539 (call function: (attribute attribute: (identifier) @callee))
540 "#;
541
542 let symbols =
543 Query::new(language, symbols_query).map_err(|e| ParseError::QueryError(e.message))?;
544 let imports =
545 Query::new(language, imports_query).map_err(|e| ParseError::QueryError(e.message))?;
546 let calls =
547 Query::new(language, calls_query).map_err(|e| ParseError::QueryError(e.message))?;
548
549 Ok(CompiledQueries {
550 symbols,
551 imports,
552 calls,
553 language,
554 })
555 }
556}
557
558#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_parser_initialization() {
568 match ArborParser::new() {
570 Ok(_) => println!("Parser initialized successfully!"),
571 Err(e) => panic!("Parser failed to initialize: {}", e),
572 }
573 }
574
575 #[test]
576 fn test_parse_typescript_symbols() {
577 let mut parser = ArborParser::new().unwrap();
578
579 let source = r#"
580 function greet(name: string): string {
581 return `Hello, ${name}!`;
582 }
583
584 export class UserService {
585 validate(user: User): boolean {
586 return true;
587 }
588 }
589
590 interface User {
591 name: string;
592 email: string;
593 }
594 "#;
595
596 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
597
598 assert!(result.symbols.iter().any(|s| s.name == "greet"));
599 assert!(result.symbols.iter().any(|s| s.name == "UserService"));
600 assert!(result.symbols.iter().any(|s| s.name == "validate"));
601 assert!(result.symbols.iter().any(|s| s.name == "User"));
602 }
603
604 #[test]
605 fn test_parse_typescript_imports() {
606 let mut parser = ArborParser::new().unwrap();
607
608 let source = r#"
609 import { useState } from 'react';
610 import lodash from 'lodash';
611
612 function Component() {
613 const [count, setCount] = useState(0);
614 }
615 "#;
616
617 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
618
619 let imports: Vec<_> = result
620 .relations
621 .iter()
622 .filter(|r| r.kind == RelationType::Imports)
623 .collect();
624
625 assert!(imports.iter().any(|i| i.to_name.contains("react")));
626 assert!(imports.iter().any(|i| i.to_name.contains("lodash")));
627 }
628
629 #[test]
630 fn test_parse_typescript_calls() {
631 let mut parser = ArborParser::new().unwrap();
632
633 let source = r#"
634 function outer() {
635 inner();
636 helper.process();
637 }
638
639 function inner() {
640 console.log("Hello");
641 }
642 "#;
643
644 let result = parser.parse_source(source, "test.ts", "ts").unwrap();
645
646 let calls: Vec<_> = result
647 .relations
648 .iter()
649 .filter(|r| r.kind == RelationType::Calls)
650 .collect();
651
652 assert!(calls.iter().any(|c| c.to_name == "inner"));
653 assert!(calls.iter().any(|c| c.to_name == "process"));
654 assert!(calls.iter().any(|c| c.to_name == "log"));
655 }
656
657 #[test]
658 fn test_parse_rust_symbols() {
659 let mut parser = ArborParser::new().unwrap();
660
661 let source = r#"
662 fn main() {
663 println!("Hello!");
664 }
665
666 pub struct User {
667 name: String,
668 }
669
670 impl User {
671 fn new(name: &str) -> Self {
672 Self { name: name.to_string() }
673 }
674 }
675
676 enum Status {
677 Active,
678 Inactive,
679 }
680 "#;
681
682 let result = parser.parse_source(source, "test.rs", "rs").unwrap();
683
684 assert!(result.symbols.iter().any(|s| s.name == "main"));
685 assert!(result.symbols.iter().any(|s| s.name == "User"));
686 assert!(result.symbols.iter().any(|s| s.name == "new"));
687 assert!(result.symbols.iter().any(|s| s.name == "Status"));
688 }
689
690 #[test]
691 fn test_parse_python_symbols() {
692 let mut parser = ArborParser::new().unwrap();
693
694 let source = r#"
695def greet(name):
696 return f"Hello, {name}!"
697
698class UserService:
699 def validate(self, user):
700 return True
701 "#;
702
703 let result = parser.parse_source(source, "test.py", "py").unwrap();
704
705 assert!(result.symbols.iter().any(|s| s.name == "greet"));
706 assert!(result.symbols.iter().any(|s| s.name == "UserService"));
707 assert!(result.symbols.iter().any(|s| s.name == "validate"));
708 }
709}