1use std::path::Path;
6
7use regex::Regex;
8use walkdir::WalkDir;
9
10use super::{DocTest, DocTestCorpus};
11use crate::Result;
12
13#[derive(Debug, Clone)]
15struct DefContext {
16 indent: usize,
18 name: String,
20 is_class: bool,
22 signature: Option<String>,
24}
25
26#[derive(Debug)]
28pub struct DocTestParser {
29 docstring_re: Regex,
31 def_re: Regex,
33}
34
35impl Default for DocTestParser {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl DocTestParser {
42 #[must_use]
48 #[allow(clippy::expect_used)]
49 pub fn new() -> Self {
50 Self {
51 docstring_re: Regex::new(r#"(?s)"""(.*?)""""#).expect("valid regex"),
54 def_re: Regex::new(r"(?m)^(\s*)(def|class)\s+(\w+)(\([^)]*\)(?:\s*->\s*[^:]+)?)?")
58 .expect("valid regex"),
59 }
60 }
61
62 #[must_use]
68 pub fn parse_source(&self, source: &str, module: &str) -> Vec<DocTest> {
69 let mut results = Vec::new();
70 let lines: Vec<&str> = source.lines().collect();
71
72 let mut context_map: Vec<Option<DefContext>> = vec![None; lines.len()];
74 let mut context_stack: Vec<DefContext> = Vec::new();
75
76 for (i, line) in lines.iter().enumerate() {
77 if let Some(caps) = self.def_re.captures(line) {
78 let indent = caps.get(1).map_or(0, |m| m.as_str().len());
79 let kind = caps.get(2).map_or("", |m| m.as_str());
80 let name = caps.get(3).map_or("", |m| m.as_str()).to_string();
81 let params = caps.get(4).map(|m| m.as_str().to_string());
82 let is_class = kind == "class";
83
84 let signature = if is_class {
86 None
87 } else {
88 params.map(|p| format!("def {name}{p}"))
89 };
90
91 while context_stack.last().is_some_and(|ctx| ctx.indent >= indent) {
93 context_stack.pop();
94 }
95
96 context_stack.push(DefContext {
97 indent,
98 name,
99 is_class,
100 signature,
101 });
102 }
103
104 context_map[i] = context_stack.last().cloned();
106 }
107
108 for caps in self.docstring_re.captures_iter(source) {
110 let Some(docstring_match) = caps.get(0) else {
111 continue;
112 };
113 let content = caps.get(1).map_or("", |m| m.as_str());
114
115 let start_byte = docstring_match.start();
117 let line_num = source[..start_byte].matches('\n').count();
118
119 let (function_name, signature) =
121 Self::get_function_context(line_num, &context_map, &lines);
122
123 let doctests = Self::extract_from_docstring_with_sig(
125 content,
126 module,
127 &function_name,
128 signature.as_deref(),
129 );
130 results.extend(doctests);
131 }
132
133 results
134 }
135
136 fn get_function_context(
139 line_num: usize,
140 context_map: &[Option<DefContext>],
141 lines: &[&str],
142 ) -> (String, Option<String>) {
143 if line_num < lines.len() {
145 if let Some(ctx) = context_map.get(line_num).and_then(|c| c.clone()) {
147 if !ctx.is_class {
149 for i in (0..line_num).rev() {
151 if let Some(class_ctx) = context_map.get(i).and_then(|c| c.clone()) {
152 if class_ctx.is_class {
153 let full_name = format!("{}.{}", class_ctx.name, ctx.name);
154 return (full_name, ctx.signature);
155 }
156 }
157 }
158 }
159 return (ctx.name.clone(), ctx.signature);
160 }
161 }
162
163 ("__module__".to_string(), None)
165 }
166
167 fn extract_from_docstring_with_sig(
169 content: &str,
170 module: &str,
171 function: &str,
172 signature: Option<&str>,
173 ) -> Vec<DocTest> {
174 let mut results = Vec::new();
175 let lines: Vec<&str> = content.lines().collect();
176 let mut i = 0;
177
178 while i < lines.len() {
179 let line = lines[i].trim();
180
181 if let Some(input_start) = line.strip_prefix(">>>") {
183 let mut input_lines = vec![format!(">>>{}", input_start)];
184 i += 1;
185
186 while i < lines.len() {
188 let next_line = lines[i].trim();
189 if let Some(cont) = next_line.strip_prefix("...") {
190 input_lines.push(format!("...{}", cont));
191 i += 1;
192 } else {
193 break;
194 }
195 }
196
197 let mut expected_lines: Vec<&str> = Vec::new();
199 let base_indent = lines
201 .get(i.saturating_sub(1))
202 .map(|l| l.len() - l.trim_start().len())
203 .unwrap_or(0);
204
205 while i < lines.len() {
206 let next_line = lines[i];
207 let trimmed = next_line.trim();
208
209 if trimmed.starts_with(">>>") {
211 break;
212 }
213
214 if is_prose_continuation(trimmed) {
216 break;
217 }
218
219 if trimmed.is_empty() && !expected_lines.is_empty() {
221 let mut j = i + 1;
223 while j < lines.len() && lines[j].trim().is_empty() {
224 j += 1;
225 }
226 if j >= lines.len() || lines[j].trim().starts_with(">>>") {
227 break;
228 }
229 }
232
233 if !trimmed.is_empty() || !expected_lines.is_empty() {
234 let stripped = if next_line.len() > base_indent {
236 &next_line
237 [base_indent.min(next_line.len() - next_line.trim_start().len())..]
238 } else {
239 trimmed
240 };
241 expected_lines.push(stripped.trim_end());
242 }
243 i += 1;
244 }
245
246 while expected_lines.last().is_some_and(|l| l.is_empty()) {
248 expected_lines.pop();
249 }
250
251 let input = input_lines.join("\n");
252 let expected = expected_lines.join("\n");
253
254 let mut doctest = DocTest::new(module, function, input, expected);
256 if let Some(sig) = signature {
257 doctest = doctest.with_signature(sig);
258 }
259 results.push(doctest);
260 } else {
261 i += 1;
262 }
263 }
264
265 results
266 }
267
268 #[allow(dead_code)]
270 fn extract_from_docstring(content: &str, module: &str, function: &str) -> Vec<DocTest> {
271 Self::extract_from_docstring_with_sig(content, module, function, None)
272 }
273
274 pub fn parse_file(&self, path: &Path, module: &str) -> Result<Vec<DocTest>> {
276 let source = std::fs::read_to_string(path).map_err(|e| crate::Error::Io {
277 path: Some(path.to_path_buf()),
278 source: e,
279 })?;
280 Ok(self.parse_source(&source, module))
281 }
282
283 pub fn parse_directory(
290 &self,
291 dir: &Path,
292 source: &str,
293 version: &str,
294 ) -> Result<DocTestCorpus> {
295 let mut corpus = DocTestCorpus::new(source, version);
296
297 for entry in WalkDir::new(dir)
298 .follow_links(true)
299 .into_iter()
300 .filter_map(|e| e.ok())
301 {
302 let path = entry.path();
303 if path.extension().is_some_and(|ext| ext == "py") {
304 let module = path_to_module(dir, path);
305 let doctests = self.parse_file(path, &module)?;
306 for dt in doctests {
307 corpus.push(dt);
308 }
309 }
310 }
311
312 Ok(corpus)
313 }
314}
315
316fn path_to_module(base: &Path, path: &Path) -> String {
320 let relative = path.strip_prefix(base).unwrap_or(path);
321 let stem = relative.with_extension("");
322 stem.to_string_lossy()
323 .replace(std::path::MAIN_SEPARATOR, ".")
324 .trim_end_matches(".__init__")
325 .to_string()
326}
327
328const DOC_MARKERS: &[&str] = &[
352 ":param", ":return", ":raises", ":type", ":rtype", ":arg", ":args:", ":keyword", ":ivar",
353 ":cvar",
354];
355
356const PROSE_STARTERS: &[&str] = &[
358 "The ",
359 "This ",
360 "Note:",
361 "Warning:",
362 "Example:",
363 "Examples:",
364 "See ",
365 "If ",
366 "When ",
367 "For ",
368 "An ",
369 "A ",
370 "It ",
371 "Returns ",
372 "Raises ",
373 "Args:",
374 "Arguments:",
375 "Parameters:",
376 "By ",
377 "Use ",
378 "Set ",
379 "Get ",
380 "You ",
381 "We ",
382 "They ",
383];
384
385const PYTHON_CONSTANTS: &[&str] = &["True", "False", "None", "Traceback"];
387
388fn is_python_exception_word(word: &str) -> bool {
391 (word.ends_with("Error") || word.ends_with("Exception") || word.ends_with("Warning"))
392 && word.len() > 7
393 && word.chars().filter(|c| c.is_uppercase()).count() >= 2
394}
395
396fn matches_prose_starter(trimmed: &str) -> Option<bool> {
398 if PROSE_STARTERS.iter().any(|s| trimmed.starts_with(s)) {
399 Some(!trimmed.contains(">>>"))
400 } else {
401 None
402 }
403}
404
405fn is_sentence_heuristic(trimmed: &str, first_word: &str) -> bool {
407 let chars: Vec<char> = trimmed.chars().collect();
408 if chars.len() < 2 || !chars[0].is_uppercase() || !chars[1].is_lowercase() {
409 return false;
410 }
411
412 if PYTHON_CONSTANTS.contains(&first_word) {
413 return false;
414 }
415
416 if first_word.chars().all(|c| c.is_alphanumeric() || c == '_')
418 && trimmed.split_whitespace().count() == 1
419 {
420 return false;
421 }
422
423 if trimmed.split_whitespace().count() > 2 {
425 return !looks_like_code_output(trimmed);
426 }
427
428 false
429}
430
431fn looks_like_code_output(trimmed: &str) -> bool {
433 trimmed.contains(">>>")
434 || trimmed.starts_with("...")
435 || trimmed.starts_with('<')
436 || trimmed.starts_with('[')
437 || trimmed.starts_with('{')
438 || trimmed.starts_with('(')
439}
440
441#[must_use]
446pub fn is_prose_continuation(line: &str) -> bool {
447 let trimmed = line.trim();
448 if trimmed.is_empty() {
449 return false;
450 }
451
452 let first_word: &str = trimmed
453 .split(|c: char| c == ':' || c.is_whitespace())
454 .next()
455 .unwrap_or("");
456
457 if is_python_exception_word(first_word) {
459 return false;
460 }
461
462 if DOC_MARKERS.iter().any(|m| trimmed.starts_with(m)) {
464 return true;
465 }
466
467 if let Some(is_prose) = matches_prose_starter(trimmed) {
469 return is_prose;
470 }
471
472 if trimmed.starts_with(".. ") || trimmed.starts_with(">>>") {
474 return false;
475 }
476
477 is_sentence_heuristic(trimmed, first_word)
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
488 fn test_prose_detection_sentence() {
489 assert!(is_prose_continuation("The stdout argument is not allowed."));
491 assert!(is_prose_continuation("This function returns a value."));
492 assert!(is_prose_continuation("Note: This is important."));
493 assert!(is_prose_continuation("Warning: Use with caution."));
494 }
495
496 #[test]
497 fn test_prose_detection_docstring_markers() {
498 assert!(is_prose_continuation(":param x: the input value"));
500 assert!(is_prose_continuation(":return: the computed result"));
501 assert!(is_prose_continuation(":raises ValueError: if invalid"));
502 assert!(is_prose_continuation(":type x: int"));
503 }
504
505 #[test]
506 fn test_prose_detection_common_starters() {
507 assert!(is_prose_continuation("If you use this argument..."));
509 assert!(is_prose_continuation("When the value is negative..."));
510 assert!(is_prose_continuation("For more information..."));
511 assert!(is_prose_continuation("Returns the computed value."));
512 }
513
514 #[test]
515 fn test_prose_detection_false_negatives() {
516 assert!(!is_prose_continuation("True"));
518 assert!(!is_prose_continuation("False"));
519 assert!(!is_prose_continuation("None"));
520 assert!(!is_prose_continuation("123"));
521 assert!(!is_prose_continuation("'hello world'"));
522 assert!(!is_prose_continuation("b'bytes'"));
523 assert!(!is_prose_continuation("[1, 2, 3]"));
524 assert!(!is_prose_continuation("{'key': 'value'}"));
525 assert!(!is_prose_continuation("(0, '/bin/ls')"));
526 assert!(!is_prose_continuation("Point(x=11, y=22)"));
527 assert!(!is_prose_continuation(""));
528 }
529
530 #[test]
531 fn test_prose_detection_edge_cases() {
532 assert!(!is_prose_continuation("ValueError"));
534 assert!(!is_prose_continuation("MyClass"));
535 assert!(!is_prose_continuation("Traceback (most recent call last):"));
537 assert!(!is_prose_continuation(""));
539 assert!(!is_prose_continuation(" "));
540 }
541
542 #[test]
543 fn test_extract_with_prose_contamination() {
544 let parser = DocTestParser::new();
546 let source = r#"
547def check_output():
548 """
549 >>> check_output(["ls", "-l"])
550 b'output\n'
551
552 The stdout argument is not allowed as it is used internally.
553 To capture standard error, use stderr=STDOUT.
554
555 >>> check_output(["echo", "hi"])
556 b'hi\n'
557 """
558 pass
559"#;
560 let doctests = parser.parse_source(source, "test");
561 assert_eq!(doctests.len(), 2);
562 assert_eq!(doctests[0].expected, "b'output\\n'");
564 assert!(!doctests[0].expected.contains("stdout argument"));
565 assert_eq!(doctests[1].expected, "b'hi\\n'");
567 }
568
569 #[test]
572 fn test_path_to_module_simple() {
573 let base = Path::new("/lib");
574 let path = Path::new("/lib/os.py");
575 assert_eq!(path_to_module(base, path), "os");
576 }
577
578 #[test]
579 fn test_path_to_module_nested() {
580 let base = Path::new("/lib");
581 let path = Path::new("/lib/os/path.py");
582 assert_eq!(path_to_module(base, path), "os.path");
583 }
584
585 #[test]
586 fn test_path_to_module_init() {
587 let base = Path::new("/lib");
588 let path = Path::new("/lib/collections/__init__.py");
589 assert_eq!(path_to_module(base, path), "collections");
590 }
591
592 #[test]
593 fn test_extract_simple() {
594 let parser = DocTestParser::new();
595 let source = r#"
596def foo():
597 """
598 >>> 1 + 1
599 2
600 """
601 pass
602"#;
603 let doctests = parser.parse_source(source, "test");
604 assert_eq!(doctests.len(), 1);
605 assert_eq!(doctests[0].input, ">>> 1 + 1");
606 assert_eq!(doctests[0].expected, "2");
607 }
608
609 #[test]
610 fn test_extract_multiline_input() {
611 let parser = DocTestParser::new();
612 let source = r#"
613def foo():
614 """
615 >>> x = (
616 ... 1 + 2
617 ... )
618 >>> x
619 3
620 """
621 pass
622"#;
623 let doctests = parser.parse_source(source, "test");
624 assert_eq!(doctests.len(), 2);
625 assert_eq!(doctests[0].input, ">>> x = (\n... 1 + 2\n... )");
626 assert_eq!(doctests[0].expected, "");
627 }
628
629 #[test]
630 fn test_extract_signature() {
631 let parser = DocTestParser::new();
632 let source = r#"
633def add(a: int, b: int) -> int:
634 """Add two numbers.
635
636 >>> add(1, 2)
637 3
638 """
639 return a + b
640"#;
641 let doctests = parser.parse_source(source, "math");
642 assert_eq!(doctests.len(), 1);
643 assert_eq!(doctests[0].function, "add");
644 assert_eq!(
645 doctests[0].signature,
646 Some("def add(a: int, b: int) -> int".to_string())
647 );
648 }
649
650 #[test]
651 fn test_extract_signature_no_return_type() {
652 let parser = DocTestParser::new();
653 let source = r#"
654def greet(name: str):
655 """Greet someone.
656
657 >>> greet("world")
658 'Hello, world!'
659 """
660 return f"Hello, {name}!"
661"#;
662 let doctests = parser.parse_source(source, "hello");
663 assert_eq!(doctests.len(), 1);
664 assert_eq!(
665 doctests[0].signature,
666 Some("def greet(name: str)".to_string())
667 );
668 }
669
670 #[test]
671 fn test_module_doctest_no_signature() {
672 let parser = DocTestParser::new();
673 let source = r#"
674"""Module docstring.
675
676>>> 1 + 1
6772
678"""
679"#;
680 let doctests = parser.parse_source(source, "mymodule");
681 assert_eq!(doctests.len(), 1);
682 assert_eq!(doctests[0].function, "__module__");
683 assert!(doctests[0].signature.is_none());
684 }
685
686 use proptest::prelude::*;
689
690 proptest! {
691 #![proptest_config(ProptestConfig::with_cases(50))]
692
693 #[test]
694 fn prop_empty_never_prose(s in "\\s*") {
695 assert!(!is_prose_continuation(&s));
697 }
698
699 #[test]
700 fn prop_python_literals_never_prose(literal in prop_oneof![
701 Just("True"),
702 Just("False"),
703 Just("None"),
704 ]) {
705 assert!(!is_prose_continuation(literal));
706 }
707
708 #[test]
709 fn prop_exception_lines_never_prose(exc in prop_oneof![
710 Just("ValueError: invalid input"),
711 Just("TypeError: expected str"),
712 Just("ZeroDivisionError: division by zero"),
713 Just("KeyError: 'missing'"),
714 Just("IndexError: out of range"),
715 Just("RuntimeError: something went wrong"),
716 ]) {
717 assert!(!is_prose_continuation(exc), "Exception detected as prose: {}", exc);
718 }
719
720 #[test]
721 fn prop_docstring_markers_are_prose(marker in prop_oneof![
722 Just(":param x: value"),
723 Just(":return: result"),
724 Just(":raises ValueError: msg"),
725 Just(":type x: int"),
726 ]) {
727 assert!(is_prose_continuation(marker));
728 }
729
730 #[test]
731 fn prop_code_output_preserved(output in prop_oneof![
732 Just("[1, 2, 3]"),
733 Just("{'a': 1}"),
734 Just("(1, 2)"),
735 Just("<object at 0x...>"),
736 Just("123"),
737 Just("'string'"),
738 ]) {
739 assert!(!is_prose_continuation(output));
740 }
741
742 #[test]
743 fn prop_deterministic(s in ".*") {
744 let r1 = is_prose_continuation(&s);
746 let r2 = is_prose_continuation(&s);
747 assert_eq!(r1, r2);
748 }
749 }
750}