1use super::extraction;
7use super::init;
8use super::language::Language;
9use super::query_builder;
10use crate::types::{Symbol, SymbolKind};
11use std::collections::HashMap;
12use thiserror::Error;
13use tree_sitter::{Parser as TSParser, Query, QueryCursor, StreamingIterator, Tree};
14
15#[derive(Debug, Error)]
17pub enum ParserError {
18 #[error("Unsupported language: {0}")]
19 UnsupportedLanguage(String),
20
21 #[error("Parse error: {0}")]
22 ParseError(String),
23
24 #[error("Query error: {0}")]
25 QueryError(String),
26
27 #[error("Invalid UTF-8 in source code")]
28 InvalidUtf8,
29}
30
31pub struct Parser {
40 parsers: HashMap<Language, TSParser>,
41 queries: HashMap<Language, Query>,
42 super_queries: HashMap<Language, Query>,
44}
45
46impl Parser {
47 pub fn new() -> Self {
50 Self { parsers: HashMap::new(), queries: HashMap::new(), super_queries: HashMap::new() }
51 }
52
53 fn ensure_initialized(&mut self, language: Language) -> Result<(), ParserError> {
55 use std::collections::hash_map::Entry;
56 if let Entry::Vacant(parser_entry) = self.parsers.entry(language) {
57 let (parser, query, super_query) = match language {
58 Language::Python => (
59 init::python()?,
60 query_builder::python_query()?,
61 query_builder::python_super_query()?,
62 ),
63 Language::JavaScript => (
64 init::javascript()?,
65 query_builder::javascript_query()?,
66 query_builder::javascript_super_query()?,
67 ),
68 Language::TypeScript => (
69 init::typescript()?,
70 query_builder::typescript_query()?,
71 query_builder::typescript_super_query()?,
72 ),
73 Language::Rust => (
74 init::rust()?,
75 query_builder::rust_query()?,
76 query_builder::rust_super_query()?,
77 ),
78 Language::Go => {
79 (init::go()?, query_builder::go_query()?, query_builder::go_super_query()?)
80 },
81 Language::Java => (
82 init::java()?,
83 query_builder::java_query()?,
84 query_builder::java_super_query()?,
85 ),
86 Language::C => {
87 (init::c()?, query_builder::c_query()?, query_builder::c_super_query()?)
88 },
89 Language::Cpp => {
90 (init::cpp()?, query_builder::cpp_query()?, query_builder::cpp_super_query()?)
91 },
92 Language::CSharp => (
93 init::csharp()?,
94 query_builder::csharp_query()?,
95 query_builder::csharp_super_query()?,
96 ),
97 Language::Ruby => (
98 init::ruby()?,
99 query_builder::ruby_query()?,
100 query_builder::ruby_super_query()?,
101 ),
102 Language::Bash => (
103 init::bash()?,
104 query_builder::bash_query()?,
105 query_builder::bash_super_query()?,
106 ),
107 Language::Php => {
108 (init::php()?, query_builder::php_query()?, query_builder::php_super_query()?)
109 },
110 Language::Kotlin => (
111 init::kotlin()?,
112 query_builder::kotlin_query()?,
113 query_builder::kotlin_super_query()?,
114 ),
115 Language::Swift => (
116 init::swift()?,
117 query_builder::swift_query()?,
118 query_builder::swift_super_query()?,
119 ),
120 Language::Scala => (
121 init::scala()?,
122 query_builder::scala_query()?,
123 query_builder::scala_super_query()?,
124 ),
125 Language::Haskell => (
126 init::haskell()?,
127 query_builder::haskell_query()?,
128 query_builder::haskell_super_query()?,
129 ),
130 Language::Elixir => (
131 init::elixir()?,
132 query_builder::elixir_query()?,
133 query_builder::elixir_super_query()?,
134 ),
135 Language::Clojure => (
136 init::clojure()?,
137 query_builder::clojure_query()?,
138 query_builder::clojure_super_query()?,
139 ),
140 Language::OCaml => (
141 init::ocaml()?,
142 query_builder::ocaml_query()?,
143 query_builder::ocaml_super_query()?,
144 ),
145 Language::FSharp => {
146 return Err(ParserError::UnsupportedLanguage(
147 "F# not yet supported (no tree-sitter grammar available)".to_owned(),
148 ));
149 },
150 Language::Lua => {
151 (init::lua()?, query_builder::lua_query()?, query_builder::lua_super_query()?)
152 },
153 Language::R => {
154 (init::r()?, query_builder::r_query()?, query_builder::r_super_query()?)
155 },
156 };
157 parser_entry.insert(parser);
158 self.queries.insert(language, query);
159 self.super_queries.insert(language, super_query);
160 }
161 Ok(())
162 }
163
164 pub fn parse(
169 &mut self,
170 source_code: &str,
171 language: Language,
172 ) -> Result<Vec<Symbol>, ParserError> {
173 self.ensure_initialized(language)?;
175
176 let parser = self
177 .parsers
178 .get_mut(&language)
179 .ok_or_else(|| ParserError::UnsupportedLanguage(language.name().to_owned()))?;
180
181 let tree = parser
182 .parse(source_code, None)
183 .ok_or_else(|| ParserError::ParseError("Failed to parse source code".to_owned()))?;
184
185 let super_query = self
187 .super_queries
188 .get(&language)
189 .ok_or_else(|| ParserError::QueryError("No super-query available".to_owned()))?;
190
191 self.extract_symbols_single_pass(&tree, source_code, super_query, language)
192 }
193
194 fn extract_symbols_single_pass(
196 &self,
197 tree: &Tree,
198 source_code: &str,
199 query: &Query,
200 language: Language,
201 ) -> Result<Vec<Symbol>, ParserError> {
202 let mut symbols = Vec::new();
203 let mut cursor = QueryCursor::new();
204 let root_node = tree.root_node();
205
206 let mut matches = cursor.matches(query, root_node, source_code.as_bytes());
207 let capture_names: Vec<&str> = query.capture_names().to_vec();
208
209 while let Some(m) = matches.next() {
210 if let Some(import_symbol) = self.process_import_match(m, source_code, &capture_names) {
212 symbols.push(import_symbol);
213 continue;
214 }
215
216 if let Some(symbol) =
218 self.process_match_single_pass(m, source_code, &capture_names, language)
219 {
220 symbols.push(symbol);
221 }
222 }
223
224 Ok(symbols)
225 }
226
227 fn process_import_match(
229 &self,
230 m: &tree_sitter::QueryMatch<'_, '_>,
231 source_code: &str,
232 capture_names: &[&str],
233 ) -> Option<Symbol> {
234 let captures = &m.captures;
235
236 let import_capture = captures.iter().find(|c| {
238 capture_names
239 .get(c.index as usize)
240 .map(|n| *n == "import")
241 .unwrap_or(false)
242 })?;
243
244 let node = import_capture.node;
245 let text = node.utf8_text(source_code.as_bytes()).ok()?;
246
247 let mut symbol = Symbol::new(text.trim(), SymbolKind::Import);
248 symbol.start_line = node.start_position().row as u32 + 1;
249 symbol.end_line = node.end_position().row as u32 + 1;
250
251 Some(symbol)
252 }
253
254 fn process_match_single_pass(
256 &self,
257 m: &tree_sitter::QueryMatch<'_, '_>,
258 source_code: &str,
259 capture_names: &[&str],
260 language: Language,
261 ) -> Option<Symbol> {
262 let captures = &m.captures;
263
264 let name_node = captures
266 .iter()
267 .find(|c| {
268 capture_names
269 .get(c.index as usize)
270 .map(|n| *n == "name")
271 .unwrap_or(false)
272 })?
273 .node;
274
275 let kind_capture = captures.iter().find(|c| {
277 capture_names
278 .get(c.index as usize)
279 .map(|n| {
280 ["function", "class", "method", "struct", "enum", "interface", "trait"]
281 .contains(n)
282 })
283 .unwrap_or(false)
284 })?;
285
286 let kind_name = capture_names.get(kind_capture.index as usize)?;
287 let mut symbol_kind = extraction::map_symbol_kind(kind_name);
288
289 let name = name_node.utf8_text(source_code.as_bytes()).ok()?;
290
291 let def_node = captures
293 .iter()
294 .max_by_key(|c| c.node.byte_range().len())
295 .map(|c| c.node)
296 .unwrap_or(name_node);
297
298 if language == Language::Kotlin && def_node.kind() == "class_declaration" {
299 let mut cursor = def_node.walk();
300 for child in def_node.children(&mut cursor) {
301 if child.kind() == "interface" {
302 symbol_kind = SymbolKind::Interface;
303 break;
304 }
305 }
306 }
307
308 let start_line = def_node.start_position().row as u32 + 1;
309 let end_line = def_node.end_position().row as u32 + 1;
310
311 let signature = extraction::extract_signature(def_node, source_code, language);
313 let docstring = extraction::extract_docstring(def_node, source_code, language);
314 let parent = if symbol_kind == SymbolKind::Method {
315 extraction::extract_parent(def_node, source_code)
316 } else {
317 None
318 };
319 let visibility = extraction::extract_visibility(def_node, source_code, language);
320 let calls = if matches!(symbol_kind, SymbolKind::Function | SymbolKind::Method) {
321 extraction::extract_calls(def_node, source_code, language)
322 } else {
323 Vec::new()
324 };
325
326 let (extends, implements) = if matches!(
328 symbol_kind,
329 SymbolKind::Class | SymbolKind::Struct | SymbolKind::Interface
330 ) {
331 extraction::extract_inheritance(def_node, source_code, language)
332 } else {
333 (None, Vec::new())
334 };
335
336 let mut symbol = Symbol::new(name, symbol_kind);
337 symbol.start_line = start_line;
338 symbol.end_line = end_line;
339 symbol.signature = signature;
340 symbol.docstring = docstring;
341 symbol.parent = parent;
342 symbol.visibility = visibility;
343 symbol.calls = calls;
344 symbol.extends = extends;
345 symbol.implements = implements;
346
347 Some(symbol)
348 }
349}
350
351impl Default for Parser {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_language_from_extension() {
377 assert_eq!(Language::from_extension("py"), Some(Language::Python));
378 assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
379 assert_eq!(Language::from_extension("ts"), Some(Language::TypeScript));
380 assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
381 assert_eq!(Language::from_extension("go"), Some(Language::Go));
382 assert_eq!(Language::from_extension("java"), Some(Language::Java));
383 assert_eq!(Language::from_extension("unknown"), None);
384 }
385
386 #[test]
387 fn test_parse_python() {
388 let mut parser = Parser::new();
389 let source = r#"
390def hello_world():
391 """This is a docstring"""
392 print("Hello, World!")
393
394class MyClass:
395 def method(self, x):
396 return x * 2
397"#;
398
399 let symbols = parser.parse(source, Language::Python).unwrap();
400 assert!(!symbols.is_empty());
401
402 let func = symbols
404 .iter()
405 .find(|s| s.name == "hello_world" && s.kind == SymbolKind::Function);
406 assert!(func.is_some());
407
408 let class = symbols
410 .iter()
411 .find(|s| s.name == "MyClass" && s.kind == SymbolKind::Class);
412 assert!(class.is_some());
413
414 let method = symbols
416 .iter()
417 .find(|s| s.name == "method" && s.kind == SymbolKind::Method);
418 assert!(method.is_some());
419 }
420
421 #[test]
422 fn test_parse_rust() {
423 let mut parser = Parser::new();
424 let source = r#"
425/// A test function
426fn test_function() -> i32 {
427 42
428}
429
430struct MyStruct {
431 field: i32,
432}
433
434enum MyEnum {
435 Variant1,
436 Variant2,
437}
438"#;
439
440 let symbols = parser.parse(source, Language::Rust).unwrap();
441 assert!(!symbols.is_empty());
442
443 let func = symbols
445 .iter()
446 .find(|s| s.name == "test_function" && s.kind == SymbolKind::Function);
447 assert!(func.is_some());
448
449 let struct_sym = symbols
451 .iter()
452 .find(|s| s.name == "MyStruct" && s.kind == SymbolKind::Struct);
453 assert!(struct_sym.is_some());
454
455 let enum_sym = symbols
457 .iter()
458 .find(|s| s.name == "MyEnum" && s.kind == SymbolKind::Enum);
459 assert!(enum_sym.is_some());
460 }
461
462 #[test]
463 fn test_parse_javascript() {
464 let mut parser = Parser::new();
465 let source = r#"
466function testFunction() {
467 return 42;
468}
469
470class TestClass {
471 testMethod() {
472 return "test";
473 }
474}
475
476const arrowFunc = () => {
477 console.log("arrow");
478};
479"#;
480
481 let symbols = parser.parse(source, Language::JavaScript).unwrap();
482 assert!(!symbols.is_empty());
483
484 let func = symbols
486 .iter()
487 .find(|s| s.name == "testFunction" && s.kind == SymbolKind::Function);
488 assert!(func.is_some());
489
490 let class = symbols
492 .iter()
493 .find(|s| s.name == "TestClass" && s.kind == SymbolKind::Class);
494 assert!(class.is_some());
495 }
496
497 #[test]
498 fn test_parse_typescript() {
499 let mut parser = Parser::new();
500 let source = r#"
501interface TestInterface {
502 method(): void;
503}
504
505enum TestEnum {
506 Value1,
507 Value2
508}
509
510class TestClass implements TestInterface {
511 method(): void {
512 console.log("test");
513 }
514}
515"#;
516
517 let symbols = parser.parse(source, Language::TypeScript).unwrap();
518 assert!(!symbols.is_empty());
519
520 let interface = symbols
522 .iter()
523 .find(|s| s.name == "TestInterface" && s.kind == SymbolKind::Interface);
524 assert!(interface.is_some());
525
526 let enum_sym = symbols
528 .iter()
529 .find(|s| s.name == "TestEnum" && s.kind == SymbolKind::Enum);
530 assert!(enum_sym.is_some());
531 }
532
533 #[test]
534 fn test_symbol_metadata() {
535 let mut parser = Parser::new();
536 let source = r#"
537def test_func(x, y):
538 """A test function with params"""
539 return x + y
540"#;
541
542 let symbols = parser.parse(source, Language::Python).unwrap();
543 let func = symbols
544 .iter()
545 .find(|s| s.name == "test_func")
546 .expect("Function not found");
547
548 assert!(func.start_line > 0);
549 assert!(func.end_line >= func.start_line);
550 assert!(func.signature.is_some());
551 assert!(func.docstring.is_some());
552 }
553}