Skip to main content

codelens_engine/
type_hierarchy.rs

1//! Tree-sitter based type hierarchy analysis.
2//!
3//! Replaces JetBrains PSI `getTypeHierarchy` with direct AST node traversal.
4
5use crate::db::{IndexDb, index_db_path};
6use crate::project::ProjectRoot;
7use crate::project::is_excluded;
8use crate::symbols::language_for_path;
9use anyhow::Result;
10use serde::Serialize;
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::fs;
13use tree_sitter::{Node, Parser};
14use walkdir::WalkDir;
15
16#[derive(Debug, Clone, Serialize)]
17pub struct TypeNode {
18    pub name: String,
19    pub file_path: String,
20    pub line: usize,
21    pub kind: TypeNodeKind,
22    pub supertypes: Vec<String>,
23}
24
25#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
26#[serde(rename_all = "snake_case")]
27pub enum TypeNodeKind {
28    Class,
29    Interface,
30    Trait,
31    Enum,
32    Struct,
33}
34
35#[derive(Debug, Clone, Serialize)]
36pub struct TypeHierarchyResult {
37    pub root: String,
38    pub hierarchy_type: String,
39    pub nodes: Vec<TypeNode>,
40}
41
42/// Get the type hierarchy for a named type.
43///
44/// - `hierarchy_type`: `"super"`, `"sub"`, or `"both"`
45/// - `depth`: max traversal depth (0 = unlimited)
46pub fn get_type_hierarchy_native(
47    project: &ProjectRoot,
48    type_name: &str,
49    _file_path: Option<&str>,
50    hierarchy_type: &str,
51    depth: usize,
52) -> Result<TypeHierarchyResult> {
53    // Step 1: Build a project-wide type map { name -> TypeNode }
54    let type_map = build_type_map(project)?;
55
56    let max_depth = if depth == 0 { 50 } else { depth };
57    let mut result_nodes = Vec::new();
58
59    // Include root type
60    if let Some(root) = type_map.get(type_name) {
61        result_nodes.push(root.clone());
62    }
63
64    if hierarchy_type == "super" || hierarchy_type == "both" {
65        collect_supertypes(type_name, &type_map, max_depth, &mut result_nodes);
66    }
67
68    if hierarchy_type == "sub" || hierarchy_type == "both" {
69        collect_subtypes(type_name, &type_map, max_depth, &mut result_nodes);
70    }
71
72    // Deduplicate
73    let mut seen = HashSet::new();
74    result_nodes.retain(|n| seen.insert(format!("{}:{}", n.file_path, n.name)));
75
76    Ok(TypeHierarchyResult {
77        root: type_name.to_string(),
78        hierarchy_type: hierarchy_type.to_string(),
79        nodes: result_nodes,
80    })
81}
82
83/// Build a map of all types in the project with their supertypes.
84fn build_type_map(project: &ProjectRoot) -> Result<HashMap<String, TypeNode>> {
85    let mut map = HashMap::new();
86
87    // Try DB-accelerated path: only parse files that contain type declarations
88    let db_path = index_db_path(project.as_path());
89    let type_file_paths = IndexDb::open(&db_path).ok().and_then(|db| {
90        db.files_with_symbol_kinds(&["class", "interface", "enum", "module"])
91            .ok()
92            .filter(|paths| !paths.is_empty()) // empty DB → fallback to walk
93    });
94
95    if let Some(rel_paths) = type_file_paths {
96        // Fast path: only parse files known to have type declarations
97        for rel_path in &rel_paths {
98            let abs_path = project.as_path().join(rel_path);
99            let Some(config) = language_for_path(&abs_path) else {
100                continue;
101            };
102            let source = match fs::read_to_string(&abs_path) {
103                Ok(s) => s,
104                Err(_) => continue,
105            };
106            let mut parser = Parser::new();
107            if parser.set_language(&config.language).is_err() {
108                continue;
109            }
110            let Some(tree) = parser.parse(&source, None) else {
111                continue;
112            };
113            extract_types_from_node(
114                tree.root_node(),
115                source.as_bytes(),
116                rel_path,
117                config.extension,
118                &mut map,
119            );
120        }
121    } else {
122        // Fallback: full walk (no index available)
123        for entry in WalkDir::new(project.as_path())
124            .into_iter()
125            .filter_entry(|e| !is_excluded(e.path()))
126        {
127            let entry = entry?;
128            if !entry.file_type().is_file() {
129                continue;
130            }
131            let Some(config) = language_for_path(entry.path()) else {
132                continue;
133            };
134            let source = match fs::read_to_string(entry.path()) {
135                Ok(s) => s,
136                Err(_) => continue,
137            };
138            let rel = project.to_relative(entry.path());
139            let mut parser = Parser::new();
140            if parser.set_language(&config.language).is_err() {
141                continue;
142            }
143            let Some(tree) = parser.parse(&source, None) else {
144                continue;
145            };
146            extract_types_from_node(
147                tree.root_node(),
148                source.as_bytes(),
149                &rel,
150                config.extension,
151                &mut map,
152            );
153        }
154    }
155
156    Ok(map)
157}
158
159/// Walk AST to find class/interface/struct/trait/enum declarations and their supertypes.
160fn extract_types_from_node(
161    node: Node,
162    source: &[u8],
163    file_path: &str,
164    ext: &str,
165    map: &mut HashMap<String, TypeNode>,
166) {
167    let kind = node.kind();
168
169    match kind {
170        // Python: class Foo(Bar, Baz):
171        "class_definition" => {
172            if let Some(name) = node.child_by_field_name("name") {
173                let type_name = node_text(name, source).to_string();
174                let supertypes = extract_python_supertypes(node, source);
175                map.insert(
176                    type_name.clone(),
177                    TypeNode {
178                        name: type_name,
179                        file_path: file_path.to_string(),
180                        line: node.start_position().row + 1,
181                        kind: TypeNodeKind::Class,
182                        supertypes,
183                    },
184                );
185            }
186        }
187        // JS/TS: class Foo extends Bar implements I {}
188        "class_declaration" => {
189            if let Some(name) = node.child_by_field_name("name") {
190                let type_name = node_text(name, source).to_string();
191                let supertypes = extract_js_ts_supertypes(node, source);
192                let node_kind = if ext == "java" || ext == "kt" {
193                    // Java/Kotlin also use class_declaration
194                    TypeNodeKind::Class
195                } else {
196                    TypeNodeKind::Class
197                };
198                map.insert(
199                    type_name.clone(),
200                    TypeNode {
201                        name: type_name,
202                        file_path: file_path.to_string(),
203                        line: node.start_position().row + 1,
204                        kind: node_kind,
205                        supertypes,
206                    },
207                );
208            }
209        }
210        // TS: interface Foo extends Bar {}
211        "interface_declaration" => {
212            if let Some(name) = node.child_by_field_name("name") {
213                let type_name = node_text(name, source).to_string();
214                let supertypes = extract_js_ts_supertypes(node, source);
215                map.insert(
216                    type_name.clone(),
217                    TypeNode {
218                        name: type_name,
219                        file_path: file_path.to_string(),
220                        line: node.start_position().row + 1,
221                        kind: TypeNodeKind::Interface,
222                        supertypes,
223                    },
224                );
225            }
226        }
227        // Rust: struct Foo {}
228        "struct_item" => {
229            if let Some(name) = node.child_by_field_name("name") {
230                let type_name = node_text(name, source).to_string();
231                map.insert(
232                    type_name.clone(),
233                    TypeNode {
234                        name: type_name,
235                        file_path: file_path.to_string(),
236                        line: node.start_position().row + 1,
237                        kind: TypeNodeKind::Struct,
238                        supertypes: Vec::new(),
239                    },
240                );
241            }
242        }
243        // Rust: impl Trait for Struct — adds Trait as supertype of Struct
244        "impl_item" => {
245            // Try field names first
246            let by_field = node
247                .child_by_field_name("trait")
248                .zip(node.child_by_field_name("type"));
249            if let Some((trait_node, type_node)) = by_field {
250                let struct_name = node_text(type_node, source).to_string();
251                let trait_name = node_text(trait_node, source).to_string();
252                if let Some(existing) = map.get_mut(&struct_name)
253                    && !existing.supertypes.contains(&trait_name)
254                {
255                    existing.supertypes.push(trait_name);
256                }
257            } else {
258                // Fallback: scan child type_identifiers — pattern: impl TRAIT for TYPE
259                let mut type_ids = Vec::new();
260                let mut has_for = false;
261                for i in 0..node.child_count() {
262                    if let Some(child) = node.child(i) {
263                        if child.kind() == "type_identifier" {
264                            type_ids.push(node_text(child, source).to_string());
265                        }
266                        if node_text(child, source) == "for" {
267                            has_for = true;
268                        }
269                    }
270                }
271                if has_for && type_ids.len() >= 2 {
272                    let trait_name = &type_ids[0];
273                    let struct_name = &type_ids[1];
274                    if let Some(existing) = map.get_mut(struct_name)
275                        && !existing.supertypes.contains(trait_name)
276                    {
277                        existing.supertypes.push(trait_name.clone());
278                    }
279                }
280            }
281        }
282        // Go: type Foo struct { Bar }  (embedded fields = inheritance)
283        "type_declaration" | "type_spec" => {
284            if let Some(name) = node.child_by_field_name("name") {
285                let type_name = node_text(name, source).to_string();
286                let supertypes = extract_go_embedded_types(node, source);
287                map.insert(
288                    type_name.clone(),
289                    TypeNode {
290                        name: type_name,
291                        file_path: file_path.to_string(),
292                        line: node.start_position().row + 1,
293                        kind: TypeNodeKind::Struct,
294                        supertypes,
295                    },
296                );
297            }
298        }
299        // Enum declarations
300        "enum_declaration" | "enum_item" => {
301            if let Some(name) = node.child_by_field_name("name") {
302                let type_name = node_text(name, source).to_string();
303                map.insert(
304                    type_name.clone(),
305                    TypeNode {
306                        name: type_name,
307                        file_path: file_path.to_string(),
308                        line: node.start_position().row + 1,
309                        kind: TypeNodeKind::Enum,
310                        supertypes: Vec::new(),
311                    },
312                );
313            }
314        }
315        _ => {}
316    }
317
318    // Recurse
319    for i in 0..node.child_count() {
320        if let Some(child) = node.child(i) {
321            extract_types_from_node(child, source, file_path, ext, map);
322        }
323    }
324}
325
326// ── Language-specific supertype extraction ────────────────────────────────
327
328fn extract_python_supertypes(class_node: Node, source: &[u8]) -> Vec<String> {
329    let mut supers = Vec::new();
330    if let Some(args) = class_node.child_by_field_name("superclasses") {
331        for i in 0..args.child_count() {
332            if let Some(child) = args.child(i) {
333                let kind = child.kind();
334                if kind == "identifier" || kind == "attribute" {
335                    supers.push(node_text(child, source).to_string());
336                }
337            }
338        }
339    }
340    supers
341}
342
343fn extract_js_ts_supertypes(class_node: Node, source: &[u8]) -> Vec<String> {
344    let mut supers = Vec::new();
345    for i in 0..class_node.child_count() {
346        let Some(child) = class_node.child(i) else {
347            continue;
348        };
349        let kind = child.kind();
350        // extends_clause, implements_clause, class_heritage
351        if kind.contains("extends") || kind.contains("implements") || kind == "class_heritage" {
352            collect_type_identifiers(child, source, &mut supers);
353        }
354        // Java: superclass / superinterfaces fields
355        if kind == "superclass" || kind == "super_interfaces" {
356            collect_type_identifiers(child, source, &mut supers);
357        }
358        // Kotlin: delegation_specifier
359        if kind == "delegation_specifier" || kind == "delegation_specifiers" {
360            collect_type_identifiers(child, source, &mut supers);
361        }
362    }
363    supers
364}
365
366fn extract_go_embedded_types(type_node: Node, source: &[u8]) -> Vec<String> {
367    let mut supers = Vec::new();
368    // Look for struct_type -> field_declaration_list -> field_declaration with no name (embedded)
369    for i in 0..type_node.child_count() {
370        let Some(child) = type_node.child(i) else {
371            continue;
372        };
373        if child.kind() == "struct_type" || child.kind() == "field_declaration_list" {
374            for j in 0..child.child_count() {
375                if let Some(field) = child.child(j)
376                    && (field.kind() == "field_declaration"
377                        || field.kind() == "field_declaration_list")
378                {
379                    // Embedded field: only type, no name
380                    if field.child_by_field_name("name").is_none()
381                        && let Some(type_child) = field.child_by_field_name("type")
382                    {
383                        supers.push(node_text(type_child, source).to_string());
384                    }
385                }
386            }
387            // Recurse into field_declaration_list
388            supers.extend(extract_go_embedded_types(child, source));
389        }
390    }
391    supers
392}
393
394fn collect_type_identifiers(node: Node, source: &[u8], out: &mut Vec<String>) {
395    let kind = node.kind();
396    if kind == "type_identifier" || kind == "identifier" {
397        let text = node_text(node, source).to_string();
398        if !text.is_empty()
399            && text
400                .chars()
401                .next()
402                .map(|c| c.is_uppercase())
403                .unwrap_or(false)
404        {
405            out.push(text);
406        }
407    }
408    // Generic types: extract the base name
409    if kind == "generic_type" || kind == "parameterized_type" {
410        if let Some(first) = node.child(0) {
411            let text = node_text(first, source).to_string();
412            if !text.is_empty() {
413                out.push(text);
414            }
415        }
416        return; // Don't recurse into type parameters
417    }
418    for i in 0..node.child_count() {
419        if let Some(child) = node.child(i) {
420            collect_type_identifiers(child, source, out);
421        }
422    }
423}
424
425// ── Hierarchy traversal ──────────────────────────────────────────────────
426
427fn collect_supertypes(
428    type_name: &str,
429    type_map: &HashMap<String, TypeNode>,
430    max_depth: usize,
431    out: &mut Vec<TypeNode>,
432) {
433    let mut queue = VecDeque::new();
434    let mut visited = HashSet::new();
435    visited.insert(type_name.to_string());
436
437    if let Some(root) = type_map.get(type_name) {
438        for s in &root.supertypes {
439            queue.push_back((s.clone(), 1usize));
440        }
441    }
442
443    while let Some((name, depth)) = queue.pop_front() {
444        if depth > max_depth || !visited.insert(name.clone()) {
445            continue;
446        }
447        if let Some(node) = type_map.get(&name) {
448            out.push(node.clone());
449            for s in &node.supertypes {
450                queue.push_back((s.clone(), depth + 1));
451            }
452        }
453    }
454}
455
456fn collect_subtypes(
457    type_name: &str,
458    type_map: &HashMap<String, TypeNode>,
459    max_depth: usize,
460    out: &mut Vec<TypeNode>,
461) {
462    let mut queue = VecDeque::new();
463    let mut visited = HashSet::new();
464    visited.insert(type_name.to_string());
465
466    // Find direct subtypes: types whose supertypes include type_name
467    for node in type_map.values() {
468        if node.supertypes.contains(&type_name.to_string()) {
469            queue.push_back((node.name.clone(), 1usize));
470        }
471    }
472
473    while let Some((name, depth)) = queue.pop_front() {
474        if depth > max_depth || !visited.insert(name.clone()) {
475            continue;
476        }
477        if let Some(node) = type_map.get(&name) {
478            out.push(node.clone());
479            // Find types that extend this subtype
480            for child in type_map.values() {
481                if child.supertypes.contains(&name) {
482                    queue.push_back((child.name.clone(), depth + 1));
483                }
484            }
485        }
486    }
487}
488
489fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
490    std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::ProjectRoot;
497
498    #[test]
499    fn python_class_inheritance() {
500        let dir = temp_dir("py-hier");
501        fs::write(
502            dir.join("models.py"),
503            "class Animal:\n    pass\n\nclass Dog(Animal):\n    pass\n\nclass GoldenRetriever(Dog):\n    pass\n",
504        ).unwrap();
505        let project = ProjectRoot::new(&dir).unwrap();
506
507        let result =
508            get_type_hierarchy_native(&project, "GoldenRetriever", None, "super", 0).unwrap();
509        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
510        assert!(
511            names.contains(&"GoldenRetriever"),
512            "should include root: {names:?}"
513        );
514        assert!(names.contains(&"Dog"), "should include Dog: {names:?}");
515        assert!(
516            names.contains(&"Animal"),
517            "should include Animal: {names:?}"
518        );
519    }
520
521    #[test]
522    fn python_subtypes() {
523        let dir = temp_dir("py-sub");
524        fs::write(
525            dir.join("models.py"),
526            "class Base:\n    pass\n\nclass ChildA(Base):\n    pass\n\nclass ChildB(Base):\n    pass\n",
527        ).unwrap();
528        let project = ProjectRoot::new(&dir).unwrap();
529
530        let result = get_type_hierarchy_native(&project, "Base", None, "sub", 0).unwrap();
531        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
532        assert!(names.contains(&"ChildA"), "should find ChildA: {names:?}");
533        assert!(names.contains(&"ChildB"), "should find ChildB: {names:?}");
534    }
535
536    #[test]
537    fn typescript_extends() {
538        let dir = temp_dir("ts-hier");
539        fs::write(
540            dir.join("models.ts"),
541            "class Base {}\nclass Child extends Base {}\ninterface Printable {}\nclass PrintableChild extends Child implements Printable {}\n",
542        ).unwrap();
543        let project = ProjectRoot::new(&dir).unwrap();
544
545        let result =
546            get_type_hierarchy_native(&project, "PrintableChild", None, "super", 0).unwrap();
547        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
548        assert!(names.contains(&"Child"), "should find Child: {names:?}");
549        assert!(names.contains(&"Base"), "should find Base: {names:?}");
550    }
551
552    #[test]
553    fn both_direction() {
554        let dir = temp_dir("both");
555        fs::write(
556            dir.join("hier.py"),
557            "class A:\n    pass\n\nclass B(A):\n    pass\n\nclass C(B):\n    pass\n",
558        )
559        .unwrap();
560        let project = ProjectRoot::new(&dir).unwrap();
561
562        let result = get_type_hierarchy_native(&project, "B", None, "both", 0).unwrap();
563        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
564        assert!(names.contains(&"A"), "super: {names:?}");
565        assert!(names.contains(&"C"), "sub: {names:?}");
566        assert!(names.contains(&"B"), "self: {names:?}");
567    }
568
569    #[test]
570    fn java_class_hierarchy() {
571        let dir = temp_dir("java-hier");
572        fs::write(dir.join("Animal.java"), "public class Animal {}\n").unwrap();
573        fs::write(dir.join("Dog.java"), "public class Dog extends Animal {}\n").unwrap();
574        let project = ProjectRoot::new(&dir).unwrap();
575
576        let result = get_type_hierarchy_native(&project, "Dog", None, "super", 0).unwrap();
577        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
578        assert!(names.contains(&"Animal"), "should find Animal: {names:?}");
579    }
580
581    #[test]
582    fn rust_trait_impl() {
583        let dir = temp_dir("rs-impl");
584        fs::write(
585            dir.join("lib.rs"),
586            "pub trait Drawable { fn draw(&self); }\npub struct Circle { pub radius: f64 }\nimpl Drawable for Circle { fn draw(&self) {} }\n",
587        ).unwrap();
588        let project = ProjectRoot::new(&dir).unwrap();
589
590        let result = get_type_hierarchy_native(&project, "Circle", None, "super", 0).unwrap();
591        let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
592        assert!(
593            names.contains(&"Circle"),
594            "should include Circle: {names:?}"
595        );
596        // Circle should have Drawable as supertype
597        let circle = result.nodes.iter().find(|n| n.name == "Circle").unwrap();
598        assert!(
599            circle.supertypes.contains(&"Drawable".to_string()),
600            "Circle should impl Drawable: {:?}",
601            circle.supertypes
602        );
603    }
604
605    #[test]
606    fn type_node_kind_serialization() {
607        assert_eq!(
608            serde_json::to_string(&TypeNodeKind::Class).unwrap(),
609            "\"class\""
610        );
611        assert_eq!(
612            serde_json::to_string(&TypeNodeKind::Trait).unwrap(),
613            "\"trait\""
614        );
615    }
616
617    fn temp_dir(name: &str) -> std::path::PathBuf {
618        let dir = std::env::temp_dir().join(format!(
619            "codelens-{name}-{}",
620            std::time::SystemTime::now()
621                .duration_since(std::time::UNIX_EPOCH)
622                .unwrap()
623                .as_nanos()
624        ));
625        fs::create_dir_all(&dir).unwrap();
626        dir
627    }
628}