Skip to main content

kiss/rust_units/
mod.rs

1use crate::rust_parsing::ParsedRustFile;
2use crate::units::CodeUnitKind;
3use syn::visit::Visit;
4use syn::{ImplItem, Item};
5
6#[derive(Debug)]
7pub struct RustCodeUnit {
8    pub kind: CodeUnitKind,
9    pub name: String,
10    pub start_line: usize,
11    pub end_line: usize,
12    pub parent_type: Option<String>,
13}
14
15struct CodeUnitVisitor {
16    units: Vec<RustCodeUnit>,
17    current_impl_type: Option<String>,
18    pub(super) source_lines: usize,
19}
20
21impl CodeUnitVisitor {
22    fn new(source: &str) -> Self {
23        Self {
24            units: Vec::new(),
25            current_impl_type: None,
26            source_lines: source.lines().count(),
27        }
28    }
29}
30
31impl<'ast> Visit<'ast> for CodeUnitVisitor {
32    fn visit_item(&mut self, item: &'ast Item) {
33        match item {
34            Item::Fn(func) => self.visit_top_level_fn(func),
35            Item::Struct(s) => self.record_struct(s),
36            Item::Enum(e) => self.record_enum(e),
37            Item::Impl(impl_block) => self.visit_impl_block(impl_block),
38            Item::Mod(m) => self.visit_item_mod(m),
39            _ => syn::visit::visit_item(self, item),
40        }
41    }
42}
43
44impl<'ast> CodeUnitVisitor {
45    fn visit_top_level_fn(&mut self, func: &'ast syn::ItemFn) {
46        let start_line = func.sig.ident.span().start().line;
47        let end_line = start_line + estimate_block_lines(&func.block);
48
49        self.units.push(RustCodeUnit {
50            kind: CodeUnitKind::Function,
51            name: func.sig.ident.to_string(),
52            start_line,
53            end_line,
54            parent_type: None,
55        });
56        syn::visit::visit_item_fn(self, func);
57    }
58
59    fn record_struct(&mut self, s: &syn::ItemStruct) {
60        let start_line = s.ident.span().start().line;
61        self.units.push(RustCodeUnit {
62            kind: CodeUnitKind::Class,
63            name: s.ident.to_string(),
64            start_line,
65            end_line: start_line,
66            parent_type: None,
67        });
68    }
69
70    fn record_enum(&mut self, e: &syn::ItemEnum) {
71        let start_line = e.ident.span().start().line;
72        self.units.push(RustCodeUnit {
73            kind: CodeUnitKind::Class,
74            name: e.ident.to_string(),
75            start_line,
76            end_line: start_line,
77            parent_type: None,
78        });
79    }
80
81    fn visit_impl_block(&mut self, impl_block: &'ast syn::ItemImpl) {
82        let type_name = if let syn::Type::Path(type_path) = impl_block.self_ty.as_ref() {
83            type_path.path.segments.last().map(|s| s.ident.to_string())
84        } else {
85            None
86        };
87
88        self.current_impl_type = type_name;
89        for impl_item in &impl_block.items {
90            if let ImplItem::Fn(method) = impl_item {
91                let start_line = method.sig.ident.span().start().line;
92                let end_line = start_line + estimate_block_lines(&method.block);
93
94                self.units.push(RustCodeUnit {
95                    kind: CodeUnitKind::Method,
96                    name: method.sig.ident.to_string(),
97                    start_line,
98                    end_line,
99                    parent_type: self.current_impl_type.clone(),
100                });
101            }
102        }
103
104        self.current_impl_type = None;
105    }
106
107    fn visit_item_mod(&mut self, m: &'ast syn::ItemMod) {
108        if m.content.is_some() {
109            let start_line = m.ident.span().start().line;
110            self.units.push(RustCodeUnit {
111                kind: CodeUnitKind::Module,
112                name: m.ident.to_string(),
113                start_line,
114                end_line: start_line,
115                parent_type: None,
116            });
117        }
118        syn::visit::visit_item_mod(self, m);
119    }
120}
121
122fn estimate_block_lines(block: &syn::Block) -> usize {
123    if block.stmts.is_empty() {
124        return 1;
125    }
126    let start = block.brace_token.span.open().start().line;
127    let end = block.brace_token.span.close().end().line;
128
129    if end >= start {
130        end - start + 1
131    } else {
132        block.stmts.len().max(1)
133    }
134}
135
136pub fn extract_rust_code_units(parsed: &ParsedRustFile) -> Vec<RustCodeUnit> {
137    let mut visitor = CodeUnitVisitor::new(&parsed.source);
138    visitor.units.push(RustCodeUnit {
139        kind: CodeUnitKind::Module,
140        name: parsed.path.file_stem().map_or_else(
141            || "unknown".to_string(),
142            |s| s.to_string_lossy().into_owned(),
143        ),
144        start_line: 1,
145        end_line: visitor.source_lines,
146        parent_type: None,
147    });
148    for item in &parsed.ast.items {
149        visitor.visit_item(item);
150    }
151
152    visitor.units
153}
154
155#[cfg(test)]
156mod inline_coverage_tests {
157    use super::*;
158    use syn::visit::Visit;
159
160    #[test]
161    fn direct_visitor_helpers_exercised() {
162        let src = "enum E { A, B }\nstruct S { x: i32 }\nmod inner { fn nested() {} }\nimpl S { fn f(&self) { let _ = 1; } }\nfn top() { let _ = 2; }\n";
163        let file: syn::File = syn::parse_str(src).unwrap();
164        let mut visitor = CodeUnitVisitor::new(src);
165        for item in &file.items {
166            match item {
167                syn::Item::Fn(f) => visitor.visit_top_level_fn(f),
168                syn::Item::Struct(s) => visitor.record_struct(s),
169                syn::Item::Enum(e) => visitor.record_enum(e),
170                syn::Item::Impl(i) => visitor.visit_impl_block(i),
171                syn::Item::Mod(m) => visitor.visit_item_mod(m),
172                other => visitor.visit_item(other),
173            }
174        }
175        let names: Vec<_> = visitor.units.iter().map(|u| u.name.as_str()).collect();
176        for expected in ["E", "S", "inner", "f", "top", "nested"] {
177            assert!(names.contains(&expected), "missing {expected} in {names:?}");
178        }
179    }
180
181    #[test]
182    fn estimate_block_lines_fallback_on_missing_brace_span() {
183        let block: syn::Block = syn::parse_str("{ let x = 1; }").unwrap();
184        assert!(estimate_block_lines(&block) >= 1);
185    }
186}
187
188#[cfg(test)]
189#[path = "rust_units_test.rs"]
190mod rust_units_test;