1use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15#[derive(Debug, Error)]
17pub enum ParserError {
18 #[error("Unsupported language: {0}")]
19 UnsupportedLanguage(String),
20
21 #[error("Parse error: {0}")]
22 ParseError(String),
23
24 #[error("Query error: {0}")]
25 QueryError(String),
26
27 #[error("Invalid UTF-8 in source code")]
28 InvalidUtf8,
29}
30
31pub struct Parser {
40 parsers: HashMap<Language, TSParser>,
41 queries: HashMap<Language, Query>,
42 super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47 pub fn new() -> Self {
50 Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51 }
52
53 fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55 use std::collections::hash_map::Entry;
56 if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57 let (parser, query, super_query) = match language {
58 Language::Python => (
59 init::python()?,
60 query_builder::python_query()?,
61 query_builder::python_super_query()?,
62 ),
63 Language::JavaScript => (
64 init::javascript()?,
65 query_builder::javascript_query()?,
66 query_builder::javascript_super_query()?,
67 ),
68 Language::TypeScript => (
69 init::typescript()?,
70 query_builder::typescript_query()?,
71 query_builder::typescript_super_query()?,
72 ),
73 Language::Rust => (
74 init::rust()?,
75 query_builder::rust_query()?,
76 query_builder::rust_super_query()?,
77 ),
78 Language::Go => (
79 init::go()?,
80 query_builder::go_query()?,
81 query_builder::go_super_query()?,
82 ),
83 Language::Java => (
84 init::java()?,
85 query_builder::java_query()?,
86 query_builder::java_super_query()?,
87 ),
88 Language::C => (
89 init::c()?,
90 query_builder::c_query()?,
91 query_builder::c_super_query()?,
92 ),
93 Language::Cpp => (
94 init::cpp()?,
95 query_builder::cpp_query()?,
96 query_builder::cpp_super_query()?,
97 ),
98 Language::CSharp => (
99 init::csharp()?,
100 query_builder::csharp_query()?,
101 query_builder::csharp_super_query()?,
102 ),
103 Language::Ruby => (
104 init::ruby()?,
105 query_builder::ruby_query()?,
106 query_builder::ruby_super_query()?,
107 ),
108 Language::Bash => (
109 init::bash()?,
110 query_builder::bash_query()?,
111 query_builder::bash_super_query()?,
112 ),
113 Language::Php => (
114 init::php()?,
115 query_builder::php_query()?,
116 query_builder::php_super_query()?,
117 ),
118 Language::Kotlin => (
119 init::kotlin()?,
120 query_builder::kotlin_query()?,
121 query_builder::kotlin_super_query()?,
122 ),
123 Language::Swift => (
124 init::swift()?,
125 query_builder::swift_query()?,
126 query_builder::swift_super_query()?,
127 ),
128 Language::Scala => (
129 init::scala()?,
130 query_builder::scala_query()?,
131 query_builder::scala_super_query()?,
132 ),
133 Language::Haskell => (
134 init::haskell()?,
135 query_builder::haskell_query()?,
136 query_builder::haskell_super_query()?,
137 ),
138 Language::Elixir => (
139 init::elixir()?,
140 query_builder::elixir_query()?,
141 query_builder::elixir_super_query()?,
142 ),
143 Language::Clojure => (
144 init::clojure()?,
145 query_builder::clojure_query()?,
146 query_builder::clojure_super_query()?,
147 ),
148 Language::OCaml => (
149 init::ocaml()?,
150 query_builder::ocaml_query()?,
151 query_builder::ocaml_super_query()?,
152 ),
153 Language::FSharp => {
154 return Err(ParserError::UnsupportedLanguage(
155 "F# not yet supported (no tree-sitter grammar available)".to_owned(),
156 ));
157 },
158 Language::Lua => (
159 init::lua()?,
160 query_builder::lua_query()?,
161 query_builder::lua_super_query()?,
162 ),
163 Language::R => (
164 init::r()?,
165 query_builder::r_query()?,
166 query_builder::r_super_query()?,
167 ),
168 };
169 parser_entry.insert(parser);
170 self.queries.insert(language, query);
171 self.super_queries.insert(language, super_query);
172 }
173 Ok(())
174 }
175
176 pub fn parse(
181 &mut self,
182 source_code: &str,
183 language: Language,
184 ) -> Result<Vec<Symbol>, ParserError> {
185 self.ensure_initialized(language)?;
187
188 let parser = self
189 .parsers
190 .get_mut(&language)
191 .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
192
193 let tree = parser
194 .parse(source_code, None)
195 .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
196
197 let super_query = self
199 .super_queries
200 .get(&language)
201 .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
202
203 self.extract_symbols_single_pass(&tree, source_code, super_query, language)
204 }
205
206 fn extract_symbols_single_pass(
208 &self,
209 tree: &Tree,
210 source_code: &str,
211 query: &Query,
212 language: Language,
213 ) -> Result<Vec<Symbol>, ParserError> {
214 let mut symbols = Vec::new();
215 let mut cursor = QueryCursor::new();
216 let root_node = tree.root_node();
217
218 let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
219 let capture_names: Vec<&str> = query.capture_names().to_vec();
220
221 while let Some(m) = matches.next() {
222 if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
224 symbols.push(import_symbol);
225 continue;
226 }
227
228 if let Some(symbol) =
230 self.process_match_single_pass(m, source_code, &capture_names, language)
231 {
232 symbols.push(symbol);
233 }
234 }
235
236 Ok(symbols)
237 }
238
239 fn process_import_match(
241 &self,
242 m: &tree_sitter::QueryMatch<'_, '_>,
243 source_code: &str,
244 capture_names: &[&str],
245 ) -> Option<Symbol> {
246 let captures = &m.captures;
247
248 let import_capture = captures.iter().find(|c| {
250 capture_names
251 .get(c.index as usize)
252 .map(|n| *n == "import")
253 .unwrap_or(false)
254 })?;
255
256 let node = import_capture.node;
257 let text = node.utf8_text(source_code.as_bytes()).ok()?;
258
259 let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
260 symbol.start_line = node.start_position().row as u32 + 1;
261 symbol.end_line = node.end_position().row as u32 + 1;
262
263 Some(symbol)
264 }
265
266 fn process_match_single_pass(
268 &self,
269 m: &tree_sitter::QueryMatch<'_, '_>,
270 source_code: &str,
271 capture_names: &[&str],
272 language: Language,
273 ) -> Option<Symbol> {
274 let captures = &m.captures;
275
276 let name_node = captures
278 .iter()
279 .find(|c| {
280 capture_names
281 .get(c.index as usize)
282 .map(|n| *n == "name")
283 .unwrap_or(false)
284 })?
285 .node;
286
287 let kind_capture = captures.iter().find(|c| {
289 capture_names
290 .get(c.index as usize)
291 .map(|n| {
292 ["function", "class", "method", "struct", "enum", "interface", "trait"]
293 .contains(n)
294 })
295 .unwrap_or(false)
296 })?;
297
298 let kind_name = capture_names.get(kind_capture.index as usize)?;
299 let mut symbol_kind = extraction::map_symbol_kind(kind_name);
300
301 let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
302
303 let def_node = captures
305 .iter()
306 .max_by_key(|c| c.node.byte_range().len())
307 .map(|c| c.node)
308 .unwrap_or(name_node);
309
310 if language == Language::Kotlin && def_node.kind() == "class_declaration" {
311 let mut cursor = def_node.walk();
312 for child in def_node.children(&mut cursor) {
313 if child.kind() == "interface" {
314 symbol_kind = SymbolKind::Interface;
315 break;
316 }
317 }
318 }
319
320 let start_line = def_node.start_position().row as u32 + 1;
321 let end_line = def_node.end_position().row as u32 + 1;
322
323 let signature = extraction::extract_signature(def_node, source_code, language);
325 let docstring = extraction::extract_docstring(def_node, source_code, language);
326 let parent = if symbol_kind == SymbolKind::Method {
327 extraction::extract_parent(def_node, source_code)
328 } else {
329 None
330 };
331 let visibility = extraction::extract_visibility(def_node, source_code, language);
332 let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
333 extraction::extract_calls(def_node, source_code, language)
334 } else {
335 Vec::new()
336 };
337
338 let (extends, implements) = if matches!(
340 symbol_kind,
341 SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
342 ) {
343 extraction::extract_inheritance(def_node, source_code, language)
344 } else {
345 (None, Vec::new())
346 };
347
348 let mut symbol = Symbol::new(name, symbol_kind);
349 symbol.start_line = start_line;
350 symbol.end_line = end_line;
351 symbol.signature = signature;
352 symbol.docstring = docstring;
353 symbol.parent = parent;
354 symbol.visibility = visibility;
355 symbol.calls = calls;
356 symbol.extends = extends;
357 symbol.implements = implements;
358
359 Some(symbol)
360 }
361}
362
363impl Default for Parser {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_language_from_extension() {
389 assert_eq!(Language::from_extension("py"), Some(Language::Python));
390 assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
391 assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
392 assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
393 assert_eq!(Language::from_extension("go"), Some(Language::Go));
394 assert_eq!(Language::from_extension("java"), Some(Language::Java));
395 assert_eq!(Language::from_extension("unknown"), None);
396 }
397
398 #[test]
399 fn test_parse_python() {
400 let mut parser = Parser::new();
401 let source = r#"
402def hello_world():
403 """This is a docstring"""
404 print("Hello, World!")
405
406class MyClass:
407 def method(self, x):
408 return x * 2
409"#;
410
411 let symbols = parser.parse(source, Language::Python).unwrap();
412 assert!(!symbols.is_empty());
413
414 let func = symbols
416 .iter()
417 .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
418 assert!(func.is_some());
419
420 let class = symbols
422 .iter()
423 .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
424 assert!(class.is_some());
425
426 let method = symbols
428 .iter()
429 .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
430 assert!(method.is_some());
431 }
432
433 #[test]
434 fn test_parse_rust() {
435 let mut parser = Parser::new();
436 let source = r#"
437/// A test function
438fn test_function() -> i32 {
439 42
440}
441
442struct MyStruct {
443 field: i32,
444}
445
446enum MyEnum {
447 Variant1,
448 Variant2,
449}
450"#;
451
452 let symbols = parser.parse(source, Language::Rust).unwrap();
453 assert!(!symbols.is_empty());
454
455 let func = symbols
457 .iter()
458 .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
459 assert!(func.is_some());
460
461 let struct_sym = symbols
463 .iter()
464 .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
465 assert!(struct_sym.is_some());
466
467 let enum_sym = symbols
469 .iter()
470 .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
471 assert!(enum_sym.is_some());
472 }
473
474 #[test]
475 fn test_parse_javascript() {
476 let mut parser = Parser::new();
477 let source = r#"
478function testFunction() {
479 return 42;
480}
481
482class TestClass {
483 testMethod() {
484 return "test";
485 }
486}
487
488const arrowFunc = () => {
489 console.log("arrow");
490};
491"#;
492
493 let symbols = parser.parse(source, Language::JavaScript).unwrap();
494 assert!(!symbols.is_empty());
495
496 let func = symbols
498 .iter()
499 .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
500 assert!(func.is_some());
501
502 let class = symbols
504 .iter()
505 .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
506 assert!(class.is_some());
507 }
508
509 #[test]
510 fn test_parse_typescript() {
511 let mut parser = Parser::new();
512 let source = r#"
513interface TestInterface {
514 method(): void;
515}
516
517enum TestEnum {
518 Value1,
519 Value2
520}
521
522class TestClass implements TestInterface {
523 method(): void {
524 console.log("test");
525 }
526}
527"#;
528
529 let symbols = parser.parse(source, Language::TypeScript).unwrap();
530 assert!(!symbols.is_empty());
531
532 let interface = symbols
534 .iter()
535 .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
536 assert!(interface.is_some());
537
538 let enum_sym = symbols
540 .iter()
541 .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
542 assert!(enum_sym.is_some());
543 }
544
545 #[test]
546 fn test_symbol_metadata() {
547 let mut parser = Parser::new();
548 let source = r#"
549def test_func(x, y):
550 """A test function with params"""
551 return x + y
552"#;
553
554 let symbols = parser.parse(source, Language::Python).unwrap();
555 let func = symbols
556 .iter()
557 .find(|s| s.name == "test_func")
558 .expect("Function not found");
559
560 assert!(func.start_line > 0);
561 assert!(func.end_line >= func.start_line);
562 assert!(func.signature.is_some());
563 assert!(func.docstring.is_some());
564 }
565}