ast_doc_core/parser/lang/
rust_parser.rs1use std::path::Path;
7
8use tree_sitter::{Parser, Tree};
9
10use crate::{
11 error::AstDocError,
12 parser::{
13 Language, LanguageParser, ParsedFile,
14 strategy::{self, RemovalRange, RemovalReason},
15 },
16};
17
18#[derive(Debug, Default)]
20pub struct RustParser;
21
22impl RustParser {
23 #[must_use]
25 pub const fn new() -> Self {
26 Self
27 }
28
29 fn parse_tree(source: &str) -> Result<Tree, AstDocError> {
31 let mut parser = Parser::new();
32 let language = tree_sitter_rust::LANGUAGE;
33 parser.set_language(&language.into()).map_err(|e| AstDocError::Parse {
34 path: Path::new("<inline>").to_path_buf(),
35 message: format!("Failed to set Rust language: {e}"),
36 })?;
37 parser.parse(source.as_bytes(), None).ok_or_else(|| AstDocError::Parse {
38 path: Path::new("<inline>").to_path_buf(),
39 message: "Failed to parse Rust source".to_string(),
40 })
41 }
42}
43
44impl LanguageParser for RustParser {
45 fn parse(&self, source: &str, path: &Path) -> Result<ParsedFile, AstDocError> {
46 let tree = Self::parse_tree(source)?;
47 let root_node = tree.root_node();
48
49 let test_ranges = collect_test_ranges(&root_node, source);
50 let summary_ranges = collect_summary_ranges(&root_node, source);
51
52 let strategies_data = strategy::build_strategies(source, &test_ranges, &summary_ranges);
53
54 Ok(ParsedFile {
55 path: path.to_path_buf(),
56 language: Language::Rust,
57 source: source.to_string(),
58 strategies_data,
59 })
60 }
61}
62
63fn has_attribute(node: tree_sitter::Node<'_>, source: &str, attr_name: &str) -> bool {
65 let mut cursor = node.walk();
67 for child in node.children(&mut cursor) {
68 if child.kind() == "attribute_item" {
69 let text = &source[child.start_byte()..child.end_byte()];
70 if text.contains(attr_name) {
71 return true;
72 }
73 }
74 }
75
76 if let Some(parent) = node.parent() {
78 let mut pcursor = parent.walk();
79 for sibling in parent.children(&mut pcursor) {
80 if sibling.id() == node.id() {
81 break;
82 }
83 if sibling.kind() == "attribute_item" {
84 let text = &source[sibling.start_byte()..sibling.end_byte()];
85 if text.contains(attr_name) {
86 return true;
87 }
88 }
89 }
90 }
91
92 false
93}
94
95fn is_test_module(node: tree_sitter::Node<'_>, source: &str) -> bool {
97 has_attribute(node, source, "cfg(test)")
98}
99
100fn is_test_function(node: tree_sitter::Node<'_>, source: &str) -> bool {
105 if let Some(parent) = node.parent() {
106 let mut cursor = parent.walk();
107 for sibling in parent.children(&mut cursor) {
108 if sibling.id() == node.id() {
109 break;
110 }
111 if sibling.kind() == "attribute_item" {
112 let text = &source[sibling.start_byte()..sibling.end_byte()];
113 if text == "#[test]" {
114 return true;
115 }
116 }
117 }
118 }
119 false
120}
121
122fn collect_test_ranges(root: &tree_sitter::Node<'_>, source: &str) -> Vec<RemovalRange> {
124 let mut ranges = Vec::new();
125 collect_test_ranges_recursive(root, source, &mut ranges);
126 ranges
127}
128
129fn collect_test_ranges_recursive(
130 node: &tree_sitter::Node<'_>,
131 source: &str,
132 ranges: &mut Vec<RemovalRange>,
133) {
134 let mut cursor = node.walk();
135 for child in node.children(&mut cursor) {
136 match child.kind() {
137 "mod_item" => {
138 if is_test_module(child, source) {
139 let start = find_attr_start(&child, source);
140 ranges.push(RemovalRange {
141 start,
142 end: child.end_byte(),
143 reason: RemovalReason::TestModule,
144 });
145 continue;
146 }
147 collect_test_ranges_recursive(&child, source, ranges);
148 }
149 "function_item" => {
150 if is_test_function(child, source) {
151 let start = find_attr_start(&child, source);
152 ranges.push(RemovalRange {
153 start,
154 end: child.end_byte(),
155 reason: RemovalReason::TestFunction,
156 });
157 }
158 }
159 _ => {
160 collect_test_ranges_recursive(&child, source, ranges);
161 }
162 }
163 }
164}
165
166fn find_attr_start(node: &tree_sitter::Node<'_>, source: &str) -> usize {
168 if let Some(parent) = node.parent() {
169 let mut cursor = parent.walk();
170 let mut first_attr_start = node.start_byte();
171
172 for sibling in parent.children(&mut cursor) {
173 if sibling.id() == node.id() {
174 break;
175 }
176 if sibling.kind() == "attribute_item" && sibling.end_byte() <= node.start_byte() {
177 let between = &source[sibling.end_byte()..node.start_byte()];
178 if between.trim().is_empty() {
179 first_attr_start = sibling.start_byte();
180 }
181 }
182 }
183
184 return first_attr_start;
185 }
186
187 node.start_byte()
188}
189
190fn collect_summary_ranges(root: &tree_sitter::Node<'_>, source: &str) -> Vec<RemovalRange> {
192 let mut ranges = Vec::new();
193 collect_summary_ranges_recursive(root, source, &mut ranges);
194 ranges
195}
196
197fn collect_summary_ranges_recursive(
198 node: &tree_sitter::Node<'_>,
199 source: &str,
200 ranges: &mut Vec<RemovalRange>,
201) {
202 let mut cursor = node.walk();
203 for child in node.children(&mut cursor) {
204 match child.kind() {
205 "function_item" => {
206 if is_test_function(child, source) {
207 continue;
208 }
209 if let Some(range) = extract_implementation_range(child) {
210 ranges.push(range);
211 }
212 }
213 "impl_item" => {
214 if let Some(range) = extract_impl_body_range(child) {
215 ranges.push(range);
216 }
217 collect_summary_ranges_recursive(&child, source, ranges);
218 }
219 "mod_item" => {
220 if is_test_module(child, source) {
221 continue;
222 }
223 collect_summary_ranges_recursive(&child, source, ranges);
224 }
225 _ => {
226 collect_summary_ranges_recursive(&child, source, ranges);
227 }
228 }
229 }
230}
231
232fn extract_implementation_range(node: tree_sitter::Node<'_>) -> Option<RemovalRange> {
234 let mut cursor = node.walk();
235 for child in node.children(&mut cursor) {
236 if child.kind() == "block" {
237 return Some(RemovalRange {
238 start: child.start_byte(),
239 end: child.end_byte(),
240 reason: RemovalReason::Implementation,
241 });
242 }
243 }
244 None
245}
246
247fn extract_impl_body_range(node: tree_sitter::Node<'_>) -> Option<RemovalRange> {
249 let mut cursor = node.walk();
250 for child in node.children(&mut cursor) {
251 if child.kind() == "declaration_list" {
252 return Some(RemovalRange {
253 start: child.start_byte(),
254 end: child.end_byte(),
255 reason: RemovalReason::Implementation,
256 });
257 }
258 }
259 None
260}
261
262#[cfg(test)]
263#[expect(clippy::unwrap_used, clippy::panic)]
264mod tests {
265 use super::*;
266 use crate::config::OutputStrategy;
267
268 fn parse_rust(source: &str) -> ParsedFile {
269 let parser = RustParser::new();
270 parser.parse(source, Path::new("test.rs")).unwrap()
271 }
272
273 #[test]
274 fn test_rust_parser_creates_three_strategies() {
275 let source = "fn main() {\n println!(\"hello\");\n}\n";
276 let parsed = parse_rust(source);
277 assert!(parsed.strategies_data.contains_key(&OutputStrategy::Full));
278 assert!(parsed.strategies_data.contains_key(&OutputStrategy::NoTests));
279 assert!(parsed.strategies_data.contains_key(&OutputStrategy::Summary));
280 }
281
282 #[test]
283 fn test_rust_parser_full_is_verbatim() {
284 let source = "fn main() {\n println!(\"hello\");\n}\n";
285 let parsed = parse_rust(source);
286 assert_eq!(parsed.strategies_data[&OutputStrategy::Full].content, source);
287 }
288
289 #[test]
290 fn test_rust_parser_detects_cfg_test_module() {
291 let source = "pub fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n\n#[cfg(test)]\nmod tests {\n #[test]\n fn test_add() {\n assert_eq!(add(1, 2), 3);\n }\n}\n";
292 let parsed = parse_rust(source);
293 let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
294 assert!(!no_tests.contains("#[cfg(test)]"), "NoTests should remove #[cfg(test)] module");
295 assert!(!no_tests.contains("test_add"), "NoTests should remove test function");
296 assert!(no_tests.contains("pub fn add"), "NoTests should preserve non-test code");
297 }
298
299 #[test]
300 fn test_rust_parser_removes_test_function() {
301 let source = "pub fn helper() -> i32 {\n 42\n}\n\n#[test]\nfn test_helper() {\n assert_eq!(helper(), 42);\n}\n";
302 let parsed = parse_rust(source);
303 let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
304 assert!(no_tests.contains("pub fn helper"), "should preserve helper");
305 assert!(!no_tests.contains("test_helper"), "should remove test function");
306 }
307
308 #[test]
309 fn test_rust_parser_summary_extracts_signatures() {
310 let source = "pub fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n";
311 let parsed = parse_rust(source);
312 let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
313 assert!(summary.contains("pub fn add(a: i32, b: i32) -> i32"), "should preserve signature");
314 assert!(!summary.contains("a + b"), "should remove body");
315 assert!(summary.contains("✂️ implementations omitted"), "should insert marker");
316 }
317
318 #[test]
319 fn test_rust_parser_summary_handles_struct() {
320 let source = "#[derive(Debug)]\npub struct Point {\n x: f64,\n y: f64,\n}\n";
321 let parsed = parse_rust(source);
322 let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
323 assert!(summary.contains("struct Point"), "should contain struct");
324 }
325
326 #[test]
327 fn test_rust_parser_no_tests_fewer_tokens_than_full() {
328 let source = "pub fn lib() -> i32 {\n 42\n}\n\n#[cfg(test)]\nmod tests {\n #[test]\n fn test_lib() {\n assert_eq!(lib(), 42);\n }\n}\n";
329 let parsed = parse_rust(source);
330 let full_tokens = parsed.strategies_data[&OutputStrategy::Full].token_count;
331 let no_tests_tokens = parsed.strategies_data[&OutputStrategy::NoTests].token_count;
332 assert!(
333 no_tests_tokens < full_tokens,
334 "NoTests ({no_tests_tokens}) should have fewer tokens than Full ({full_tokens})"
335 );
336 }
337
338 #[test]
339 fn test_rust_parser_path_stored() {
340 let source = "fn main() {}\n";
341 let parser = RustParser::new();
342 let parsed = parser.parse(source, Path::new("src/main.rs")).unwrap();
343 assert_eq!(parsed.path, Path::new("src/main.rs"));
344 }
345
346 #[test]
347 fn test_rust_parser_language_is_rust() {
348 let source = "fn main() {}\n";
349 let parsed = parse_rust(source);
350 assert_eq!(parsed.language, Language::Rust);
351 }
352
353 #[test]
354 fn test_rust_parser_empty_file() {
355 let source = "";
356 let parsed = parse_rust(source);
357 assert_eq!(parsed.strategies_data[&OutputStrategy::Full].content, "");
358 assert_eq!(parsed.strategies_data[&OutputStrategy::Full].token_count, 0);
359 }
360
361 #[test]
362 fn test_rust_parser_multiple_test_functions() {
363 let source = "pub fn add(a: i32, b: i32) -> i32 { a + b }\npub fn sub(a: i32, b: i32) -> i32 { a - b }\n\n#[test]\nfn test_add() { assert_eq!(add(1, 2), 3); }\n\n#[test]\nfn test_sub() { assert_eq!(sub(3, 1), 2); }\n";
364 let parsed = parse_rust(source);
365 let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
366 assert!(no_tests.contains("pub fn add"), "should preserve add");
367 assert!(no_tests.contains("pub fn sub"), "should preserve sub");
368 assert!(!no_tests.contains("test_add"), "should remove test_add");
369 assert!(!no_tests.contains("test_sub"), "should remove test_sub");
370 }
371
372 #[test]
373 fn test_rust_parser_nested_test_module() {
374 let source = "pub fn helper() {}\n\n#[cfg(test)]\nmod tests {\n use super::*;\n\n #[test]\n fn test_helper() {\n helper();\n }\n}\n";
375 let parsed = parse_rust(source);
376 let no_tests = &parsed.strategies_data[&OutputStrategy::NoTests].content;
377 assert!(no_tests.contains("pub fn helper"));
378 assert!(!no_tests.contains("test_helper"));
379 }
380
381 #[test]
382 fn test_rust_parser_impl_block_summary() {
383 let source = "pub struct Counter {\n count: u32,\n}\n\nimpl Counter {\n pub fn new() -> Self {\n Self { count: 0 }\n }\n\n pub fn increment(&mut self) {\n self.count += 1;\n }\n}\n";
384 let parsed = parse_rust(source);
385 let summary = &parsed.strategies_data[&OutputStrategy::Summary].content;
386 assert!(summary.contains("impl Counter"), "should contain impl");
387 assert!(summary.contains("struct Counter"), "should contain struct");
388 }
389
390 use proptest::prelude::*;
391
392 fn rust_source_strategy() -> impl Strategy<Value = String> {
393 (
394 proptest::collection::vec(proptest::string::string_regex("[a-z_]{1,10}").unwrap(), 1..5),
395 proptest::collection::vec(proptest::string::string_regex("[a-z0-9_ +\\-*/;(){}\n\t]{0,50}").unwrap(), 1..5),
396 proptest::bool::ANY,
397 ).prop_map(|(fn_names, bodies, add_test_module)| {
398 let mut source = String::new();
399 for (i, name) in fn_names.iter().enumerate() {
400 let body = &bodies[i % bodies.len()];
401 source.push_str(&format!("pub fn {name}() {{\n {body}\n}}\n\n"));
402 }
403 if add_test_module {
404 source.push_str("#[cfg(test)]\nmod tests {\n #[test]\n fn test_something() {\n assert!(true);\n }\n}\n");
405 }
406 source
407 })
408 }
409
410 fn strip_markers(text: &str) -> String {
413 let markers = ["// ✂️ test module omitted\n", "// ✂️ implementations omitted"];
414 let mut result = text.to_string();
415 for marker in &markers {
416 result = result.replace(marker, "");
417 }
418 result
419 }
420
421 fn is_subsequence(source: &str, candidate: &str) -> bool {
424 let mut source_iter = source.chars();
425 let mut src_char = source_iter.next();
426 for ch in candidate.chars() {
427 loop {
428 match src_char {
429 Some(s) if s == ch => {
430 src_char = source_iter.next();
431 break;
432 }
433 Some(_) => {
434 src_char = source_iter.next();
435 }
436 None => return false,
437 }
438 }
439 }
440 true
441 }
442
443 proptest! {
444 #[test]
445 fn parser_content_subset_invariant(source in rust_source_strategy()) {
446 let parsed = parse_rust(&source);
447 for strategy in [OutputStrategy::Full, OutputStrategy::NoTests, OutputStrategy::Summary] {
448 if let Some(data) = parsed.strategies_data.get(&strategy) {
449 let stripped = strip_markers(&data.content);
450 prop_assert!(
451 is_subsequence(&source, &stripped),
452 "strategy {strategy}: stripped content is not a subsequence of source.\n\
453 source len={}, stripped len={}",
454 source.len(),
455 stripped.len(),
456 );
457 }
458 }
459 }
460 }
461}