Skip to main content

depyler_tooling/
pytest_extractor.rs

1//! # Pytest Assertion Extractor for CITL Training Pipeline
2//!
3//! GH-174: Extracts simple `assert` statements from `test_*.py` files
4//! as additional CITL training signal.
5//!
6//! ## Overview
7//!
8//! Many test files contain simple I/O assertions equivalent to doctests:
9//!
10//! ```python
11//! def test_square():
12//!     assert square(4) == 16
13//!     assert square(-3) == 9
14//! ```
15//!
16//! This module extracts these patterns into the same format as doctests.
17//!
18//! ## Scope
19//!
20//! Extract **only** simple patterns:
21//! - `assert f(x) == y`
22//! - `assert f(x, y) == z`
23//! - `assert f(x) == [a, b, c]`
24//!
25//! **Ignore** complex patterns:
26//! - Fixtures, mocks, parametrize
27//! - Exception testing (`pytest.raises`)
28//! - Approximate comparisons (`pytest.approx`)
29
30use crate::doctest_extractor::Doctest;
31use anyhow::Result;
32use serde::{Deserialize, Serialize};
33
34/// Result of extracting pytest assertions
35#[derive(Debug, Clone, Default, Serialize, Deserialize)]
36pub struct PytestResult {
37    /// Source file
38    pub source: String,
39    /// Extracted assertions as doctests
40    pub assertions: Vec<Doctest>,
41}
42
43/// Extracts simple assert statements from pytest files
44#[derive(Debug, Clone, Default)]
45pub struct PytestExtractor {
46    /// Only extract from test_*.py files
47    pub strict_test_files: bool,
48}
49
50impl PytestExtractor {
51    /// Creates a new PytestExtractor with default settings
52    pub fn new() -> Self {
53        Self {
54            strict_test_files: true,
55        }
56    }
57
58    /// Configure whether to only extract from test_*.py files
59    pub fn with_strict_test_files(mut self, strict: bool) -> Self {
60        self.strict_test_files = strict;
61        self
62    }
63
64    /// Extract all simple assertions from Python test source code
65    pub fn extract(&self, source: &str) -> Result<Vec<Doctest>> {
66        let mut assertions = Vec::new();
67        let lines: Vec<&str> = source.lines().collect();
68
69        let mut current_function: Option<String> = None;
70
71        for (line_num, line) in lines.iter().enumerate() {
72            let trimmed = line.trim();
73
74            // Track test function definitions
75            if trimmed.starts_with("def test_") {
76                if let Some(name) = Self::extract_function_name(trimmed) {
77                    current_function = Some(name);
78                }
79            } else if trimmed.starts_with("def ") && !trimmed.starts_with("def test_") {
80                // Non-test function, clear context
81                current_function = None;
82            }
83
84            // Look for simple assert statements
85            if trimmed.starts_with("assert ") {
86                if let Some(doctest) = self.parse_assert(trimmed, line_num + 1, &current_function) {
87                    assertions.push(doctest);
88                }
89            }
90        }
91
92        Ok(assertions)
93    }
94
95    /// Extract function name from a def line
96    fn extract_function_name(line: &str) -> Option<String> {
97        let after_def = line.strip_prefix("def ")?.trim();
98        let paren_idx = after_def.find('(')?;
99        Some(after_def[..paren_idx].to_string())
100    }
101
102    /// Parse an assert statement into a Doctest if it's a simple pattern
103    fn parse_assert(
104        &self,
105        line: &str,
106        line_num: usize,
107        _current_function: &Option<String>,
108    ) -> Option<Doctest> {
109        // Remove "assert " prefix
110        let assertion = line.strip_prefix("assert ")?.trim();
111
112        // Skip complex patterns
113        if self.is_complex_assertion(assertion) {
114            return None;
115        }
116
117        // Look for == comparison
118        let eq_idx = assertion.find(" == ")?;
119        let left = assertion[..eq_idx].trim();
120        let right = assertion[eq_idx + 4..].trim();
121
122        // Left side should be a function call
123        if !left.contains('(') || !left.contains(')') {
124            return None;
125        }
126
127        // Extract the function being tested (from the call)
128        let func_name = self.extract_called_function(left)?;
129
130        // Clean up right side (remove trailing comments, etc.)
131        let expected = self.clean_expected(right);
132
133        Some(Doctest {
134            function: func_name,
135            input: left.to_string(),
136            expected,
137            line: line_num,
138        })
139    }
140
141    /// Check if an assertion is too complex to extract
142    fn is_complex_assertion(&self, assertion: &str) -> bool {
143        // Skip pytest-specific patterns
144        if assertion.contains("pytest.") {
145            return true;
146        }
147
148        // Skip approximate comparisons
149        if assertion.contains("approx(") {
150            return true;
151        }
152
153        // Skip assertions with 'in' operator
154        if assertion.contains(" in ") && !assertion.contains(" == ") {
155            return true;
156        }
157
158        // Skip assertions with 'is' operator (identity checks)
159        if assertion.contains(" is ") && !assertion.contains(" == ") {
160            return true;
161        }
162
163        // Skip assertions with 'not' operator at the start
164        if assertion.starts_with("not ") {
165            return true;
166        }
167
168        // Skip multi-condition assertions
169        if assertion.contains(" and ") || assertion.contains(" or ") {
170            return true;
171        }
172
173        // Skip assertions with lambda
174        if assertion.contains("lambda") {
175            return true;
176        }
177
178        // Skip type checks
179        if assertion.contains("isinstance(") || assertion.contains("type(") {
180            return true;
181        }
182
183        false
184    }
185
186    /// Extract the function name being called
187    fn extract_called_function(&self, call_expr: &str) -> Option<String> {
188        let paren_idx = call_expr.find('(')?;
189        let func_part = &call_expr[..paren_idx];
190
191        // Handle method calls like obj.method()
192        if let Some(dot_idx) = func_part.rfind('.') {
193            Some(func_part[dot_idx + 1..].to_string())
194        } else {
195            Some(func_part.to_string())
196        }
197    }
198
199    /// Clean up the expected value
200    fn clean_expected(&self, expected: &str) -> String {
201        let mut result = expected.to_string();
202
203        // Remove trailing comments
204        if let Some(hash_idx) = result.find('#') {
205            result = result[..hash_idx].trim().to_string();
206        }
207
208        // Remove trailing comma (from tuple unpacking)
209        result = result.trim_end_matches(',').trim().to_string();
210
211        result
212    }
213
214    /// Extract assertions to the same format as doctest results
215    pub fn extract_to_result(&self, source: &str, filename: &str) -> Result<PytestResult> {
216        let assertions = self.extract(source)?;
217        Ok(PytestResult {
218            source: filename.to_string(),
219            assertions,
220        })
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    // =========================================================================
229    // RED TESTS - These define the expected behavior (GH-174)
230    // =========================================================================
231
232    #[test]
233    fn test_extract_simple_assert_eq() {
234        let source = r#"
235def test_square():
236    assert square(4) == 16
237"#;
238
239        let extractor = PytestExtractor::new();
240        let assertions = extractor.extract(source).unwrap();
241
242        assert_eq!(assertions.len(), 1);
243        assert_eq!(assertions[0].function, "square");
244        assert_eq!(assertions[0].input, "square(4)");
245        assert_eq!(assertions[0].expected, "16");
246    }
247
248    #[test]
249    fn test_extract_multiple_assertions() {
250        let source = r#"
251def test_square():
252    assert square(4) == 16
253    assert square(-3) == 9
254    assert square(0) == 0
255"#;
256
257        let extractor = PytestExtractor::new();
258        let assertions = extractor.extract(source).unwrap();
259
260        assert_eq!(assertions.len(), 3);
261        assert_eq!(assertions[0].expected, "16");
262        assert_eq!(assertions[1].expected, "9");
263        assert_eq!(assertions[2].expected, "0");
264    }
265
266    #[test]
267    fn test_extract_multiple_args() {
268        let source = r#"
269def test_add():
270    assert add(1, 2) == 3
271    assert add(-1, 1) == 0
272"#;
273
274        let extractor = PytestExtractor::new();
275        let assertions = extractor.extract(source).unwrap();
276
277        assert_eq!(assertions.len(), 2);
278        assert_eq!(assertions[0].input, "add(1, 2)");
279        assert_eq!(assertions[0].expected, "3");
280    }
281
282    #[test]
283    fn test_extract_string_expected() {
284        let source = r#"
285def test_greet():
286    assert greet("World") == "Hello, World!"
287"#;
288
289        let extractor = PytestExtractor::new();
290        let assertions = extractor.extract(source).unwrap();
291
292        assert_eq!(assertions.len(), 1);
293        assert_eq!(assertions[0].expected, "\"Hello, World!\"");
294    }
295
296    #[test]
297    fn test_extract_list_expected() {
298        let source = r#"
299def test_range_list():
300    assert range_list(3) == [0, 1, 2]
301"#;
302
303        let extractor = PytestExtractor::new();
304        let assertions = extractor.extract(source).unwrap();
305
306        assert_eq!(assertions.len(), 1);
307        assert_eq!(assertions[0].expected, "[0, 1, 2]");
308    }
309
310    #[test]
311    fn test_extract_dict_expected() {
312        let source = r#"
313def test_make_dict():
314    assert make_dict("a", 1) == {"a": 1}
315"#;
316
317        let extractor = PytestExtractor::new();
318        let assertions = extractor.extract(source).unwrap();
319
320        assert_eq!(assertions.len(), 1);
321        assert_eq!(assertions[0].expected, "{\"a\": 1}");
322    }
323
324    #[test]
325    fn test_extract_boolean_expected() {
326        let source = r#"
327def test_is_even():
328    assert is_even(4) == True
329    assert is_even(3) == False
330"#;
331
332        let extractor = PytestExtractor::new();
333        let assertions = extractor.extract(source).unwrap();
334
335        assert_eq!(assertions.len(), 2);
336        assert_eq!(assertions[0].expected, "True");
337        assert_eq!(assertions[1].expected, "False");
338    }
339
340    #[test]
341    fn test_skip_pytest_raises() {
342        let source = r#"
343def test_error():
344    with pytest.raises(ValueError):
345        divide(1, 0)
346    assert divide(10, 2) == 5
347"#;
348
349        let extractor = PytestExtractor::new();
350        let assertions = extractor.extract(source).unwrap();
351
352        // Should only extract the simple assertion
353        assert_eq!(assertions.len(), 1);
354        assert_eq!(assertions[0].input, "divide(10, 2)");
355    }
356
357    #[test]
358    fn test_skip_pytest_approx() {
359        let source = r#"
360def test_float():
361    assert divide(10, 3) == pytest.approx(3.333, rel=0.01)
362    assert multiply(2, 3) == 6
363"#;
364
365        let extractor = PytestExtractor::new();
366        let assertions = extractor.extract(source).unwrap();
367
368        // Should only extract the exact comparison
369        assert_eq!(assertions.len(), 1);
370        assert_eq!(assertions[0].input, "multiply(2, 3)");
371    }
372
373    #[test]
374    fn test_skip_complex_and_or() {
375        let source = r#"
376def test_complex():
377    assert foo(1) == 1 and bar(2) == 2
378    assert simple(3) == 3
379"#;
380
381        let extractor = PytestExtractor::new();
382        let assertions = extractor.extract(source).unwrap();
383
384        // Should only extract the simple assertion
385        assert_eq!(assertions.len(), 1);
386        assert_eq!(assertions[0].input, "simple(3)");
387    }
388
389    #[test]
390    fn test_skip_isinstance() {
391        let source = r#"
392def test_types():
393    assert isinstance(foo(), int)
394    assert bar() == 42
395"#;
396
397        let extractor = PytestExtractor::new();
398        let assertions = extractor.extract(source).unwrap();
399
400        assert_eq!(assertions.len(), 1);
401        assert_eq!(assertions[0].input, "bar()");
402    }
403
404    #[test]
405    fn test_skip_in_operator() {
406        let source = r#"
407def test_membership():
408    assert 1 in get_list()
409    assert get_first() == 1
410"#;
411
412        let extractor = PytestExtractor::new();
413        let assertions = extractor.extract(source).unwrap();
414
415        assert_eq!(assertions.len(), 1);
416        assert_eq!(assertions[0].input, "get_first()");
417    }
418
419    #[test]
420    fn test_method_call() {
421        let source = r#"
422def test_method():
423    obj = MyClass()
424    assert obj.compute(5) == 25
425"#;
426
427        let extractor = PytestExtractor::new();
428        let assertions = extractor.extract(source).unwrap();
429
430        assert_eq!(assertions.len(), 1);
431        assert_eq!(assertions[0].function, "compute");
432        assert_eq!(assertions[0].input, "obj.compute(5)");
433    }
434
435    #[test]
436    fn test_line_numbers() {
437        let source = r#"
438def test_foo():
439    x = 1
440    assert foo(1) == 1
441    y = 2
442    assert foo(2) == 4
443"#;
444
445        let extractor = PytestExtractor::new();
446        let assertions = extractor.extract(source).unwrap();
447
448        assert_eq!(assertions.len(), 2);
449        assert_eq!(assertions[0].line, 4);
450        assert_eq!(assertions[1].line, 6);
451    }
452
453    #[test]
454    fn test_extract_to_result() {
455        let source = r#"
456def test_square():
457    assert square(4) == 16
458"#;
459
460        let extractor = PytestExtractor::new();
461        let result = extractor.extract_to_result(source, "test_math.py").unwrap();
462
463        assert_eq!(result.source, "test_math.py");
464        assert_eq!(result.assertions.len(), 1);
465    }
466
467    #[test]
468    fn test_empty_source() {
469        let source = "";
470        let extractor = PytestExtractor::new();
471        let assertions = extractor.extract(source).unwrap();
472        assert!(assertions.is_empty());
473    }
474
475    #[test]
476    fn test_no_assertions() {
477        let source = r#"
478def test_foo():
479    x = compute()
480    print(x)
481"#;
482
483        let extractor = PytestExtractor::new();
484        let assertions = extractor.extract(source).unwrap();
485        assert!(assertions.is_empty());
486    }
487
488    #[test]
489    fn test_non_function_call_lhs() {
490        let source = r#"
491def test_foo():
492    assert x == 1
493    assert foo() == 2
494"#;
495
496        let extractor = PytestExtractor::new();
497        let assertions = extractor.extract(source).unwrap();
498
499        // Should only extract the function call assertion
500        assert_eq!(assertions.len(), 1);
501        assert_eq!(assertions[0].input, "foo()");
502    }
503
504    #[test]
505    fn test_trailing_comment() {
506        let source = r#"
507def test_foo():
508    assert foo(1) == 1  # This tests the basic case
509"#;
510
511        let extractor = PytestExtractor::new();
512        let assertions = extractor.extract(source).unwrap();
513
514        assert_eq!(assertions.len(), 1);
515        assert_eq!(assertions[0].expected, "1");
516    }
517
518    #[test]
519    fn test_none_expected() {
520        let source = r#"
521def test_returns_none():
522    assert returns_none() == None
523"#;
524
525        let extractor = PytestExtractor::new();
526        let assertions = extractor.extract(source).unwrap();
527
528        assert_eq!(assertions.len(), 1);
529        assert_eq!(assertions[0].expected, "None");
530    }
531
532    #[test]
533    fn test_tuple_expected() {
534        let source = r#"
535def test_tuple():
536    assert get_tuple() == (1, 2, 3)
537"#;
538
539        let extractor = PytestExtractor::new();
540        let assertions = extractor.extract(source).unwrap();
541
542        assert_eq!(assertions.len(), 1);
543        assert_eq!(assertions[0].expected, "(1, 2, 3)");
544    }
545
546    #[test]
547    fn test_float_expected() {
548        let source = r#"
549def test_float():
550    assert divide(10, 4) == 2.5
551"#;
552
553        let extractor = PytestExtractor::new();
554        let assertions = extractor.extract(source).unwrap();
555
556        assert_eq!(assertions.len(), 1);
557        assert_eq!(assertions[0].expected, "2.5");
558    }
559
560    #[test]
561    fn test_negative_number_expected() {
562        let source = r#"
563def test_negative():
564    assert negate(5) == -5
565"#;
566
567        let extractor = PytestExtractor::new();
568        let assertions = extractor.extract(source).unwrap();
569
570        assert_eq!(assertions.len(), 1);
571        assert_eq!(assertions[0].expected, "-5");
572    }
573}