infiniloom_engine/parser/
core.rs

1//! Core parser implementation for symbol extraction
2//!
3//! This module contains the main Parser struct and symbol extraction logic.
4//! Language definitions and queries are in separate modules for better organization.
5
6use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15/// Parser errors
16#[derive(Debug, Error)]
17pub enum ParserError {
18    #[error("Unsupported language: {0}")]
19    UnsupportedLanguage(String),
20
21    #[error("Parse error: {0}")]
22    ParseError(String),
23
24    #[error("Query error: {0}")]
25    QueryError(String),
26
27    #[error("Invalid UTF-8 in source code")]
28    InvalidUtf8,
29}
30
31/// Main parser struct for extracting code symbols
32/// Uses lazy initialization - parsers are only created when first needed
33///
34/// # Performance
35///
36/// The parser uses "super-queries" that combine symbol extraction, imports, and call
37/// expressions into a single tree traversal per file. This is more efficient than
38/// running multiple separate queries.
39pub struct Parser {
40    parsers: HashMap<Language, TSParser>,
41    queries: HashMap<Language, Query>,
42    /// Super-queries that combine symbols + imports in one pass
43    super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47    /// Create a new parser instance with lazy initialization
48    /// Parsers and queries are created on-demand when parse() is called
49    pub fn new() -> Self {
50        Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51    }
52
53    /// Ensure parser and query are initialized for a language
54    fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55        use std::collections::hash_map::Entry;
56        if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57            let (parser, query, super_query) = match language {
58                Language::Python => (
59                    init::python()?,
60                    query_builder::python_query()?,
61                    query_builder::python_super_query()?,
62                ),
63                Language::JavaScript => (
64                    init::javascript()?,
65                    query_builder::javascript_query()?,
66                    query_builder::javascript_super_query()?,
67                ),
68                Language::TypeScript => (
69                    init::typescript()?,
70                    query_builder::typescript_query()?,
71                    query_builder::typescript_super_query()?,
72                ),
73                Language::Rust => (
74                    init::rust()?,
75                    query_builder::rust_query()?,
76                    query_builder::rust_super_query()?,
77                ),
78                Language::Go => (
79                    init::go()?,
80                    query_builder::go_query()?,
81                    query_builder::go_super_query()?,
82                ),
83                Language::Java => (
84                    init::java()?,
85                    query_builder::java_query()?,
86                    query_builder::java_super_query()?,
87                ),
88                Language::C => (
89                    init::c()?,
90                    query_builder::c_query()?,
91                    query_builder::c_super_query()?,
92                ),
93                Language::Cpp => (
94                    init::cpp()?,
95                    query_builder::cpp_query()?,
96                    query_builder::cpp_super_query()?,
97                ),
98                Language::CSharp => (
99                    init::csharp()?,
100                    query_builder::csharp_query()?,
101                    query_builder::csharp_super_query()?,
102                ),
103                Language::Ruby => (
104                    init::ruby()?,
105                    query_builder::ruby_query()?,
106                    query_builder::ruby_super_query()?,
107                ),
108                Language::Bash => (
109                    init::bash()?,
110                    query_builder::bash_query()?,
111                    query_builder::bash_super_query()?,
112                ),
113                Language::Php => (
114                    init::php()?,
115                    query_builder::php_query()?,
116                    query_builder::php_super_query()?,
117                ),
118                Language::Kotlin => (
119                    init::kotlin()?,
120                    query_builder::kotlin_query()?,
121                    query_builder::kotlin_super_query()?,
122                ),
123                Language::Swift => (
124                    init::swift()?,
125                    query_builder::swift_query()?,
126                    query_builder::swift_super_query()?,
127                ),
128                Language::Scala => (
129                    init::scala()?,
130                    query_builder::scala_query()?,
131                    query_builder::scala_super_query()?,
132                ),
133                Language::Haskell => (
134                    init::haskell()?,
135                    query_builder::haskell_query()?,
136                    query_builder::haskell_super_query()?,
137                ),
138                Language::Elixir => (
139                    init::elixir()?,
140                    query_builder::elixir_query()?,
141                    query_builder::elixir_super_query()?,
142                ),
143                Language::Clojure => (
144                    init::clojure()?,
145                    query_builder::clojure_query()?,
146                    query_builder::clojure_super_query()?,
147                ),
148                Language::OCaml => (
149                    init::ocaml()?,
150                    query_builder::ocaml_query()?,
151                    query_builder::ocaml_super_query()?,
152                ),
153                Language::FSharp => {
154                    return Err(ParserError::UnsupportedLanguage(
155                        "F# not yet supported (no tree-sitter grammar available)".to_owned(),
156                    ));
157                },
158                Language::Lua => (
159                    init::lua()?,
160                    query_builder::lua_query()?,
161                    query_builder::lua_super_query()?,
162                ),
163                Language::R => (
164                    init::r()?,
165                    query_builder::r_query()?,
166                    query_builder::r_super_query()?,
167                ),
168            };
169            parser_entry.insert(parser);
170            self.queries.insert(language, query);
171            self.super_queries.insert(language, super_query);
172        }
173        Ok(())
174    }
175
176    /// Parse source code and extract symbols
177    ///
178    /// This method now uses "super-queries" that combine symbol extraction and imports
179    /// into a single AST traversal for better performance.
180    pub fn parse(
181        &mut self,
182        source_code: &str,
183        language: Language,
184    ) -> Result<Vec<Symbol>, ParserError> {
185        // Lazy initialization - only init parser for this language
186        self.ensure_initialized(language)?;
187
188        let parser = self
189            .parsers
190            .get_mut(&language)
191            .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
192
193        let tree = parser
194            .parse(source_code, None)
195            .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
196
197        // Use super-query for single-pass extraction (symbols + imports)
198        let super_query = self
199            .super_queries
200            .get(&language)
201            .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
202
203        self.extract_symbols_single_pass(&tree, source_code, super_query, language)
204    }
205
206    /// Extract symbols using single-pass super-query (combines symbols + imports)
207    fn extract_symbols_single_pass(
208        &self,
209        tree: &Tree,
210        source_code: &str,
211        query: &Query,
212        language: Language,
213    ) -> Result<Vec<Symbol>, ParserError> {
214        let mut symbols = Vec::new();
215        let mut cursor = QueryCursor::new();
216        let root_node = tree.root_node();
217
218        let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
219        let capture_names: Vec<&str> = query.capture_names().to_vec();
220
221        while let Some(m) = matches.next() {
222            // Process imports (captured with @import)
223            if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
224                symbols.push(import_symbol);
225                continue;
226            }
227
228            // Process regular symbols (functions, classes, etc.)
229            if let Some(symbol) =
230                self.process_match_single_pass(m, source_code, &capture_names, language)
231            {
232                symbols.push(symbol);
233            }
234        }
235
236        Ok(symbols)
237    }
238
239    /// Process an import match from super-query
240    fn process_import_match(
241        &self,
242        m: &tree_sitter::QueryMatch<'_, '_>,
243        source_code: &str,
244        capture_names: &[&str],
245    ) -> Option<Symbol> {
246        let captures = &m.captures;
247
248        // Look for import capture
249        let import_capture = captures.iter().find(|c| {
250            capture_names
251                .get(c.index as usize)
252                .map(|n| *n == "import")
253                .unwrap_or(false)
254        })?;
255
256        let node = import_capture.node;
257        let text = node.utf8_text(source_code.as_bytes()).ok()?;
258
259        let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
260        symbol.start_line = node.start_position().row as u32 + 1;
261        symbol.end_line = node.end_position().row as u32 + 1;
262
263        Some(symbol)
264    }
265
266    /// Process a symbol match from super-query (single-pass version)
267    fn process_match_single_pass(
268        &self,
269        m: &tree_sitter::QueryMatch<'_, '_>,
270        source_code: &str,
271        capture_names: &[&str],
272        language: Language,
273    ) -> Option<Symbol> {
274        let captures = &m.captures;
275
276        // Find name capture
277        let name_node = captures
278            .iter()
279            .find(|c| {
280                capture_names
281                    .get(c.index as usize)
282                    .map(|n| *n == "name")
283                    .unwrap_or(false)
284            })?
285            .node;
286
287        // Find kind capture (function, class, method, etc.)
288        let kind_capture = captures.iter().find(|c| {
289            capture_names
290                .get(c.index as usize)
291                .map(|n| {
292                    ["function", "class", "method", "struct", "enum", "interface", "trait"]
293                        .contains(n)
294                })
295                .unwrap_or(false)
296        })?;
297
298        let kind_name = capture_names.get(kind_capture.index as usize)?;
299        let mut symbol_kind = extraction::map_symbol_kind(kind_name);
300
301        let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
302
303        // Find the definition node (usually the largest capture)
304        let def_node = captures
305            .iter()
306            .max_by_key(|c| c.node.byte_range().len())
307            .map(|c| c.node)
308            .unwrap_or(name_node);
309
310        if language == Language::Kotlin && def_node.kind() == "class_declaration" {
311            let mut cursor = def_node.walk();
312            for child in def_node.children(&mut cursor) {
313                if child.kind() == "interface" {
314                    symbol_kind = SymbolKind::Interface;
315                    break;
316                }
317            }
318        }
319
320        let start_line = def_node.start_position().row as u32 + 1;
321        let end_line = def_node.end_position().row as u32 + 1;
322
323        // Extract signature, docstring, parent, visibility, calls
324        let signature = extraction::extract_signature(def_node, source_code, language);
325        let docstring = extraction::extract_docstring(def_node, source_code, language);
326        let parent = if symbol_kind == SymbolKind::Method {
327            extraction::extract_parent(def_node, source_code)
328        } else {
329            None
330        };
331        let visibility = extraction::extract_visibility(def_node, source_code, language);
332        let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
333            extraction::extract_calls(def_node, source_code, language)
334        } else {
335            Vec::new()
336        };
337
338        // Extract inheritance info for classes, structs, interfaces
339        let (extends, implements) = if matches!(
340            symbol_kind,
341            SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
342        ) {
343            extraction::extract_inheritance(def_node, source_code, language)
344        } else {
345            (None, Vec::new())
346        };
347
348        let mut symbol = Symbol::new(name, symbol_kind);
349        symbol.start_line = start_line;
350        symbol.end_line = end_line;
351        symbol.signature = signature;
352        symbol.docstring = docstring;
353        symbol.parent = parent;
354        symbol.visibility = visibility;
355        symbol.calls = calls;
356        symbol.extends = extends;
357        symbol.implements = implements;
358
359        Some(symbol)
360    }
361}
362
363impl Default for Parser {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369// NOTE: The following extraction functions have been moved to super::extraction module:
370// - extract_signature
371// - extract_docstring
372// - extract_parent
373// - extract_visibility
374// - extract_calls
375// - find_body_node
376// - collect_calls_recursive
377// - is_builtin
378// - clean_jsdoc
379// - clean_javadoc
380// - extract_inheritance
381// - map_symbol_kind
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_language_from_extension() {
389        assert_eq!(Language::from_extension("py"), Some(Language::Python));
390        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
391        assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
392        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
393        assert_eq!(Language::from_extension("go"), Some(Language::Go));
394        assert_eq!(Language::from_extension("java"), Some(Language::Java));
395        assert_eq!(Language::from_extension("unknown"), None);
396    }
397
398    #[test]
399    fn test_parse_python() {
400        let mut parser = Parser::new();
401        let source = r#"
402def hello_world():
403    """This is a docstring"""
404    print("Hello, World!")
405
406class MyClass:
407    def method(self, x):
408        return x * 2
409"#;
410
411        let symbols = parser.parse(source, Language::Python).unwrap();
412        assert!(!symbols.is_empty());
413
414        // Find function
415        let func = symbols
416            .iter()
417            .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
418        assert!(func.is_some());
419
420        // Find class
421        let class = symbols
422            .iter()
423            .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
424        assert!(class.is_some());
425
426        // Find method
427        let method = symbols
428            .iter()
429            .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
430        assert!(method.is_some());
431    }
432
433    #[test]
434    fn test_parse_rust() {
435        let mut parser = Parser::new();
436        let source = r#"
437/// A test function
438fn test_function() -> i32 {
439    42
440}
441
442struct MyStruct {
443    field: i32,
444}
445
446enum MyEnum {
447    Variant1,
448    Variant2,
449}
450"#;
451
452        let symbols = parser.parse(source, Language::Rust).unwrap();
453        assert!(!symbols.is_empty());
454
455        // Find function
456        let func = symbols
457            .iter()
458            .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
459        assert!(func.is_some());
460
461        // Find struct
462        let struct_sym = symbols
463            .iter()
464            .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
465        assert!(struct_sym.is_some());
466
467        // Find enum
468        let enum_sym = symbols
469            .iter()
470            .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
471        assert!(enum_sym.is_some());
472    }
473
474    #[test]
475    fn test_parse_javascript() {
476        let mut parser = Parser::new();
477        let source = r#"
478function testFunction() {
479    return 42;
480}
481
482class TestClass {
483    testMethod() {
484        return "test";
485    }
486}
487
488const arrowFunc = () => {
489    console.log("arrow");
490};
491"#;
492
493        let symbols = parser.parse(source, Language::JavaScript).unwrap();
494        assert!(!symbols.is_empty());
495
496        // Find function
497        let func = symbols
498            .iter()
499            .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
500        assert!(func.is_some());
501
502        // Find class
503        let class = symbols
504            .iter()
505            .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
506        assert!(class.is_some());
507    }
508
509    #[test]
510    fn test_parse_typescript() {
511        let mut parser = Parser::new();
512        let source = r#"
513interface TestInterface {
514    method(): void;
515}
516
517enum TestEnum {
518    Value1,
519    Value2
520}
521
522class TestClass implements TestInterface {
523    method(): void {
524        console.log("test");
525    }
526}
527"#;
528
529        let symbols = parser.parse(source, Language::TypeScript).unwrap();
530        assert!(!symbols.is_empty());
531
532        // Find interface
533        let interface = symbols
534            .iter()
535            .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
536        assert!(interface.is_some());
537
538        // Find enum
539        let enum_sym = symbols
540            .iter()
541            .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
542        assert!(enum_sym.is_some());
543    }
544
545    #[test]
546    fn test_symbol_metadata() {
547        let mut parser = Parser::new();
548        let source = r#"
549def test_func(x, y):
550    """A test function with params"""
551    return x + y
552"#;
553
554        let symbols = parser.parse(source, Language::Python).unwrap();
555        let func = symbols
556            .iter()
557            .find(|s| s.name == "test_func")
558            .expect("Function not found");
559
560        assert!(func.start_line > 0);
561        assert!(func.end_line >= func.start_line);
562        assert!(func.signature.is_some());
563        assert!(func.docstring.is_some());
564    }
565}