acp/ast/languages/
java.rs

1//! @acp:module "Java Extractor"
2//! @acp:summary "Symbol extraction for Java source files"
3//! @acp:domain cli
4//! @acp:layer parsing
5
6use super::{node_text, LanguageExtractor};
7use crate::ast::{
8    ExtractedSymbol, FunctionCall, Import, ImportedName, Parameter, SymbolKind, Visibility,
9};
10use crate::error::Result;
11use tree_sitter::{Language, Node, Tree};
12
13/// Java language extractor
14pub struct JavaExtractor;
15
16impl LanguageExtractor for JavaExtractor {
17    fn language(&self) -> Language {
18        tree_sitter_java::LANGUAGE.into()
19    }
20
21    fn name(&self) -> &'static str {
22        "java"
23    }
24
25    fn extensions(&self) -> &'static [&'static str] {
26        &["java"]
27    }
28
29    fn extract_symbols(&self, tree: &Tree, source: &str) -> Result<Vec<ExtractedSymbol>> {
30        let mut symbols = Vec::new();
31        let root = tree.root_node();
32        self.extract_symbols_recursive(&root, source, &mut symbols, None);
33        Ok(symbols)
34    }
35
36    fn extract_imports(&self, tree: &Tree, source: &str) -> Result<Vec<Import>> {
37        let mut imports = Vec::new();
38        let root = tree.root_node();
39        self.extract_imports_recursive(&root, source, &mut imports);
40        Ok(imports)
41    }
42
43    fn extract_calls(
44        &self,
45        tree: &Tree,
46        source: &str,
47        current_function: Option<&str>,
48    ) -> Result<Vec<FunctionCall>> {
49        let mut calls = Vec::new();
50        let root = tree.root_node();
51        self.extract_calls_recursive(&root, source, &mut calls, current_function);
52        Ok(calls)
53    }
54
55    fn extract_doc_comment(&self, node: &Node, source: &str) -> Option<String> {
56        // Look for Javadoc comments
57        if let Some(prev) = node.prev_sibling() {
58            if prev.kind() == "block_comment" || prev.kind() == "line_comment" {
59                let comment = node_text(&prev, source);
60                if comment.starts_with("/**") {
61                    return Some(Self::clean_javadoc(comment));
62                }
63            }
64        }
65        None
66    }
67}
68
69impl JavaExtractor {
70    fn extract_symbols_recursive(
71        &self,
72        node: &Node,
73        source: &str,
74        symbols: &mut Vec<ExtractedSymbol>,
75        parent: Option<&str>,
76    ) {
77        match node.kind() {
78            "class_declaration" => {
79                if let Some(sym) = self.extract_class(node, source, parent) {
80                    let class_name = sym.name.clone();
81                    symbols.push(sym);
82
83                    // Extract class body
84                    if let Some(body) = node.child_by_field_name("body") {
85                        self.extract_class_members(&body, source, symbols, Some(&class_name));
86                    }
87                    return;
88                }
89            }
90
91            "interface_declaration" => {
92                if let Some(sym) = self.extract_interface(node, source, parent) {
93                    let interface_name = sym.name.clone();
94                    symbols.push(sym);
95
96                    // Extract interface body
97                    if let Some(body) = node.child_by_field_name("body") {
98                        self.extract_interface_members(
99                            &body,
100                            source,
101                            symbols,
102                            Some(&interface_name),
103                        );
104                    }
105                    return;
106                }
107            }
108
109            "enum_declaration" => {
110                if let Some(sym) = self.extract_enum(node, source, parent) {
111                    let enum_name = sym.name.clone();
112                    symbols.push(sym);
113
114                    // Extract enum body
115                    if let Some(body) = node.child_by_field_name("body") {
116                        self.extract_enum_constants(&body, source, symbols, Some(&enum_name));
117                    }
118                    return;
119                }
120            }
121
122            "method_declaration" | "constructor_declaration" => {
123                if let Some(sym) = self.extract_method(node, source, parent) {
124                    symbols.push(sym);
125                }
126            }
127
128            "field_declaration" => {
129                self.extract_fields(node, source, symbols, parent);
130            }
131
132            _ => {}
133        }
134
135        // Recurse into children
136        let mut cursor = node.walk();
137        for child in node.children(&mut cursor) {
138            self.extract_symbols_recursive(&child, source, symbols, parent);
139        }
140    }
141
142    fn extract_class(
143        &self,
144        node: &Node,
145        source: &str,
146        parent: Option<&str>,
147    ) -> Option<ExtractedSymbol> {
148        let name_node = node.child_by_field_name("name")?;
149        let name = node_text(&name_node, source).to_string();
150
151        let mut sym = ExtractedSymbol::new(
152            name,
153            SymbolKind::Class,
154            node.start_position().row + 1,
155            node.end_position().row + 1,
156        )
157        .with_columns(node.start_position().column, node.end_position().column);
158
159        // Extract modifiers for visibility
160        sym.visibility = self.extract_visibility(node, source);
161        if matches!(sym.visibility, Visibility::Public) {
162            sym = sym.exported();
163        }
164
165        // Check for static
166        let text = node_text(node, source);
167        if text.contains("static") {
168            sym = sym.static_fn();
169        }
170
171        // Extract generics
172        if let Some(type_params) = node.child_by_field_name("type_parameters") {
173            self.extract_generics(&type_params, source, &mut sym);
174        }
175
176        sym.doc_comment = self.extract_doc_comment(node, source);
177
178        if let Some(p) = parent {
179            sym = sym.with_parent(p);
180        }
181
182        Some(sym)
183    }
184
185    fn extract_interface(
186        &self,
187        node: &Node,
188        source: &str,
189        parent: Option<&str>,
190    ) -> Option<ExtractedSymbol> {
191        let name_node = node.child_by_field_name("name")?;
192        let name = node_text(&name_node, source).to_string();
193
194        let mut sym = ExtractedSymbol::new(
195            name,
196            SymbolKind::Interface,
197            node.start_position().row + 1,
198            node.end_position().row + 1,
199        );
200
201        sym.visibility = self.extract_visibility(node, source);
202        if matches!(sym.visibility, Visibility::Public) {
203            sym = sym.exported();
204        }
205
206        if let Some(type_params) = node.child_by_field_name("type_parameters") {
207            self.extract_generics(&type_params, source, &mut sym);
208        }
209
210        sym.doc_comment = self.extract_doc_comment(node, source);
211
212        if let Some(p) = parent {
213            sym = sym.with_parent(p);
214        }
215
216        Some(sym)
217    }
218
219    fn extract_enum(
220        &self,
221        node: &Node,
222        source: &str,
223        parent: Option<&str>,
224    ) -> Option<ExtractedSymbol> {
225        let name_node = node.child_by_field_name("name")?;
226        let name = node_text(&name_node, source).to_string();
227
228        let mut sym = ExtractedSymbol::new(
229            name,
230            SymbolKind::Enum,
231            node.start_position().row + 1,
232            node.end_position().row + 1,
233        );
234
235        sym.visibility = self.extract_visibility(node, source);
236        if matches!(sym.visibility, Visibility::Public) {
237            sym = sym.exported();
238        }
239
240        sym.doc_comment = self.extract_doc_comment(node, source);
241
242        if let Some(p) = parent {
243            sym = sym.with_parent(p);
244        }
245
246        Some(sym)
247    }
248
249    fn extract_method(
250        &self,
251        node: &Node,
252        source: &str,
253        parent: Option<&str>,
254    ) -> Option<ExtractedSymbol> {
255        let is_constructor = node.kind() == "constructor_declaration";
256
257        let name = if is_constructor {
258            // Constructor name is same as class name
259            parent.map(String::from)?
260        } else {
261            let name_node = node.child_by_field_name("name")?;
262            node_text(&name_node, source).to_string()
263        };
264
265        let mut sym = ExtractedSymbol::new(
266            name,
267            SymbolKind::Method,
268            node.start_position().row + 1,
269            node.end_position().row + 1,
270        );
271
272        sym.visibility = self.extract_visibility(node, source);
273        if matches!(sym.visibility, Visibility::Public) {
274            sym = sym.exported();
275        }
276
277        // Check for static
278        let text = node_text(node, source);
279        if text.contains("static ") {
280            sym = sym.static_fn();
281        }
282
283        // Extract generics
284        if let Some(type_params) = node.child_by_field_name("type_parameters") {
285            self.extract_generics(&type_params, source, &mut sym);
286        }
287
288        // Extract parameters
289        if let Some(params) = node.child_by_field_name("parameters") {
290            self.extract_parameters(&params, source, &mut sym);
291        }
292
293        // Extract return type (not for constructors)
294        if !is_constructor {
295            if let Some(ret_type) = node.child_by_field_name("type") {
296                sym.return_type = Some(node_text(&ret_type, source).to_string());
297            }
298        }
299
300        sym.doc_comment = self.extract_doc_comment(node, source);
301
302        if let Some(p) = parent {
303            sym = sym.with_parent(p);
304        }
305
306        sym.signature = Some(self.build_method_signature(node, source, is_constructor));
307
308        Some(sym)
309    }
310
311    fn extract_fields(
312        &self,
313        node: &Node,
314        source: &str,
315        symbols: &mut Vec<ExtractedSymbol>,
316        parent: Option<&str>,
317    ) {
318        let visibility = self.extract_visibility(node, source);
319        let is_static = node_text(node, source).contains("static ");
320
321        let type_node = node.child_by_field_name("type");
322        let type_info = type_node.map(|n| node_text(&n, source).to_string());
323
324        let mut cursor = node.walk();
325        for child in node.children(&mut cursor) {
326            if child.kind() == "variable_declarator" {
327                if let Some(name_node) = child.child_by_field_name("name") {
328                    let name = node_text(&name_node, source).to_string();
329
330                    let mut sym = ExtractedSymbol::new(
331                        name,
332                        SymbolKind::Field,
333                        child.start_position().row + 1,
334                        child.end_position().row + 1,
335                    );
336
337                    sym.visibility = visibility;
338                    sym.type_info = type_info.clone();
339
340                    if is_static {
341                        sym = sym.static_fn();
342                    }
343
344                    if let Some(p) = parent {
345                        sym = sym.with_parent(p);
346                    }
347
348                    symbols.push(sym);
349                }
350            }
351        }
352    }
353
354    fn extract_class_members(
355        &self,
356        body: &Node,
357        source: &str,
358        symbols: &mut Vec<ExtractedSymbol>,
359        class_name: Option<&str>,
360    ) {
361        let mut cursor = body.walk();
362        for child in body.children(&mut cursor) {
363            match child.kind() {
364                "method_declaration" | "constructor_declaration" => {
365                    if let Some(sym) = self.extract_method(&child, source, class_name) {
366                        symbols.push(sym);
367                    }
368                }
369                "field_declaration" => {
370                    self.extract_fields(&child, source, symbols, class_name);
371                }
372                "class_declaration" => {
373                    // Nested class
374                    if let Some(sym) = self.extract_class(&child, source, class_name) {
375                        let nested_name = sym.name.clone();
376                        symbols.push(sym);
377
378                        if let Some(nested_body) = child.child_by_field_name("body") {
379                            self.extract_class_members(
380                                &nested_body,
381                                source,
382                                symbols,
383                                Some(&nested_name),
384                            );
385                        }
386                    }
387                }
388                _ => {}
389            }
390        }
391    }
392
393    fn extract_interface_members(
394        &self,
395        body: &Node,
396        source: &str,
397        symbols: &mut Vec<ExtractedSymbol>,
398        interface_name: Option<&str>,
399    ) {
400        let mut cursor = body.walk();
401        for child in body.children(&mut cursor) {
402            if child.kind() == "method_declaration" {
403                if let Some(sym) = self.extract_method(&child, source, interface_name) {
404                    symbols.push(sym);
405                }
406            } else if child.kind() == "constant_declaration" {
407                self.extract_fields(&child, source, symbols, interface_name);
408            }
409        }
410    }
411
412    fn extract_enum_constants(
413        &self,
414        body: &Node,
415        source: &str,
416        symbols: &mut Vec<ExtractedSymbol>,
417        enum_name: Option<&str>,
418    ) {
419        let mut cursor = body.walk();
420        for child in body.children(&mut cursor) {
421            if child.kind() == "enum_constant" {
422                if let Some(name_node) = child.child_by_field_name("name") {
423                    let name = node_text(&name_node, source).to_string();
424
425                    let mut sym = ExtractedSymbol::new(
426                        name,
427                        SymbolKind::EnumVariant,
428                        child.start_position().row + 1,
429                        child.end_position().row + 1,
430                    );
431
432                    sym.visibility = Visibility::Public;
433                    sym = sym.exported();
434
435                    if let Some(p) = enum_name {
436                        sym = sym.with_parent(p);
437                    }
438
439                    symbols.push(sym);
440                }
441            }
442        }
443    }
444
445    fn extract_visibility(&self, node: &Node, source: &str) -> Visibility {
446        let mut cursor = node.walk();
447        for child in node.children(&mut cursor) {
448            if child.kind() == "modifiers" {
449                let text = node_text(&child, source);
450                if text.contains("public") {
451                    return Visibility::Public;
452                } else if text.contains("private") {
453                    return Visibility::Private;
454                } else if text.contains("protected") {
455                    return Visibility::Protected;
456                }
457                return Visibility::Internal; // package-private
458            }
459        }
460        Visibility::Internal // default is package-private
461    }
462
463    fn extract_parameters(&self, params: &Node, source: &str, sym: &mut ExtractedSymbol) {
464        let mut cursor = params.walk();
465        for child in params.children(&mut cursor) {
466            if child.kind() == "formal_parameter" || child.kind() == "spread_parameter" {
467                let is_rest = child.kind() == "spread_parameter";
468
469                let name = child
470                    .child_by_field_name("name")
471                    .map(|n| node_text(&n, source).to_string())
472                    .unwrap_or_default();
473
474                let type_info = child
475                    .child_by_field_name("type")
476                    .map(|n| node_text(&n, source).to_string());
477
478                sym.add_parameter(Parameter {
479                    name,
480                    type_info,
481                    default_value: None,
482                    is_rest,
483                    is_optional: false,
484                });
485            }
486        }
487    }
488
489    fn extract_generics(&self, type_params: &Node, source: &str, sym: &mut ExtractedSymbol) {
490        let mut cursor = type_params.walk();
491        for child in type_params.children(&mut cursor) {
492            if child.kind() == "type_parameter" {
493                if let Some(name) = child.child_by_field_name("name") {
494                    sym.add_generic(node_text(&name, source));
495                } else {
496                    // Fallback to first identifier
497                    let mut inner_cursor = child.walk();
498                    for inner in child.children(&mut inner_cursor) {
499                        if inner.kind() == "type_identifier" || inner.kind() == "identifier" {
500                            sym.add_generic(node_text(&inner, source));
501                            break;
502                        }
503                    }
504                }
505            }
506        }
507    }
508
509    fn extract_imports_recursive(&self, node: &Node, source: &str, imports: &mut Vec<Import>) {
510        if node.kind() == "import_declaration" {
511            if let Some(import) = self.parse_import(node, source) {
512                imports.push(import);
513            }
514        }
515
516        let mut cursor = node.walk();
517        for child in node.children(&mut cursor) {
518            self.extract_imports_recursive(&child, source, imports);
519        }
520    }
521
522    fn parse_import(&self, node: &Node, source: &str) -> Option<Import> {
523        let text = node_text(node, source);
524
525        // Check for wildcard import
526        let is_wildcard = text.contains(".*");
527        let _is_static = text.contains("static ");
528
529        let path = text
530            .trim_start_matches("import ")
531            .trim_start_matches("static ")
532            .trim_end_matches(';')
533            .trim()
534            .trim_end_matches(".*")
535            .to_string();
536
537        let (source_path, name) = if is_wildcard {
538            (path.clone(), "*".to_string())
539        } else {
540            // Split package.Class into package and class
541            let parts: Vec<&str> = path.rsplitn(2, '.').collect();
542            if parts.len() == 2 {
543                (parts[1].to_string(), parts[0].to_string())
544            } else {
545                (String::new(), path)
546            }
547        };
548
549        Some(Import {
550            source: source_path,
551            names: vec![ImportedName { name, alias: None }],
552            is_default: false,
553            is_namespace: is_wildcard,
554            line: node.start_position().row + 1,
555        })
556    }
557
558    fn extract_calls_recursive(
559        &self,
560        node: &Node,
561        source: &str,
562        calls: &mut Vec<FunctionCall>,
563        current_function: Option<&str>,
564    ) {
565        if node.kind() == "method_invocation" {
566            if let Some(call) = self.parse_call(node, source, current_function) {
567                calls.push(call);
568            }
569        }
570
571        let func_name = match node.kind() {
572            "method_declaration" | "constructor_declaration" => node
573                .child_by_field_name("name")
574                .map(|n| node_text(&n, source)),
575            _ => None,
576        };
577
578        let current = func_name
579            .map(String::from)
580            .or_else(|| current_function.map(String::from));
581
582        let mut cursor = node.walk();
583        for child in node.children(&mut cursor) {
584            self.extract_calls_recursive(&child, source, calls, current.as_deref());
585        }
586    }
587
588    fn parse_call(
589        &self,
590        node: &Node,
591        source: &str,
592        current_function: Option<&str>,
593    ) -> Option<FunctionCall> {
594        let name = node
595            .child_by_field_name("name")
596            .map(|n| node_text(&n, source).to_string())?;
597
598        let object = node
599            .child_by_field_name("object")
600            .map(|n| node_text(&n, source).to_string());
601
602        Some(FunctionCall {
603            caller: current_function.unwrap_or("<class>").to_string(),
604            callee: name,
605            line: node.start_position().row + 1,
606            is_method: object.is_some(),
607            receiver: object,
608        })
609    }
610
611    fn build_method_signature(&self, node: &Node, source: &str, is_constructor: bool) -> String {
612        let modifiers = node
613            .children(&mut node.walk())
614            .find(|c| c.kind() == "modifiers")
615            .map(|n| format!("{} ", node_text(&n, source)))
616            .unwrap_or_default();
617
618        let return_type = if is_constructor {
619            String::new()
620        } else {
621            node.child_by_field_name("type")
622                .map(|n| format!("{} ", node_text(&n, source)))
623                .unwrap_or_default()
624        };
625
626        let name = node
627            .child_by_field_name("name")
628            .map(|n| node_text(&n, source))
629            .unwrap_or("unknown");
630
631        let params = node
632            .child_by_field_name("parameters")
633            .map(|n| node_text(&n, source))
634            .unwrap_or("()");
635
636        format!("{}{}{}{}", modifiers, return_type, name, params)
637    }
638
639    fn clean_javadoc(comment: &str) -> String {
640        comment
641            .trim_start_matches("/**")
642            .trim_end_matches("*/")
643            .lines()
644            .map(|line| line.trim().trim_start_matches('*').trim())
645            .filter(|line| !line.is_empty() && !line.starts_with('@'))
646            .collect::<Vec<_>>()
647            .join("\n")
648    }
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    fn parse_java(source: &str) -> (Tree, String) {
656        let mut parser = tree_sitter::Parser::new();
657        parser
658            .set_language(&tree_sitter_java::LANGUAGE.into())
659            .unwrap();
660        let tree = parser.parse(source, None).unwrap();
661        (tree, source.to_string())
662    }
663
664    #[test]
665    fn test_extract_class() {
666        let source = r#"
667public class UserService {
668    private String name;
669
670    public UserService(String name) {
671        this.name = name;
672    }
673
674    public String greet() {
675        return "Hello, " + name + "!";
676    }
677}
678"#;
679        let (tree, src) = parse_java(source);
680        let extractor = JavaExtractor;
681        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
682
683        assert!(symbols
684            .iter()
685            .any(|s| s.name == "UserService" && s.kind == SymbolKind::Class));
686        assert!(symbols
687            .iter()
688            .any(|s| s.name == "name" && s.kind == SymbolKind::Field));
689        assert!(symbols
690            .iter()
691            .any(|s| s.name == "UserService" && s.kind == SymbolKind::Method)); // constructor
692        assert!(symbols
693            .iter()
694            .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
695    }
696
697    #[test]
698    fn test_extract_interface() {
699        let source = r#"
700public interface Greeter {
701    String greet();
702    String farewell();
703}
704"#;
705        let (tree, src) = parse_java(source);
706        let extractor = JavaExtractor;
707        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
708
709        assert!(symbols
710            .iter()
711            .any(|s| s.name == "Greeter" && s.kind == SymbolKind::Interface));
712        assert!(symbols
713            .iter()
714            .any(|s| s.name == "greet" && s.kind == SymbolKind::Method));
715    }
716
717    #[test]
718    fn test_extract_enum() {
719        let source = r#"
720public enum Status {
721    ACTIVE,
722    INACTIVE,
723    PENDING
724}
725"#;
726        let (tree, src) = parse_java(source);
727        let extractor = JavaExtractor;
728        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
729
730        assert!(symbols
731            .iter()
732            .any(|s| s.name == "Status" && s.kind == SymbolKind::Enum));
733        assert!(symbols
734            .iter()
735            .any(|s| s.name == "ACTIVE" && s.kind == SymbolKind::EnumVariant));
736    }
737
738    #[test]
739    fn test_extract_generics() {
740        let source = r#"
741public class Container<T> {
742    private T value;
743
744    public T getValue() {
745        return value;
746    }
747}
748"#;
749        let (tree, src) = parse_java(source);
750        let extractor = JavaExtractor;
751        let symbols = extractor.extract_symbols(&tree, &src).unwrap();
752
753        let class = symbols.iter().find(|s| s.name == "Container").unwrap();
754        assert!(class.generics.contains(&"T".to_string()));
755    }
756}