project_rag/relations/repomap/
symbol_extractor.rs

1//! Symbol extraction from AST nodes.
2//!
3//! This module extracts symbol definitions (functions, classes, methods, etc.)
4//! from source code using tree-sitter AST parsing.
5
6use anyhow::{Context, Result};
7use chrono::Utc;
8use tree_sitter::{Language, Node, Parser};
9
10use crate::indexer::FileInfo;
11use crate::relations::types::{Definition, SymbolId, SymbolKind, Visibility};
12
13/// Extracts symbol definitions from source code using AST parsing.
14pub struct SymbolExtractor {
15    // No persistent state needed - parser created per-file
16}
17
18impl SymbolExtractor {
19    /// Create a new symbol extractor
20    pub fn new() -> Self {
21        Self {}
22    }
23
24    /// Extract all symbol definitions from a file
25    pub fn extract_definitions(&self, file_info: &FileInfo) -> Result<Vec<Definition>> {
26        let extension = file_info.extension.as_deref().unwrap_or("");
27
28        // Get language and parser
29        let (language, language_name) = match get_language_for_extension(extension) {
30            Some(lang) => lang,
31            None => return Ok(Vec::new()), // Unsupported language
32        };
33
34        let mut parser = Parser::new();
35        parser
36            .set_language(&language)
37            .context("Failed to set parser language")?;
38
39        let tree = parser
40            .parse(&file_info.content, None)
41            .context("Failed to parse source code")?;
42
43        let root_node = tree.root_node();
44        let mut definitions = Vec::new();
45
46        // Extract definitions recursively
47        self.extract_from_node(
48            root_node,
49            &file_info.content,
50            &language_name,
51            file_info,
52            None,
53            &mut definitions,
54        );
55
56        Ok(definitions)
57    }
58
59    /// Extract definitions from a node and its children
60    fn extract_from_node(
61        &self,
62        node: Node,
63        source: &str,
64        language: &str,
65        file_info: &FileInfo,
66        parent_id: Option<String>,
67        result: &mut Vec<Definition>,
68    ) {
69        let kind = node.kind();
70
71        // Check if this node is a definition we care about
72        if is_definition_node(kind, language) {
73            if let Some(def) = self.node_to_definition(node, source, language, file_info, &parent_id)
74            {
75                let new_parent_id = Some(def.to_storage_id());
76                result.push(def);
77
78                // Extract nested definitions with this as parent
79                let mut cursor = node.walk();
80                for child in node.children(&mut cursor) {
81                    self.extract_from_node(
82                        child,
83                        source,
84                        language,
85                        file_info,
86                        new_parent_id.clone(),
87                        result,
88                    );
89                }
90                return;
91            }
92        }
93
94        // Recurse into children
95        let mut cursor = node.walk();
96        for child in node.children(&mut cursor) {
97            self.extract_from_node(child, source, language, file_info, parent_id.clone(), result);
98        }
99    }
100
101    /// Convert an AST node to a Definition
102    fn node_to_definition(
103        &self,
104        node: Node,
105        source: &str,
106        language: &str,
107        file_info: &FileInfo,
108        parent_id: &Option<String>,
109    ) -> Option<Definition> {
110        let kind = node.kind();
111        let symbol_kind = SymbolKind::from_ast_kind(kind);
112
113        // Extract the symbol name
114        let name = extract_symbol_name(node, source, language)?;
115
116        // Get position info
117        let start_pos = node.start_position();
118        let end_pos = node.end_position();
119
120        // Extract signature (first line or declaration)
121        let signature = extract_signature(node, source, language);
122
123        // Extract doc comment
124        let doc_comment = extract_doc_comment(node, source, language);
125
126        // Determine visibility
127        let node_text = &source[node.start_byte()..node.end_byte().min(source.len())];
128        let visibility = Visibility::from_keywords(node_text);
129
130        Some(Definition {
131            symbol_id: SymbolId::new(
132                &file_info.relative_path,
133                name,
134                symbol_kind,
135                start_pos.row + 1, // Convert to 1-based
136                start_pos.column,
137            ),
138            root_path: Some(file_info.root_path.clone()),
139            project: file_info.project.clone(),
140            end_line: end_pos.row + 1,
141            end_col: end_pos.column,
142            signature,
143            doc_comment,
144            visibility,
145            parent_id: parent_id.clone(),
146            indexed_at: Utc::now().timestamp(),
147        })
148    }
149}
150
151impl Default for SymbolExtractor {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157/// Get the tree-sitter language for a file extension
158fn get_language_for_extension(extension: &str) -> Option<(Language, String)> {
159    match extension.to_lowercase().as_str() {
160        "rs" => Some((tree_sitter_rust::LANGUAGE.into(), "Rust".to_string())),
161        "py" => Some((tree_sitter_python::LANGUAGE.into(), "Python".to_string())),
162        "js" | "mjs" | "cjs" | "jsx" => Some((
163            tree_sitter_javascript::LANGUAGE.into(),
164            "JavaScript".to_string(),
165        )),
166        "ts" | "tsx" => Some((
167            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
168            "TypeScript".to_string(),
169        )),
170        "go" => Some((tree_sitter_go::LANGUAGE.into(), "Go".to_string())),
171        "java" => Some((tree_sitter_java::LANGUAGE.into(), "Java".to_string())),
172        "swift" => Some((tree_sitter_swift::LANGUAGE.into(), "Swift".to_string())),
173        "c" | "h" => Some((tree_sitter_c::LANGUAGE.into(), "C".to_string())),
174        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => {
175            Some((tree_sitter_cpp::LANGUAGE.into(), "C++".to_string()))
176        }
177        "cs" => Some((tree_sitter_c_sharp::LANGUAGE.into(), "C#".to_string())),
178        "rb" => Some((tree_sitter_ruby::LANGUAGE.into(), "Ruby".to_string())),
179        "php" => Some((tree_sitter_php::LANGUAGE_PHP.into(), "PHP".to_string())),
180        _ => None,
181    }
182}
183
184/// Check if a node kind represents a definition
185fn is_definition_node(kind: &str, language: &str) -> bool {
186    match language {
187        "Rust" => matches!(
188            kind,
189            "function_item"
190                | "impl_item"
191                | "trait_item"
192                | "struct_item"
193                | "enum_item"
194                | "mod_item"
195                | "const_item"
196                | "static_item"
197                | "type_item"
198        ),
199        "Python" => matches!(
200            kind,
201            "function_definition" | "class_definition" | "decorated_definition"
202        ),
203        "JavaScript" | "TypeScript" => matches!(
204            kind,
205            "function_declaration"
206                | "function_expression"
207                | "arrow_function"
208                | "method_definition"
209                | "class_declaration"
210                | "interface_declaration"
211                | "type_alias_declaration"
212        ),
213        "Go" => matches!(
214            kind,
215            "function_declaration" | "method_declaration" | "type_declaration"
216        ),
217        "Java" => matches!(
218            kind,
219            "method_declaration"
220                | "class_declaration"
221                | "interface_declaration"
222                | "constructor_declaration"
223                | "enum_declaration"
224        ),
225        "Swift" => matches!(
226            kind,
227            "function_declaration"
228                | "class_declaration"
229                | "struct_declaration"
230                | "enum_declaration"
231                | "protocol_declaration"
232        ),
233        "C" => matches!(
234            kind,
235            "function_definition" | "struct_specifier" | "enum_specifier"
236        ),
237        "C++" => matches!(
238            kind,
239            "function_definition"
240                | "class_specifier"
241                | "struct_specifier"
242                | "enum_specifier"
243                | "namespace_definition"
244        ),
245        "C#" => matches!(
246            kind,
247            "method_declaration"
248                | "class_declaration"
249                | "struct_declaration"
250                | "interface_declaration"
251                | "enum_declaration"
252                | "constructor_declaration"
253        ),
254        "Ruby" => matches!(kind, "method" | "singleton_method" | "class" | "module"),
255        "PHP" => matches!(
256            kind,
257            "function_definition"
258                | "method_declaration"
259                | "class_declaration"
260                | "interface_declaration"
261                | "trait_declaration"
262        ),
263        _ => false,
264    }
265}
266
267/// Extract the symbol name from an AST node
268fn extract_symbol_name(node: Node, source: &str, language: &str) -> Option<String> {
269    // Strategy: Find the identifier/name child node based on language
270    let name_node = find_name_node(node, language)?;
271
272    let start = name_node.start_byte();
273    let end = name_node.end_byte();
274
275    if end > source.len() {
276        return None;
277    }
278
279    let name = source[start..end].to_string();
280
281    // Filter out empty or whitespace-only names
282    if name.trim().is_empty() {
283        return None;
284    }
285
286    Some(name)
287}
288
289/// Find the child node containing the symbol name
290fn find_name_node<'a>(node: Node<'a>, language: &str) -> Option<Node<'a>> {
291    let kind = node.kind();
292
293    // Language-specific name extraction
294    match language {
295        "Rust" => {
296            // Rust: name is usually in "name" field or first identifier
297            if let Some(name_node) = node.child_by_field_name("name") {
298                return Some(name_node);
299            }
300            // For impl items, look for type name
301            if kind == "impl_item" {
302                if let Some(type_node) = node.child_by_field_name("type") {
303                    return Some(type_node);
304                }
305            }
306        }
307        "Python" => {
308            // Python: class and function have "name" field
309            if let Some(name_node) = node.child_by_field_name("name") {
310                return Some(name_node);
311            }
312            // Decorated definitions: look inside for the actual definition
313            if kind == "decorated_definition" {
314                let mut cursor = node.walk();
315                for child in node.children(&mut cursor) {
316                    if child.kind() == "function_definition" || child.kind() == "class_definition" {
317                        return find_name_node(child, language);
318                    }
319                }
320            }
321        }
322        "JavaScript" | "TypeScript" => {
323            // JS/TS: "name" field for most declarations
324            if let Some(name_node) = node.child_by_field_name("name") {
325                return Some(name_node);
326            }
327            // Arrow functions in variable declarations need special handling
328            if kind == "arrow_function" {
329                // Look at parent for variable name
330                if let Some(parent) = node.parent() {
331                    if parent.kind() == "variable_declarator" {
332                        if let Some(name_node) = parent.child_by_field_name("name") {
333                            return Some(name_node);
334                        }
335                    }
336                }
337            }
338        }
339        "Go" => {
340            if let Some(name_node) = node.child_by_field_name("name") {
341                return Some(name_node);
342            }
343        }
344        "Java" => {
345            if let Some(name_node) = node.child_by_field_name("name") {
346                return Some(name_node);
347            }
348        }
349        "Swift" => {
350            if let Some(name_node) = node.child_by_field_name("name") {
351                return Some(name_node);
352            }
353        }
354        "C" | "C++" => {
355            // C/C++: declarator contains the name
356            if let Some(declarator) = node.child_by_field_name("declarator") {
357                // Navigate through possible pointer/reference declarators
358                return find_innermost_identifier(declarator);
359            }
360            // For struct/class, name is in the type specifier
361            if kind == "struct_specifier" || kind == "class_specifier" || kind == "enum_specifier" {
362                if let Some(name_node) = node.child_by_field_name("name") {
363                    return Some(name_node);
364                }
365            }
366        }
367        "C#" => {
368            if let Some(name_node) = node.child_by_field_name("name") {
369                return Some(name_node);
370            }
371        }
372        "Ruby" => {
373            if let Some(name_node) = node.child_by_field_name("name") {
374                return Some(name_node);
375            }
376        }
377        "PHP" => {
378            if let Some(name_node) = node.child_by_field_name("name") {
379                return Some(name_node);
380            }
381        }
382        _ => {}
383    }
384
385    // Fallback: find first identifier child
386    let mut cursor = node.walk();
387    for child in node.children(&mut cursor) {
388        if child.kind() == "identifier"
389            || child.kind() == "type_identifier"
390            || child.kind() == "name"
391        {
392            return Some(child);
393        }
394    }
395
396    None
397}
398
399/// Find the innermost identifier in a declarator chain (for C/C++)
400fn find_innermost_identifier<'a>(node: Node<'a>) -> Option<Node<'a>> {
401    // If this is an identifier, return it
402    if node.kind() == "identifier" || node.kind() == "field_identifier" {
403        return Some(node);
404    }
405
406    // Check for name field
407    if let Some(name_node) = node.child_by_field_name("declarator") {
408        return find_innermost_identifier(name_node);
409    }
410
411    // Fallback: look through children
412    let mut cursor = node.walk();
413    for child in node.children(&mut cursor) {
414        if let Some(id) = find_innermost_identifier(child) {
415            return Some(id);
416        }
417    }
418
419    None
420}
421
422/// Extract the signature (first line of declaration)
423fn extract_signature(node: Node, source: &str, _language: &str) -> String {
424    let start = node.start_byte();
425    let end = node.end_byte().min(source.len());
426    let text = &source[start..end];
427
428    // Get first line or first 200 chars, whichever is shorter
429    let first_line = text.lines().next().unwrap_or("");
430    if first_line.len() > 200 {
431        format!("{}...", &first_line[..200])
432    } else {
433        first_line.to_string()
434    }
435}
436
437/// Extract documentation comment preceding the node
438fn extract_doc_comment(node: Node, source: &str, language: &str) -> Option<String> {
439    // Look for comment sibling before this node
440    let mut prev = node.prev_sibling();
441
442    while let Some(sibling) = prev {
443        let kind = sibling.kind();
444
445        // Check if it's a comment
446        let is_doc_comment = match language {
447            "Rust" => kind == "line_comment" || kind == "block_comment",
448            "Python" => kind == "comment" || kind == "expression_statement", // docstrings
449            "JavaScript" | "TypeScript" => kind == "comment",
450            "Java" => kind == "line_comment" || kind == "block_comment",
451            "Go" => kind == "comment",
452            "C" | "C++" => kind == "comment",
453            "C#" => kind == "comment",
454            "Ruby" => kind == "comment",
455            "PHP" => kind == "comment",
456            _ => kind.contains("comment"),
457        };
458
459        if is_doc_comment {
460            let start = sibling.start_byte();
461            let end = sibling.end_byte().min(source.len());
462            let comment = source[start..end].trim().to_string();
463
464            // Clean up comment syntax
465            let cleaned = clean_comment(&comment, language);
466            if !cleaned.is_empty() {
467                return Some(cleaned);
468            }
469        }
470
471        // Stop if we hit a non-comment, non-whitespace node
472        if !kind.contains("comment") && kind != "decorator" && kind != "attribute" {
473            break;
474        }
475
476        prev = sibling.prev_sibling();
477    }
478
479    None
480}
481
482/// Clean comment syntax from a comment string
483fn clean_comment(comment: &str, _language: &str) -> String {
484    let lines: Vec<&str> = comment.lines().collect();
485
486    let cleaned: Vec<String> = lines
487        .iter()
488        .map(|line| {
489            let mut s = line.trim();
490            // Remove common prefixes
491            for prefix in ["///", "//!", "//", "/*", "*/", "*", "#", "\"\"\"", "'''"] {
492                s = s.trim_start_matches(prefix);
493            }
494            s.trim().to_string()
495        })
496        .filter(|s| !s.is_empty())
497        .collect();
498
499    cleaned.join("\n")
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use std::path::PathBuf;
506
507    fn make_file_info(content: &str, extension: &str) -> FileInfo {
508        FileInfo {
509            path: PathBuf::from(format!("test.{}", extension)),
510            relative_path: format!("test.{}", extension),
511            root_path: "/test".to_string(),
512            project: None,
513            extension: Some(extension.to_string()),
514            language: None,
515            content: content.to_string(),
516            hash: "test_hash".to_string(),
517        }
518    }
519
520    #[test]
521    fn test_rust_extraction() {
522        let source = r#"
523/// A greeting function
524pub fn greet(name: &str) -> String {
525    format!("Hello, {}!", name)
526}
527
528struct Person {
529    name: String,
530}
531
532impl Person {
533    fn new(name: String) -> Self {
534        Self { name }
535    }
536}
537"#;
538        let file_info = make_file_info(source, "rs");
539        let extractor = SymbolExtractor::new();
540        let definitions = extractor.extract_definitions(&file_info).unwrap();
541
542        assert!(!definitions.is_empty());
543
544        // Find the greet function
545        let greet = definitions.iter().find(|d| d.name() == "greet");
546        assert!(greet.is_some(), "Should find greet function");
547
548        let greet = greet.unwrap();
549        assert_eq!(greet.kind(), SymbolKind::Function);
550        assert_eq!(greet.visibility, Visibility::Public);
551        assert!(greet.doc_comment.is_some());
552    }
553
554    #[test]
555    fn test_python_extraction() {
556        let source = r#"
557def hello(name):
558    """Say hello."""
559    print(f"Hello, {name}!")
560
561class MyClass:
562    def __init__(self, value):
563        self.value = value
564"#;
565        let file_info = make_file_info(source, "py");
566        let extractor = SymbolExtractor::new();
567        let definitions = extractor.extract_definitions(&file_info).unwrap();
568
569        assert!(!definitions.is_empty());
570
571        // Find hello function
572        let hello = definitions.iter().find(|d| d.name() == "hello");
573        assert!(hello.is_some(), "Should find hello function");
574
575        // Find MyClass
576        let my_class = definitions.iter().find(|d| d.name() == "MyClass");
577        assert!(my_class.is_some(), "Should find MyClass");
578    }
579
580    #[test]
581    fn test_javascript_extraction() {
582        let source = r#"
583function add(a, b) {
584    return a + b;
585}
586
587class Calculator {
588    constructor() {
589        this.result = 0;
590    }
591
592    add(x) {
593        this.result += x;
594    }
595}
596"#;
597        let file_info = make_file_info(source, "js");
598        let extractor = SymbolExtractor::new();
599        let definitions = extractor.extract_definitions(&file_info).unwrap();
600
601        assert!(!definitions.is_empty());
602
603        // Find add function
604        let add = definitions.iter().find(|d| d.name() == "add");
605        assert!(add.is_some(), "Should find add function");
606    }
607
608    #[test]
609    fn test_unsupported_extension() {
610        let source = "some content";
611        let file_info = make_file_info(source, "xyz");
612        let extractor = SymbolExtractor::new();
613        let definitions = extractor.extract_definitions(&file_info).unwrap();
614
615        assert!(definitions.is_empty());
616    }
617
618    #[test]
619    fn test_definition_storage_id() {
620        let source = "fn foo() {}";
621        let file_info = make_file_info(source, "rs");
622        let extractor = SymbolExtractor::new();
623        let definitions = extractor.extract_definitions(&file_info).unwrap();
624
625        assert!(!definitions.is_empty());
626        let def = &definitions[0];
627        let storage_id = def.to_storage_id();
628        assert!(storage_id.contains("foo"));
629    }
630}