Skip to main content

testx/adapters/
python.rs

1use std::path::Path;
2use std::process::Command;
3use std::time::Duration;
4
5use anyhow::Result;
6
7use super::util::duration_from_secs_safe;
8use super::{DetectionResult, TestAdapter, TestCase, TestRunResult, TestStatus, TestSuite};
9
10pub struct PythonAdapter;
11
12impl Default for PythonAdapter {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl PythonAdapter {
19    pub fn new() -> Self {
20        Self
21    }
22
23    /// Check if pytest is the test framework
24    fn is_pytest(project_dir: &Path) -> bool {
25        // Check for pytest-specific files/configs
26        let markers = ["pytest.ini", ".pytest_cache", "conftest.py"];
27        for m in &markers {
28            if project_dir.join(m).exists() {
29                return true;
30            }
31        }
32
33        // Check pyproject.toml for pytest config
34        let pyproject = project_dir.join("pyproject.toml");
35        if pyproject.exists()
36            && let Ok(content) = std::fs::read_to_string(&pyproject)
37            && (content.contains("[tool.pytest") || content.contains("pytest"))
38        {
39            return true;
40        }
41
42        // Check setup.cfg
43        let setup_cfg = project_dir.join("setup.cfg");
44        if setup_cfg.exists()
45            && let Ok(content) = std::fs::read_to_string(&setup_cfg)
46            && content.contains("[tool:pytest]")
47        {
48            return true;
49        }
50
51        // Check tox.ini
52        let tox_ini = project_dir.join("tox.ini");
53        if tox_ini.exists()
54            && let Ok(content) = std::fs::read_to_string(&tox_ini)
55            && content.contains("[pytest]")
56        {
57            return true;
58        }
59
60        false
61    }
62
63    /// Check if Django is present
64    fn is_django(project_dir: &Path) -> bool {
65        project_dir.join("manage.py").exists()
66    }
67
68    /// Detect the Python package manager to use as a prefix
69    fn detect_runner_prefix(project_dir: &Path) -> Option<Vec<String>> {
70        if project_dir.join("uv.lock").exists() || {
71            let pyproject = project_dir.join("pyproject.toml");
72            pyproject.exists()
73                && std::fs::read_to_string(&pyproject)
74                    .map(|c| c.contains("[tool.uv]"))
75                    .unwrap_or(false)
76        } {
77            return Some(vec!["uv".into(), "run".into()]);
78        }
79        if project_dir.join("poetry.lock").exists() {
80            return Some(vec!["poetry".into(), "run".into()]);
81        }
82        if project_dir.join("pdm.lock").exists() {
83            return Some(vec!["pdm".into(), "run".into()]);
84        }
85        None
86    }
87
88    /// Find pytest binary inside a local virtualenv (.venv, venv, .env, env).
89    fn find_venv_pytest(project_dir: &Path) -> Option<std::path::PathBuf> {
90        for venv_dir in [".venv", "venv", ".env", "env"] {
91            let venv_pytest = project_dir.join(venv_dir).join("bin").join("pytest");
92            if venv_pytest.exists() {
93                return Some(venv_pytest);
94            }
95        }
96        None
97    }
98}
99
100impl TestAdapter for PythonAdapter {
101    fn name(&self) -> &str {
102        "Python"
103    }
104
105    fn check_runner(&self) -> Option<String> {
106        // If uv/poetry/pdm is the runner, check that instead
107        for runner in ["uv", "poetry", "pdm", "pytest", "python"] {
108            if which::which(runner).is_ok() {
109                return None;
110            }
111        }
112        Some("python".into())
113    }
114
115    fn detect(&self, project_dir: &Path) -> Option<DetectionResult> {
116        let has_python_files = project_dir.join("pyproject.toml").exists()
117            || project_dir.join("setup.py").exists()
118            || project_dir.join("setup.cfg").exists()
119            || project_dir.join("requirements.txt").exists()
120            || project_dir.join("Pipfile").exists();
121
122        if !has_python_files {
123            return None;
124        }
125
126        let framework = if Self::is_pytest(project_dir) {
127            "pytest"
128        } else if Self::is_django(project_dir) {
129            "django"
130        } else {
131            "unittest"
132        };
133
134        Some(DetectionResult {
135            language: "Python".into(),
136            framework: framework.into(),
137            confidence: if Self::is_pytest(project_dir) {
138                0.95
139            } else {
140                0.7
141            },
142        })
143    }
144
145    fn build_command(&self, project_dir: &Path, extra_args: &[String]) -> Result<Command> {
146        let prefix = Self::detect_runner_prefix(project_dir);
147        let is_pytest = Self::is_pytest(project_dir);
148        let is_django = Self::is_django(project_dir);
149
150        let mut cmd;
151
152        if let Some(prefix_args) = &prefix {
153            cmd = Command::new(&prefix_args[0]);
154            for arg in &prefix_args[1..] {
155                cmd.arg(arg);
156            }
157            if is_pytest {
158                cmd.arg("pytest");
159            } else if is_django {
160                cmd.arg("python").arg("-m").arg("django").arg("test");
161            } else {
162                cmd.arg("python").arg("-m").arg("unittest");
163            }
164        } else if is_pytest {
165            // Check for pytest inside a local virtualenv first
166            if let Some(venv_pytest) = Self::find_venv_pytest(project_dir) {
167                cmd = Command::new(venv_pytest);
168            } else {
169                cmd = Command::new("pytest");
170            }
171        } else if is_django {
172            cmd = Command::new("python");
173            cmd.arg("manage.py").arg("test");
174        } else {
175            cmd = Command::new("python");
176            cmd.arg("-m").arg("unittest");
177        }
178
179        // Add verbose flag for better output parsing (pytest)
180        if is_pytest && extra_args.is_empty() {
181            cmd.arg("-v");
182        }
183
184        for arg in extra_args {
185            cmd.arg(arg);
186        }
187
188        cmd.current_dir(project_dir);
189        Ok(cmd)
190    }
191
192    fn parse_output(&self, stdout: &str, stderr: &str, exit_code: i32) -> TestRunResult {
193        let combined = format!("{}\n{}", stdout, stderr);
194        let failure_messages = parse_pytest_failures(&combined);
195        let mut suites: Vec<TestSuite> = Vec::new();
196        let mut current_suite_name = String::from("tests");
197        let mut tests: Vec<TestCase> = Vec::new();
198
199        for line in combined.lines() {
200            let trimmed = line.trim();
201
202            // pytest verbose output: "test_file.py::TestClass::test_name PASSED"
203            // or: "test_file.py::test_name PASSED"
204            if let Some((test_path, status_str)) = parse_pytest_line(trimmed) {
205                let parts: Vec<&str> = test_path.split("::").collect();
206                let suite_name = parts.first().unwrap_or(&"tests").to_string();
207                let test_name = parts.last().unwrap_or(&"unknown").to_string();
208
209                // If suite changed, flush current tests
210                if suite_name != current_suite_name && !tests.is_empty() {
211                    suites.push(TestSuite {
212                        name: current_suite_name.clone(),
213                        tests: std::mem::take(&mut tests),
214                    });
215                }
216                current_suite_name = suite_name;
217
218                let status = match status_str.to_uppercase().as_str() {
219                    "PASSED" => TestStatus::Passed,
220                    "FAILED" => TestStatus::Failed,
221                    "SKIPPED" | "XFAIL" | "XPASS" => TestStatus::Skipped,
222                    "ERROR" => TestStatus::Failed,
223                    _ => TestStatus::Failed,
224                };
225
226                let error = if status == TestStatus::Failed {
227                    // Try full path first, then just test name
228                    failure_messages
229                        .get(&test_path)
230                        .or_else(|| failure_messages.get(&test_name))
231                        .map(|msg| super::TestError {
232                            message: msg.clone(),
233                            location: None,
234                        })
235                } else {
236                    None
237                };
238
239                tests.push(TestCase {
240                    name: test_name,
241                    status,
242                    duration: Duration::from_millis(0),
243                    error,
244                });
245            }
246        }
247
248        // Flush remaining tests
249        if !tests.is_empty() {
250            suites.push(TestSuite {
251                name: current_suite_name,
252                tests,
253            });
254        }
255
256        // If we couldn't parse any individual tests, create a summary suite from the summary line
257        if suites.is_empty() {
258            suites.push(parse_pytest_summary(&combined, exit_code));
259        }
260
261        // Try to parse total duration from pytest summary
262        let duration = parse_pytest_duration(&combined).unwrap_or(Duration::from_secs(0));
263
264        TestRunResult {
265            suites,
266            duration,
267            raw_exit_code: exit_code,
268        }
269    }
270}
271
272/// Parse a pytest verbose output line like "tests/test_foo.py::test_bar PASSED"
273fn parse_pytest_line(line: &str) -> Option<(String, String)> {
274    // Match patterns like: "path::test_name PASSED  [ 50%]"
275    let statuses = ["PASSED", "FAILED", "SKIPPED", "ERROR", "XFAIL", "XPASS"];
276    for status in &statuses {
277        if let Some(idx) = line.rfind(status) {
278            // Ensure the status word is preceded by whitespace (not part of test name)
279            if idx > 0 && !line.as_bytes()[idx - 1].is_ascii_whitespace() {
280                continue;
281            }
282            let path = line[..idx].trim().to_string();
283            if path.contains("::") {
284                return Some((path, status.to_string()));
285            }
286        }
287    }
288    None
289}
290
291/// Parse pytest summary line like "=== 5 passed, 2 failed in 0.32s ==="
292fn parse_pytest_summary(output: &str, exit_code: i32) -> TestSuite {
293    let mut passed = 0usize;
294    let mut failed = 0usize;
295    let mut skipped = 0usize;
296
297    for line in output.lines() {
298        let trimmed = line.trim().trim_matches('=').trim();
299        if trimmed.contains("passed") || trimmed.contains("failed") || trimmed.contains("error") {
300            // Parse "5 passed", "2 failed", etc.
301            for part in trimmed.split(',') {
302                let part = part.trim();
303                if let Some(n) = part
304                    .split_whitespace()
305                    .next()
306                    .and_then(|s| s.parse::<usize>().ok())
307                {
308                    if part.contains("passed") {
309                        passed = n;
310                    } else if part.contains("failed") || part.contains("error") {
311                        failed = n;
312                    } else if part.contains("skipped") {
313                        skipped = n;
314                    }
315                }
316            }
317        }
318    }
319
320    let mut tests = Vec::new();
321    for i in 0..passed {
322        tests.push(TestCase {
323            name: format!("test_{}", i + 1),
324            status: TestStatus::Passed,
325            duration: Duration::from_millis(0),
326            error: None,
327        });
328    }
329    for i in 0..failed {
330        tests.push(TestCase {
331            name: format!("failed_test_{}", i + 1),
332            status: TestStatus::Failed,
333            duration: Duration::from_millis(0),
334            error: None,
335        });
336    }
337    for i in 0..skipped {
338        tests.push(TestCase {
339            name: format!("skipped_test_{}", i + 1),
340            status: TestStatus::Skipped,
341            duration: Duration::from_millis(0),
342            error: None,
343        });
344    }
345
346    // If we still got nothing, infer from exit code
347    if tests.is_empty() {
348        tests.push(TestCase {
349            name: "test_suite".into(),
350            status: if exit_code == 0 {
351                TestStatus::Passed
352            } else {
353                TestStatus::Failed
354            },
355            duration: Duration::from_millis(0),
356            error: None,
357        });
358    }
359
360    TestSuite {
361        name: "tests".into(),
362        tests,
363    }
364}
365
366/// Parse pytest FAILURES section to extract error messages per test.
367/// Pytest output looks like:
368/// ```text
369/// =========================== FAILURES ===========================
370/// __________________ test_multiply __________________
371///
372///     def test_multiply():
373/// >       assert multiply(2, 3) == 7
374/// E       assert 6 == 7
375/// E       +  where 6 = multiply(2, 3)
376///
377/// tests/test_math.py:10: AssertionError
378/// =========================== short test summary info ===========================
379/// ```
380fn parse_pytest_failures(output: &str) -> std::collections::HashMap<String, String> {
381    let mut failures = std::collections::HashMap::new();
382    let lines: Vec<&str> = output.lines().collect();
383    let mut in_failures = false;
384
385    let mut i = 0;
386    while i < lines.len() {
387        let trimmed = lines[i].trim();
388
389        // Enter FAILURES section
390        if trimmed.contains("FAILURES") && trimmed.starts_with('=') {
391            in_failures = true;
392            i += 1;
393            continue;
394        }
395
396        // Exit FAILURES section
397        if in_failures
398            && trimmed.starts_with('=')
399            && (trimmed.contains("short test summary")
400                || trimmed.contains("passed")
401                || trimmed.contains("failed")
402                || trimmed.contains("error"))
403        {
404            break;
405        }
406
407        // Match test header: "__________________ test_name __________________"
408        if in_failures && trimmed.starts_with('_') && trimmed.ends_with('_') {
409            let test_name = trimmed.trim_matches('_').trim().to_string();
410            if !test_name.is_empty() {
411                let mut error_lines = Vec::new();
412                let mut location = None;
413                i += 1;
414                while i < lines.len() {
415                    let l = lines[i].trim();
416                    // Next test header or section boundary
417                    if (l.starts_with('_') && l.ends_with('_') && l.len() > 5)
418                        || (l.starts_with('=') && l.len() > 5)
419                    {
420                        break;
421                    }
422                    // Assertion lines start with "E"
423                    if l.starts_with("E ") || l.starts_with("E\t") {
424                        error_lines.push(l[1..].trim().to_string());
425                    }
426                    // Location line like "tests/test_math.py:10: AssertionError"
427                    if l.contains(".py:")
428                        && l.contains(':')
429                        && !l.starts_with('>')
430                        && !l.starts_with("E")
431                    {
432                        let parts: Vec<&str> = l.splitn(3, ':').collect();
433                        if parts.len() >= 2 {
434                            location = Some(format!("{}:{}", parts[0].trim(), parts[1].trim()));
435                        }
436                    }
437                    i += 1;
438                }
439                if !error_lines.is_empty() {
440                    let mut msg = error_lines.join(" | ");
441                    if let Some(loc) = location {
442                        msg = format!("{} ({})", msg, loc);
443                    }
444                    failures.insert(test_name, msg);
445                }
446                continue;
447            }
448        }
449        i += 1;
450    }
451    failures
452}
453
454/// Parse duration from pytest summary like "in 0.32s"
455fn parse_pytest_duration(output: &str) -> Option<Duration> {
456    for line in output.lines() {
457        if let Some(idx) = line.find(" in ") {
458            let after = &line[idx + 4..];
459            let num_str: String = after
460                .chars()
461                .take_while(|c| c.is_ascii_digit() || *c == '.')
462                .collect();
463            if let Ok(secs) = num_str.parse::<f64>() {
464                return Some(duration_from_secs_safe(secs));
465            }
466        }
467    }
468    None
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn parse_pytest_verbose_output() {
477        let stdout = r#"
478============================= test session starts ==============================
479collected 4 items
480
481tests/test_math.py::test_add PASSED                                      [ 25%]
482tests/test_math.py::test_subtract PASSED                                 [ 50%]
483tests/test_math.py::test_multiply FAILED                                 [ 75%]
484tests/test_string.py::test_upper PASSED                                  [100%]
485
486=================================== FAILURES ===================================
487________________________________ test_multiply _________________________________
488
489    def test_multiply():
490>       assert multiply(2, 3) == 7
491E       assert 6 == 7
492E         +  where 6 = multiply(2, 3)
493
494tests/test_math.py:10: AssertionError
495=========================== short test summary info ============================
496FAILED tests/test_math.py::test_multiply - assert 6 == 7
497============================== 3 passed, 1 failed in 0.12s =====================
498"#;
499        let adapter = PythonAdapter::new();
500        let result = adapter.parse_output(stdout, "", 1);
501
502        assert_eq!(result.total_tests(), 4);
503        assert_eq!(result.total_passed(), 3);
504        assert_eq!(result.total_failed(), 1);
505        assert!(!result.is_success());
506        assert_eq!(result.suites.len(), 2); // two test files
507        assert_eq!(result.duration, Duration::from_millis(120));
508
509        // Verify error message was captured
510        let failed: Vec<_> = result.suites.iter().flat_map(|s| s.failures()).collect();
511        assert_eq!(failed.len(), 1);
512        assert!(failed[0].error.is_some());
513        assert!(
514            failed[0]
515                .error
516                .as_ref()
517                .unwrap()
518                .message
519                .contains("assert 6 == 7")
520        );
521    }
522
523    #[test]
524    fn parse_pytest_all_pass() {
525        let stdout = "========================= 5 passed in 0.32s =========================\n";
526        let adapter = PythonAdapter::new();
527        let result = adapter.parse_output(stdout, "", 0);
528
529        assert_eq!(result.total_tests(), 5);
530        assert_eq!(result.total_passed(), 5);
531        assert!(result.is_success());
532    }
533
534    #[test]
535    fn parse_pytest_with_skipped() {
536        let stdout = r#"
537tests/test_foo.py::test_a PASSED
538tests/test_foo.py::test_b SKIPPED
539tests/test_foo.py::test_c PASSED
540
541========================= 2 passed, 1 skipped in 0.05s =========================
542"#;
543        let adapter = PythonAdapter::new();
544        let result = adapter.parse_output(stdout, "", 0);
545
546        assert_eq!(result.total_passed(), 2);
547        assert_eq!(result.total_skipped(), 1);
548        assert!(result.is_success());
549    }
550
551    #[test]
552    fn parse_pytest_class_based() {
553        let stdout = r#"
554tests/test_calc.py::TestCalculator::test_add PASSED
555tests/test_calc.py::TestCalculator::test_div FAILED
556"#;
557        let adapter = PythonAdapter::new();
558        let result = adapter.parse_output(stdout, "", 1);
559
560        assert_eq!(result.total_tests(), 2);
561        assert_eq!(result.total_passed(), 1);
562        assert_eq!(result.total_failed(), 1);
563    }
564
565    #[test]
566    fn parse_pytest_summary_only() {
567        let stdout = "===== 10 passed, 2 failed, 3 skipped in 1.50s =====\n";
568        let adapter = PythonAdapter::new();
569        let result = adapter.parse_output(stdout, "", 1);
570
571        assert_eq!(result.total_passed(), 10);
572        assert_eq!(result.total_failed(), 2);
573        assert_eq!(result.total_skipped(), 3);
574        assert_eq!(result.total_tests(), 15);
575    }
576
577    #[test]
578    fn parse_pytest_duration_extraction() {
579        assert_eq!(
580            parse_pytest_duration("=== 1 passed in 2.34s ==="),
581            Some(Duration::from_millis(2340))
582        );
583        assert_eq!(parse_pytest_duration("no duration here"), None);
584    }
585
586    #[test]
587    fn parse_pytest_line_function() {
588        assert_eq!(
589            parse_pytest_line("tests/test_foo.py::test_bar PASSED                    [ 50%]"),
590            Some(("tests/test_foo.py::test_bar".into(), "PASSED".into()))
591        );
592        assert_eq!(parse_pytest_line("collected 5 items"), None);
593        assert_eq!(parse_pytest_line(""), None);
594    }
595
596    #[test]
597    fn detect_in_pytest_project() {
598        let dir = tempfile::tempdir().unwrap();
599        std::fs::write(
600            dir.path().join("pyproject.toml"),
601            "[tool.pytest.ini_options]\n",
602        )
603        .unwrap();
604        let adapter = PythonAdapter::new();
605        let det = adapter.detect(dir.path()).unwrap();
606        assert_eq!(det.framework, "pytest");
607        assert!(det.confidence > 0.9);
608    }
609
610    #[test]
611    fn detect_no_python() {
612        let dir = tempfile::tempdir().unwrap();
613        std::fs::write(dir.path().join("main.go"), "package main\n").unwrap();
614        let adapter = PythonAdapter::new();
615        assert!(adapter.detect(dir.path()).is_none());
616    }
617
618    #[test]
619    fn detect_django_project() {
620        let dir = tempfile::tempdir().unwrap();
621        std::fs::write(dir.path().join("requirements.txt"), "django\n").unwrap();
622        std::fs::write(dir.path().join("manage.py"), "#!/usr/bin/env python\n").unwrap();
623        let adapter = PythonAdapter::new();
624        let det = adapter.detect(dir.path()).unwrap();
625        assert_eq!(det.framework, "django");
626    }
627
628    #[test]
629    fn parse_pytest_empty_output() {
630        let adapter = PythonAdapter::new();
631        let result = adapter.parse_output("", "", 0);
632
633        assert_eq!(result.total_tests(), 1);
634        assert!(result.is_success());
635    }
636
637    #[test]
638    fn parse_pytest_xfail_xpass() {
639        let stdout = r#"
640tests/test_edge.py::test_expected_fail XFAIL
641tests/test_edge.py::test_unexpected_pass XPASS
642
643========================= 2 xfailed in 0.05s =========================
644"#;
645        let adapter = PythonAdapter::new();
646        let result = adapter.parse_output(stdout, "", 0);
647
648        // XFAIL and XPASS should be counted as skipped
649        assert_eq!(result.total_skipped(), 2);
650        assert!(result.is_success());
651    }
652
653    #[test]
654    fn parse_pytest_parametrized() {
655        let stdout = r#"
656tests/test_math.py::test_add[1-2-3] PASSED
657tests/test_math.py::test_add[0-0-0] PASSED
658tests/test_math.py::test_add[-1-1-0] PASSED
659
660========================= 3 passed in 0.01s =========================
661"#;
662        let adapter = PythonAdapter::new();
663        let result = adapter.parse_output(stdout, "", 0);
664
665        assert_eq!(result.total_tests(), 3);
666        assert_eq!(result.total_passed(), 3);
667    }
668
669    #[test]
670    fn parse_pytest_error_status() {
671        let stdout = r#"
672tests/test_math.py::test_setup ERROR
673
674========================= 1 error in 0.10s =========================
675"#;
676        let adapter = PythonAdapter::new();
677        let result = adapter.parse_output(stdout, "", 1);
678
679        assert_eq!(result.total_failed(), 1);
680        assert!(!result.is_success());
681    }
682
683    #[test]
684    fn detect_pipfile_project() {
685        let dir = tempfile::tempdir().unwrap();
686        std::fs::write(dir.path().join("Pipfile"), "[packages]\n").unwrap();
687        let adapter = PythonAdapter::new();
688        let det = adapter.detect(dir.path()).unwrap();
689        assert_eq!(det.language, "Python");
690    }
691
692    #[test]
693    fn detect_unittest_fallback() {
694        // Has Python markers but no pytest/django markers
695        let dir = tempfile::tempdir().unwrap();
696        std::fs::write(
697            dir.path().join("setup.py"),
698            "from setuptools import setup\n",
699        )
700        .unwrap();
701        let adapter = PythonAdapter::new();
702        let det = adapter.detect(dir.path()).unwrap();
703        assert_eq!(det.framework, "unittest");
704        assert!(det.confidence < 0.8);
705    }
706}