Skip to main content

infiniloom_engine/parser/
core.rs

1//! Core parser implementation for symbol extraction
2//!
3//! This module contains the main Parser struct and symbol extraction logic.
4//! Language definitions and queries are in separate modules for better organization.
5
6use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15/// Parser errors
16#[derive(Debug, Error)]
17pub enum ParserError {
18    #[error("Unsupported language: {0}")]
19    UnsupportedLanguage(String),
20
21    #[error("Parse error: {0}")]
22    ParseError(String),
23
24    #[error("Query error: {0}")]
25    QueryError(String),
26
27    #[error("Invalid UTF-8 in source code")]
28    InvalidUtf8,
29}
30
31/// Main parser struct for extracting code symbols
32/// Uses lazy initialization - parsers are only created when first needed
33///
34/// # Performance
35///
36/// The parser uses "super-queries" that combine symbol extraction, imports, and call
37/// expressions into a single tree traversal per file. This is more efficient than
38/// running multiple separate queries.
39pub struct Parser {
40    parsers: HashMap<Language, TSParser>,
41    queries: HashMap<Language, Query>,
42    /// Super-queries that combine symbols + imports in one pass
43    super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47    /// Create a new parser instance with lazy initialization
48    /// Parsers and queries are created on-demand when parse() is called
49    pub fn new() -> Self {
50        Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51    }
52
53    /// Ensure parser and query are initialized for a language
54    fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55        use std::collections::hash_map::Entry;
56        if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57            let (parser, query, super_query) = match language {
58                Language::Python => (
59                    init::python()?,
60                    query_builder::python_query()?,
61                    query_builder::python_super_query()?,
62                ),
63                Language::JavaScript => (
64                    init::javascript()?,
65                    query_builder::javascript_query()?,
66                    query_builder::javascript_super_query()?,
67                ),
68                Language::TypeScript => (
69                    init::typescript()?,
70                    query_builder::typescript_query()?,
71                    query_builder::typescript_super_query()?,
72                ),
73                Language::Rust => (
74                    init::rust()?,
75                    query_builder::rust_query()?,
76                    query_builder::rust_super_query()?,
77                ),
78                Language::Go => {
79                    (init::go()?, query_builder::go_query()?, query_builder::go_super_query()?)
80                },
81                Language::Java => (
82                    init::java()?,
83                    query_builder::java_query()?,
84                    query_builder::java_super_query()?,
85                ),
86                Language::C => {
87                    (init::c()?, query_builder::c_query()?, query_builder::c_super_query()?)
88                },
89                Language::Cpp => {
90                    (init::cpp()?, query_builder::cpp_query()?, query_builder::cpp_super_query()?)
91                },
92                Language::CSharp => (
93                    init::csharp()?,
94                    query_builder::csharp_query()?,
95                    query_builder::csharp_super_query()?,
96                ),
97                Language::Ruby => (
98                    init::ruby()?,
99                    query_builder::ruby_query()?,
100                    query_builder::ruby_super_query()?,
101                ),
102                Language::Bash => (
103                    init::bash()?,
104                    query_builder::bash_query()?,
105                    query_builder::bash_super_query()?,
106                ),
107                Language::Php => {
108                    (init::php()?, query_builder::php_query()?, query_builder::php_super_query()?)
109                },
110                Language::Kotlin => (
111                    init::kotlin()?,
112                    query_builder::kotlin_query()?,
113                    query_builder::kotlin_super_query()?,
114                ),
115                Language::Swift => (
116                    init::swift()?,
117                    query_builder::swift_query()?,
118                    query_builder::swift_super_query()?,
119                ),
120                Language::Scala => (
121                    init::scala()?,
122                    query_builder::scala_query()?,
123                    query_builder::scala_super_query()?,
124                ),
125                Language::Haskell => (
126                    init::haskell()?,
127                    query_builder::haskell_query()?,
128                    query_builder::haskell_super_query()?,
129                ),
130                Language::Elixir => (
131                    init::elixir()?,
132                    query_builder::elixir_query()?,
133                    query_builder::elixir_super_query()?,
134                ),
135                Language::Clojure => (
136                    init::clojure()?,
137                    query_builder::clojure_query()?,
138                    query_builder::clojure_super_query()?,
139                ),
140                Language::OCaml => (
141                    init::ocaml()?,
142                    query_builder::ocaml_query()?,
143                    query_builder::ocaml_super_query()?,
144                ),
145                Language::FSharp => {
146                    return Err(ParserError::UnsupportedLanguage(
147                        "F# not yet supported (no tree-sitter grammar available)".to_owned(),
148                    ));
149                },
150                Language::Lua => {
151                    (init::lua()?, query_builder::lua_query()?, query_builder::lua_super_query()?)
152                },
153                Language::R => {
154                    (init::r()?, query_builder::r_query()?, query_builder::r_super_query()?)
155                },
156                Language::Hcl => {
157                    (init::hcl()?, query_builder::hcl_query()?, query_builder::hcl_super_query()?)
158                },
159                Language::Zig => {
160                    (init::zig()?, query_builder::zig_query()?, query_builder::zig_super_query()?)
161                },
162                Language::Dart => (
163                    init::dart()?,
164                    query_builder::dart_query()?,
165                    query_builder::dart_super_query()?,
166                ),
167            };
168            parser_entry.insert(parser);
169            self.queries.insert(language, query);
170            self.super_queries.insert(language, super_query);
171        }
172        Ok(())
173    }
174
175    /// Parse source code and extract symbols
176    ///
177    /// This method now uses "super-queries" that combine symbol extraction and imports
178    /// into a single AST traversal for better performance.
179    pub fn parse(
180        &mut self,
181        source_code: &str,
182        language: Language,
183    ) -> Result<Vec<Symbol>, ParserError> {
184        // Lazy initialization - only init parser for this language
185        self.ensure_initialized(language)?;
186
187        let parser = self
188            .parsers
189            .get_mut(&language)
190            .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
191
192        let tree = parser
193            .parse(source_code, None)
194            .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
195
196        // Use super-query for single-pass extraction (symbols + imports)
197        let super_query = self
198            .super_queries
199            .get(&language)
200            .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
201
202        self.extract_symbols_single_pass(&tree, source_code, super_query, language)
203    }
204
205    /// Extract symbols using single-pass super-query (combines symbols + imports)
206    fn extract_symbols_single_pass(
207        &self,
208        tree: &Tree,
209        source_code: &str,
210        query: &Query,
211        language: Language,
212    ) -> Result<Vec<Symbol>, ParserError> {
213        let mut symbols = Vec::new();
214        let mut cursor = QueryCursor::new();
215        let root_node = tree.root_node();
216
217        let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
218        let capture_names: Vec<&str> = query.capture_names().to_vec();
219
220        while let Some(m) = matches.next() {
221            // Process imports (captured with @import)
222            if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
223                symbols.push(import_symbol);
224                continue;
225            }
226
227            // Process regular symbols (functions, classes, etc.)
228            if let Some(symbol) =
229                self.process_match_single_pass(m, source_code, &capture_names, language)
230            {
231                symbols.push(symbol);
232            }
233        }
234
235        Ok(symbols)
236    }
237
238    /// Process an import match from super-query
239    fn process_import_match(
240        &self,
241        m: &tree_sitter::QueryMatch<'_, '_>,
242        source_code: &str,
243        capture_names: &[&str],
244    ) -> Option<Symbol> {
245        let captures = &m.captures;
246
247        // Look for import capture
248        let import_capture = captures.iter().find(|c| {
249            capture_names
250                .get(c.index as usize)
251                .is_some_and(|n| *n == "import")
252        })?;
253
254        let node = import_capture.node;
255        let text = node.utf8_text(source_code.as_bytes()).ok()?;
256
257        let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
258        symbol.start_line = node.start_position().row as u32 + 1;
259        symbol.end_line = node.end_position().row as u32 + 1;
260
261        Some(symbol)
262    }
263
264    /// Process a symbol match from super-query (single-pass version)
265    fn process_match_single_pass(
266        &self,
267        m: &tree_sitter::QueryMatch<'_, '_>,
268        source_code: &str,
269        capture_names: &[&str],
270        language: Language,
271    ) -> Option<Symbol> {
272        let captures = &m.captures;
273
274        // Find name capture
275        let name_node = captures
276            .iter()
277            .find(|c| {
278                capture_names
279                    .get(c.index as usize)
280                    .is_some_and(|n| *n == "name")
281            })?
282            .node;
283
284        // Find kind capture (function, class, method, etc.)
285        let kind_capture = captures.iter().find(|c| {
286            capture_names.get(c.index as usize).is_some_and(|n| {
287                [
288                    "function",
289                    "class",
290                    "method",
291                    "struct",
292                    "enum",
293                    "interface",
294                    "trait",
295                    "constant",
296                    "module",
297                ]
298                .contains(n)
299            })
300        })?;
301
302        let kind_name = capture_names.get(kind_capture.index as usize)?;
303        let mut symbol_kind = extraction::map_symbol_kind(kind_name);
304
305        let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
306
307        // Find the definition node (usually the largest capture)
308        let def_node = captures
309            .iter()
310            .max_by_key(|c| c.node.byte_range().len())
311            .map_or(name_node, |c| c.node);
312
313        if language == Language::Kotlin && def_node.kind() == "class_declaration" {
314            let mut cursor = def_node.walk();
315            for child in def_node.children(&mut cursor) {
316                if child.kind() == "interface" {
317                    symbol_kind = SymbolKind::Interface;
318                    break;
319                }
320            }
321        }
322
323        let start_line = def_node.start_position().row as u32 + 1;
324        let end_line = def_node.end_position().row as u32 + 1;
325
326        // Extract signature, docstring, parent, visibility, calls
327        let signature = extraction::extract_signature(def_node, source_code, language);
328        let docstring = extraction::extract_docstring(def_node, source_code, language);
329        let parent = if symbol_kind == SymbolKind::Method {
330            extraction::extract_parent(def_node, source_code)
331        } else {
332            None
333        };
334        let visibility = extraction::extract_visibility(def_node, source_code, language);
335        let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
336            extraction::extract_calls(def_node, source_code, language)
337        } else {
338            Vec::new()
339        };
340
341        // Extract inheritance info for classes, structs, interfaces
342        let (extends, implements) = if matches!(
343            symbol_kind,
344            SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
345        ) {
346            extraction::extract_inheritance(def_node, source_code, language)
347        } else {
348            (None, Vec::new())
349        };
350
351        let mut symbol = Symbol::new(name, symbol_kind);
352        symbol.start_line = start_line;
353        symbol.end_line = end_line;
354        symbol.signature = signature;
355        symbol.docstring = docstring;
356        symbol.parent = parent;
357        symbol.visibility = visibility;
358        symbol.calls = calls;
359        symbol.extends = extends;
360        symbol.implements = implements;
361
362        Some(symbol)
363    }
364}
365
366impl Default for Parser {
367    fn default() -> Self {
368        Self::new()
369    }
370}
371
372// NOTE: The following extraction functions have been moved to super::extraction module:
373// - extract_signature
374// - extract_docstring
375// - extract_parent
376// - extract_visibility
377// - extract_calls
378// - find_body_node
379// - collect_calls_recursive
380// - is_builtin
381// - clean_jsdoc
382// - clean_javadoc
383// - extract_inheritance
384// - map_symbol_kind
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_language_from_extension() {
392        assert_eq!(Language::from_extension("py"), Some(Language::Python));
393        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
394        assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
395        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
396        assert_eq!(Language::from_extension("go"), Some(Language::Go));
397        assert_eq!(Language::from_extension("java"), Some(Language::Java));
398        assert_eq!(Language::from_extension("unknown"), None);
399    }
400
401    #[test]
402    fn test_parse_python() {
403        let mut parser = Parser::new();
404        let source = r#"
405def hello_world():
406    """This is a docstring"""
407    print("Hello, World!")
408
409class MyClass:
410    def method(self, x):
411        return x * 2
412"#;
413
414        let symbols = parser.parse(source, Language::Python).unwrap();
415        assert!(!symbols.is_empty());
416
417        // Find function
418        let func = symbols
419            .iter()
420            .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
421        assert!(func.is_some());
422
423        // Find class
424        let class = symbols
425            .iter()
426            .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
427        assert!(class.is_some());
428
429        // Find method
430        let method = symbols
431            .iter()
432            .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
433        assert!(method.is_some());
434    }
435
436    #[test]
437    fn test_parse_rust() {
438        let mut parser = Parser::new();
439        let source = r#"
440/// A test function
441fn test_function() -> i32 {
442    42
443}
444
445struct MyStruct {
446    field: i32,
447}
448
449enum MyEnum {
450    Variant1,
451    Variant2,
452}
453"#;
454
455        let symbols = parser.parse(source, Language::Rust).unwrap();
456        assert!(!symbols.is_empty());
457
458        // Find function
459        let func = symbols
460            .iter()
461            .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
462        assert!(func.is_some());
463
464        // Find struct
465        let struct_sym = symbols
466            .iter()
467            .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
468        assert!(struct_sym.is_some());
469
470        // Find enum
471        let enum_sym = symbols
472            .iter()
473            .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
474        assert!(enum_sym.is_some());
475    }
476
477    #[test]
478    fn test_parse_javascript() {
479        let mut parser = Parser::new();
480        let source = r#"
481function testFunction() {
482    return 42;
483}
484
485class TestClass {
486    testMethod() {
487        return "test";
488    }
489}
490
491const arrowFunc = () => {
492    console.log("arrow");
493};
494"#;
495
496        let symbols = parser.parse(source, Language::JavaScript).unwrap();
497        assert!(!symbols.is_empty());
498
499        // Find function
500        let func = symbols
501            .iter()
502            .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
503        assert!(func.is_some());
504
505        // Find class
506        let class = symbols
507            .iter()
508            .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
509        assert!(class.is_some());
510    }
511
512    #[test]
513    fn test_parse_typescript() {
514        let mut parser = Parser::new();
515        let source = r#"
516interface TestInterface {
517    method(): void;
518}
519
520enum TestEnum {
521    Value1,
522    Value2
523}
524
525class TestClass implements TestInterface {
526    method(): void {
527        console.log("test");
528    }
529}
530"#;
531
532        let symbols = parser.parse(source, Language::TypeScript).unwrap();
533        assert!(!symbols.is_empty());
534
535        // Find interface
536        let interface = symbols
537            .iter()
538            .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
539        assert!(interface.is_some());
540
541        // Find enum
542        let enum_sym = symbols
543            .iter()
544            .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
545        assert!(enum_sym.is_some());
546    }
547
548    #[test]
549    fn test_symbol_metadata() {
550        let mut parser = Parser::new();
551        let source = r#"
552def test_func(x, y):
553    """A test function with params"""
554    return x + y
555"#;
556
557        let symbols = parser.parse(source, Language::Python).unwrap();
558        let func = symbols
559            .iter()
560            .find(|s| s.name == "test_func")
561            .expect("Function not found");
562
563        assert!(func.start_line > 0);
564        assert!(func.end_line >= func.start_line);
565        assert!(func.signature.is_some());
566        assert!(func.docstring.is_some());
567    }
568}