probe_code/extract/
symbol_finder.rs

1//! Functions for finding symbols in files.
2//!
3//! This module provides functions for finding symbols (functions, structs, classes, etc.)
4//! in files using tree-sitter.
5
6use anyhow::Result;
7use probe_code::models::SearchResult;
8use std::path::Path;
9
10/// Find a symbol (function, struct, class, etc.) in a file by name
11///
12/// This function searches for a symbol by name in a file and returns the code block
13/// containing that symbol. It uses tree-sitter to parse the code and find the symbol.
14///
15/// # Arguments
16///
17/// * `path` - The path to the file to search in
18/// * `symbol` - The name of the symbol to find
19/// * `content` - The content of the file
20/// * `allow_tests` - Whether to include test files and test code blocks
21/// * `context_lines` - Number of context lines to include
22///
23/// # Returns
24///
25/// A SearchResult containing the extracted code block for the symbol, or an error
26/// if the symbol couldn't be found.
27pub fn find_symbol_in_file(
28    path: &Path,
29    symbol: &str,
30    content: &str,
31    _allow_tests: bool,
32    context_lines: usize,
33) -> Result<SearchResult> {
34    let debug_mode = std::env::var("DEBUG").unwrap_or_default() == "1";
35
36    // Check if the symbol contains a dot, indicating a nested symbol path
37    let symbol_parts: Vec<&str> = symbol.split('.').collect();
38    let is_nested_symbol = symbol_parts.len() > 1;
39
40    // For nested symbols, we'll use the AST-based approach directly
41    // The find_symbol_node function already handles nested symbols
42
43    if debug_mode {
44        println!("\n[DEBUG] ===== Symbol Search =====");
45        if is_nested_symbol {
46            println!("[DEBUG] Searching for nested symbol '{symbol}' in file {path:?}");
47            println!(
48                "[DEBUG] Symbol parts: {:?} (parent: '{}', child: '{}')",
49                symbol_parts,
50                symbol_parts[0],
51                symbol_parts.last().unwrap_or(&"")
52            );
53        } else {
54            println!("[DEBUG] Searching for symbol '{symbol}' in file {path:?}");
55        }
56        println!(
57            "[DEBUG] Content size: {content_len} bytes",
58            content_len = content.len()
59        );
60        println!(
61            "[DEBUG] Line count: {line_count}",
62            line_count = content.lines().count()
63        );
64    }
65
66    // Get the file extension to determine the language
67    let extension = path.extension().and_then(|ext| ext.to_str()).unwrap_or("");
68
69    if debug_mode {
70        println!("[DEBUG] File extension: {extension}");
71    }
72
73    // Get the language implementation for this extension
74    let language_impl = crate::language::factory::get_language_impl(extension)
75        .ok_or_else(|| anyhow::anyhow!("Unsupported language extension: {}", extension))?;
76
77    if debug_mode {
78        println!("[DEBUG] Language detected: {extension}");
79        println!("[DEBUG] Using tree-sitter to parse file");
80    }
81
82    // Parse the file with tree-sitter
83    let mut parser = tree_sitter::Parser::new();
84    parser
85        .set_language(&language_impl.get_tree_sitter_language())
86        .map_err(|e| anyhow::anyhow!("Failed to set language: {}", e))?;
87
88    let tree = parser
89        .parse(content.as_bytes(), None)
90        .ok_or_else(|| anyhow::anyhow!("Failed to parse file"))?;
91
92    let root_node = tree.root_node();
93
94    if debug_mode {
95        println!("[DEBUG] File parsed successfully");
96        println!(
97            "[DEBUG] Root node type: {root_node_kind}",
98            root_node_kind = root_node.kind()
99        );
100        println!(
101            "[DEBUG] Root node range: {}:{} - {}:{}",
102            root_node.start_position().row + 1,
103            root_node.start_position().column + 1,
104            root_node.end_position().row + 1,
105            root_node.end_position().column + 1
106        );
107        println!("[DEBUG] Searching for symbol '{symbol}' in AST");
108    }
109
110    // Function to recursively search for a node with the given symbol name
111    fn find_symbol_node<'a>(
112        node: tree_sitter::Node<'a>,
113        symbol_parts: &[&str],
114        language_impl: &dyn crate::language::language_trait::LanguageImpl,
115        content: &'a [u8],
116        debug_mode: bool,
117    ) -> Option<tree_sitter::Node<'a>> {
118        // If we're looking for a nested symbol (e.g., "Class.method"), we need to:
119        // 1. First find the parent symbol (e.g., "Class")
120        // 2. Then search within that node for the child symbol (e.g., "method")
121        let current_symbol = symbol_parts[0];
122        let is_nested = symbol_parts.len() > 1;
123
124        // Check if this node is an acceptable parent (function, struct, class, etc.)
125        if language_impl.is_acceptable_parent(&node) {
126            if debug_mode {
127                println!(
128                    "[DEBUG] Checking node type '{}' at {}:{} for symbol '{}'",
129                    node.kind(),
130                    node.start_position().row + 1,
131                    node.start_position().column + 1,
132                    current_symbol
133                );
134            }
135
136            // Try to extract the name of this node
137            let mut cursor = node.walk();
138            for child in node.children(&mut cursor) {
139                if child.kind() == "identifier"
140                    || child.kind() == "field_identifier"
141                    || child.kind() == "type_identifier"
142                    || child.kind() == "function_declarator"
143                {
144                    // Get the text of this identifier
145                    if let Ok(name) = child.utf8_text(content) {
146                        if debug_mode {
147                            println!(
148                                "[DEBUG] Found identifier: '{name}' (looking for '{current_symbol}')"
149                            );
150                        }
151
152                        if name == current_symbol {
153                            if is_nested {
154                                // If this is a nested symbol, we found the parent
155                                // Now we need to search for the child within this node
156                                if debug_mode {
157                                    println!(
158                                        "[DEBUG] Found parent symbol '{}' in node type '{}', now searching for child '{}'",
159                                        current_symbol,
160                                        node.kind(),
161                                        symbol_parts[1]
162                                    );
163                                }
164
165                                // First, check if there's a direct method definition with this name
166                                let mut direct_method_cursor = node.walk();
167                                for direct_child in node.children(&mut direct_method_cursor) {
168                                    if direct_child.kind() == "method_definition" {
169                                        let mut method_cursor = direct_child.walk();
170                                        for method_child in
171                                            direct_child.children(&mut method_cursor)
172                                        {
173                                            if method_child.kind() == "property_identifier" {
174                                                if let Ok(method_name) =
175                                                    method_child.utf8_text(content)
176                                                {
177                                                    if debug_mode {
178                                                        println!(
179                                                            "[DEBUG] Found direct method: '{}' (looking for '{}')",
180                                                            method_name, symbol_parts[1]
181                                                        );
182                                                    }
183
184                                                    if method_name == symbol_parts[1] {
185                                                        if debug_mode {
186                                                            println!(
187                                                                "[DEBUG] Found child symbol '{}' as direct method_definition",
188                                                                symbol_parts[1]
189                                                            );
190                                                            println!(
191                                                                "[DEBUG] Symbol location: {}:{} - {}:{}",
192                                                                direct_child.start_position().row + 1,
193                                                                direct_child.start_position().column + 1,
194                                                                direct_child.end_position().row + 1,
195                                                                direct_child.end_position().column + 1
196                                                            );
197                                                        }
198                                                        return Some(direct_child);
199                                                    }
200                                                }
201                                            }
202                                        }
203                                    }
204                                }
205
206                                // Look for any node that might contain the child symbol
207                                let mut child_cursor = node.walk();
208                                for child_node in node.children(&mut child_cursor) {
209                                    if debug_mode {
210                                        println!(
211                                            "[DEBUG] Checking child node type '{}' for symbol '{}'",
212                                            child_node.kind(),
213                                            symbol_parts[1]
214                                        );
215                                    }
216
217                                    // Check if this node is the child symbol we're looking for
218                                    if language_impl.is_acceptable_parent(&child_node) {
219                                        // Try to extract the name of this node
220                                        let mut subcursor = child_node.walk();
221                                        for subchild in child_node.children(&mut subcursor) {
222                                            if subchild.kind() == "identifier"
223                                                || subchild.kind() == "property_identifier"
224                                                || subchild.kind() == "field_identifier"
225                                                || subchild.kind() == "type_identifier"
226                                                || subchild.kind() == "method_definition"
227                                            {
228                                                // For method_definition, we need to get the property_identifier child
229                                                if subchild.kind() == "method_definition" {
230                                                    let mut method_cursor = subchild.walk();
231                                                    for method_child in
232                                                        subchild.children(&mut method_cursor)
233                                                    {
234                                                        if method_child.kind()
235                                                            == "property_identifier"
236                                                        {
237                                                            if let Ok(method_name) =
238                                                                method_child.utf8_text(content)
239                                                            {
240                                                                if debug_mode {
241                                                                    println!(
242                                                                        "[DEBUG] Found method: '{}' (looking for '{}')",
243                                                                        method_name, symbol_parts[1]
244                                                                    );
245                                                                }
246
247                                                                if method_name == symbol_parts[1] {
248                                                                    if debug_mode {
249                                                                        println!(
250                                                                            "[DEBUG] Found child symbol '{}' in method_definition",
251                                                                            symbol_parts[1]
252                                                                        );
253                                                                        println!(
254                                                                            "[DEBUG] Symbol location: {}:{} - {}:{}",
255                                                                            subchild.start_position().row + 1,
256                                                                            subchild.start_position().column + 1,
257                                                                            subchild.end_position().row + 1,
258                                                                            subchild.end_position().column + 1
259                                                                        );
260                                                                    }
261                                                                    return Some(subchild);
262                                                                }
263                                                            }
264                                                        }
265                                                    }
266                                                    continue;
267                                                }
268                                                if let Ok(name) = subchild.utf8_text(content) {
269                                                    if debug_mode {
270                                                        println!(
271                                                            "[DEBUG] Found identifier: '{}' (looking for '{}')",
272                                                            name, symbol_parts[1]
273                                                        );
274                                                    }
275
276                                                    if name == symbol_parts[1] {
277                                                        if debug_mode {
278                                                            println!(
279                                                                "[DEBUG] Found child symbol '{}' in node type '{}'",
280                                                                symbol_parts[1],
281                                                                child_node.kind()
282                                                            );
283                                                            println!(
284                                                                "[DEBUG] Symbol location: {}:{} - {}:{}",
285                                                                child_node.start_position().row + 1,
286                                                                child_node.start_position().column + 1,
287                                                                child_node.end_position().row + 1,
288                                                                child_node.end_position().column + 1
289                                                            );
290                                                        }
291                                                        return Some(child_node);
292                                                    }
293                                                }
294                                            }
295                                        }
296                                    }
297
298                                    // Recursively search in this child node
299                                    if let Some(found) = find_symbol_node(
300                                        child_node,
301                                        &symbol_parts[1..],
302                                        language_impl,
303                                        content,
304                                        debug_mode,
305                                    ) {
306                                        return Some(found);
307                                    }
308                                }
309                            } else {
310                                // If this is a simple symbol, we found it
311                                if debug_mode {
312                                    println!(
313                                        "[DEBUG] Found symbol '{}' in node type '{}'",
314                                        current_symbol,
315                                        node.kind()
316                                    );
317                                    println!(
318                                        "[DEBUG] Symbol location: {}:{} - {}:{}",
319                                        node.start_position().row + 1,
320                                        node.start_position().column + 1,
321                                        node.end_position().row + 1,
322                                        node.end_position().column + 1
323                                    );
324                                }
325                                return Some(node);
326                            }
327                        }
328                    }
329
330                    // For function_declarator, we need to look deeper
331                    if child.kind() == "function_declarator" {
332                        if debug_mode {
333                            println!("[DEBUG] Checking function_declarator for symbol");
334                        }
335
336                        let mut subcursor = child.walk();
337                        for subchild in child.children(&mut subcursor) {
338                            if subchild.kind() == "identifier" {
339                                if let Ok(name) = subchild.utf8_text(content) {
340                                    if debug_mode {
341                                        println!("[DEBUG] Found function identifier: '{name}' (looking for '{current_symbol}')");
342                                    }
343
344                                    if name == current_symbol {
345                                        if is_nested {
346                                            // If this is a nested symbol, we found the parent
347                                            // Now we need to search for the child within this node
348                                            if debug_mode {
349                                                println!(
350                                                    "[DEBUG] Found parent symbol '{}' in function_declarator, now searching for child '{}'",
351                                                    current_symbol,
352                                                    symbol_parts[1]
353                                                );
354                                            }
355
356                                            // Recursively search for the child symbol within this node
357                                            let mut child_cursor = node.walk();
358                                            for child_node in node.children(&mut child_cursor) {
359                                                if let Some(found) = find_symbol_node(
360                                                    child_node,
361                                                    &symbol_parts[1..],
362                                                    language_impl,
363                                                    content,
364                                                    debug_mode,
365                                                ) {
366                                                    return Some(found);
367                                                }
368                                            }
369                                        } else {
370                                            // If this is a simple symbol, we found it
371                                            if debug_mode {
372                                                println!(
373                                                    "[DEBUG] Found symbol '{current_symbol}' in function_declarator"
374                                                );
375                                                println!(
376                                                    "[DEBUG] Symbol location: {}:{} - {}:{}",
377                                                    node.start_position().row + 1,
378                                                    node.start_position().column + 1,
379                                                    node.end_position().row + 1,
380                                                    node.end_position().column + 1
381                                                );
382                                            }
383                                            return Some(node);
384                                        }
385                                    }
386                                }
387                            }
388                        }
389                    }
390                }
391            }
392        }
393
394        // Recursively search in children
395        let mut cursor = node.walk();
396        for child in node.children(&mut cursor) {
397            if let Some(found) =
398                find_symbol_node(child, symbol_parts, language_impl, content, debug_mode)
399            {
400                return Some(found);
401            }
402        }
403
404        None
405    }
406
407    // Search for the symbol in the AST
408    if let Some(found_node) = find_symbol_node(
409        root_node,
410        &symbol_parts,
411        language_impl.as_ref(),
412        content.as_bytes(),
413        debug_mode,
414    ) {
415        let node_start_line = found_node.start_position().row + 1;
416        let node_end_line = found_node.end_position().row + 1;
417
418        if debug_mode {
419            println!("\n[DEBUG] ===== Symbol Found =====");
420            println!("[DEBUG] Found symbol '{symbol}' at lines {node_start_line}-{node_end_line}");
421            println!(
422                "[DEBUG] Node type: {node_kind}",
423                node_kind = found_node.kind()
424            );
425            println!(
426                "[DEBUG] Node range: {}:{} - {}:{}",
427                found_node.start_position().row + 1,
428                found_node.start_position().column + 1,
429                found_node.end_position().row + 1,
430                found_node.end_position().column + 1
431            );
432        }
433
434        // Extract the code block
435        let node_text = &content[found_node.start_byte()..found_node.end_byte()];
436
437        if debug_mode {
438            println!(
439                "[DEBUG] Extracted code size: {code_size} bytes",
440                code_size = node_text.len()
441            );
442            println!(
443                "[DEBUG] Extracted code lines: {line_count}",
444                line_count = node_text.lines().count()
445            );
446        }
447
448        // Tokenize the content
449        let filename = path
450            .file_name()
451            .map(|f| f.to_string_lossy().to_string())
452            .unwrap_or_default();
453        let node_text_str = node_text.to_string();
454        let tokenized_content =
455            crate::ranking::preprocess_text_with_filename(&node_text_str, &filename);
456
457        return Ok(SearchResult {
458            file: path.to_string_lossy().to_string(),
459            lines: (node_start_line, node_end_line),
460            node_type: found_node.kind().to_string(),
461            code: node_text_str,
462            matched_by_filename: None,
463            rank: None,
464            score: None,
465            tfidf_score: None,
466            bm25_score: None,
467            tfidf_rank: None,
468            bm25_rank: None,
469            new_score: None,
470            hybrid2_rank: None,
471            combined_score_rank: None,
472            file_unique_terms: None,
473            file_total_matches: None,
474            file_match_rank: None,
475            block_unique_terms: None,
476            block_total_matches: None,
477            parent_file_id: None,
478            block_id: None,
479            matched_keywords: None,
480            tokenized_content: Some(tokenized_content),
481        });
482    }
483
484    // If we couldn't find the symbol using tree-sitter, try a simple text search as fallback
485    if debug_mode {
486        println!("\n[DEBUG] ===== Symbol Not Found in AST =====");
487        println!("[DEBUG] Symbol '{symbol}' not found in AST");
488        println!("[DEBUG] Trying text search fallback");
489    }
490
491    // Simple text search for the symbol
492    let lines: Vec<&str> = content.lines().collect();
493    let mut found_line = None;
494
495    // For nested symbols, we'll try to find lines that contain all parts
496    // This is a simple fallback and may not be as accurate as AST parsing
497    let search_terms = if is_nested_symbol {
498        // For nested symbols, we'll look for lines containing all parts
499        if debug_mode {
500            println!(
501                "[DEBUG] Using fallback search for nested symbol: looking for lines containing all parts"
502            );
503        }
504        symbol_parts.to_vec()
505    } else {
506        vec![symbol]
507    };
508
509    if debug_mode {
510        println!(
511            "[DEBUG] Performing text search for '{:?}' across {} lines",
512            search_terms,
513            lines.len()
514        );
515    }
516
517    for (i, line) in lines.iter().enumerate() {
518        // Check if the line contains all search terms
519        let found = search_terms.iter().all(|term| line.contains(term));
520
521        if found {
522            found_line = Some(i + 1); // 1-indexed line number
523            if debug_mode {
524                println!(
525                    "[DEBUG] Found symbol '{}' in line {}: '{}'",
526                    symbol,
527                    i + 1,
528                    line.trim()
529                );
530            }
531            break;
532        }
533    }
534
535    if let Some(line_num) = found_line {
536        if debug_mode {
537            println!("\n[DEBUG] ===== Symbol Found via Text Search =====");
538            println!("[DEBUG] Found symbol '{symbol}' using text search at line {line_num}");
539        }
540
541        // Extract context around the line
542        let start_line = line_num.saturating_sub(context_lines);
543        let end_line = std::cmp::min(line_num + context_lines, lines.len());
544
545        if debug_mode {
546            println!("[DEBUG] Extracting context around line {line_num}");
547            println!("[DEBUG] Context lines: {context_lines}");
548            println!("[DEBUG] Extracting lines {start_line}-{end_line}");
549        }
550
551        // Adjust start_line to be at least 1 (1-indexed)
552        let start_idx = if start_line > 0 { start_line - 1 } else { 0 };
553
554        let context = lines[start_idx..end_line].join("\n");
555
556        if debug_mode {
557            println!(
558                "[DEBUG] Extracted {line_count} lines of code",
559                line_count = end_line - start_line
560            );
561            println!(
562                "[DEBUG] Content size: {content_size} bytes",
563                content_size = context.len()
564            );
565        }
566
567        // Tokenize the content
568        let filename = path
569            .file_name()
570            .map(|f| f.to_string_lossy().to_string())
571            .unwrap_or_default();
572        let tokenized_content = crate::ranking::preprocess_text_with_filename(&context, &filename);
573
574        return Ok(SearchResult {
575            file: path.to_string_lossy().to_string(),
576            lines: (start_line, end_line),
577            node_type: "text_search".to_string(),
578            code: context,
579            matched_by_filename: None,
580            rank: None,
581            score: None,
582            tfidf_score: None,
583            bm25_score: None,
584            tfidf_rank: None,
585            bm25_rank: None,
586            new_score: None,
587            hybrid2_rank: None,
588            combined_score_rank: None,
589            file_unique_terms: None,
590            file_total_matches: None,
591            file_match_rank: None,
592            block_unique_terms: None,
593            block_total_matches: None,
594            parent_file_id: None,
595            block_id: None,
596            matched_keywords: None,
597            tokenized_content: Some(tokenized_content),
598        });
599    }
600
601    // If we get here, we couldn't find the symbol
602    if debug_mode {
603        println!("\n[DEBUG] ===== Symbol Not Found =====");
604        println!("[DEBUG] Symbol '{symbol}' not found in file {path:?}");
605        println!("[DEBUG] Neither AST parsing nor text search found the symbol");
606    }
607
608    Err(anyhow::anyhow!(
609        "Symbol '{}' not found in file {:?}",
610        symbol,
611        path
612    ))
613}