1use super::{node_text, LanguageExtractor};
7use crate::ast::{
8 ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13pub struct JavaExtractor;
15
16impl LanguageExtractor for JavaExtractor {
17 fn language(&self) -> Language {
18 tree_sitter_java::LANGUAGE.into()
19 }
20
21 fn name(&self) -> &'static str {
22 "java"
23 }
24
25 fn extensions(&self) -> &'static [&'static str] {
26 &["java"]
27 }
28
29 fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30 let mut symbols = Vec::new();
31 let root = tree.root_node();
32 self.extract_symbols_recursive(&root, source, &mut symbols, None);
33 Ok(symbols)
34 }
35
36 fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37 let mut imports = Vec::new();
38 let root = tree.root_node();
39 self.extract_imports_recursive(&root, source, &mut imports);
40 Ok(imports)
41 }
42
43 fn extract_calls(
44 &self,
45 tree: &Tree,
46 source: &str,
47 current_function: Option<&str>,
48 ) -> Result<Vec<FunctionCall>> {
49 let mut calls = Vec::new();
50 let root = tree.root_node();
51 self.extract_calls_recursive(&root, source, &mut calls, current_function);
52 Ok(calls)
53 }
54
55 fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56 if let Some(prev) = node.prev_sibling() {
58 if prev.kind() == "block_comment" || prev.kind() == "line_comment" {
59 let comment = node_text(&prev, source);
60 if comment.starts_with("/**") {
61 return Some(Self::clean_javadoc(comment));
62 }
63 }
64 }
65 None
66 }
67}
68
69impl JavaExtractor {
70 fn extract_symbols_recursive(
71 &self,
72 node: &Node,
73 source: &str,
74 symbols: &mut Vec<ExtractedSymbol>,
75 parent: Option<&str>,
76 ) {
77 match node.kind() {
78 "class_declaration" => {
79 if let Some(sym) = self.extract_class(node, source, parent) {
80 let class_name = sym.name.clone();
81 symbols.push(sym);
82
83 if let Some(body) = node.child_by_field_name("body") {
85 self.extract_class_members(&body, source, symbols, Some(&class_name));
86 }
87 return;
88 }
89 }
90
91 "interface_declaration" => {
92 if let Some(sym) = self.extract_interface(node, source, parent) {
93 let interface_name = sym.name.clone();
94 symbols.push(sym);
95
96 if let Some(body) = node.child_by_field_name("body") {
98 self.extract_interface_members(
99 &body,
100 source,
101 symbols,
102 Some(&interface_name),
103 );
104 }
105 return;
106 }
107 }
108
109 "enum_declaration" => {
110 if let Some(sym) = self.extract_enum(node, source, parent) {
111 let enum_name = sym.name.clone();
112 symbols.push(sym);
113
114 if let Some(body) = node.child_by_field_name("body") {
116 self.extract_enum_constants(&body, source, symbols, Some(&enum_name));
117 }
118 return;
119 }
120 }
121
122 "method_declaration" | "constructor_declaration" => {
123 if let Some(sym) = self.extract_method(node, source, parent) {
124 symbols.push(sym);
125 }
126 }
127
128 "field_declaration" => {
129 self.extract_fields(node, source, symbols, parent);
130 }
131
132 _ => {}
133 }
134
135 let mut cursor = node.walk();
137 for child in node.children(&mut cursor) {
138 self.extract_symbols_recursive(&child, source, symbols, parent);
139 }
140 }
141
142 fn extract_class(
143 &self,
144 node: &Node,
145 source: &str,
146 parent: Option<&str>,
147 ) -> Option<ExtractedSymbol> {
148 let name_node = node.child_by_field_name("name")?;
149 let name = node_text(&name_node, source).to_string();
150
151 let mut sym = ExtractedSymbol::new(
152 name,
153 SymbolKind::Class,
154 node.start_position().row + 1,
155 node.end_position().row + 1,
156 )
157 .with_columns(node.start_position().column, node.end_position().column);
158
159 sym.visibility = self.extract_visibility(node, source);
161 if matches!(sym.visibility, Visibility::Public) {
162 sym = sym.exported();
163 }
164
165 let text = node_text(node, source);
167 if text.contains("static") {
168 sym = sym.static_fn();
169 }
170
171 if let Some(type_params) = node.child_by_field_name("type_parameters") {
173 self.extract_generics(&type_params, source, &mut sym);
174 }
175
176 sym.doc_comment = self.extract_doc_comment(node, source);
177
178 if let Some(p) = parent {
179 sym = sym.with_parent(p);
180 }
181
182 Some(sym)
183 }
184
185 fn extract_interface(
186 &self,
187 node: &Node,
188 source: &str,
189 parent: Option<&str>,
190 ) -> Option<ExtractedSymbol> {
191 let name_node = node.child_by_field_name("name")?;
192 let name = node_text(&name_node, source).to_string();
193
194 let mut sym = ExtractedSymbol::new(
195 name,
196 SymbolKind::Interface,
197 node.start_position().row + 1,
198 node.end_position().row + 1,
199 );
200
201 sym.visibility = self.extract_visibility(node, source);
202 if matches!(sym.visibility, Visibility::Public) {
203 sym = sym.exported();
204 }
205
206 if let Some(type_params) = node.child_by_field_name("type_parameters") {
207 self.extract_generics(&type_params, source, &mut sym);
208 }
209
210 sym.doc_comment = self.extract_doc_comment(node, source);
211
212 if let Some(p) = parent {
213 sym = sym.with_parent(p);
214 }
215
216 Some(sym)
217 }
218
219 fn extract_enum(
220 &self,
221 node: &Node,
222 source: &str,
223 parent: Option<&str>,
224 ) -> Option<ExtractedSymbol> {
225 let name_node = node.child_by_field_name("name")?;
226 let name = node_text(&name_node, source).to_string();
227
228 let mut sym = ExtractedSymbol::new(
229 name,
230 SymbolKind::Enum,
231 node.start_position().row + 1,
232 node.end_position().row + 1,
233 );
234
235 sym.visibility = self.extract_visibility(node, source);
236 if matches!(sym.visibility, Visibility::Public) {
237 sym = sym.exported();
238 }
239
240 sym.doc_comment = self.extract_doc_comment(node, source);
241
242 if let Some(p) = parent {
243 sym = sym.with_parent(p);
244 }
245
246 Some(sym)
247 }
248
249 fn extract_method(
250 &self,
251 node: &Node,
252 source: &str,
253 parent: Option<&str>,
254 ) -> Option<ExtractedSymbol> {
255 let is_constructor = node.kind() == "constructor_declaration";
256
257 let name = if is_constructor {
258 parent.map(String::from)?
260 } else {
261 let name_node = node.child_by_field_name("name")?;
262 node_text(&name_node, source).to_string()
263 };
264
265 let mut sym = ExtractedSymbol::new(
266 name,
267 SymbolKind::Method,
268 node.start_position().row + 1,
269 node.end_position().row + 1,
270 );
271
272 sym.visibility = self.extract_visibility(node, source);
273 if matches!(sym.visibility, Visibility::Public) {
274 sym = sym.exported();
275 }
276
277 let text = node_text(node, source);
279 if text.contains("static ") {
280 sym = sym.static_fn();
281 }
282
283 if let Some(type_params) = node.child_by_field_name("type_parameters") {
285 self.extract_generics(&type_params, source, &mut sym);
286 }
287
288 if let Some(params) = node.child_by_field_name("parameters") {
290 self.extract_parameters(¶ms, source, &mut sym);
291 }
292
293 if !is_constructor {
295 if let Some(ret_type) = node.child_by_field_name("type") {
296 sym.return_type = Some(node_text(&ret_type, source).to_string());
297 }
298 }
299
300 sym.doc_comment = self.extract_doc_comment(node, source);
301
302 if let Some(p) = parent {
303 sym = sym.with_parent(p);
304 }
305
306 sym.signature = Some(self.build_method_signature(node, source, is_constructor));
307
308 Some(sym)
309 }
310
311 fn extract_fields(
312 &self,
313 node: &Node,
314 source: &str,
315 symbols: &mut Vec<ExtractedSymbol>,
316 parent: Option<&str>,
317 ) {
318 let visibility = self.extract_visibility(node, source);
319 let is_static = node_text(node, source).contains("static ");
320
321 let type_node = node.child_by_field_name("type");
322 let type_info = type_node.map(|n| node_text(&n, source).to_string());
323
324 let mut cursor = node.walk();
325 for child in node.children(&mut cursor) {
326 if child.kind() == "variable_declarator" {
327 if let Some(name_node) = child.child_by_field_name("name") {
328 let name = node_text(&name_node, source).to_string();
329
330 let mut sym = ExtractedSymbol::new(
331 name,
332 SymbolKind::Field,
333 child.start_position().row + 1,
334 child.end_position().row + 1,
335 );
336
337 sym.visibility = visibility;
338 sym.type_info = type_info.clone();
339
340 if is_static {
341 sym = sym.static_fn();
342 }
343
344 if let Some(p) = parent {
345 sym = sym.with_parent(p);
346 }
347
348 symbols.push(sym);
349 }
350 }
351 }
352 }
353
354 fn extract_class_members(
355 &self,
356 body: &Node,
357 source: &str,
358 symbols: &mut Vec<ExtractedSymbol>,
359 class_name: Option<&str>,
360 ) {
361 let mut cursor = body.walk();
362 for child in body.children(&mut cursor) {
363 match child.kind() {
364 "method_declaration" | "constructor_declaration" => {
365 if let Some(sym) = self.extract_method(&child, source, class_name) {
366 symbols.push(sym);
367 }
368 }
369 "field_declaration" => {
370 self.extract_fields(&child, source, symbols, class_name);
371 }
372 "class_declaration" => {
373 if let Some(sym) = self.extract_class(&child, source, class_name) {
375 let nested_name = sym.name.clone();
376 symbols.push(sym);
377
378 if let Some(nested_body) = child.child_by_field_name("body") {
379 self.extract_class_members(
380 &nested_body,
381 source,
382 symbols,
383 Some(&nested_name),
384 );
385 }
386 }
387 }
388 _ => {}
389 }
390 }
391 }
392
393 fn extract_interface_members(
394 &self,
395 body: &Node,
396 source: &str,
397 symbols: &mut Vec<ExtractedSymbol>,
398 interface_name: Option<&str>,
399 ) {
400 let mut cursor = body.walk();
401 for child in body.children(&mut cursor) {
402 if child.kind() == "method_declaration" {
403 if let Some(sym) = self.extract_method(&child, source, interface_name) {
404 symbols.push(sym);
405 }
406 } else if child.kind() == "constant_declaration" {
407 self.extract_fields(&child, source, symbols, interface_name);
408 }
409 }
410 }
411
412 fn extract_enum_constants(
413 &self,
414 body: &Node,
415 source: &str,
416 symbols: &mut Vec<ExtractedSymbol>,
417 enum_name: Option<&str>,
418 ) {
419 let mut cursor = body.walk();
420 for child in body.children(&mut cursor) {
421 if child.kind() == "enum_constant" {
422 if let Some(name_node) = child.child_by_field_name("name") {
423 let name = node_text(&name_node, source).to_string();
424
425 let mut sym = ExtractedSymbol::new(
426 name,
427 SymbolKind::EnumVariant,
428 child.start_position().row + 1,
429 child.end_position().row + 1,
430 );
431
432 sym.visibility = Visibility::Public;
433 sym = sym.exported();
434
435 if let Some(p) = enum_name {
436 sym = sym.with_parent(p);
437 }
438
439 symbols.push(sym);
440 }
441 }
442 }
443 }
444
445 fn extract_visibility(&self, node: &Node, source: &str) -> Visibility {
446 let mut cursor = node.walk();
447 for child in node.children(&mut cursor) {
448 if child.kind() == "modifiers" {
449 let text = node_text(&child, source);
450 if text.contains("public") {
451 return Visibility::Public;
452 } else if text.contains("private") {
453 return Visibility::Private;
454 } else if text.contains("protected") {
455 return Visibility::Protected;
456 }
457 return Visibility::Internal; }
459 }
460 Visibility::Internal }
462
463 fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
464 let mut cursor = params.walk();
465 for child in params.children(&mut cursor) {
466 if child.kind() == "formal_parameter" || child.kind() == "spread_parameter" {
467 let is_rest = child.kind() == "spread_parameter";
468
469 let name = child
470 .child_by_field_name("name")
471 .map(|n| node_text(&n, source).to_string())
472 .unwrap_or_default();
473
474 let type_info = child
475 .child_by_field_name("type")
476 .map(|n| node_text(&n, source).to_string());
477
478 sym.add_parameter(Parameter {
479 name,
480 type_info,
481 default_value: None,
482 is_rest,
483 is_optional: false,
484 });
485 }
486 }
487 }
488
489 fn extract_generics(&self, type_params: &Node, source: &str, sym: &mut ExtractedSymbol) {
490 let mut cursor = type_params.walk();
491 for child in type_params.children(&mut cursor) {
492 if child.kind() == "type_parameter" {
493 if let Some(name) = child.child_by_field_name("name") {
494 sym.add_generic(node_text(&name, source));
495 } else {
496 let mut inner_cursor = child.walk();
498 for inner in child.children(&mut inner_cursor) {
499 if inner.kind() == "type_identifier" || inner.kind() == "identifier" {
500 sym.add_generic(node_text(&inner, source));
501 break;
502 }
503 }
504 }
505 }
506 }
507 }
508
509 fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
510 if node.kind() == "import_declaration" {
511 if let Some(import) = self.parse_import(node, source) {
512 imports.push(import);
513 }
514 }
515
516 let mut cursor = node.walk();
517 for child in node.children(&mut cursor) {
518 self.extract_imports_recursive(&child, source, imports);
519 }
520 }
521
522 fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
523 let text = node_text(node, source);
524
525 let is_wildcard = text.contains(".*");
527 let _is_static = text.contains("static ");
528
529 let path = text
530 .trim_start_matches("import ")
531 .trim_start_matches("static ")
532 .trim_end_matches(';')
533 .trim()
534 .trim_end_matches(".*")
535 .to_string();
536
537 let (source_path, name) = if is_wildcard {
538 (path.clone(), "*".to_string())
539 } else {
540 let parts: Vec<&str> = path.rsplitn(2, '.').collect();
542 if parts.len() == 2 {
543 (parts[1].to_string(), parts[0].to_string())
544 } else {
545 (String::new(), path)
546 }
547 };
548
549 Some(Import {
550 source: source_path,
551 names: vec![ImportedName { name, alias: None }],
552 is_default: false,
553 is_namespace: is_wildcard,
554 line: node.start_position().row + 1,
555 })
556 }
557
558 fn extract_calls_recursive(
559 &self,
560 node: &Node,
561 source: &str,
562 calls: &mut Vec<FunctionCall>,
563 current_function: Option<&str>,
564 ) {
565 if node.kind() == "method_invocation" {
566 if let Some(call) = self.parse_call(node, source, current_function) {
567 calls.push(call);
568 }
569 }
570
571 let func_name = match node.kind() {
572 "method_declaration" | "constructor_declaration" => node
573 .child_by_field_name("name")
574 .map(|n| node_text(&n, source)),
575 _ => None,
576 };
577
578 let current = func_name
579 .map(String::from)
580 .or_else(|| current_function.map(String::from));
581
582 let mut cursor = node.walk();
583 for child in node.children(&mut cursor) {
584 self.extract_calls_recursive(&child, source, calls, current.as_deref());
585 }
586 }
587
588 fn parse_call(
589 &self,
590 node: &Node,
591 source: &str,
592 current_function: Option<&str>,
593 ) -> Option<FunctionCall> {
594 let name = node
595 .child_by_field_name("name")
596 .map(|n| node_text(&n, source).to_string())?;
597
598 let object = node
599 .child_by_field_name("object")
600 .map(|n| node_text(&n, source).to_string());
601
602 Some(FunctionCall {
603 caller: current_function.unwrap_or("<class>").to_string(),
604 callee: name,
605 line: node.start_position().row + 1,
606 is_method: object.is_some(),
607 receiver: object,
608 })
609 }
610
611 fn build_method_signature(&self, node: &Node, source: &str, is_constructor: bool) -> String {
612 let modifiers = node
613 .children(&mut node.walk())
614 .find(|c| c.kind() == "modifiers")
615 .map(|n| format!("{} ", node_text(&n, source)))
616 .unwrap_or_default();
617
618 let return_type = if is_constructor {
619 String::new()
620 } else {
621 node.child_by_field_name("type")
622 .map(|n| format!("{} ", node_text(&n, source)))
623 .unwrap_or_default()
624 };
625
626 let name = node
627 .child_by_field_name("name")
628 .map(|n| node_text(&n, source))
629 .unwrap_or("unknown");
630
631 let params = node
632 .child_by_field_name("parameters")
633 .map(|n| node_text(&n, source))
634 .unwrap_or("()");
635
636 format!("{}{}{}{}", modifiers, return_type, name, params)
637 }
638
639 fn clean_javadoc(comment: &str) -> String {
640 comment
641 .trim_start_matches("/**")
642 .trim_end_matches("*/")
643 .lines()
644 .map(|line| line.trim().trim_start_matches('*').trim())
645 .filter(|line| !line.is_empty() && !line.starts_with('@'))
646 .collect::<Vec<_>>()
647 .join("\n")
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654
655 fn parse_java(source: &str) -> (Tree, String) {
656 let mut parser = tree_sitter::Parser::new();
657 parser
658 .set_language(&tree_sitter_java::LANGUAGE.into())
659 .unwrap();
660 let tree = parser.parse(source, None).unwrap();
661 (tree, source.to_string())
662 }
663
664 #[test]
665 fn test_extract_class() {
666 let source = r#"
667public class UserService {
668 private String name;
669
670 public UserService(String name) {
671 this.name = name;
672 }
673
674 public String greet() {
675 return "Hello, " + name + "!";
676 }
677}
678"#;
679 let (tree, src) = parse_java(source);
680 let extractor = JavaExtractor;
681 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
682
683 assert!(symbols
684 .iter()
685 .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
686 assert!(symbols
687 .iter()
688 .any(|s| s.name == "name" && s.kind == SymbolKind::Field));
689 assert!(symbols
690 .iter()
691 .any(|s| s.name == "UserService" && s.kind == SymbolKind::Method)); assert!(symbols
693 .iter()
694 .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
695 }
696
697 #[test]
698 fn test_extract_interface() {
699 let source = r#"
700public interface Greeter {
701 String greet();
702 String farewell();
703}
704"#;
705 let (tree, src) = parse_java(source);
706 let extractor = JavaExtractor;
707 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
708
709 assert!(symbols
710 .iter()
711 .any(|s| s.name == "Greeter" && s.kind == SymbolKind::Interface));
712 assert!(symbols
713 .iter()
714 .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
715 }
716
717 #[test]
718 fn test_extract_enum() {
719 let source = r#"
720public enum Status {
721 ACTIVE,
722 INACTIVE,
723 PENDING
724}
725"#;
726 let (tree, src) = parse_java(source);
727 let extractor = JavaExtractor;
728 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
729
730 assert!(symbols
731 .iter()
732 .any(|s| s.name == "Status" && s.kind == SymbolKind::Enum));
733 assert!(symbols
734 .iter()
735 .any(|s| s.name == "ACTIVE" && s.kind == SymbolKind::EnumVariant));
736 }
737
738 #[test]
739 fn test_extract_generics() {
740 let source = r#"
741public class Container<T> {
742 private T value;
743
744 public T getValue() {
745 return value;
746 }
747}
748"#;
749 let (tree, src) = parse_java(source);
750 let extractor = JavaExtractor;
751 let symbols = extractor.extract_symbols(&tree, &src).unwrap();
752
753 let class = symbols.iter().find(|s| s.name == "Container").unwrap();
754 assert!(class.generics.contains(&"T".to_string()));
755 }
756}