1use std::path::Path;
12
13use crate::{parser::Language, symbol_extraction::find_definitions};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct ResolvedSymbol {
18 pub name: String,
20 pub start_line: u32,
22 pub end_line: u32,
24 pub parent_name: Option<String>,
26}
27
28#[derive(Debug, thiserror::Error)]
30pub enum SymbolResolveError {
31 #[error("unsupported file extension: {0}")]
32 UnsupportedLanguage(String),
33
34 #[error("failed to parse source file")]
35 ParseFailed,
36
37 #[error("symbol not found: {0}")]
38 SymbolNotFound(String),
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum DefinitionKind {
47 Type,
49 Trait,
51 Class,
53 Interface,
55 TypeAlias,
57 EnumDef,
59 ConstDecl,
61 Module,
63 Function,
65 Other,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct Definition {
72 pub name: String,
76 pub kind: DefinitionKind,
77 pub start_line: u32,
79 pub end_line: u32,
81 pub parent_name: Option<String>,
83}
84
85pub fn extract_definitions(
92 source: &[u8],
93 path: &Path,
94) -> Result<Vec<Definition>, SymbolResolveError> {
95 let language = Language::from_path(path).parser_handle().ok_or_else(|| {
96 SymbolResolveError::UnsupportedLanguage(
97 path.extension()
98 .map(|e| e.to_string_lossy().into_owned())
99 .unwrap_or_else(|| "<none>".to_string()),
100 )
101 })?;
102
103 let mut parser = tree_sitter::Parser::new();
104 parser
105 .set_language(&language)
106 .map_err(|_| SymbolResolveError::ParseFailed)?;
107
108 let tree = parser
109 .parse(source, None)
110 .ok_or(SymbolResolveError::ParseFailed)?;
111
112 let mut out = Vec::new();
113 walk_definitions(&tree.root_node(), source, None, &mut out);
114 Ok(out)
115}
116
117fn node_text<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> &'a str {
118 std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
119}
120
121fn push_named_definition(
122 node: &tree_sitter::Node,
123 source: &[u8],
124 dk: DefinitionKind,
125 parent: Option<&str>,
126 out: &mut Vec<Definition>,
127) {
128 if let Some(name_node) = node.child_by_field_name("name") {
129 let name = node_text(&name_node, source).to_string();
130 if name.is_empty() {
131 return;
132 }
133 out.push(Definition {
134 name,
135 kind: dk,
136 start_line: node.start_position().row as u32 + 1,
137 end_line: node.end_position().row as u32 + 1,
138 parent_name: parent.map(String::from),
139 });
140 }
141}
142
143fn walk_definitions(
144 node: &tree_sitter::Node,
145 source: &[u8],
146 current_parent: Option<&str>,
147 out: &mut Vec<Definition>,
148) {
149 let kind = node.kind();
150
151 match kind {
152 "function_item" => {
154 push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
155 }
156 "struct_item" => {
157 push_named_definition(node, source, DefinitionKind::Type, current_parent, out)
158 }
159 "enum_item" => {
160 push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
161 }
162 "trait_item" => {
163 push_named_definition(node, source, DefinitionKind::Trait, current_parent, out)
164 }
165 "type_item" => {
166 push_named_definition(node, source, DefinitionKind::TypeAlias, current_parent, out)
167 }
168 "const_item" | "static_item" => {
169 push_named_definition(node, source, DefinitionKind::ConstDecl, current_parent, out)
170 }
171 "mod_item" => {
172 push_named_definition(node, source, DefinitionKind::Module, current_parent, out)
173 }
174 "impl_item" => {
175 let parent_name = extract_rust_impl_type_name(node, source);
178 let parent = parent_name.as_deref();
179 let mut cursor = node.walk();
180 for child in node.children(&mut cursor) {
181 walk_definitions(&child, source, parent, out);
182 }
183 return;
184 }
185
186 "function_definition" => {
188 push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
189 }
190 "class_definition" => {
191 let class_name = node
192 .child_by_field_name("name")
193 .map(|n| node_text(&n, source).to_string());
194 if let Some(ref name) = class_name
195 && !name.is_empty()
196 {
197 out.push(Definition {
198 name: name.clone(),
199 kind: DefinitionKind::Class,
200 start_line: node.start_position().row as u32 + 1,
201 end_line: node.end_position().row as u32 + 1,
202 parent_name: current_parent.map(String::from),
203 });
204 }
205 let parent = class_name.as_deref();
206 let mut cursor = node.walk();
207 for child in node.children(&mut cursor) {
208 walk_definitions(&child, source, parent, out);
209 }
210 return;
211 }
212
213 "function_declaration" => {
215 push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
218 }
219 "method_declaration" => {
220 if let Some(name_node) = node.child_by_field_name("name") {
221 let name = node_text(&name_node, source).to_string();
222 if !name.is_empty() {
223 let receiver = extract_go_receiver_type(node, source);
224 out.push(Definition {
225 name,
226 kind: DefinitionKind::Function,
227 start_line: node.start_position().row as u32 + 1,
228 end_line: node.end_position().row as u32 + 1,
229 parent_name: receiver.or_else(|| current_parent.map(String::from)),
230 });
231 }
232 }
233 }
234 "type_declaration" => {
235 let mut cursor = node.walk();
237 for child in node.children(&mut cursor) {
238 if child.kind() == "type_spec"
239 && let Some(name_node) = child.child_by_field_name("name")
240 {
241 let name = node_text(&name_node, source).to_string();
242 if name.is_empty() {
243 continue;
244 }
245 let dk = match child.child_by_field_name("type").map(|t| t.kind()) {
246 Some("interface_type") => DefinitionKind::Interface,
247 Some("struct_type") => DefinitionKind::Type,
248 _ => DefinitionKind::TypeAlias,
249 };
250 out.push(Definition {
251 name,
252 kind: dk,
253 start_line: child.start_position().row as u32 + 1,
254 end_line: child.end_position().row as u32 + 1,
255 parent_name: current_parent.map(String::from),
256 });
257 }
258 }
259 }
260
261 "method_definition" => {
263 push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
264 }
265 "class_declaration" => {
266 let class_name = node
267 .child_by_field_name("name")
268 .map(|n| node_text(&n, source).to_string());
269 if let Some(ref name) = class_name
270 && !name.is_empty()
271 {
272 out.push(Definition {
273 name: name.clone(),
274 kind: DefinitionKind::Class,
275 start_line: node.start_position().row as u32 + 1,
276 end_line: node.end_position().row as u32 + 1,
277 parent_name: current_parent.map(String::from),
278 });
279 }
280 let parent = class_name.as_deref();
281 let mut cursor = node.walk();
282 for child in node.children(&mut cursor) {
283 walk_definitions(&child, source, parent, out);
284 }
285 return;
286 }
287 "interface_declaration" => {
288 push_named_definition(node, source, DefinitionKind::Interface, current_parent, out)
289 }
290 "type_alias_declaration" => {
291 push_named_definition(node, source, DefinitionKind::TypeAlias, current_parent, out)
292 }
293 "enum_declaration" => {
294 push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
295 }
296 "lexical_declaration" | "variable_declaration" => {
297 let mut cursor = node.walk();
299 for child in node.children(&mut cursor) {
300 if child.kind() == "variable_declarator"
301 && let Some(name_node) = child.child_by_field_name("name")
302 {
303 let name = node_text(&name_node, source).to_string();
304 if name.is_empty() {
305 continue;
306 }
307 if let Some(value_node) = child.child_by_field_name("value") {
308 let vkind = value_node.kind();
309 let dk = if vkind == "arrow_function"
310 || vkind == "function"
311 || vkind == "function_expression"
312 {
313 DefinitionKind::Function
314 } else {
315 DefinitionKind::ConstDecl
316 };
317 out.push(Definition {
318 name,
319 kind: dk,
320 start_line: node.start_position().row as u32 + 1,
321 end_line: node.end_position().row as u32 + 1,
322 parent_name: current_parent.map(String::from),
323 });
324 }
325 }
326 }
327 }
328
329 "struct_specifier" | "class_specifier" => {
331 push_named_definition(node, source, DefinitionKind::Class, current_parent, out)
332 }
333 "namespace_definition" => {
334 push_named_definition(node, source, DefinitionKind::Module, current_parent, out)
335 }
336 "enum_specifier" => {
337 push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
338 }
339 "constructor_declaration" => {
340 push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
341 }
342
343 _ => {}
344 }
345
346 let mut cursor = node.walk();
348 for child in node.children(&mut cursor) {
349 walk_definitions(&child, source, current_parent, out);
350 }
351}
352
353fn extract_rust_impl_type_name(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
354 let type_node = node.child_by_field_name("type")?;
355 Some(extract_type_identifier(&type_node, source))
356}
357
358fn extract_type_identifier(node: &tree_sitter::Node, source: &[u8]) -> String {
359 match node.kind() {
360 "type_identifier" | "identifier" => node_text(node, source).to_string(),
361 "generic_type" | "scoped_type_identifier" => {
362 let mut cursor = node.walk();
363 for child in node.children(&mut cursor) {
364 if child.kind() == "type_identifier" || child.kind() == "identifier" {
365 return node_text(&child, source).to_string();
366 }
367 }
368 node_text(node, source).to_string()
369 }
370 _ => node_text(node, source).to_string(),
371 }
372}
373
374fn extract_go_receiver_type(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
375 let params = node.child_by_field_name("receiver")?;
376 let mut cursor = params.walk();
377 for child in params.children(&mut cursor) {
378 if child.kind() == "parameter_declaration"
379 && let Some(type_node) = child.child_by_field_name("type")
380 {
381 let text = node_text(&type_node, source);
382 return Some(text.trim_start_matches('*').to_string());
383 }
384 }
385 None
386}
387
388pub fn resolve_symbol_lines(
396 source: &[u8],
397 path: &Path,
398 symbol: &str,
399) -> Result<(u32, u32), SymbolResolveError> {
400 let language = Language::from_path(path).parser_handle().ok_or_else(|| {
401 SymbolResolveError::UnsupportedLanguage(
402 path.extension()
403 .map(|e| e.to_string_lossy().into_owned())
404 .unwrap_or_else(|| "<none>".to_string()),
405 )
406 })?;
407
408 let mut parser = tree_sitter::Parser::new();
409 parser
410 .set_language(&language)
411 .map_err(|_| SymbolResolveError::ParseFailed)?;
412
413 let tree = parser
414 .parse(source, None)
415 .ok_or(SymbolResolveError::ParseFailed)?;
416
417 let (parent_filter, target_name) = if let Some(pos) = symbol.rfind("::") {
419 (Some(&symbol[..pos]), &symbol[pos + 2..])
420 } else {
421 (None, symbol)
422 };
423
424 let definitions = find_definitions(&tree.root_node(), source, target_name);
425
426 let matched = if let Some(parent) = parent_filter {
428 definitions
429 .iter()
430 .find(|d| {
431 d.parent_name
432 .as_deref()
433 .map(|p| p == parent)
434 .unwrap_or(false)
435 })
436 .or_else(|| definitions.first())
437 } else {
438 definitions.first()
439 };
440
441 match matched {
442 Some(sym) => Ok((sym.start_line, sym.end_line)),
443 None => Err(SymbolResolveError::SymbolNotFound(symbol.to_string())),
444 }
445}
446
447pub fn resolve_all_symbols(
452 source: &[u8],
453 path: &Path,
454 symbol: &str,
455) -> Result<Vec<ResolvedSymbol>, SymbolResolveError> {
456 let language = Language::from_path(path).parser_handle().ok_or_else(|| {
457 SymbolResolveError::UnsupportedLanguage(
458 path.extension()
459 .map(|e| e.to_string_lossy().into_owned())
460 .unwrap_or_else(|| "<none>".to_string()),
461 )
462 })?;
463
464 let mut parser = tree_sitter::Parser::new();
465 parser
466 .set_language(&language)
467 .map_err(|_| SymbolResolveError::ParseFailed)?;
468
469 let tree = parser
470 .parse(source, None)
471 .ok_or(SymbolResolveError::ParseFailed)?;
472
473 let (parent_filter, target_name) = if let Some(pos) = symbol.rfind("::") {
474 (Some(&symbol[..pos]), &symbol[pos + 2..])
475 } else {
476 (None, symbol)
477 };
478
479 let definitions = find_definitions(&tree.root_node(), source, target_name);
480
481 if let Some(parent) = parent_filter {
482 let filtered: Vec<_> = definitions
483 .into_iter()
484 .filter(|d| {
485 d.parent_name
486 .as_deref()
487 .map(|p| p == parent)
488 .unwrap_or(false)
489 })
490 .collect();
491 Ok(filtered)
492 } else {
493 Ok(definitions)
494 }
495}
496
497pub fn extract_line_range(source: &[u8], start: u32, end: u32) -> Vec<u8> {
502 let mut result = Vec::new();
503 let mut current_line: u32 = 1;
504 let mut i = 0;
505
506 while i < source.len() {
507 if current_line >= start && current_line <= end {
508 result.push(source[i]);
509 }
510 if source[i] == b'\n' {
511 current_line += 1;
512 if current_line > end {
513 break;
514 }
515 }
516 i += 1;
517 }
518
519 result
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn resolve_rust_fn_main() {
528 let source = br#"
529fn helper() -> bool {
530 true
531}
532
533fn main() {
534 println!("hello");
535 let x = 1;
536}
537
538fn after() {}
539"#;
540 let path = Path::new("test.rs");
541 let (start, end) = resolve_symbol_lines(source, path, "main").unwrap();
542 assert_eq!(start, 6);
543 assert_eq!(end, 9);
544 }
545
546 #[test]
547 fn resolve_rust_qualified_impl_method() {
548 let source = br#"
549struct Repository {
550 path: String,
551}
552
553impl Repository {
554 pub fn open(path: &str) -> Self {
555 Repository {
556 path: path.to_string(),
557 }
558 }
559
560 pub fn close(&self) {}
561}
562
563impl Default for Repository {
564 fn default() -> Self {
565 Repository::open(".")
566 }
567}
568"#;
569 let path = Path::new("repo.rs");
570 let (start, end) = resolve_symbol_lines(source, path, "Repository::open").unwrap();
571 assert_eq!(start, 7);
572 assert_eq!(end, 11);
573 }
574
575 #[test]
576 fn resolve_rust_struct() {
577 let source = br#"
578pub struct Config {
579 pub name: String,
580 pub value: u32,
581}
582"#;
583 let path = Path::new("config.rs");
584 let (start, end) = resolve_symbol_lines(source, path, "Config").unwrap();
585 assert_eq!(start, 2);
586 assert_eq!(end, 5);
587 }
588
589 #[test]
590 fn resolve_python_function() {
591 let source = br#"
592def helper():
593 pass
594
595def process_data(items):
596 result = []
597 for item in items:
598 result.append(item * 2)
599 return result
600
601def cleanup():
602 pass
603"#;
604 let path = Path::new("main.py");
605 let (start, end) = resolve_symbol_lines(source, path, "process_data").unwrap();
606 assert_eq!(start, 5);
607 assert_eq!(end, 9);
608 }
609
610 #[test]
611 fn resolve_python_class_method() {
612 let source = br#"
613class Repository:
614 def __init__(self, path):
615 self.path = path
616
617 def open(self):
618 return True
619"#;
620 let path = Path::new("repo.py");
621 let (start, end) = resolve_symbol_lines(source, path, "Repository::open").unwrap();
622 assert_eq!(start, 6);
623 assert_eq!(end, 7);
624 }
625
626 #[test]
627 #[cfg(feature = "lang-go")]
628 fn resolve_go_function() {
629 let source = br#"package main
630
631func helper() bool {
632 return true
633}
634
635func processData(items []int) []int {
636 result := make([]int, 0)
637 for _, item := range items {
638 result = append(result, item*2)
639 }
640 return result
641}
642"#;
643 let path = Path::new("main.go");
644 let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
645 assert_eq!(start, 7);
646 assert_eq!(end, 13);
647 }
648
649 #[test]
650 fn resolve_symbol_not_found() {
651 let source = br#"
652fn main() {}
653"#;
654 let path = Path::new("test.rs");
655 let err = resolve_symbol_lines(source, path, "nonexistent").unwrap_err();
656 assert!(matches!(err, SymbolResolveError::SymbolNotFound(_)));
657 }
658
659 #[test]
660 fn resolve_unsupported_extension() {
661 let source = b"some content";
662 let path = Path::new("test.xyz");
663 let err = resolve_symbol_lines(source, path, "main").unwrap_err();
664 assert!(matches!(err, SymbolResolveError::UnsupportedLanguage(_)));
665 }
666
667 #[test]
668 fn extract_line_range_basic() {
669 let source = b"line 1\nline 2\nline 3\nline 4\nline 5\n";
670 let result = extract_line_range(source, 2, 4);
671 assert_eq!(result, b"line 2\nline 3\nline 4\n");
672 }
673
674 #[test]
675 fn extract_line_range_single_line() {
676 let source = b"line 1\nline 2\nline 3\n";
677 let result = extract_line_range(source, 2, 2);
678 assert_eq!(result, b"line 2\n");
679 }
680
681 #[test]
682 fn resolve_js_function_declaration() {
683 let source = br#"
684function helper() {
685 return true;
686}
687
688function processData(items) {
689 return items.map(x => x * 2);
690}
691"#;
692 let path = Path::new("main.js");
693 let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
694 assert_eq!(start, 6);
695 assert_eq!(end, 8);
696 }
697
698 #[test]
699 fn resolve_js_arrow_function_const() {
700 let source = br#"
701const helper = () => true;
702
703const processData = (items) => {
704 return items.map(x => x * 2);
705};
706"#;
707 let path = Path::new("utils.js");
708 let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
709 assert_eq!(start, 4);
710 assert_eq!(end, 6);
711 }
712
713 #[test]
720 fn resolve_typescript_object_literal_property_arrow_function() {
721 let source = br#"
722export const db = {
723 query: async (sql: string) => {
724 return [];
725 },
726 insert: async (table: string, data: Record<string, any>) => {
727 const keys = Object.keys(data);
728 return keys;
729 },
730};
731"#;
732 let path = Path::new("db.ts");
733 let (start, end) = resolve_symbol_lines(source, path, "insert").unwrap();
734 assert!((5..=7).contains(&start), "got start={start}");
737 assert!(end > start && end <= 10, "got end={end}");
738 }
739
740 #[test]
741 fn resolve_typescript_function() {
742 let source = br#"
743function helper(): boolean {
744 return true;
745}
746
747function processData(items: number[]): number[] {
748 return items.map(x => x * 2);
749}
750"#;
751 let path = Path::new("main.ts");
752 let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
753 assert_eq!(start, 6);
754 assert_eq!(end, 8);
755 }
756
757 #[test]
758 fn resolve_all_returns_multiple_matches() {
759 let source = br#"
760impl Foo {
761 fn do_thing(&self) {}
762}
763
764impl Bar {
765 fn do_thing(&self) {}
766}
767"#;
768 let path = Path::new("test.rs");
769 let results = resolve_all_symbols(source, path, "do_thing").unwrap();
770 assert_eq!(results.len(), 2);
771 assert_eq!(results[0].parent_name.as_deref(), Some("Foo"));
772 assert_eq!(results[1].parent_name.as_deref(), Some("Bar"));
773 }
774}