1use std::path::Path;
6
7use regex::Regex;
8use walkdir::WalkDir;
9
10use super::{DocTest, DocTestCorpus};
11use crate::Result;
12
13#[derive(Debug, Clone)]
15struct DefContext {
16 indent: usize,
18 name: String,
20 is_class: bool,
22 signature: Option<String>,
24}
25
26#[derive(Debug)]
28pub struct DocTestParser {
29 docstring_re: Regex,
31 def_re: Regex,
33}
34
35impl Default for DocTestParser {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl DocTestParser {
42 #[must_use]
48 #[allow(clippy::expect_used)]
49 pub fn new() -> Self {
50 Self {
51 docstring_re: Regex::new(r#"(?s)"""(.*?)""""#).expect("valid regex"),
54 def_re: Regex::new(r"(?m)^(\s*)(def|class)\s+(\w+)(\([^)]*\)(?:\s*->\s*[^:]+)?)?")
58 .expect("valid regex"),
59 }
60 }
61
62 #[must_use]
68 pub fn parse_source(&self, source: &str, module: &str) -> Vec<DocTest> {
69 let mut results = Vec::new();
70 let lines: Vec<&str> = source.lines().collect();
71
72 let mut context_map: Vec<Option<DefContext>> = vec![None; lines.len()];
74 let mut context_stack: Vec<DefContext> = Vec::new();
75
76 for (i, line) in lines.iter().enumerate() {
77 if let Some(caps) = self.def_re.captures(line) {
78 let indent = caps.get(1).map_or(0, |m| m.as_str().len());
79 let kind = caps.get(2).map_or("", |m| m.as_str());
80 let name = caps.get(3).map_or("", |m| m.as_str()).to_string();
81 let params = caps.get(4).map(|m| m.as_str().to_string());
82 let is_class = kind == "class";
83
84 let signature = if is_class {
86 None
87 } else {
88 params.map(|p| format!("def {name}{p}"))
89 };
90
91 while context_stack.last().is_some_and(|ctx| ctx.indent >= indent) {
93 context_stack.pop();
94 }
95
96 context_stack.push(DefContext {
97 indent,
98 name,
99 is_class,
100 signature,
101 });
102 }
103
104 context_map[i] = context_stack.last().cloned();
106 }
107
108 for caps in self.docstring_re.captures_iter(source) {
110 let Some(docstring_match) = caps.get(0) else {
111 continue;
112 };
113 let content = caps.get(1).map_or("", |m| m.as_str());
114
115 let start_byte = docstring_match.start();
117 let line_num = source[..start_byte].matches('\n').count();
118
119 let (function_name, signature) =
121 Self::get_function_context(line_num, &context_map, &lines);
122
123 let doctests = Self::extract_from_docstring_with_sig(
125 content,
126 module,
127 &function_name,
128 signature.as_deref(),
129 );
130 results.extend(doctests);
131 }
132
133 results
134 }
135
136 fn get_function_context(
139 line_num: usize,
140 context_map: &[Option<DefContext>],
141 lines: &[&str],
142 ) -> (String, Option<String>) {
143 if line_num < lines.len() {
145 if let Some(ctx) = context_map.get(line_num).and_then(|c| c.clone()) {
147 if !ctx.is_class {
149 for i in (0..line_num).rev() {
151 if let Some(class_ctx) = context_map.get(i).and_then(|c| c.clone()) {
152 if class_ctx.is_class {
153 let full_name = format!("{}.{}", class_ctx.name, ctx.name);
154 return (full_name, ctx.signature);
155 }
156 }
157 }
158 }
159 return (ctx.name.clone(), ctx.signature);
160 }
161 }
162
163 ("__module__".to_string(), None)
165 }
166
167 fn extract_from_docstring_with_sig(
169 content: &str,
170 module: &str,
171 function: &str,
172 signature: Option<&str>,
173 ) -> Vec<DocTest> {
174 let mut results = Vec::new();
175 let lines: Vec<&str> = content.lines().collect();
176 let mut i = 0;
177
178 while i < lines.len() {
179 let line = lines[i].trim();
180
181 if let Some(input_start) = line.strip_prefix(">>>") {
183 let mut input_lines = vec![format!(">>>{}", input_start)];
184 i += 1;
185
186 while i < lines.len() {
188 let next_line = lines[i].trim();
189 if let Some(cont) = next_line.strip_prefix("...") {
190 input_lines.push(format!("...{}", cont));
191 i += 1;
192 } else {
193 break;
194 }
195 }
196
197 let mut expected_lines: Vec<&str> = Vec::new();
199 let base_indent = lines
201 .get(i.saturating_sub(1))
202 .map(|l| l.len() - l.trim_start().len())
203 .unwrap_or(0);
204
205 while i < lines.len() {
206 let next_line = lines[i];
207 let trimmed = next_line.trim();
208
209 if trimmed.starts_with(">>>") {
211 break;
212 }
213
214 if is_prose_continuation(trimmed) {
216 break;
217 }
218
219 if trimmed.is_empty() && !expected_lines.is_empty() {
221 let mut j = i + 1;
223 while j < lines.len() && lines[j].trim().is_empty() {
224 j += 1;
225 }
226 if j >= lines.len() || lines[j].trim().starts_with(">>>") {
227 break;
228 }
229 }
232
233 if !trimmed.is_empty() || !expected_lines.is_empty() {
234 let stripped = if next_line.len() > base_indent {
236 &next_line
237 [base_indent.min(next_line.len() - next_line.trim_start().len())..]
238 } else {
239 trimmed
240 };
241 expected_lines.push(stripped.trim_end());
242 }
243 i += 1;
244 }
245
246 while expected_lines.last().is_some_and(|l| l.is_empty()) {
248 expected_lines.pop();
249 }
250
251 let input = input_lines.join("\n");
252 let expected = expected_lines.join("\n");
253
254 let mut doctest = DocTest::new(module, function, input, expected);
256 if let Some(sig) = signature {
257 doctest = doctest.with_signature(sig);
258 }
259 results.push(doctest);
260 } else {
261 i += 1;
262 }
263 }
264
265 results
266 }
267
268 #[allow(dead_code)]
270 fn extract_from_docstring(content: &str, module: &str, function: &str) -> Vec<DocTest> {
271 Self::extract_from_docstring_with_sig(content, module, function, None)
272 }
273
274 pub fn parse_file(&self, path: &Path, module: &str) -> Result<Vec<DocTest>> {
276 let source = std::fs::read_to_string(path).map_err(|e| crate::Error::Io {
277 path: Some(path.to_path_buf()),
278 source: e,
279 })?;
280 Ok(self.parse_source(&source, module))
281 }
282
283 pub fn parse_directory(
290 &self,
291 dir: &Path,
292 source: &str,
293 version: &str,
294 ) -> Result<DocTestCorpus> {
295 let mut corpus = DocTestCorpus::new(source, version);
296
297 for entry in WalkDir::new(dir)
298 .follow_links(true)
299 .into_iter()
300 .filter_map(|e| e.ok())
301 {
302 let path = entry.path();
303 if path.extension().is_some_and(|ext| ext == "py") {
304 let module = path_to_module(dir, path);
305 let doctests = self.parse_file(path, &module)?;
306 for dt in doctests {
307 corpus.push(dt);
308 }
309 }
310 }
311
312 Ok(corpus)
313 }
314}
315
316fn path_to_module(base: &Path, path: &Path) -> String {
320 let relative = path.strip_prefix(base).unwrap_or(path);
321 let stem = relative.with_extension("");
322 stem.to_string_lossy()
323 .replace(std::path::MAIN_SEPARATOR, ".")
324 .trim_end_matches(".__init__")
325 .to_string()
326}
327
328const DOC_MARKERS: &[&str] = &[
352 ":param", ":return", ":raises", ":type", ":rtype", ":arg", ":args:", ":keyword", ":ivar",
353 ":cvar",
354];
355
356const PROSE_STARTERS: &[&str] = &[
358 "The ",
359 "This ",
360 "Note:",
361 "Warning:",
362 "Example:",
363 "Examples:",
364 "See ",
365 "If ",
366 "When ",
367 "For ",
368 "An ",
369 "A ",
370 "It ",
371 "Returns ",
372 "Raises ",
373 "Args:",
374 "Arguments:",
375 "Parameters:",
376 "By ",
377 "Use ",
378 "Set ",
379 "Get ",
380 "You ",
381 "We ",
382 "They ",
383];
384
385const PYTHON_CONSTANTS: &[&str] = &["True", "False", "None", "Traceback"];
387
388fn is_python_exception_word(word: &str) -> bool {
391 (word.ends_with("Error") || word.ends_with("Exception") || word.ends_with("Warning"))
392 && word.len() > 7
393 && word.chars().filter(|c| c.is_uppercase()).count() >= 2
394}
395
396fn matches_prose_starter(trimmed: &str) -> Option<bool> {
398 if PROSE_STARTERS.iter().any(|s| trimmed.starts_with(s)) {
399 Some(!trimmed.contains(">>>"))
400 } else {
401 None
402 }
403}
404
405fn is_sentence_heuristic(trimmed: &str, first_word: &str) -> bool {
407 let chars: Vec<char> = trimmed.chars().collect();
408 if chars.len() < 2 || !chars[0].is_uppercase() || !chars[1].is_lowercase() {
409 return false;
410 }
411
412 if PYTHON_CONSTANTS.contains(&first_word) {
413 return false;
414 }
415
416 if first_word.chars().all(|c| c.is_alphanumeric() || c == '_')
418 && trimmed.split_whitespace().count() == 1
419 {
420 return false;
421 }
422
423 if trimmed.split_whitespace().count() > 2 {
425 return !looks_like_code_output(trimmed);
426 }
427
428 false
429}
430
431fn looks_like_code_output(trimmed: &str) -> bool {
433 trimmed.contains(">>>")
434 || trimmed.starts_with("...")
435 || trimmed.starts_with('<')
436 || trimmed.starts_with('[')
437 || trimmed.starts_with('{')
438 || trimmed.starts_with('(')
439}
440
441#[must_use]
442pub fn is_prose_continuation(line: &str) -> bool {
443 let trimmed = line.trim();
444 if trimmed.is_empty() {
445 return false;
446 }
447
448 let first_word: &str = trimmed
449 .split(|c: char| c == ':' || c.is_whitespace())
450 .next()
451 .unwrap_or("");
452
453 if is_python_exception_word(first_word) {
455 return false;
456 }
457
458 if DOC_MARKERS.iter().any(|m| trimmed.starts_with(m)) {
460 return true;
461 }
462
463 if let Some(is_prose) = matches_prose_starter(trimmed) {
465 return is_prose;
466 }
467
468 if trimmed.starts_with(".. ") || trimmed.starts_with(">>>") {
470 return false;
471 }
472
473 is_sentence_heuristic(trimmed, first_word)
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
484 fn test_prose_detection_sentence() {
485 assert!(is_prose_continuation("The stdout argument is not allowed."));
487 assert!(is_prose_continuation("This function returns a value."));
488 assert!(is_prose_continuation("Note: This is important."));
489 assert!(is_prose_continuation("Warning: Use with caution."));
490 }
491
492 #[test]
493 fn test_prose_detection_docstring_markers() {
494 assert!(is_prose_continuation(":param x: the input value"));
496 assert!(is_prose_continuation(":return: the computed result"));
497 assert!(is_prose_continuation(":raises ValueError: if invalid"));
498 assert!(is_prose_continuation(":type x: int"));
499 }
500
501 #[test]
502 fn test_prose_detection_common_starters() {
503 assert!(is_prose_continuation("If you use this argument..."));
505 assert!(is_prose_continuation("When the value is negative..."));
506 assert!(is_prose_continuation("For more information..."));
507 assert!(is_prose_continuation("Returns the computed value."));
508 }
509
510 #[test]
511 fn test_prose_detection_false_negatives() {
512 assert!(!is_prose_continuation("True"));
514 assert!(!is_prose_continuation("False"));
515 assert!(!is_prose_continuation("None"));
516 assert!(!is_prose_continuation("123"));
517 assert!(!is_prose_continuation("'hello world'"));
518 assert!(!is_prose_continuation("b'bytes'"));
519 assert!(!is_prose_continuation("[1, 2, 3]"));
520 assert!(!is_prose_continuation("{'key': 'value'}"));
521 assert!(!is_prose_continuation("(0, '/bin/ls')"));
522 assert!(!is_prose_continuation("Point(x=11, y=22)"));
523 assert!(!is_prose_continuation(""));
524 }
525
526 #[test]
527 fn test_prose_detection_edge_cases() {
528 assert!(!is_prose_continuation("ValueError"));
530 assert!(!is_prose_continuation("MyClass"));
531 assert!(!is_prose_continuation("Traceback (most recent call last):"));
533 assert!(!is_prose_continuation(""));
535 assert!(!is_prose_continuation(" "));
536 }
537
538 #[test]
539 fn test_extract_with_prose_contamination() {
540 let parser = DocTestParser::new();
542 let source = r#"
543def check_output():
544 """
545 >>> check_output(["ls", "-l"])
546 b'output\n'
547
548 The stdout argument is not allowed as it is used internally.
549 To capture standard error, use stderr=STDOUT.
550
551 >>> check_output(["echo", "hi"])
552 b'hi\n'
553 """
554 pass
555"#;
556 let doctests = parser.parse_source(source, "test");
557 assert_eq!(doctests.len(), 2);
558 assert_eq!(doctests[0].expected, "b'output\\n'");
560 assert!(!doctests[0].expected.contains("stdout argument"));
561 assert_eq!(doctests[1].expected, "b'hi\\n'");
563 }
564
565 #[test]
568 fn test_path_to_module_simple() {
569 let base = Path::new("/lib");
570 let path = Path::new("/lib/os.py");
571 assert_eq!(path_to_module(base, path), "os");
572 }
573
574 #[test]
575 fn test_path_to_module_nested() {
576 let base = Path::new("/lib");
577 let path = Path::new("/lib/os/path.py");
578 assert_eq!(path_to_module(base, path), "os.path");
579 }
580
581 #[test]
582 fn test_path_to_module_init() {
583 let base = Path::new("/lib");
584 let path = Path::new("/lib/collections/__init__.py");
585 assert_eq!(path_to_module(base, path), "collections");
586 }
587
588 #[test]
589 fn test_extract_simple() {
590 let parser = DocTestParser::new();
591 let source = r#"
592def foo():
593 """
594 >>> 1 + 1
595 2
596 """
597 pass
598"#;
599 let doctests = parser.parse_source(source, "test");
600 assert_eq!(doctests.len(), 1);
601 assert_eq!(doctests[0].input, ">>> 1 + 1");
602 assert_eq!(doctests[0].expected, "2");
603 }
604
605 #[test]
606 fn test_extract_multiline_input() {
607 let parser = DocTestParser::new();
608 let source = r#"
609def foo():
610 """
611 >>> x = (
612 ... 1 + 2
613 ... )
614 >>> x
615 3
616 """
617 pass
618"#;
619 let doctests = parser.parse_source(source, "test");
620 assert_eq!(doctests.len(), 2);
621 assert_eq!(doctests[0].input, ">>> x = (\n... 1 + 2\n... )");
622 assert_eq!(doctests[0].expected, "");
623 }
624
625 #[test]
626 fn test_extract_signature() {
627 let parser = DocTestParser::new();
628 let source = r#"
629def add(a: int, b: int) -> int:
630 """Add two numbers.
631
632 >>> add(1, 2)
633 3
634 """
635 return a + b
636"#;
637 let doctests = parser.parse_source(source, "math");
638 assert_eq!(doctests.len(), 1);
639 assert_eq!(doctests[0].function, "add");
640 assert_eq!(
641 doctests[0].signature,
642 Some("def add(a: int, b: int) -> int".to_string())
643 );
644 }
645
646 #[test]
647 fn test_extract_signature_no_return_type() {
648 let parser = DocTestParser::new();
649 let source = r#"
650def greet(name: str):
651 """Greet someone.
652
653 >>> greet("world")
654 'Hello, world!'
655 """
656 return f"Hello, {name}!"
657"#;
658 let doctests = parser.parse_source(source, "hello");
659 assert_eq!(doctests.len(), 1);
660 assert_eq!(
661 doctests[0].signature,
662 Some("def greet(name: str)".to_string())
663 );
664 }
665
666 #[test]
667 fn test_module_doctest_no_signature() {
668 let parser = DocTestParser::new();
669 let source = r#"
670"""Module docstring.
671
672>>> 1 + 1
6732
674"""
675"#;
676 let doctests = parser.parse_source(source, "mymodule");
677 assert_eq!(doctests.len(), 1);
678 assert_eq!(doctests[0].function, "__module__");
679 assert!(doctests[0].signature.is_none());
680 }
681
682 use proptest::prelude::*;
685
686 proptest! {
687 #![proptest_config(ProptestConfig::with_cases(50))]
688
689 #[test]
690 fn prop_empty_never_prose(s in "\\s*") {
691 assert!(!is_prose_continuation(&s));
693 }
694
695 #[test]
696 fn prop_python_literals_never_prose(literal in prop_oneof![
697 Just("True"),
698 Just("False"),
699 Just("None"),
700 ]) {
701 assert!(!is_prose_continuation(literal));
702 }
703
704 #[test]
705 fn prop_exception_lines_never_prose(exc in prop_oneof![
706 Just("ValueError: invalid input"),
707 Just("TypeError: expected str"),
708 Just("ZeroDivisionError: division by zero"),
709 Just("KeyError: 'missing'"),
710 Just("IndexError: out of range"),
711 Just("RuntimeError: something went wrong"),
712 ]) {
713 assert!(!is_prose_continuation(exc), "Exception detected as prose: {}", exc);
714 }
715
716 #[test]
717 fn prop_docstring_markers_are_prose(marker in prop_oneof![
718 Just(":param x: value"),
719 Just(":return: result"),
720 Just(":raises ValueError: msg"),
721 Just(":type x: int"),
722 ]) {
723 assert!(is_prose_continuation(marker));
724 }
725
726 #[test]
727 fn prop_code_output_preserved(output in prop_oneof![
728 Just("[1, 2, 3]"),
729 Just("{'a': 1}"),
730 Just("(1, 2)"),
731 Just("<object at 0x...>"),
732 Just("123"),
733 Just("'string'"),
734 ]) {
735 assert!(!is_prose_continuation(output));
736 }
737
738 #[test]
739 fn prop_deterministic(s in ".*") {
740 let r1 = is_prose_continuation(&s);
742 let r2 = is_prose_continuation(&s);
743 assert_eq!(r1, r2);
744 }
745 }
746}