1use crate::db::{IndexDb, index_db_path};
6use crate::project::ProjectRoot;
7use crate::project::is_excluded;
8use crate::symbols::language_for_path;
9use anyhow::Result;
10use serde::Serialize;
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::fs;
13use tree_sitter::{Node, Parser};
14use walkdir::WalkDir;
15
16#[derive(Debug, Clone, Serialize)]
17pub struct TypeNode {
18 pub name: String,
19 pub file_path: String,
20 pub line: usize,
21 pub kind: TypeNodeKind,
22 pub supertypes: Vec<String>,
23}
24
25#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
26#[serde(rename_all = "snake_case")]
27pub enum TypeNodeKind {
28 Class,
29 Interface,
30 Trait,
31 Enum,
32 Struct,
33}
34
35#[derive(Debug, Clone, Serialize)]
36pub struct TypeHierarchyResult {
37 pub root: String,
38 pub hierarchy_type: String,
39 pub nodes: Vec<TypeNode>,
40}
41
42pub fn get_type_hierarchy_native(
47 project: &ProjectRoot,
48 type_name: &str,
49 _file_path: Option<&str>,
50 hierarchy_type: &str,
51 depth: usize,
52) -> Result<TypeHierarchyResult> {
53 let type_map = build_type_map(project)?;
55
56 let max_depth = if depth == 0 { 50 } else { depth };
57 let mut result_nodes = Vec::new();
58
59 if let Some(root) = type_map.get(type_name) {
61 result_nodes.push(root.clone());
62 }
63
64 if hierarchy_type == "super" || hierarchy_type == "both" {
65 collect_supertypes(type_name, &type_map, max_depth, &mut result_nodes);
66 }
67
68 if hierarchy_type == "sub" || hierarchy_type == "both" {
69 collect_subtypes(type_name, &type_map, max_depth, &mut result_nodes);
70 }
71
72 let mut seen = HashSet::new();
74 result_nodes.retain(|n| seen.insert(format!("{}:{}", n.file_path, n.name)));
75
76 Ok(TypeHierarchyResult {
77 root: type_name.to_string(),
78 hierarchy_type: hierarchy_type.to_string(),
79 nodes: result_nodes,
80 })
81}
82
83fn build_type_map(project: &ProjectRoot) -> Result<HashMap<String, TypeNode>> {
85 let mut map = HashMap::new();
86
87 let db_path = index_db_path(project.as_path());
89 let type_file_paths = IndexDb::open(&db_path).ok().and_then(|db| {
90 db.files_with_symbol_kinds(&["class", "interface", "enum", "module"])
91 .ok()
92 .filter(|paths| !paths.is_empty()) });
94
95 if let Some(rel_paths) = type_file_paths {
96 for rel_path in &rel_paths {
98 let abs_path = project.as_path().join(rel_path);
99 let Some(config) = language_for_path(&abs_path) else {
100 continue;
101 };
102 let source = match fs::read_to_string(&abs_path) {
103 Ok(s) => s,
104 Err(_) => continue,
105 };
106 let mut parser = Parser::new();
107 if parser.set_language(&config.language).is_err() {
108 continue;
109 }
110 let Some(tree) = parser.parse(&source, None) else {
111 continue;
112 };
113 extract_types_from_node(
114 tree.root_node(),
115 source.as_bytes(),
116 rel_path,
117 config.extension,
118 &mut map,
119 );
120 }
121 } else {
122 for entry in WalkDir::new(project.as_path())
124 .into_iter()
125 .filter_entry(|e| !is_excluded(e.path()))
126 {
127 let entry = entry?;
128 if !entry.file_type().is_file() {
129 continue;
130 }
131 let Some(config) = language_for_path(entry.path()) else {
132 continue;
133 };
134 let source = match fs::read_to_string(entry.path()) {
135 Ok(s) => s,
136 Err(_) => continue,
137 };
138 let rel = project.to_relative(entry.path());
139 let mut parser = Parser::new();
140 if parser.set_language(&config.language).is_err() {
141 continue;
142 }
143 let Some(tree) = parser.parse(&source, None) else {
144 continue;
145 };
146 extract_types_from_node(
147 tree.root_node(),
148 source.as_bytes(),
149 &rel,
150 config.extension,
151 &mut map,
152 );
153 }
154 }
155
156 Ok(map)
157}
158
159fn extract_types_from_node(
161 node: Node,
162 source: &[u8],
163 file_path: &str,
164 ext: &str,
165 map: &mut HashMap<String, TypeNode>,
166) {
167 let kind = node.kind();
168
169 match kind {
170 "class_definition" => {
172 if let Some(name) = node.child_by_field_name("name") {
173 let type_name = node_text(name, source).to_string();
174 let supertypes = extract_python_supertypes(node, source);
175 map.insert(
176 type_name.clone(),
177 TypeNode {
178 name: type_name,
179 file_path: file_path.to_string(),
180 line: node.start_position().row + 1,
181 kind: TypeNodeKind::Class,
182 supertypes,
183 },
184 );
185 }
186 }
187 "class_declaration" => {
189 if let Some(name) = node.child_by_field_name("name") {
190 let type_name = node_text(name, source).to_string();
191 let supertypes = extract_js_ts_supertypes(node, source);
192 let node_kind = if ext == "java" || ext == "kt" {
193 TypeNodeKind::Class
195 } else {
196 TypeNodeKind::Class
197 };
198 map.insert(
199 type_name.clone(),
200 TypeNode {
201 name: type_name,
202 file_path: file_path.to_string(),
203 line: node.start_position().row + 1,
204 kind: node_kind,
205 supertypes,
206 },
207 );
208 }
209 }
210 "interface_declaration" => {
212 if let Some(name) = node.child_by_field_name("name") {
213 let type_name = node_text(name, source).to_string();
214 let supertypes = extract_js_ts_supertypes(node, source);
215 map.insert(
216 type_name.clone(),
217 TypeNode {
218 name: type_name,
219 file_path: file_path.to_string(),
220 line: node.start_position().row + 1,
221 kind: TypeNodeKind::Interface,
222 supertypes,
223 },
224 );
225 }
226 }
227 "struct_item" => {
229 if let Some(name) = node.child_by_field_name("name") {
230 let type_name = node_text(name, source).to_string();
231 map.insert(
232 type_name.clone(),
233 TypeNode {
234 name: type_name,
235 file_path: file_path.to_string(),
236 line: node.start_position().row + 1,
237 kind: TypeNodeKind::Struct,
238 supertypes: Vec::new(),
239 },
240 );
241 }
242 }
243 "impl_item" => {
245 let by_field = node
247 .child_by_field_name("trait")
248 .zip(node.child_by_field_name("type"));
249 if let Some((trait_node, type_node)) = by_field {
250 let struct_name = node_text(type_node, source).to_string();
251 let trait_name = node_text(trait_node, source).to_string();
252 if let Some(existing) = map.get_mut(&struct_name)
253 && !existing.supertypes.contains(&trait_name)
254 {
255 existing.supertypes.push(trait_name);
256 }
257 } else {
258 let mut type_ids = Vec::new();
260 let mut has_for = false;
261 for i in 0..node.child_count() {
262 if let Some(child) = node.child(i) {
263 if child.kind() == "type_identifier" {
264 type_ids.push(node_text(child, source).to_string());
265 }
266 if node_text(child, source) == "for" {
267 has_for = true;
268 }
269 }
270 }
271 if has_for && type_ids.len() >= 2 {
272 let trait_name = &type_ids[0];
273 let struct_name = &type_ids[1];
274 if let Some(existing) = map.get_mut(struct_name)
275 && !existing.supertypes.contains(trait_name)
276 {
277 existing.supertypes.push(trait_name.clone());
278 }
279 }
280 }
281 }
282 "type_declaration" | "type_spec" => {
284 if let Some(name) = node.child_by_field_name("name") {
285 let type_name = node_text(name, source).to_string();
286 let supertypes = extract_go_embedded_types(node, source);
287 map.insert(
288 type_name.clone(),
289 TypeNode {
290 name: type_name,
291 file_path: file_path.to_string(),
292 line: node.start_position().row + 1,
293 kind: TypeNodeKind::Struct,
294 supertypes,
295 },
296 );
297 }
298 }
299 "enum_declaration" | "enum_item" => {
301 if let Some(name) = node.child_by_field_name("name") {
302 let type_name = node_text(name, source).to_string();
303 map.insert(
304 type_name.clone(),
305 TypeNode {
306 name: type_name,
307 file_path: file_path.to_string(),
308 line: node.start_position().row + 1,
309 kind: TypeNodeKind::Enum,
310 supertypes: Vec::new(),
311 },
312 );
313 }
314 }
315 _ => {}
316 }
317
318 for i in 0..node.child_count() {
320 if let Some(child) = node.child(i) {
321 extract_types_from_node(child, source, file_path, ext, map);
322 }
323 }
324}
325
326fn extract_python_supertypes(class_node: Node, source: &[u8]) -> Vec<String> {
329 let mut supers = Vec::new();
330 if let Some(args) = class_node.child_by_field_name("superclasses") {
331 for i in 0..args.child_count() {
332 if let Some(child) = args.child(i) {
333 let kind = child.kind();
334 if kind == "identifier" || kind == "attribute" {
335 supers.push(node_text(child, source).to_string());
336 }
337 }
338 }
339 }
340 supers
341}
342
343fn extract_js_ts_supertypes(class_node: Node, source: &[u8]) -> Vec<String> {
344 let mut supers = Vec::new();
345 for i in 0..class_node.child_count() {
346 let Some(child) = class_node.child(i) else {
347 continue;
348 };
349 let kind = child.kind();
350 if kind.contains("extends") || kind.contains("implements") || kind == "class_heritage" {
352 collect_type_identifiers(child, source, &mut supers);
353 }
354 if kind == "superclass" || kind == "super_interfaces" {
356 collect_type_identifiers(child, source, &mut supers);
357 }
358 if kind == "delegation_specifier" || kind == "delegation_specifiers" {
360 collect_type_identifiers(child, source, &mut supers);
361 }
362 }
363 supers
364}
365
366fn extract_go_embedded_types(type_node: Node, source: &[u8]) -> Vec<String> {
367 let mut supers = Vec::new();
368 for i in 0..type_node.child_count() {
370 let Some(child) = type_node.child(i) else {
371 continue;
372 };
373 if child.kind() == "struct_type" || child.kind() == "field_declaration_list" {
374 for j in 0..child.child_count() {
375 if let Some(field) = child.child(j)
376 && (field.kind() == "field_declaration"
377 || field.kind() == "field_declaration_list")
378 {
379 if field.child_by_field_name("name").is_none()
381 && let Some(type_child) = field.child_by_field_name("type")
382 {
383 supers.push(node_text(type_child, source).to_string());
384 }
385 }
386 }
387 supers.extend(extract_go_embedded_types(child, source));
389 }
390 }
391 supers
392}
393
394fn collect_type_identifiers(node: Node, source: &[u8], out: &mut Vec<String>) {
395 let kind = node.kind();
396 if kind == "type_identifier" || kind == "identifier" {
397 let text = node_text(node, source).to_string();
398 if !text.is_empty()
399 && text
400 .chars()
401 .next()
402 .map(|c| c.is_uppercase())
403 .unwrap_or(false)
404 {
405 out.push(text);
406 }
407 }
408 if kind == "generic_type" || kind == "parameterized_type" {
410 if let Some(first) = node.child(0) {
411 let text = node_text(first, source).to_string();
412 if !text.is_empty() {
413 out.push(text);
414 }
415 }
416 return; }
418 for i in 0..node.child_count() {
419 if let Some(child) = node.child(i) {
420 collect_type_identifiers(child, source, out);
421 }
422 }
423}
424
425fn collect_supertypes(
428 type_name: &str,
429 type_map: &HashMap<String, TypeNode>,
430 max_depth: usize,
431 out: &mut Vec<TypeNode>,
432) {
433 let mut queue = VecDeque::new();
434 let mut visited = HashSet::new();
435 visited.insert(type_name.to_string());
436
437 if let Some(root) = type_map.get(type_name) {
438 for s in &root.supertypes {
439 queue.push_back((s.clone(), 1usize));
440 }
441 }
442
443 while let Some((name, depth)) = queue.pop_front() {
444 if depth > max_depth || !visited.insert(name.clone()) {
445 continue;
446 }
447 if let Some(node) = type_map.get(&name) {
448 out.push(node.clone());
449 for s in &node.supertypes {
450 queue.push_back((s.clone(), depth + 1));
451 }
452 }
453 }
454}
455
456fn collect_subtypes(
457 type_name: &str,
458 type_map: &HashMap<String, TypeNode>,
459 max_depth: usize,
460 out: &mut Vec<TypeNode>,
461) {
462 let mut queue = VecDeque::new();
463 let mut visited = HashSet::new();
464 visited.insert(type_name.to_string());
465
466 for node in type_map.values() {
468 if node.supertypes.contains(&type_name.to_string()) {
469 queue.push_back((node.name.clone(), 1usize));
470 }
471 }
472
473 while let Some((name, depth)) = queue.pop_front() {
474 if depth > max_depth || !visited.insert(name.clone()) {
475 continue;
476 }
477 if let Some(node) = type_map.get(&name) {
478 out.push(node.clone());
479 for child in type_map.values() {
481 if child.supertypes.contains(&name) {
482 queue.push_back((child.name.clone(), depth + 1));
483 }
484 }
485 }
486 }
487}
488
489fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
490 std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::ProjectRoot;
497
498 #[test]
499 fn python_class_inheritance() {
500 let dir = temp_dir("py-hier");
501 fs::write(
502 dir.join("models.py"),
503 "class Animal:\n pass\n\nclass Dog(Animal):\n pass\n\nclass GoldenRetriever(Dog):\n pass\n",
504 ).unwrap();
505 let project = ProjectRoot::new(&dir).unwrap();
506
507 let result =
508 get_type_hierarchy_native(&project, "GoldenRetriever", None, "super", 0).unwrap();
509 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
510 assert!(
511 names.contains(&"GoldenRetriever"),
512 "should include root: {names:?}"
513 );
514 assert!(names.contains(&"Dog"), "should include Dog: {names:?}");
515 assert!(
516 names.contains(&"Animal"),
517 "should include Animal: {names:?}"
518 );
519 }
520
521 #[test]
522 fn python_subtypes() {
523 let dir = temp_dir("py-sub");
524 fs::write(
525 dir.join("models.py"),
526 "class Base:\n pass\n\nclass ChildA(Base):\n pass\n\nclass ChildB(Base):\n pass\n",
527 ).unwrap();
528 let project = ProjectRoot::new(&dir).unwrap();
529
530 let result = get_type_hierarchy_native(&project, "Base", None, "sub", 0).unwrap();
531 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
532 assert!(names.contains(&"ChildA"), "should find ChildA: {names:?}");
533 assert!(names.contains(&"ChildB"), "should find ChildB: {names:?}");
534 }
535
536 #[test]
537 fn typescript_extends() {
538 let dir = temp_dir("ts-hier");
539 fs::write(
540 dir.join("models.ts"),
541 "class Base {}\nclass Child extends Base {}\ninterface Printable {}\nclass PrintableChild extends Child implements Printable {}\n",
542 ).unwrap();
543 let project = ProjectRoot::new(&dir).unwrap();
544
545 let result =
546 get_type_hierarchy_native(&project, "PrintableChild", None, "super", 0).unwrap();
547 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
548 assert!(names.contains(&"Child"), "should find Child: {names:?}");
549 assert!(names.contains(&"Base"), "should find Base: {names:?}");
550 }
551
552 #[test]
553 fn both_direction() {
554 let dir = temp_dir("both");
555 fs::write(
556 dir.join("hier.py"),
557 "class A:\n pass\n\nclass B(A):\n pass\n\nclass C(B):\n pass\n",
558 )
559 .unwrap();
560 let project = ProjectRoot::new(&dir).unwrap();
561
562 let result = get_type_hierarchy_native(&project, "B", None, "both", 0).unwrap();
563 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
564 assert!(names.contains(&"A"), "super: {names:?}");
565 assert!(names.contains(&"C"), "sub: {names:?}");
566 assert!(names.contains(&"B"), "self: {names:?}");
567 }
568
569 #[test]
570 fn java_class_hierarchy() {
571 let dir = temp_dir("java-hier");
572 fs::write(dir.join("Animal.java"), "public class Animal {}\n").unwrap();
573 fs::write(dir.join("Dog.java"), "public class Dog extends Animal {}\n").unwrap();
574 let project = ProjectRoot::new(&dir).unwrap();
575
576 let result = get_type_hierarchy_native(&project, "Dog", None, "super", 0).unwrap();
577 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
578 assert!(names.contains(&"Animal"), "should find Animal: {names:?}");
579 }
580
581 #[test]
582 fn rust_trait_impl() {
583 let dir = temp_dir("rs-impl");
584 fs::write(
585 dir.join("lib.rs"),
586 "pub trait Drawable { fn draw(&self); }\npub struct Circle { pub radius: f64 }\nimpl Drawable for Circle { fn draw(&self) {} }\n",
587 ).unwrap();
588 let project = ProjectRoot::new(&dir).unwrap();
589
590 let result = get_type_hierarchy_native(&project, "Circle", None, "super", 0).unwrap();
591 let names: Vec<_> = result.nodes.iter().map(|n| n.name.as_str()).collect();
592 assert!(
593 names.contains(&"Circle"),
594 "should include Circle: {names:?}"
595 );
596 let circle = result.nodes.iter().find(|n| n.name == "Circle").unwrap();
598 assert!(
599 circle.supertypes.contains(&"Drawable".to_string()),
600 "Circle should impl Drawable: {:?}",
601 circle.supertypes
602 );
603 }
604
605 #[test]
606 fn type_node_kind_serialization() {
607 assert_eq!(
608 serde_json::to_string(&TypeNodeKind::Class).unwrap(),
609 "\"class\""
610 );
611 assert_eq!(
612 serde_json::to_string(&TypeNodeKind::Trait).unwrap(),
613 "\"trait\""
614 );
615 }
616
617 fn temp_dir(name: &str) -> std::path::PathBuf {
618 let dir = std::env::temp_dir().join(format!(
619 "codelens-{name}-{}",
620 std::time::SystemTime::now()
621 .duration_since(std::time::UNIX_EPOCH)
622 .unwrap()
623 .as_nanos()
624 ));
625 fs::create_dir_all(&dir).unwrap();
626 dir
627 }
628}