1use crate::parsing::ParsedFile;
2use tree_sitter::Node;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum CodeUnitKind {
6 Function,
7 Method,
8 Class,
9 Module,
10 Struct,
11 Enum,
12 TraitImplMethod,
13}
14
15impl CodeUnitKind {
16 pub const fn as_str(&self) -> &'static str {
17 match self {
18 Self::Function => "function",
19 Self::Method => "method",
20 Self::Class => "class",
21 Self::Module => "module",
22 Self::Struct => "struct",
23 Self::Enum => "enum",
24 Self::TraitImplMethod => "trait_impl_method",
25 }
26 }
27}
28
29impl std::fmt::Display for CodeUnitKind {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.as_str())
32 }
33}
34
35#[derive(Debug)]
36pub struct CodeUnit {
37 pub kind: CodeUnitKind,
38 pub name: String,
39 pub start_line: usize,
40 pub end_line: usize,
41 pub start_byte: usize,
42 pub end_byte: usize,
43}
44
45pub fn extract_code_units(parsed: &ParsedFile) -> Vec<CodeUnit> {
46 let mut units = Vec::new();
47 let root = parsed.tree.root_node();
48
49 units.push(CodeUnit {
50 kind: CodeUnitKind::Module,
51 name: parsed.path.file_stem().map_or_else(
52 || "unknown".to_string(),
53 |s| s.to_string_lossy().into_owned(),
54 ),
55 start_line: 1,
56 end_line: root.end_position().row + 1,
57 start_byte: 0,
58 end_byte: parsed.source.len(),
59 });
60
61 extract_from_node(root, &parsed.source, &mut units, false);
62
63 units
64}
65
66#[must_use]
70pub fn count_code_units(parsed: &ParsedFile) -> usize {
71 let root = parsed.tree.root_node();
72 1 + count_from_node(root)
74}
75
76fn count_from_node(node: Node) -> usize {
77 match node.kind() {
78 "function_definition" | "async_function_definition" | "class_definition" => {
79 let mut count = usize::from(node.child_by_field_name("name").is_some());
80 let mut cursor = node.walk();
81 for child in node.children(&mut cursor) {
82 count += count_from_node(child);
83 }
84 count
85 }
86 _ => {
87 let mut count = 0;
88 let mut cursor = node.walk();
89 for child in node.children(&mut cursor) {
90 count += count_from_node(child);
91 }
92 count
93 }
94 }
95}
96
97fn extract_children(node: Node, source: &str, units: &mut Vec<CodeUnit>, inside_class: bool) {
98 let mut cursor = node.walk();
99 for child in node.children(&mut cursor) {
100 extract_from_node(child, source, units, inside_class);
101 }
102}
103
104fn extract_from_node(node: Node, source: &str, units: &mut Vec<CodeUnit>, inside_class: bool) {
105 match node.kind() {
106 "function_definition" | "async_function_definition" => {
107 if let Some(name) = get_child_by_field(node, "name", source) {
108 units.push(CodeUnit {
109 kind: if inside_class {
110 CodeUnitKind::Method
111 } else {
112 CodeUnitKind::Function
113 },
114 name,
115 start_line: node.start_position().row + 1,
116 end_line: node.end_position().row + 1,
117 start_byte: node.start_byte(),
118 end_byte: node.end_byte(),
119 });
120 }
121 extract_children(node, source, units, false);
122 }
123 "class_definition" => {
124 if let Some(name) = get_child_by_field(node, "name", source) {
125 units.push(CodeUnit {
126 kind: CodeUnitKind::Class,
127 name,
128 start_line: node.start_position().row + 1,
129 end_line: node.end_position().row + 1,
130 start_byte: node.start_byte(),
131 end_byte: node.end_byte(),
132 });
133 }
134 extract_children(node, source, units, true);
135 }
136 _ => extract_children(node, source, units, inside_class),
137 }
138}
139
140pub(crate) fn get_child_by_field(node: Node, field: &str, source: &str) -> Option<String> {
141 node.child_by_field_name(field)
142 .map(|n| source[n.start_byte()..n.end_byte()].to_string())
143}
144
145#[cfg(test)]
146#[path = "units_test.rs"]
147mod tests;