Skip to main content

ast_doc_core/parser/lang/
rust_parser.rs

1//! Rust parser using tree-sitter.
2//!
3//! Detects `#[cfg(test)]` modules and `#[test]` functions for NoTests mode.
4//! Extracts function/struct/trait/enum/impl signatures for Summary mode.
5
6use std::path::Path;
7
8use tree_sitter::{Parser, Tree};
9
10use crate::{
11    error::AstDocError,
12    parser::{
13        Language, LanguageParser, ParsedFile,
14        strategy::{self, RemovalRange, RemovalReason},
15    },
16};
17
18/// Parser for Rust source files using tree-sitter.
19#[derive(Debug, Default)]
20pub struct RustParser;
21
22impl RustParser {
23    /// Create a new Rust parser.
24    #[must_use]
25    pub const fn new() -> Self {
26        Self
27    }
28
29    /// Parse source with tree-sitter, returning the tree.
30    fn parse_tree(source: &str) -> Result<Tree, AstDocError> {
31        let mut parser = Parser::new();
32        let language = tree_sitter_rust::LANGUAGE;
33        parser.set_language(&language.into()).map_err(|e| AstDocError::Parse {
34            path: Path::new("<inline>").to_path_buf(),
35            message: format!("Failed to set Rust language: {e}"),
36        })?;
37        parser.parse(source.as_bytes(), None).ok_or_else(|| AstDocError::Parse {
38            path: Path::new("<inline>").to_path_buf(),
39            message: "Failed to parse Rust source".to_string(),
40        })
41    }
42}
43
44impl LanguageParser for RustParser {
45    fn parse(&self, source: &str, path: &Path) -> Result<ParsedFile, AstDocError> {
46        let tree = Self::parse_tree(source)?;
47        let root_node = tree.root_node();
48
49        let test_ranges = collect_test_ranges(&root_node, source);
50        let summary_ranges = collect_summary_ranges(&root_node, source);
51
52        let strategies_data = strategy::build_strategies(source, &test_ranges, &summary_ranges);
53
54        Ok(ParsedFile {
55            path: path.to_path_buf(),
56            language: Language::Rust,
57            source: source.to_string(),
58            strategies_data,
59        })
60    }
61}
62
63/// Check if a node has a specific attribute in its preceding siblings.
64fn has_attribute(node: tree_sitter::Node<'_>, source: &str, attr_name: &str) -> bool {
65    // Check attribute_item children of this node
66    let mut cursor = node.walk();
67    for child in node.children(&mut cursor) {
68        if child.kind() == "attribute_item" {
69            let text = &source[child.start_byte()..child.end_byte()];
70            if text.contains(attr_name) {
71                return true;
72            }
73        }
74    }
75
76    // Check preceding siblings in parent
77    if let Some(parent) = node.parent() {
78        let mut pcursor = parent.walk();
79        for sibling in parent.children(&mut pcursor) {
80            if sibling.id() == node.id() {
81                break;
82            }
83            if sibling.kind() == "attribute_item" {
84                let text = &source[sibling.start_byte()..sibling.end_byte()];
85                if text.contains(attr_name) {
86                    return true;
87                }
88            }
89        }
90    }
91
92    false
93}
94
95/// Check if a module is annotated with `#[cfg(test)]`.
96fn is_test_module(node: tree_sitter::Node<'_>, source: &str) -> bool {
97    has_attribute(node, source, "cfg(test)")
98}
99
100/// Check if a function is annotated with `#[test]`.
101///
102/// In tree-sitter-rust, `#[test]` is an `attribute_item` sibling that
103/// precedes the `function_item`, not a child of it.
104fn is_test_function(node: tree_sitter::Node<'_>, source: &str) -> bool {
105    if let Some(parent) = node.parent() {
106        let mut cursor = parent.walk();
107        for sibling in parent.children(&mut cursor) {
108            if sibling.id() == node.id() {
109                break;
110            }
111            if sibling.kind() == "attribute_item" {
112                let text = &source[sibling.start_byte()..sibling.end_byte()];
113                if text == "#[test]" {
114                    return true;
115                }
116            }
117        }
118    }
119    false
120}
121
122/// Collect byte ranges for test modules and test functions.
123fn collect_test_ranges(root: &tree_sitter::Node<'_>, source: &str) -> Vec<RemovalRange> {
124    let mut ranges = Vec::new();
125    collect_test_ranges_recursive(root, source, &mut ranges);
126    ranges
127}
128
129fn collect_test_ranges_recursive(
130    node: &tree_sitter::Node<'_>,
131    source: &str,
132    ranges: &mut Vec<RemovalRange>,
133) {
134    let mut cursor = node.walk();
135    for child in node.children(&mut cursor) {
136        match child.kind() {
137            "mod_item" => {
138                if is_test_module(child, source) {
139                    let start = find_attr_start(&child, source);
140                    ranges.push(RemovalRange {
141                        start,
142                        end: child.end_byte(),
143                        reason: RemovalReason::TestModule,
144                    });
145                    continue;
146                }
147                collect_test_ranges_recursive(&child, source, ranges);
148            }
149            "function_item" => {
150                if is_test_function(child, source) {
151                    let start = find_attr_start(&child, source);
152                    ranges.push(RemovalRange {
153                        start,
154                        end: child.end_byte(),
155                        reason: RemovalReason::TestFunction,
156                    });
157                }
158            }
159            _ => {
160                collect_test_ranges_recursive(&child, source, ranges);
161            }
162        }
163    }
164}
165
166/// Find the start byte of attributes preceding a node.
167fn find_attr_start(node: &tree_sitter::Node<'_>, source: &str) -> usize {
168    if let Some(parent) = node.parent() {
169        let mut cursor = parent.walk();
170        let mut first_attr_start = node.start_byte();
171
172        for sibling in parent.children(&mut cursor) {
173            if sibling.id() == node.id() {
174                break;
175            }
176            if sibling.kind() == "attribute_item" && sibling.end_byte() <= node.start_byte() {
177                let between = &source[sibling.end_byte()..node.start_byte()];
178                if between.trim().is_empty() {
179                    first_attr_start = sibling.start_byte();
180                }
181            }
182        }
183
184        return first_attr_start;
185    }
186
187    node.start_byte()
188}
189
190/// Collect byte ranges for Summary mode: replace implementation bodies.
191fn collect_summary_ranges(root: &tree_sitter::Node<'_>, source: &str) -> Vec<RemovalRange> {
192    let mut ranges = Vec::new();
193    collect_summary_ranges_recursive(root, source, &mut ranges);
194    ranges
195}
196
197fn collect_summary_ranges_recursive(
198    node: &tree_sitter::Node<'_>,
199    source: &str,
200    ranges: &mut Vec<RemovalRange>,
201) {
202    let mut cursor = node.walk();
203    for child in node.children(&mut cursor) {
204        match child.kind() {
205            "function_item" => {
206                if is_test_function(child, source) {
207                    continue;
208                }
209                if let Some(range) = extract_implementation_range(child) {
210                    ranges.push(range);
211                }
212            }
213            "impl_item" => {
214                if let Some(range) = extract_impl_body_range(child) {
215                    ranges.push(range);
216                }
217                collect_summary_ranges_recursive(&child, source, ranges);
218            }
219            "mod_item" => {
220                if is_test_module(child, source) {
221                    continue;
222                }
223                collect_summary_ranges_recursive(&child, source, ranges);
224            }
225            _ => {
226                collect_summary_ranges_recursive(&child, source, ranges);
227            }
228        }
229    }
230}
231
232/// Extract the implementation body range of a function (the `block` node).
233fn extract_implementation_range(node: tree_sitter::Node<'_>) -> Option<RemovalRange> {
234    let mut cursor = node.walk();
235    for child in node.children(&mut cursor) {
236        if child.kind() == "block" {
237            return Some(RemovalRange {
238                start: child.start_byte(),
239                end: child.end_byte(),
240                reason: RemovalReason::Implementation,
241            });
242        }
243    }
244    None
245}
246
247/// Extract the body range of an impl block (the `declaration_list` node).
248fn extract_impl_body_range(node: tree_sitter::Node<'_>) -> Option<RemovalRange> {
249    let mut cursor = node.walk();
250    for child in node.children(&mut cursor) {
251        if child.kind() == "declaration_list" {
252            return Some(RemovalRange {
253                start: child.start_byte(),
254                end: child.end_byte(),
255                reason: RemovalReason::Implementation,
256            });
257        }
258    }
259    None
260}
261
262#[cfg(test)]
263#[expect(clippy::unwrap_used, clippy::panic)]
264mod tests {
265    use super::*;
266    use crate::config::OutputStrategy;
267
268    fn parse_rust(source: &str) -> ParsedFile {
269        let parser = RustParser::new();
270        parser.parse(source, Path::new("test.rs")).unwrap()
271    }
272
273    #[test]
274    fn test_rust_parser_creates_three_strategies() {
275        let source = "fn main() {\n    println!(\"hello\");\n}\n";
276        let parsed = parse_rust(source);
277        assert!(parsed.strategies_data.contains_key(&OutputStrategy::Full));
278        assert!(parsed.strategies_data.contains_key(&OutputStrategy::NoTests));
279        assert!(parsed.strategies_data.contains_key(&OutputStrategy::Summary));
280    }
281
282    #[test]
283    fn test_rust_parser_full_is_verbatim() {
284        let source = "fn main() {\n    println!(\"hello\");\n}\n";
285        let parsed = parse_rust(source);
286        assert_eq!(parsed.strategies_data[&OutputStrategy::Full].content, source);
287    }
288
289    #[test]
290    fn test_rust_parser_detects_cfg_test_module() {
291        let source = "pub fn add(a: i32, b: i32) -> i32 {\n    a + b\n}\n\n#[cfg(test)]\nmod tests {\n    #[test]\n    fn test_add() {\n        assert_eq!(add(1, 2), 3);\n    }\n}\n";
292        let parsed = parse_rust(source);
293        let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
294        assert!(!no_tests.contains("#[cfg(test)]"), "NoTests should remove #[cfg(test)] module");
295        assert!(!no_tests.contains("test_add"), "NoTests should remove test function");
296        assert!(no_tests.contains("pub fn add"), "NoTests should preserve non-test code");
297    }
298
299    #[test]
300    fn test_rust_parser_removes_test_function() {
301        let source = "pub fn helper() -> i32 {\n    42\n}\n\n#[test]\nfn test_helper() {\n    assert_eq!(helper(), 42);\n}\n";
302        let parsed = parse_rust(source);
303        let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
304        assert!(no_tests.contains("pub fn helper"), "should preserve helper");
305        assert!(!no_tests.contains("test_helper"), "should remove test function");
306    }
307
308    #[test]
309    fn test_rust_parser_summary_extracts_signatures() {
310        let source = "pub fn add(a: i32, b: i32) -> i32 {\n    a + b\n}\n";
311        let parsed = parse_rust(source);
312        let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
313        assert!(summary.contains("pub fn add(a: i32, b: i32) -> i32"), "should preserve signature");
314        assert!(!summary.contains("a + b"), "should remove body");
315        assert!(summary.contains("✂️ implementations omitted"), "should insert marker");
316    }
317
318    #[test]
319    fn test_rust_parser_summary_handles_struct() {
320        let source = "#[derive(Debug)]\npub struct Point {\n    x: f64,\n    y: f64,\n}\n";
321        let parsed = parse_rust(source);
322        let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
323        assert!(summary.contains("struct Point"), "should contain struct");
324    }
325
326    #[test]
327    fn test_rust_parser_no_tests_fewer_tokens_than_full() {
328        let source = "pub fn lib() -> i32 {\n    42\n}\n\n#[cfg(test)]\nmod tests {\n    #[test]\n    fn test_lib() {\n        assert_eq!(lib(), 42);\n    }\n}\n";
329        let parsed = parse_rust(source);
330        let full_tokens = parsed.strategies_data[&OutputStrategy::Full].token_count;
331        let no_tests_tokens = parsed.strategies_data[&OutputStrategy::NoTests].token_count;
332        assert!(
333            no_tests_tokens < full_tokens,
334            "NoTests ({no_tests_tokens}) should have fewer tokens than Full ({full_tokens})"
335        );
336    }
337
338    #[test]
339    fn test_rust_parser_path_stored() {
340        let source = "fn main() {}\n";
341        let parser = RustParser::new();
342        let parsed = parser.parse(source, Path::new("src/main.rs")).unwrap();
343        assert_eq!(parsed.path, Path::new("src/main.rs"));
344    }
345
346    #[test]
347    fn test_rust_parser_language_is_rust() {
348        let source = "fn main() {}\n";
349        let parsed = parse_rust(source);
350        assert_eq!(parsed.language, Language::Rust);
351    }
352
353    #[test]
354    fn test_rust_parser_empty_file() {
355        let source = "";
356        let parsed = parse_rust(source);
357        assert_eq!(parsed.strategies_data[&OutputStrategy::Full].content, "");
358        assert_eq!(parsed.strategies_data[&OutputStrategy::Full].token_count, 0);
359    }
360
361    #[test]
362    fn test_rust_parser_multiple_test_functions() {
363        let source = "pub fn add(a: i32, b: i32) -> i32 { a + b }\npub fn sub(a: i32, b: i32) -> i32 { a - b }\n\n#[test]\nfn test_add() { assert_eq!(add(1, 2), 3); }\n\n#[test]\nfn test_sub() { assert_eq!(sub(3, 1), 2); }\n";
364        let parsed = parse_rust(source);
365        let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
366        assert!(no_tests.contains("pub fn add"), "should preserve add");
367        assert!(no_tests.contains("pub fn sub"), "should preserve sub");
368        assert!(!no_tests.contains("test_add"), "should remove test_add");
369        assert!(!no_tests.contains("test_sub"), "should remove test_sub");
370    }
371
372    #[test]
373    fn test_rust_parser_nested_test_module() {
374        let source = "pub fn helper() {}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n\n    #[test]\n    fn test_helper() {\n        helper();\n    }\n}\n";
375        let parsed = parse_rust(source);
376        let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
377        assert!(no_tests.contains("pub fn helper"));
378        assert!(!no_tests.contains("test_helper"));
379    }
380
381    #[test]
382    fn test_rust_parser_impl_block_summary() {
383        let source = "pub struct Counter {\n    count: u32,\n}\n\nimpl Counter {\n    pub fn new() -> Self {\n        Self { count: 0 }\n    }\n\n    pub fn increment(&mut self) {\n        self.count += 1;\n    }\n}\n";
384        let parsed = parse_rust(source);
385        let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
386        assert!(summary.contains("impl Counter"), "should contain impl");
387        assert!(summary.contains("struct Counter"), "should contain struct");
388    }
389
390    use proptest::prelude::*;
391
392    fn rust_source_strategy() -> impl Strategy<Value = String> {
393        (
394            proptest::collection::vec(proptest::string::string_regex("[a-z_]{1,10}").unwrap(), 1..5),
395            proptest::collection::vec(proptest::string::string_regex("[a-z0-9_ +\\-*/;(){}\n\t]{0,50}").unwrap(), 1..5),
396            proptest::bool::ANY,
397        ).prop_map(|(fn_names, bodies, add_test_module)| {
398            let mut source = String::new();
399            for (i, name) in fn_names.iter().enumerate() {
400                let body = &bodies[i % bodies.len()];
401                source.push_str(&format!("pub fn {name}() {{\n    {body}\n}}\n\n"));
402            }
403            if add_test_module {
404                source.push_str("#[cfg(test)]\nmod tests {\n    #[test]\n    fn test_something() {\n        assert!(true);\n    }\n}\n");
405            }
406            source
407        })
408    }
409
410    /// Strip known marker strings from content to verify that remaining
411    /// characters form a subsequence of the original source.
412    fn strip_markers(text: &str) -> String {
413        let markers = ["// ✂️ test module omitted\n", "// ✂️ implementations omitted"];
414        let mut result = text.to_string();
415        for marker in &markers {
416            result = result.replace(marker, "");
417        }
418        result
419    }
420
421    /// Check if `candidate` is a subsequence of `source`
422    /// (all characters appear in the same relative order).
423    fn is_subsequence(source: &str, candidate: &str) -> bool {
424        let mut source_iter = source.chars();
425        let mut src_char = source_iter.next();
426        for ch in candidate.chars() {
427            loop {
428                match src_char {
429                    Some(s) if s == ch => {
430                        src_char = source_iter.next();
431                        break;
432                    }
433                    Some(_) => {
434                        src_char = source_iter.next();
435                    }
436                    None => return false,
437                }
438            }
439        }
440        true
441    }
442
443    proptest! {
444        #[test]
445        fn parser_content_subset_invariant(source in rust_source_strategy()) {
446            let parsed = parse_rust(&source);
447            for strategy in [OutputStrategy::Full, OutputStrategy::NoTests, OutputStrategy::Summary] {
448                if let Some(data) = parsed.strategies_data.get(&strategy) {
449                    let stripped = strip_markers(&data.content);
450                    prop_assert!(
451                        is_subsequence(&source, &stripped),
452                        "strategy {strategy}: stripped content is not a subsequence of source.\n\
453                         source len={}, stripped len={}",
454                        source.len(),
455                        stripped.len(),
456                    );
457                }
458            }
459        }
460    }
461}