probe_code/extract/symbol_finder.rs
1//! Functions for finding symbols in files.
2//!
3//! This module provides functions for finding symbols (functions, structs, classes, etc.)
4//! in files using tree-sitter.
5
6use anyhow::Result;
7use probe_code::models::SearchResult;
8use std::path::Path;
9
10/// Find a symbol (function, struct, class, etc.) in a file by name
11///
12/// This function searches for a symbol by name in a file and returns the code block
13/// containing that symbol. It uses tree-sitter to parse the code and find the symbol.
14///
15/// # Arguments
16///
17/// * `path` - The path to the file to search in
18/// * `symbol` - The name of the symbol to find
19/// * `content` - The content of the file
20/// * `allow_tests` - Whether to include test files and test code blocks
21/// * `context_lines` - Number of context lines to include
22///
23/// # Returns
24///
25/// A SearchResult containing the extracted code block for the symbol, or an error
26/// if the symbol couldn't be found.
27pub fn find_symbol_in_file(
28 path: &Path,
29 symbol: &str,
30 content: &str,
31 _allow_tests: bool,
32 context_lines: usize,
33) -> Result<SearchResult> {
34 let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
35
36 // Check if the symbol contains a dot, indicating a nested symbol path
37 let symbol_parts: Vec<&str> = symbol.split('.').collect();
38 let is_nested_symbol = symbol_parts.len() > 1;
39
40 // For nested symbols, we'll use the AST-based approach directly
41 // The find_symbol_node function already handles nested symbols
42
43 if debug_mode {
44 println!("\n[DEBUG] ===== Symbol Search =====");
45 if is_nested_symbol {
46 println!("[DEBUG] Searching for nested symbol '{symbol}' in file {path:?}");
47 println!(
48 "[DEBUG] Symbol parts: {:?} (parent: '{}', child: '{}')",
49 symbol_parts,
50 symbol_parts[0],
51 symbol_parts.last().unwrap_or(&"")
52 );
53 } else {
54 println!("[DEBUG] Searching for symbol '{symbol}' in file {path:?}");
55 }
56 println!(
57 "[DEBUG] Content size: {content_len} bytes",
58 content_len = content.len()
59 );
60 println!(
61 "[DEBUG] Line count: {line_count}",
62 line_count = content.lines().count()
63 );
64 }
65
66 // Get the file extension to determine the language
67 let extension = path.extension().and_then(|ext| ext.to_str()).unwrap_or("");
68
69 if debug_mode {
70 println!("[DEBUG] File extension: {extension}");
71 }
72
73 // Get the language implementation for this extension
74 let language_impl = crate::language::factory::get_language_impl(extension)
75 .ok_or_else(|| anyhow::anyhow!("Unsupported language extension: {}", extension))?;
76
77 if debug_mode {
78 println!("[DEBUG] Language detected: {extension}");
79 println!("[DEBUG] Using tree-sitter to parse file");
80 }
81
82 // Parse the file with tree-sitter
83 let mut parser = tree_sitter::Parser::new();
84 parser
85 .set_language(&language_impl.get_tree_sitter_language())
86 .map_err(|e| anyhow::anyhow!("Failed to set language: {}", e))?;
87
88 let tree = parser
89 .parse(content.as_bytes(), None)
90 .ok_or_else(|| anyhow::anyhow!("Failed to parse file"))?;
91
92 let root_node = tree.root_node();
93
94 if debug_mode {
95 println!("[DEBUG] File parsed successfully");
96 println!(
97 "[DEBUG] Root node type: {root_node_kind}",
98 root_node_kind = root_node.kind()
99 );
100 println!(
101 "[DEBUG] Root node range: {}:{} - {}:{}",
102 root_node.start_position().row + 1,
103 root_node.start_position().column + 1,
104 root_node.end_position().row + 1,
105 root_node.end_position().column + 1
106 );
107 println!("[DEBUG] Searching for symbol '{symbol}' in AST");
108 }
109
110 // Function to recursively search for a node with the given symbol name
111 fn find_symbol_node<'a>(
112 node: tree_sitter::Node<'a>,
113 symbol_parts: &[&str],
114 language_impl: &dyn crate::language::language_trait::LanguageImpl,
115 content: &'a [u8],
116 debug_mode: bool,
117 ) -> Option<tree_sitter::Node<'a>> {
118 // If we're looking for a nested symbol (e.g., "Class.method"), we need to:
119 // 1. First find the parent symbol (e.g., "Class")
120 // 2. Then search within that node for the child symbol (e.g., "method")
121 let current_symbol = symbol_parts[0];
122 let is_nested = symbol_parts.len() > 1;
123
124 // Check if this node is an acceptable parent (function, struct, class, etc.)
125 if language_impl.is_acceptable_parent(&node) {
126 if debug_mode {
127 println!(
128 "[DEBUG] Checking node type '{}' at {}:{} for symbol '{}'",
129 node.kind(),
130 node.start_position().row + 1,
131 node.start_position().column + 1,
132 current_symbol
133 );
134 }
135
136 // Try to extract the name of this node
137 let mut cursor = node.walk();
138 for child in node.children(&mut cursor) {
139 if child.kind() == "identifier"
140 || child.kind() == "field_identifier"
141 || child.kind() == "type_identifier"
142 || child.kind() == "function_declarator"
143 {
144 // Get the text of this identifier
145 if let Ok(name) = child.utf8_text(content) {
146 if debug_mode {
147 println!(
148 "[DEBUG] Found identifier: '{name}' (looking for '{current_symbol}')"
149 );
150 }
151
152 if name == current_symbol {
153 if is_nested {
154 // If this is a nested symbol, we found the parent
155 // Now we need to search for the child within this node
156 if debug_mode {
157 println!(
158 "[DEBUG] Found parent symbol '{}' in node type '{}', now searching for child '{}'",
159 current_symbol,
160 node.kind(),
161 symbol_parts[1]
162 );
163 }
164
165 // First, check if there's a direct method definition with this name
166 let mut direct_method_cursor = node.walk();
167 for direct_child in node.children(&mut direct_method_cursor) {
168 if direct_child.kind() == "method_definition" {
169 let mut method_cursor = direct_child.walk();
170 for method_child in
171 direct_child.children(&mut method_cursor)
172 {
173 if method_child.kind() == "property_identifier" {
174 if let Ok(method_name) =
175 method_child.utf8_text(content)
176 {
177 if debug_mode {
178 println!(
179 "[DEBUG] Found direct method: '{}' (looking for '{}')",
180 method_name, symbol_parts[1]
181 );
182 }
183
184 if method_name == symbol_parts[1] {
185 if debug_mode {
186 println!(
187 "[DEBUG] Found child symbol '{}' as direct method_definition",
188 symbol_parts[1]
189 );
190 println!(
191 "[DEBUG] Symbol location: {}:{} - {}:{}",
192 direct_child.start_position().row + 1,
193 direct_child.start_position().column + 1,
194 direct_child.end_position().row + 1,
195 direct_child.end_position().column + 1
196 );
197 }
198 return Some(direct_child);
199 }
200 }
201 }
202 }
203 }
204 }
205
206 // Look for any node that might contain the child symbol
207 let mut child_cursor = node.walk();
208 for child_node in node.children(&mut child_cursor) {
209 if debug_mode {
210 println!(
211 "[DEBUG] Checking child node type '{}' for symbol '{}'",
212 child_node.kind(),
213 symbol_parts[1]
214 );
215 }
216
217 // Check if this node is the child symbol we're looking for
218 if language_impl.is_acceptable_parent(&child_node) {
219 // Try to extract the name of this node
220 let mut subcursor = child_node.walk();
221 for subchild in child_node.children(&mut subcursor) {
222 if subchild.kind() == "identifier"
223 || subchild.kind() == "property_identifier"
224 || subchild.kind() == "field_identifier"
225 || subchild.kind() == "type_identifier"
226 || subchild.kind() == "method_definition"
227 {
228 // For method_definition, we need to get the property_identifier child
229 if subchild.kind() == "method_definition" {
230 let mut method_cursor = subchild.walk();
231 for method_child in
232 subchild.children(&mut method_cursor)
233 {
234 if method_child.kind()
235 == "property_identifier"
236 {
237 if let Ok(method_name) =
238 method_child.utf8_text(content)
239 {
240 if debug_mode {
241 println!(
242 "[DEBUG] Found method: '{}' (looking for '{}')",
243 method_name, symbol_parts[1]
244 );
245 }
246
247 if method_name == symbol_parts[1] {
248 if debug_mode {
249 println!(
250 "[DEBUG] Found child symbol '{}' in method_definition",
251 symbol_parts[1]
252 );
253 println!(
254 "[DEBUG] Symbol location: {}:{} - {}:{}",
255 subchild.start_position().row + 1,
256 subchild.start_position().column + 1,
257 subchild.end_position().row + 1,
258 subchild.end_position().column + 1
259 );
260 }
261 return Some(subchild);
262 }
263 }
264 }
265 }
266 continue;
267 }
268 if let Ok(name) = subchild.utf8_text(content) {
269 if debug_mode {
270 println!(
271 "[DEBUG] Found identifier: '{}' (looking for '{}')",
272 name, symbol_parts[1]
273 );
274 }
275
276 if name == symbol_parts[1] {
277 if debug_mode {
278 println!(
279 "[DEBUG] Found child symbol '{}' in node type '{}'",
280 symbol_parts[1],
281 child_node.kind()
282 );
283 println!(
284 "[DEBUG] Symbol location: {}:{} - {}:{}",
285 child_node.start_position().row + 1,
286 child_node.start_position().column + 1,
287 child_node.end_position().row + 1,
288 child_node.end_position().column + 1
289 );
290 }
291 return Some(child_node);
292 }
293 }
294 }
295 }
296 }
297
298 // Recursively search in this child node
299 if let Some(found) = find_symbol_node(
300 child_node,
301 &symbol_parts[1..],
302 language_impl,
303 content,
304 debug_mode,
305 ) {
306 return Some(found);
307 }
308 }
309 } else {
310 // If this is a simple symbol, we found it
311 if debug_mode {
312 println!(
313 "[DEBUG] Found symbol '{}' in node type '{}'",
314 current_symbol,
315 node.kind()
316 );
317 println!(
318 "[DEBUG] Symbol location: {}:{} - {}:{}",
319 node.start_position().row + 1,
320 node.start_position().column + 1,
321 node.end_position().row + 1,
322 node.end_position().column + 1
323 );
324 }
325 return Some(node);
326 }
327 }
328 }
329
330 // For function_declarator, we need to look deeper
331 if child.kind() == "function_declarator" {
332 if debug_mode {
333 println!("[DEBUG] Checking function_declarator for symbol");
334 }
335
336 let mut subcursor = child.walk();
337 for subchild in child.children(&mut subcursor) {
338 if subchild.kind() == "identifier" {
339 if let Ok(name) = subchild.utf8_text(content) {
340 if debug_mode {
341 println!("[DEBUG] Found function identifier: '{name}' (looking for '{current_symbol}')");
342 }
343
344 if name == current_symbol {
345 if is_nested {
346 // If this is a nested symbol, we found the parent
347 // Now we need to search for the child within this node
348 if debug_mode {
349 println!(
350 "[DEBUG] Found parent symbol '{}' in function_declarator, now searching for child '{}'",
351 current_symbol,
352 symbol_parts[1]
353 );
354 }
355
356 // Recursively search for the child symbol within this node
357 let mut child_cursor = node.walk();
358 for child_node in node.children(&mut child_cursor) {
359 if let Some(found) = find_symbol_node(
360 child_node,
361 &symbol_parts[1..],
362 language_impl,
363 content,
364 debug_mode,
365 ) {
366 return Some(found);
367 }
368 }
369 } else {
370 // If this is a simple symbol, we found it
371 if debug_mode {
372 println!(
373 "[DEBUG] Found symbol '{current_symbol}' in function_declarator"
374 );
375 println!(
376 "[DEBUG] Symbol location: {}:{} - {}:{}",
377 node.start_position().row + 1,
378 node.start_position().column + 1,
379 node.end_position().row + 1,
380 node.end_position().column + 1
381 );
382 }
383 return Some(node);
384 }
385 }
386 }
387 }
388 }
389 }
390 }
391 }
392 }
393
394 // Recursively search in children
395 let mut cursor = node.walk();
396 for child in node.children(&mut cursor) {
397 if let Some(found) =
398 find_symbol_node(child, symbol_parts, language_impl, content, debug_mode)
399 {
400 return Some(found);
401 }
402 }
403
404 None
405 }
406
407 // Search for the symbol in the AST
408 if let Some(found_node) = find_symbol_node(
409 root_node,
410 &symbol_parts,
411 language_impl.as_ref(),
412 content.as_bytes(),
413 debug_mode,
414 ) {
415 let node_start_line = found_node.start_position().row + 1;
416 let node_end_line = found_node.end_position().row + 1;
417
418 if debug_mode {
419 println!("\n[DEBUG] ===== Symbol Found =====");
420 println!("[DEBUG] Found symbol '{symbol}' at lines {node_start_line}-{node_end_line}");
421 println!(
422 "[DEBUG] Node type: {node_kind}",
423 node_kind = found_node.kind()
424 );
425 println!(
426 "[DEBUG] Node range: {}:{} - {}:{}",
427 found_node.start_position().row + 1,
428 found_node.start_position().column + 1,
429 found_node.end_position().row + 1,
430 found_node.end_position().column + 1
431 );
432 }
433
434 // Extract the code block
435 let node_text = &content[found_node.start_byte()..found_node.end_byte()];
436
437 if debug_mode {
438 println!(
439 "[DEBUG] Extracted code size: {code_size} bytes",
440 code_size = node_text.len()
441 );
442 println!(
443 "[DEBUG] Extracted code lines: {line_count}",
444 line_count = node_text.lines().count()
445 );
446 }
447
448 // Tokenize the content
449 let filename = path
450 .file_name()
451 .map(|f| f.to_string_lossy().to_string())
452 .unwrap_or_default();
453 let node_text_str = node_text.to_string();
454 let tokenized_content =
455 crate::ranking::preprocess_text_with_filename(&node_text_str, &filename);
456
457 return Ok(SearchResult {
458 file: path.to_string_lossy().to_string(),
459 lines: (node_start_line, node_end_line),
460 node_type: found_node.kind().to_string(),
461 code: node_text_str,
462 matched_by_filename: None,
463 rank: None,
464 score: None,
465 tfidf_score: None,
466 bm25_score: None,
467 tfidf_rank: None,
468 bm25_rank: None,
469 new_score: None,
470 hybrid2_rank: None,
471 combined_score_rank: None,
472 file_unique_terms: None,
473 file_total_matches: None,
474 file_match_rank: None,
475 block_unique_terms: None,
476 block_total_matches: None,
477 parent_file_id: None,
478 block_id: None,
479 matched_keywords: None,
480 tokenized_content: Some(tokenized_content),
481 });
482 }
483
484 // If we couldn't find the symbol using tree-sitter, try a simple text search as fallback
485 if debug_mode {
486 println!("\n[DEBUG] ===== Symbol Not Found in AST =====");
487 println!("[DEBUG] Symbol '{symbol}' not found in AST");
488 println!("[DEBUG] Trying text search fallback");
489 }
490
491 // Simple text search for the symbol
492 let lines: Vec<&str> = content.lines().collect();
493 let mut found_line = None;
494
495 // For nested symbols, we'll try to find lines that contain all parts
496 // This is a simple fallback and may not be as accurate as AST parsing
497 let search_terms = if is_nested_symbol {
498 // For nested symbols, we'll look for lines containing all parts
499 if debug_mode {
500 println!(
501 "[DEBUG] Using fallback search for nested symbol: looking for lines containing all parts"
502 );
503 }
504 symbol_parts.to_vec()
505 } else {
506 vec![symbol]
507 };
508
509 if debug_mode {
510 println!(
511 "[DEBUG] Performing text search for '{:?}' across {} lines",
512 search_terms,
513 lines.len()
514 );
515 }
516
517 for (i, line) in lines.iter().enumerate() {
518 // Check if the line contains all search terms
519 let found = search_terms.iter().all(|term| line.contains(term));
520
521 if found {
522 found_line = Some(i + 1); // 1-indexed line number
523 if debug_mode {
524 println!(
525 "[DEBUG] Found symbol '{}' in line {}: '{}'",
526 symbol,
527 i + 1,
528 line.trim()
529 );
530 }
531 break;
532 }
533 }
534
535 if let Some(line_num) = found_line {
536 if debug_mode {
537 println!("\n[DEBUG] ===== Symbol Found via Text Search =====");
538 println!("[DEBUG] Found symbol '{symbol}' using text search at line {line_num}");
539 }
540
541 // Extract context around the line
542 let start_line = line_num.saturating_sub(context_lines);
543 let end_line = std::cmp::min(line_num + context_lines, lines.len());
544
545 if debug_mode {
546 println!("[DEBUG] Extracting context around line {line_num}");
547 println!("[DEBUG] Context lines: {context_lines}");
548 println!("[DEBUG] Extracting lines {start_line}-{end_line}");
549 }
550
551 // Adjust start_line to be at least 1 (1-indexed)
552 let start_idx = if start_line > 0 { start_line - 1 } else { 0 };
553
554 let context = lines[start_idx..end_line].join("\n");
555
556 if debug_mode {
557 println!(
558 "[DEBUG] Extracted {line_count} lines of code",
559 line_count = end_line - start_line
560 );
561 println!(
562 "[DEBUG] Content size: {content_size} bytes",
563 content_size = context.len()
564 );
565 }
566
567 // Tokenize the content
568 let filename = path
569 .file_name()
570 .map(|f| f.to_string_lossy().to_string())
571 .unwrap_or_default();
572 let tokenized_content = crate::ranking::preprocess_text_with_filename(&context, &filename);
573
574 return Ok(SearchResult {
575 file: path.to_string_lossy().to_string(),
576 lines: (start_line, end_line),
577 node_type: "text_search".to_string(),
578 code: context,
579 matched_by_filename: None,
580 rank: None,
581 score: None,
582 tfidf_score: None,
583 bm25_score: None,
584 tfidf_rank: None,
585 bm25_rank: None,
586 new_score: None,
587 hybrid2_rank: None,
588 combined_score_rank: None,
589 file_unique_terms: None,
590 file_total_matches: None,
591 file_match_rank: None,
592 block_unique_terms: None,
593 block_total_matches: None,
594 parent_file_id: None,
595 block_id: None,
596 matched_keywords: None,
597 tokenized_content: Some(tokenized_content),
598 });
599 }
600
601 // If we get here, we couldn't find the symbol
602 if debug_mode {
603 println!("\n[DEBUG] ===== Symbol Not Found =====");
604 println!("[DEBUG] Symbol '{symbol}' not found in file {path:?}");
605 println!("[DEBUG] Neither AST parsing nor text search found the symbol");
606 }
607
608 Err(anyhow::anyhow!(
609 "Symbol '{}' not found in file {:?}",
610 symbol,
611 path
612 ))
613}