1use crate::doctest_extractor::Doctest;
31use anyhow::Result;
32use serde::{Deserialize, Serialize};
33
34#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct PytestResult {
37 pub source: String,
39 pub assertions: Vec<Doctest>,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct PytestExtractor {
46 pub strict_test_files: bool,
48}
49
50impl PytestExtractor {
51 pub fn new() -> Self {
53 Self {
54 strict_test_files: true,
55 }
56 }
57
58 pub fn with_strict_test_files(mut self, strict: bool) -> Self {
60 self.strict_test_files = strict;
61 self
62 }
63
64 pub fn extract(&self, source: &str) -> Result<Vec<Doctest>> {
66 let mut assertions = Vec::new();
67 let lines: Vec<&str> = source.lines().collect();
68
69 let mut current_function: Option<String> = None;
70
71 for (line_num, line) in lines.iter().enumerate() {
72 let trimmed = line.trim();
73
74 if trimmed.starts_with("def test_") {
76 if let Some(name) = Self::extract_function_name(trimmed) {
77 current_function = Some(name);
78 }
79 } else if trimmed.starts_with("def ") && !trimmed.starts_with("def test_") {
80 current_function = None;
82 }
83
84 if trimmed.starts_with("assert ") {
86 if let Some(doctest) = self.parse_assert(trimmed, line_num + 1, ¤t_function) {
87 assertions.push(doctest);
88 }
89 }
90 }
91
92 Ok(assertions)
93 }
94
95 fn extract_function_name(line: &str) -> Option<String> {
97 let after_def = line.strip_prefix("def ")?.trim();
98 let paren_idx = after_def.find('(')?;
99 Some(after_def[..paren_idx].to_string())
100 }
101
102 fn parse_assert(
104 &self,
105 line: &str,
106 line_num: usize,
107 _current_function: &Option<String>,
108 ) -> Option<Doctest> {
109 let assertion = line.strip_prefix("assert ")?.trim();
111
112 if self.is_complex_assertion(assertion) {
114 return None;
115 }
116
117 let eq_idx = assertion.find(" == ")?;
119 let left = assertion[..eq_idx].trim();
120 let right = assertion[eq_idx + 4..].trim();
121
122 if !left.contains('(') || !left.contains(')') {
124 return None;
125 }
126
127 let func_name = self.extract_called_function(left)?;
129
130 let expected = self.clean_expected(right);
132
133 Some(Doctest {
134 function: func_name,
135 input: left.to_string(),
136 expected,
137 line: line_num,
138 })
139 }
140
141 fn is_complex_assertion(&self, assertion: &str) -> bool {
143 if assertion.contains("pytest.") {
145 return true;
146 }
147
148 if assertion.contains("approx(") {
150 return true;
151 }
152
153 if assertion.contains(" in ") && !assertion.contains(" == ") {
155 return true;
156 }
157
158 if assertion.contains(" is ") && !assertion.contains(" == ") {
160 return true;
161 }
162
163 if assertion.starts_with("not ") {
165 return true;
166 }
167
168 if assertion.contains(" and ") || assertion.contains(" or ") {
170 return true;
171 }
172
173 if assertion.contains("lambda") {
175 return true;
176 }
177
178 if assertion.contains("isinstance(") || assertion.contains("type(") {
180 return true;
181 }
182
183 false
184 }
185
186 fn extract_called_function(&self, call_expr: &str) -> Option<String> {
188 let paren_idx = call_expr.find('(')?;
189 let func_part = &call_expr[..paren_idx];
190
191 if let Some(dot_idx) = func_part.rfind('.') {
193 Some(func_part[dot_idx + 1..].to_string())
194 } else {
195 Some(func_part.to_string())
196 }
197 }
198
199 fn clean_expected(&self, expected: &str) -> String {
201 let mut result = expected.to_string();
202
203 if let Some(hash_idx) = result.find('#') {
205 result = result[..hash_idx].trim().to_string();
206 }
207
208 result = result.trim_end_matches(',').trim().to_string();
210
211 result
212 }
213
214 pub fn extract_to_result(&self, source: &str, filename: &str) -> Result<PytestResult> {
216 let assertions = self.extract(source)?;
217 Ok(PytestResult {
218 source: filename.to_string(),
219 assertions,
220 })
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
233 fn test_extract_simple_assert_eq() {
234 let source = r#"
235def test_square():
236 assert square(4) == 16
237"#;
238
239 let extractor = PytestExtractor::new();
240 let assertions = extractor.extract(source).unwrap();
241
242 assert_eq!(assertions.len(), 1);
243 assert_eq!(assertions[0].function, "square");
244 assert_eq!(assertions[0].input, "square(4)");
245 assert_eq!(assertions[0].expected, "16");
246 }
247
248 #[test]
249 fn test_extract_multiple_assertions() {
250 let source = r#"
251def test_square():
252 assert square(4) == 16
253 assert square(-3) == 9
254 assert square(0) == 0
255"#;
256
257 let extractor = PytestExtractor::new();
258 let assertions = extractor.extract(source).unwrap();
259
260 assert_eq!(assertions.len(), 3);
261 assert_eq!(assertions[0].expected, "16");
262 assert_eq!(assertions[1].expected, "9");
263 assert_eq!(assertions[2].expected, "0");
264 }
265
266 #[test]
267 fn test_extract_multiple_args() {
268 let source = r#"
269def test_add():
270 assert add(1, 2) == 3
271 assert add(-1, 1) == 0
272"#;
273
274 let extractor = PytestExtractor::new();
275 let assertions = extractor.extract(source).unwrap();
276
277 assert_eq!(assertions.len(), 2);
278 assert_eq!(assertions[0].input, "add(1, 2)");
279 assert_eq!(assertions[0].expected, "3");
280 }
281
282 #[test]
283 fn test_extract_string_expected() {
284 let source = r#"
285def test_greet():
286 assert greet("World") == "Hello, World!"
287"#;
288
289 let extractor = PytestExtractor::new();
290 let assertions = extractor.extract(source).unwrap();
291
292 assert_eq!(assertions.len(), 1);
293 assert_eq!(assertions[0].expected, "\"Hello, World!\"");
294 }
295
296 #[test]
297 fn test_extract_list_expected() {
298 let source = r#"
299def test_range_list():
300 assert range_list(3) == [0, 1, 2]
301"#;
302
303 let extractor = PytestExtractor::new();
304 let assertions = extractor.extract(source).unwrap();
305
306 assert_eq!(assertions.len(), 1);
307 assert_eq!(assertions[0].expected, "[0, 1, 2]");
308 }
309
310 #[test]
311 fn test_extract_dict_expected() {
312 let source = r#"
313def test_make_dict():
314 assert make_dict("a", 1) == {"a": 1}
315"#;
316
317 let extractor = PytestExtractor::new();
318 let assertions = extractor.extract(source).unwrap();
319
320 assert_eq!(assertions.len(), 1);
321 assert_eq!(assertions[0].expected, "{\"a\": 1}");
322 }
323
324 #[test]
325 fn test_extract_boolean_expected() {
326 let source = r#"
327def test_is_even():
328 assert is_even(4) == True
329 assert is_even(3) == False
330"#;
331
332 let extractor = PytestExtractor::new();
333 let assertions = extractor.extract(source).unwrap();
334
335 assert_eq!(assertions.len(), 2);
336 assert_eq!(assertions[0].expected, "True");
337 assert_eq!(assertions[1].expected, "False");
338 }
339
340 #[test]
341 fn test_skip_pytest_raises() {
342 let source = r#"
343def test_error():
344 with pytest.raises(ValueError):
345 divide(1, 0)
346 assert divide(10, 2) == 5
347"#;
348
349 let extractor = PytestExtractor::new();
350 let assertions = extractor.extract(source).unwrap();
351
352 assert_eq!(assertions.len(), 1);
354 assert_eq!(assertions[0].input, "divide(10, 2)");
355 }
356
357 #[test]
358 fn test_skip_pytest_approx() {
359 let source = r#"
360def test_float():
361 assert divide(10, 3) == pytest.approx(3.333, rel=0.01)
362 assert multiply(2, 3) == 6
363"#;
364
365 let extractor = PytestExtractor::new();
366 let assertions = extractor.extract(source).unwrap();
367
368 assert_eq!(assertions.len(), 1);
370 assert_eq!(assertions[0].input, "multiply(2, 3)");
371 }
372
373 #[test]
374 fn test_skip_complex_and_or() {
375 let source = r#"
376def test_complex():
377 assert foo(1) == 1 and bar(2) == 2
378 assert simple(3) == 3
379"#;
380
381 let extractor = PytestExtractor::new();
382 let assertions = extractor.extract(source).unwrap();
383
384 assert_eq!(assertions.len(), 1);
386 assert_eq!(assertions[0].input, "simple(3)");
387 }
388
389 #[test]
390 fn test_skip_isinstance() {
391 let source = r#"
392def test_types():
393 assert isinstance(foo(), int)
394 assert bar() == 42
395"#;
396
397 let extractor = PytestExtractor::new();
398 let assertions = extractor.extract(source).unwrap();
399
400 assert_eq!(assertions.len(), 1);
401 assert_eq!(assertions[0].input, "bar()");
402 }
403
404 #[test]
405 fn test_skip_in_operator() {
406 let source = r#"
407def test_membership():
408 assert 1 in get_list()
409 assert get_first() == 1
410"#;
411
412 let extractor = PytestExtractor::new();
413 let assertions = extractor.extract(source).unwrap();
414
415 assert_eq!(assertions.len(), 1);
416 assert_eq!(assertions[0].input, "get_first()");
417 }
418
419 #[test]
420 fn test_method_call() {
421 let source = r#"
422def test_method():
423 obj = MyClass()
424 assert obj.compute(5) == 25
425"#;
426
427 let extractor = PytestExtractor::new();
428 let assertions = extractor.extract(source).unwrap();
429
430 assert_eq!(assertions.len(), 1);
431 assert_eq!(assertions[0].function, "compute");
432 assert_eq!(assertions[0].input, "obj.compute(5)");
433 }
434
435 #[test]
436 fn test_line_numbers() {
437 let source = r#"
438def test_foo():
439 x = 1
440 assert foo(1) == 1
441 y = 2
442 assert foo(2) == 4
443"#;
444
445 let extractor = PytestExtractor::new();
446 let assertions = extractor.extract(source).unwrap();
447
448 assert_eq!(assertions.len(), 2);
449 assert_eq!(assertions[0].line, 4);
450 assert_eq!(assertions[1].line, 6);
451 }
452
453 #[test]
454 fn test_extract_to_result() {
455 let source = r#"
456def test_square():
457 assert square(4) == 16
458"#;
459
460 let extractor = PytestExtractor::new();
461 let result = extractor.extract_to_result(source, "test_math.py").unwrap();
462
463 assert_eq!(result.source, "test_math.py");
464 assert_eq!(result.assertions.len(), 1);
465 }
466
467 #[test]
468 fn test_empty_source() {
469 let source = "";
470 let extractor = PytestExtractor::new();
471 let assertions = extractor.extract(source).unwrap();
472 assert!(assertions.is_empty());
473 }
474
475 #[test]
476 fn test_no_assertions() {
477 let source = r#"
478def test_foo():
479 x = compute()
480 print(x)
481"#;
482
483 let extractor = PytestExtractor::new();
484 let assertions = extractor.extract(source).unwrap();
485 assert!(assertions.is_empty());
486 }
487
488 #[test]
489 fn test_non_function_call_lhs() {
490 let source = r#"
491def test_foo():
492 assert x == 1
493 assert foo() == 2
494"#;
495
496 let extractor = PytestExtractor::new();
497 let assertions = extractor.extract(source).unwrap();
498
499 assert_eq!(assertions.len(), 1);
501 assert_eq!(assertions[0].input, "foo()");
502 }
503
504 #[test]
505 fn test_trailing_comment() {
506 let source = r#"
507def test_foo():
508 assert foo(1) == 1 # This tests the basic case
509"#;
510
511 let extractor = PytestExtractor::new();
512 let assertions = extractor.extract(source).unwrap();
513
514 assert_eq!(assertions.len(), 1);
515 assert_eq!(assertions[0].expected, "1");
516 }
517
518 #[test]
519 fn test_none_expected() {
520 let source = r#"
521def test_returns_none():
522 assert returns_none() == None
523"#;
524
525 let extractor = PytestExtractor::new();
526 let assertions = extractor.extract(source).unwrap();
527
528 assert_eq!(assertions.len(), 1);
529 assert_eq!(assertions[0].expected, "None");
530 }
531
532 #[test]
533 fn test_tuple_expected() {
534 let source = r#"
535def test_tuple():
536 assert get_tuple() == (1, 2, 3)
537"#;
538
539 let extractor = PytestExtractor::new();
540 let assertions = extractor.extract(source).unwrap();
541
542 assert_eq!(assertions.len(), 1);
543 assert_eq!(assertions[0].expected, "(1, 2, 3)");
544 }
545
546 #[test]
547 fn test_float_expected() {
548 let source = r#"
549def test_float():
550 assert divide(10, 4) == 2.5
551"#;
552
553 let extractor = PytestExtractor::new();
554 let assertions = extractor.extract(source).unwrap();
555
556 assert_eq!(assertions.len(), 1);
557 assert_eq!(assertions[0].expected, "2.5");
558 }
559
560 #[test]
561 fn test_negative_number_expected() {
562 let source = r#"
563def test_negative():
564 assert negate(5) == -5
565"#;
566
567 let extractor = PytestExtractor::new();
568 let assertions = extractor.extract(source).unwrap();
569
570 assert_eq!(assertions.len(), 1);
571 assert_eq!(assertions[0].expected, "-5");
572 }
573}