infiniloom_engine/parser/
core.rs

1//! Core parser implementation for symbol extraction
2//!
3//! This module contains the main Parser struct and symbol extraction logic.
4//! Language definitions and queries are in separate modules for better organization.
5
6use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15/// Parser errors
16#[derive(Debug, Error)]
17pub enum ParserError {
18    #[error("Unsupported language: {0}")]
19    UnsupportedLanguage(String),
20
21    #[error("Parse error: {0}")]
22    ParseError(String),
23
24    #[error("Query error: {0}")]
25    QueryError(String),
26
27    #[error("Invalid UTF-8 in source code")]
28    InvalidUtf8,
29}
30
31/// Main parser struct for extracting code symbols
32/// Uses lazy initialization - parsers are only created when first needed
33///
34/// # Performance
35///
36/// The parser uses "super-queries" that combine symbol extraction, imports, and call
37/// expressions into a single tree traversal per file. This is more efficient than
38/// running multiple separate queries.
39pub struct Parser {
40    parsers: HashMap<Language, TSParser>,
41    queries: HashMap<Language, Query>,
42    /// Super-queries that combine symbols + imports in one pass
43    super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47    /// Create a new parser instance with lazy initialization
48    /// Parsers and queries are created on-demand when parse() is called
49    pub fn new() -> Self {
50        Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51    }
52
53    /// Ensure parser and query are initialized for a language
54    fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55        use std::collections::hash_map::Entry;
56        if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57            let (parser, query, super_query) = match language {
58                Language::Python => (
59                    init::python()?,
60                    query_builder::python_query()?,
61                    query_builder::python_super_query()?,
62                ),
63                Language::JavaScript => (
64                    init::javascript()?,
65                    query_builder::javascript_query()?,
66                    query_builder::javascript_super_query()?,
67                ),
68                Language::TypeScript => (
69                    init::typescript()?,
70                    query_builder::typescript_query()?,
71                    query_builder::typescript_super_query()?,
72                ),
73                Language::Rust => (
74                    init::rust()?,
75                    query_builder::rust_query()?,
76                    query_builder::rust_super_query()?,
77                ),
78                Language::Go => {
79                    (init::go()?, query_builder::go_query()?, query_builder::go_super_query()?)
80                },
81                Language::Java => (
82                    init::java()?,
83                    query_builder::java_query()?,
84                    query_builder::java_super_query()?,
85                ),
86                Language::C => {
87                    (init::c()?, query_builder::c_query()?, query_builder::c_super_query()?)
88                },
89                Language::Cpp => {
90                    (init::cpp()?, query_builder::cpp_query()?, query_builder::cpp_super_query()?)
91                },
92                Language::CSharp => (
93                    init::csharp()?,
94                    query_builder::csharp_query()?,
95                    query_builder::csharp_super_query()?,
96                ),
97                Language::Ruby => (
98                    init::ruby()?,
99                    query_builder::ruby_query()?,
100                    query_builder::ruby_super_query()?,
101                ),
102                Language::Bash => (
103                    init::bash()?,
104                    query_builder::bash_query()?,
105                    query_builder::bash_super_query()?,
106                ),
107                Language::Php => {
108                    (init::php()?, query_builder::php_query()?, query_builder::php_super_query()?)
109                },
110                Language::Kotlin => (
111                    init::kotlin()?,
112                    query_builder::kotlin_query()?,
113                    query_builder::kotlin_super_query()?,
114                ),
115                Language::Swift => (
116                    init::swift()?,
117                    query_builder::swift_query()?,
118                    query_builder::swift_super_query()?,
119                ),
120                Language::Scala => (
121                    init::scala()?,
122                    query_builder::scala_query()?,
123                    query_builder::scala_super_query()?,
124                ),
125                Language::Haskell => (
126                    init::haskell()?,
127                    query_builder::haskell_query()?,
128                    query_builder::haskell_super_query()?,
129                ),
130                Language::Elixir => (
131                    init::elixir()?,
132                    query_builder::elixir_query()?,
133                    query_builder::elixir_super_query()?,
134                ),
135                Language::Clojure => (
136                    init::clojure()?,
137                    query_builder::clojure_query()?,
138                    query_builder::clojure_super_query()?,
139                ),
140                Language::OCaml => (
141                    init::ocaml()?,
142                    query_builder::ocaml_query()?,
143                    query_builder::ocaml_super_query()?,
144                ),
145                Language::FSharp => {
146                    return Err(ParserError::UnsupportedLanguage(
147                        "F# not yet supported (no tree-sitter grammar available)".to_owned(),
148                    ));
149                },
150                Language::Lua => {
151                    (init::lua()?, query_builder::lua_query()?, query_builder::lua_super_query()?)
152                },
153                Language::R => {
154                    (init::r()?, query_builder::r_query()?, query_builder::r_super_query()?)
155                },
156            };
157            parser_entry.insert(parser);
158            self.queries.insert(language, query);
159            self.super_queries.insert(language, super_query);
160        }
161        Ok(())
162    }
163
164    /// Parse source code and extract symbols
165    ///
166    /// This method now uses "super-queries" that combine symbol extraction and imports
167    /// into a single AST traversal for better performance.
168    pub fn parse(
169        &mut self,
170        source_code: &str,
171        language: Language,
172    ) -> Result<Vec<Symbol>, ParserError> {
173        // Lazy initialization - only init parser for this language
174        self.ensure_initialized(language)?;
175
176        let parser = self
177            .parsers
178            .get_mut(&language)
179            .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
180
181        let tree = parser
182            .parse(source_code, None)
183            .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
184
185        // Use super-query for single-pass extraction (symbols + imports)
186        let super_query = self
187            .super_queries
188            .get(&language)
189            .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
190
191        self.extract_symbols_single_pass(&tree, source_code, super_query, language)
192    }
193
194    /// Extract symbols using single-pass super-query (combines symbols + imports)
195    fn extract_symbols_single_pass(
196        &self,
197        tree: &Tree,
198        source_code: &str,
199        query: &Query,
200        language: Language,
201    ) -> Result<Vec<Symbol>, ParserError> {
202        let mut symbols = Vec::new();
203        let mut cursor = QueryCursor::new();
204        let root_node = tree.root_node();
205
206        let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
207        let capture_names: Vec<&str> = query.capture_names().to_vec();
208
209        while let Some(m) = matches.next() {
210            // Process imports (captured with @import)
211            if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
212                symbols.push(import_symbol);
213                continue;
214            }
215
216            // Process regular symbols (functions, classes, etc.)
217            if let Some(symbol) =
218                self.process_match_single_pass(m, source_code, &capture_names, language)
219            {
220                symbols.push(symbol);
221            }
222        }
223
224        Ok(symbols)
225    }
226
227    /// Process an import match from super-query
228    fn process_import_match(
229        &self,
230        m: &tree_sitter::QueryMatch<'_, '_>,
231        source_code: &str,
232        capture_names: &[&str],
233    ) -> Option<Symbol> {
234        let captures = &m.captures;
235
236        // Look for import capture
237        let import_capture = captures.iter().find(|c| {
238            capture_names
239                .get(c.index as usize)
240                .map(|n| *n == "import")
241                .unwrap_or(false)
242        })?;
243
244        let node = import_capture.node;
245        let text = node.utf8_text(source_code.as_bytes()).ok()?;
246
247        let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
248        symbol.start_line = node.start_position().row as u32 + 1;
249        symbol.end_line = node.end_position().row as u32 + 1;
250
251        Some(symbol)
252    }
253
254    /// Process a symbol match from super-query (single-pass version)
255    fn process_match_single_pass(
256        &self,
257        m: &tree_sitter::QueryMatch<'_, '_>,
258        source_code: &str,
259        capture_names: &[&str],
260        language: Language,
261    ) -> Option<Symbol> {
262        let captures = &m.captures;
263
264        // Find name capture
265        let name_node = captures
266            .iter()
267            .find(|c| {
268                capture_names
269                    .get(c.index as usize)
270                    .map(|n| *n == "name")
271                    .unwrap_or(false)
272            })?
273            .node;
274
275        // Find kind capture (function, class, method, etc.)
276        let kind_capture = captures.iter().find(|c| {
277            capture_names
278                .get(c.index as usize)
279                .map(|n| {
280                    ["function", "class", "method", "struct", "enum", "interface", "trait"]
281                        .contains(n)
282                })
283                .unwrap_or(false)
284        })?;
285
286        let kind_name = capture_names.get(kind_capture.index as usize)?;
287        let mut symbol_kind = extraction::map_symbol_kind(kind_name);
288
289        let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
290
291        // Find the definition node (usually the largest capture)
292        let def_node = captures
293            .iter()
294            .max_by_key(|c| c.node.byte_range().len())
295            .map(|c| c.node)
296            .unwrap_or(name_node);
297
298        if language == Language::Kotlin && def_node.kind() == "class_declaration" {
299            let mut cursor = def_node.walk();
300            for child in def_node.children(&mut cursor) {
301                if child.kind() == "interface" {
302                    symbol_kind = SymbolKind::Interface;
303                    break;
304                }
305            }
306        }
307
308        let start_line = def_node.start_position().row as u32 + 1;
309        let end_line = def_node.end_position().row as u32 + 1;
310
311        // Extract signature, docstring, parent, visibility, calls
312        let signature = extraction::extract_signature(def_node, source_code, language);
313        let docstring = extraction::extract_docstring(def_node, source_code, language);
314        let parent = if symbol_kind == SymbolKind::Method {
315            extraction::extract_parent(def_node, source_code)
316        } else {
317            None
318        };
319        let visibility = extraction::extract_visibility(def_node, source_code, language);
320        let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
321            extraction::extract_calls(def_node, source_code, language)
322        } else {
323            Vec::new()
324        };
325
326        // Extract inheritance info for classes, structs, interfaces
327        let (extends, implements) = if matches!(
328            symbol_kind,
329            SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
330        ) {
331            extraction::extract_inheritance(def_node, source_code, language)
332        } else {
333            (None, Vec::new())
334        };
335
336        let mut symbol = Symbol::new(name, symbol_kind);
337        symbol.start_line = start_line;
338        symbol.end_line = end_line;
339        symbol.signature = signature;
340        symbol.docstring = docstring;
341        symbol.parent = parent;
342        symbol.visibility = visibility;
343        symbol.calls = calls;
344        symbol.extends = extends;
345        symbol.implements = implements;
346
347        Some(symbol)
348    }
349}
350
351impl Default for Parser {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357// NOTE: The following extraction functions have been moved to super::extraction module:
358// - extract_signature
359// - extract_docstring
360// - extract_parent
361// - extract_visibility
362// - extract_calls
363// - find_body_node
364// - collect_calls_recursive
365// - is_builtin
366// - clean_jsdoc
367// - clean_javadoc
368// - extract_inheritance
369// - map_symbol_kind
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_language_from_extension() {
377        assert_eq!(Language::from_extension("py"), Some(Language::Python));
378        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
379        assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
380        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
381        assert_eq!(Language::from_extension("go"), Some(Language::Go));
382        assert_eq!(Language::from_extension("java"), Some(Language::Java));
383        assert_eq!(Language::from_extension("unknown"), None);
384    }
385
386    #[test]
387    fn test_parse_python() {
388        let mut parser = Parser::new();
389        let source = r#"
390def hello_world():
391    """This is a docstring"""
392    print("Hello, World!")
393
394class MyClass:
395    def method(self, x):
396        return x * 2
397"#;
398
399        let symbols = parser.parse(source, Language::Python).unwrap();
400        assert!(!symbols.is_empty());
401
402        // Find function
403        let func = symbols
404            .iter()
405            .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
406        assert!(func.is_some());
407
408        // Find class
409        let class = symbols
410            .iter()
411            .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
412        assert!(class.is_some());
413
414        // Find method
415        let method = symbols
416            .iter()
417            .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
418        assert!(method.is_some());
419    }
420
421    #[test]
422    fn test_parse_rust() {
423        let mut parser = Parser::new();
424        let source = r#"
425/// A test function
426fn test_function() -> i32 {
427    42
428}
429
430struct MyStruct {
431    field: i32,
432}
433
434enum MyEnum {
435    Variant1,
436    Variant2,
437}
438"#;
439
440        let symbols = parser.parse(source, Language::Rust).unwrap();
441        assert!(!symbols.is_empty());
442
443        // Find function
444        let func = symbols
445            .iter()
446            .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
447        assert!(func.is_some());
448
449        // Find struct
450        let struct_sym = symbols
451            .iter()
452            .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
453        assert!(struct_sym.is_some());
454
455        // Find enum
456        let enum_sym = symbols
457            .iter()
458            .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
459        assert!(enum_sym.is_some());
460    }
461
462    #[test]
463    fn test_parse_javascript() {
464        let mut parser = Parser::new();
465        let source = r#"
466function testFunction() {
467    return 42;
468}
469
470class TestClass {
471    testMethod() {
472        return "test";
473    }
474}
475
476const arrowFunc = () => {
477    console.log("arrow");
478};
479"#;
480
481        let symbols = parser.parse(source, Language::JavaScript).unwrap();
482        assert!(!symbols.is_empty());
483
484        // Find function
485        let func = symbols
486            .iter()
487            .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
488        assert!(func.is_some());
489
490        // Find class
491        let class = symbols
492            .iter()
493            .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
494        assert!(class.is_some());
495    }
496
497    #[test]
498    fn test_parse_typescript() {
499        let mut parser = Parser::new();
500        let source = r#"
501interface TestInterface {
502    method(): void;
503}
504
505enum TestEnum {
506    Value1,
507    Value2
508}
509
510class TestClass implements TestInterface {
511    method(): void {
512        console.log("test");
513    }
514}
515"#;
516
517        let symbols = parser.parse(source, Language::TypeScript).unwrap();
518        assert!(!symbols.is_empty());
519
520        // Find interface
521        let interface = symbols
522            .iter()
523            .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
524        assert!(interface.is_some());
525
526        // Find enum
527        let enum_sym = symbols
528            .iter()
529            .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
530        assert!(enum_sym.is_some());
531    }
532
533    #[test]
534    fn test_symbol_metadata() {
535        let mut parser = Parser::new();
536        let source = r#"
537def test_func(x, y):
538    """A test function with params"""
539    return x + y
540"#;
541
542        let symbols = parser.parse(source, Language::Python).unwrap();
543        let func = symbols
544            .iter()
545            .find(|s| s.name == "test_func")
546            .expect("Function not found");
547
548        assert!(func.start_line > 0);
549        assert!(func.end_line >= func.start_line);
550        assert!(func.signature.is_some());
551        assert!(func.docstring.is_some());
552    }
553}