infiniloom_engine/parser/
core.rs

1//! Core parser implementation for symbol extraction
2//!
3//! This module contains the main Parser struct and symbol extraction logic.
4//! Language definitions and queries are in separate modules for better organization.
5
6use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15/// Parser errors
16#[derive(Debug, Error)]
17pub enum ParserError {
18    #[error("Unsupported language: {0}")]
19    UnsupportedLanguage(String),
20
21    #[error("Parse error: {0}")]
22    ParseError(String),
23
24    #[error("Query error: {0}")]
25    QueryError(String),
26
27    #[error("Invalid UTF-8 in source code")]
28    InvalidUtf8,
29}
30
31/// Main parser struct for extracting code symbols
32/// Uses lazy initialization - parsers are only created when first needed
33///
34/// # Performance
35///
36/// The parser uses "super-queries" that combine symbol extraction, imports, and call
37/// expressions into a single tree traversal per file. This is more efficient than
38/// running multiple separate queries.
39pub struct Parser {
40    parsers: HashMap<Language, TSParser>,
41    queries: HashMap<Language, Query>,
42    /// Super-queries that combine symbols + imports in one pass
43    super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47    /// Create a new parser instance with lazy initialization
48    /// Parsers and queries are created on-demand when parse() is called
49    pub fn new() -> Self {
50        Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51    }
52
53    /// Ensure parser and query are initialized for a language
54    fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55        use std::collections::hash_map::Entry;
56        if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57            let (parser, query, super_query) = match language {
58                Language::Python => (
59                    init::python()?,
60                    query_builder::python_query()?,
61                    query_builder::python_super_query()?,
62                ),
63                Language::JavaScript => (
64                    init::javascript()?,
65                    query_builder::javascript_query()?,
66                    query_builder::javascript_super_query()?,
67                ),
68                Language::TypeScript => (
69                    init::typescript()?,
70                    query_builder::typescript_query()?,
71                    query_builder::typescript_super_query()?,
72                ),
73                Language::Rust => (
74                    init::rust()?,
75                    query_builder::rust_query()?,
76                    query_builder::rust_super_query()?,
77                ),
78                Language::Go => {
79                    (init::go()?, query_builder::go_query()?, query_builder::go_super_query()?)
80                },
81                Language::Java => (
82                    init::java()?,
83                    query_builder::java_query()?,
84                    query_builder::java_super_query()?,
85                ),
86                Language::C => {
87                    (init::c()?, query_builder::c_query()?, query_builder::c_super_query()?)
88                },
89                Language::Cpp => {
90                    (init::cpp()?, query_builder::cpp_query()?, query_builder::cpp_super_query()?)
91                },
92                Language::CSharp => (
93                    init::csharp()?,
94                    query_builder::csharp_query()?,
95                    query_builder::csharp_super_query()?,
96                ),
97                Language::Ruby => (
98                    init::ruby()?,
99                    query_builder::ruby_query()?,
100                    query_builder::ruby_super_query()?,
101                ),
102                Language::Bash => (
103                    init::bash()?,
104                    query_builder::bash_query()?,
105                    query_builder::bash_super_query()?,
106                ),
107                Language::Php => {
108                    (init::php()?, query_builder::php_query()?, query_builder::php_super_query()?)
109                },
110                Language::Kotlin => (
111                    init::kotlin()?,
112                    query_builder::kotlin_query()?,
113                    query_builder::kotlin_super_query()?,
114                ),
115                Language::Swift => (
116                    init::swift()?,
117                    query_builder::swift_query()?,
118                    query_builder::swift_super_query()?,
119                ),
120                Language::Scala => (
121                    init::scala()?,
122                    query_builder::scala_query()?,
123                    query_builder::scala_super_query()?,
124                ),
125                Language::Haskell => (
126                    init::haskell()?,
127                    query_builder::haskell_query()?,
128                    query_builder::haskell_super_query()?,
129                ),
130                Language::Elixir => (
131                    init::elixir()?,
132                    query_builder::elixir_query()?,
133                    query_builder::elixir_super_query()?,
134                ),
135                Language::Clojure => (
136                    init::clojure()?,
137                    query_builder::clojure_query()?,
138                    query_builder::clojure_super_query()?,
139                ),
140                Language::OCaml => (
141                    init::ocaml()?,
142                    query_builder::ocaml_query()?,
143                    query_builder::ocaml_super_query()?,
144                ),
145                Language::FSharp => {
146                    return Err(ParserError::UnsupportedLanguage(
147                        "F# not yet supported (no tree-sitter grammar available)".to_owned(),
148                    ));
149                },
150                Language::Lua => {
151                    (init::lua()?, query_builder::lua_query()?, query_builder::lua_super_query()?)
152                },
153                Language::R => {
154                    (init::r()?, query_builder::r_query()?, query_builder::r_super_query()?)
155                },
156            };
157            parser_entry.insert(parser);
158            self.queries.insert(language, query);
159            self.super_queries.insert(language, super_query);
160        }
161        Ok(())
162    }
163
164    /// Parse source code and extract symbols
165    ///
166    /// This method now uses "super-queries" that combine symbol extraction and imports
167    /// into a single AST traversal for better performance.
168    pub fn parse(
169        &mut self,
170        source_code: &str,
171        language: Language,
172    ) -> Result<Vec<Symbol>, ParserError> {
173        // Lazy initialization - only init parser for this language
174        self.ensure_initialized(language)?;
175
176        let parser = self
177            .parsers
178            .get_mut(&language)
179            .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
180
181        let tree = parser
182            .parse(source_code, None)
183            .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
184
185        // Use super-query for single-pass extraction (symbols + imports)
186        let super_query = self
187            .super_queries
188            .get(&language)
189            .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
190
191        self.extract_symbols_single_pass(&tree, source_code, super_query, language)
192    }
193
194    /// Extract symbols using single-pass super-query (combines symbols + imports)
195    fn extract_symbols_single_pass(
196        &self,
197        tree: &Tree,
198        source_code: &str,
199        query: &Query,
200        language: Language,
201    ) -> Result<Vec<Symbol>, ParserError> {
202        let mut symbols = Vec::new();
203        let mut cursor = QueryCursor::new();
204        let root_node = tree.root_node();
205
206        let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
207        let capture_names: Vec<&str> = query.capture_names().to_vec();
208
209        while let Some(m) = matches.next() {
210            // Process imports (captured with @import)
211            if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
212                symbols.push(import_symbol);
213                continue;
214            }
215
216            // Process regular symbols (functions, classes, etc.)
217            if let Some(symbol) =
218                self.process_match_single_pass(m, source_code, &capture_names, language)
219            {
220                symbols.push(symbol);
221            }
222        }
223
224        Ok(symbols)
225    }
226
227    /// Process an import match from super-query
228    fn process_import_match(
229        &self,
230        m: &tree_sitter::QueryMatch<'_, '_>,
231        source_code: &str,
232        capture_names: &[&str],
233    ) -> Option<Symbol> {
234        let captures = &m.captures;
235
236        // Look for import capture
237        let import_capture = captures.iter().find(|c| {
238            capture_names
239                .get(c.index as usize)
240                .is_some_and(|n| *n == "import")
241        })?;
242
243        let node = import_capture.node;
244        let text = node.utf8_text(source_code.as_bytes()).ok()?;
245
246        let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
247        symbol.start_line = node.start_position().row as u32 + 1;
248        symbol.end_line = node.end_position().row as u32 + 1;
249
250        Some(symbol)
251    }
252
253    /// Process a symbol match from super-query (single-pass version)
254    fn process_match_single_pass(
255        &self,
256        m: &tree_sitter::QueryMatch<'_, '_>,
257        source_code: &str,
258        capture_names: &[&str],
259        language: Language,
260    ) -> Option<Symbol> {
261        let captures = &m.captures;
262
263        // Find name capture
264        let name_node = captures
265            .iter()
266            .find(|c| {
267                capture_names
268                    .get(c.index as usize)
269                    .is_some_and(|n| *n == "name")
270            })?
271            .node;
272
273        // Find kind capture (function, class, method, etc.)
274        let kind_capture = captures.iter().find(|c| {
275            capture_names.get(c.index as usize).is_some_and(|n| {
276                ["function", "class", "method", "struct", "enum", "interface", "trait"].contains(n)
277            })
278        })?;
279
280        let kind_name = capture_names.get(kind_capture.index as usize)?;
281        let mut symbol_kind = extraction::map_symbol_kind(kind_name);
282
283        let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
284
285        // Find the definition node (usually the largest capture)
286        let def_node = captures
287            .iter()
288            .max_by_key(|c| c.node.byte_range().len())
289            .map_or(name_node, |c| c.node);
290
291        if language == Language::Kotlin && def_node.kind() == "class_declaration" {
292            let mut cursor = def_node.walk();
293            for child in def_node.children(&mut cursor) {
294                if child.kind() == "interface" {
295                    symbol_kind = SymbolKind::Interface;
296                    break;
297                }
298            }
299        }
300
301        let start_line = def_node.start_position().row as u32 + 1;
302        let end_line = def_node.end_position().row as u32 + 1;
303
304        // Extract signature, docstring, parent, visibility, calls
305        let signature = extraction::extract_signature(def_node, source_code, language);
306        let docstring = extraction::extract_docstring(def_node, source_code, language);
307        let parent = if symbol_kind == SymbolKind::Method {
308            extraction::extract_parent(def_node, source_code)
309        } else {
310            None
311        };
312        let visibility = extraction::extract_visibility(def_node, source_code, language);
313        let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
314            extraction::extract_calls(def_node, source_code, language)
315        } else {
316            Vec::new()
317        };
318
319        // Extract inheritance info for classes, structs, interfaces
320        let (extends, implements) = if matches!(
321            symbol_kind,
322            SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
323        ) {
324            extraction::extract_inheritance(def_node, source_code, language)
325        } else {
326            (None, Vec::new())
327        };
328
329        let mut symbol = Symbol::new(name, symbol_kind);
330        symbol.start_line = start_line;
331        symbol.end_line = end_line;
332        symbol.signature = signature;
333        symbol.docstring = docstring;
334        symbol.parent = parent;
335        symbol.visibility = visibility;
336        symbol.calls = calls;
337        symbol.extends = extends;
338        symbol.implements = implements;
339
340        Some(symbol)
341    }
342}
343
344impl Default for Parser {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350// NOTE: The following extraction functions have been moved to super::extraction module:
351// - extract_signature
352// - extract_docstring
353// - extract_parent
354// - extract_visibility
355// - extract_calls
356// - find_body_node
357// - collect_calls_recursive
358// - is_builtin
359// - clean_jsdoc
360// - clean_javadoc
361// - extract_inheritance
362// - map_symbol_kind
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_language_from_extension() {
370        assert_eq!(Language::from_extension("py"), Some(Language::Python));
371        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
372        assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
373        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
374        assert_eq!(Language::from_extension("go"), Some(Language::Go));
375        assert_eq!(Language::from_extension("java"), Some(Language::Java));
376        assert_eq!(Language::from_extension("unknown"), None);
377    }
378
379    #[test]
380    fn test_parse_python() {
381        let mut parser = Parser::new();
382        let source = r#"
383def hello_world():
384    """This is a docstring"""
385    print("Hello, World!")
386
387class MyClass:
388    def method(self, x):
389        return x * 2
390"#;
391
392        let symbols = parser.parse(source, Language::Python).unwrap();
393        assert!(!symbols.is_empty());
394
395        // Find function
396        let func = symbols
397            .iter()
398            .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
399        assert!(func.is_some());
400
401        // Find class
402        let class = symbols
403            .iter()
404            .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
405        assert!(class.is_some());
406
407        // Find method
408        let method = symbols
409            .iter()
410            .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
411        assert!(method.is_some());
412    }
413
414    #[test]
415    fn test_parse_rust() {
416        let mut parser = Parser::new();
417        let source = r#"
418/// A test function
419fn test_function() -> i32 {
420    42
421}
422
423struct MyStruct {
424    field: i32,
425}
426
427enum MyEnum {
428    Variant1,
429    Variant2,
430}
431"#;
432
433        let symbols = parser.parse(source, Language::Rust).unwrap();
434        assert!(!symbols.is_empty());
435
436        // Find function
437        let func = symbols
438            .iter()
439            .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
440        assert!(func.is_some());
441
442        // Find struct
443        let struct_sym = symbols
444            .iter()
445            .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
446        assert!(struct_sym.is_some());
447
448        // Find enum
449        let enum_sym = symbols
450            .iter()
451            .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
452        assert!(enum_sym.is_some());
453    }
454
455    #[test]
456    fn test_parse_javascript() {
457        let mut parser = Parser::new();
458        let source = r#"
459function testFunction() {
460    return 42;
461}
462
463class TestClass {
464    testMethod() {
465        return "test";
466    }
467}
468
469const arrowFunc = () => {
470    console.log("arrow");
471};
472"#;
473
474        let symbols = parser.parse(source, Language::JavaScript).unwrap();
475        assert!(!symbols.is_empty());
476
477        // Find function
478        let func = symbols
479            .iter()
480            .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
481        assert!(func.is_some());
482
483        // Find class
484        let class = symbols
485            .iter()
486            .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
487        assert!(class.is_some());
488    }
489
490    #[test]
491    fn test_parse_typescript() {
492        let mut parser = Parser::new();
493        let source = r#"
494interface TestInterface {
495    method(): void;
496}
497
498enum TestEnum {
499    Value1,
500    Value2
501}
502
503class TestClass implements TestInterface {
504    method(): void {
505        console.log("test");
506    }
507}
508"#;
509
510        let symbols = parser.parse(source, Language::TypeScript).unwrap();
511        assert!(!symbols.is_empty());
512
513        // Find interface
514        let interface = symbols
515            .iter()
516            .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
517        assert!(interface.is_some());
518
519        // Find enum
520        let enum_sym = symbols
521            .iter()
522            .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
523        assert!(enum_sym.is_some());
524    }
525
526    #[test]
527    fn test_symbol_metadata() {
528        let mut parser = Parser::new();
529        let source = r#"
530def test_func(x, y):
531    """A test function with params"""
532    return x + y
533"#;
534
535        let symbols = parser.parse(source, Language::Python).unwrap();
536        let func = symbols
537            .iter()
538            .find(|s| s.name == "test_func")
539            .expect("Function not found");
540
541        assert!(func.start_line > 0);
542        assert!(func.end_line >= func.start_line);
543        assert!(func.signature.is_some());
544        assert!(func.docstring.is_some());
545    }
546}