Skip to main content

kiss/units/
mod.rs

1use crate::parsing::ParsedFile;
2use tree_sitter::Node;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum CodeUnitKind {
6    Function,
7    Method,
8    Class,
9    Module,
10    Struct,
11    Enum,
12    TraitImplMethod,
13}
14
15impl CodeUnitKind {
16    pub const fn as_str(&self) -> &'static str {
17        match self {
18            Self::Function => "function",
19            Self::Method => "method",
20            Self::Class => "class",
21            Self::Module => "module",
22            Self::Struct => "struct",
23            Self::Enum => "enum",
24            Self::TraitImplMethod => "trait_impl_method",
25        }
26    }
27}
28
29impl std::fmt::Display for CodeUnitKind {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}", self.as_str())
32    }
33}
34
35#[derive(Debug)]
36pub struct CodeUnit {
37    pub kind: CodeUnitKind,
38    pub name: String,
39    pub start_line: usize,
40    pub end_line: usize,
41    pub start_byte: usize,
42    pub end_byte: usize,
43}
44
45pub fn extract_code_units(parsed: &ParsedFile) -> Vec<CodeUnit> {
46    let mut units = Vec::new();
47    let root = parsed.tree.root_node();
48
49    units.push(CodeUnit {
50        kind: CodeUnitKind::Module,
51        name: parsed.path.file_stem().map_or_else(
52            || "unknown".to_string(),
53            |s| s.to_string_lossy().into_owned(),
54        ),
55        start_line: 1,
56        end_line: root.end_position().row + 1,
57        start_byte: 0,
58        end_byte: parsed.source.len(),
59    });
60
61    extract_from_node(root, &parsed.source, &mut units, false);
62
63    units
64}
65
66/// Fast-path for callers that only need the *count* of units.
67///
68/// This matches `extract_code_units(parsed).len()` but avoids allocations and string copies.
69#[must_use]
70pub fn count_code_units(parsed: &ParsedFile) -> usize {
71    let root = parsed.tree.root_node();
72    // Always include the synthetic module unit.
73    1 + count_from_node(root)
74}
75
76fn count_from_node(node: Node) -> usize {
77    match node.kind() {
78        "function_definition" | "async_function_definition" | "class_definition" => {
79            let mut count = usize::from(node.child_by_field_name("name").is_some());
80            let mut cursor = node.walk();
81            for child in node.children(&mut cursor) {
82                count += count_from_node(child);
83            }
84            count
85        }
86        _ => {
87            let mut count = 0;
88            let mut cursor = node.walk();
89            for child in node.children(&mut cursor) {
90                count += count_from_node(child);
91            }
92            count
93        }
94    }
95}
96
97fn extract_children(node: Node, source: &str, units: &mut Vec<CodeUnit>, inside_class: bool) {
98    let mut cursor = node.walk();
99    for child in node.children(&mut cursor) {
100        extract_from_node(child, source, units, inside_class);
101    }
102}
103
104fn extract_from_node(node: Node, source: &str, units: &mut Vec<CodeUnit>, inside_class: bool) {
105    match node.kind() {
106        "function_definition" | "async_function_definition" => {
107            if let Some(name) = get_child_by_field(node, "name", source) {
108                units.push(CodeUnit {
109                    kind: if inside_class {
110                        CodeUnitKind::Method
111                    } else {
112                        CodeUnitKind::Function
113                    },
114                    name,
115                    start_line: node.start_position().row + 1,
116                    end_line: node.end_position().row + 1,
117                    start_byte: node.start_byte(),
118                    end_byte: node.end_byte(),
119                });
120            }
121            extract_children(node, source, units, false);
122        }
123        "class_definition" => {
124            if let Some(name) = get_child_by_field(node, "name", source) {
125                units.push(CodeUnit {
126                    kind: CodeUnitKind::Class,
127                    name,
128                    start_line: node.start_position().row + 1,
129                    end_line: node.end_position().row + 1,
130                    start_byte: node.start_byte(),
131                    end_byte: node.end_byte(),
132                });
133            }
134            extract_children(node, source, units, true);
135        }
136        _ => extract_children(node, source, units, inside_class),
137    }
138}
139
140pub(crate) fn get_child_by_field(node: Node, field: &str, source: &str) -> Option<String> {
141    node.child_by_field_name(field)
142        .map(|n| source[n.start_byte()..n.end_byte()].to_string())
143}
144
145#[cfg(test)]
146#[path = "units_test.rs"]
147mod tests;