Skip to main content

offload/
orchestrator.rs

1//! Test execution engine and orchestration.
2//!
3//! This module contains the core execution logic that coordinates test
4//! discovery, distribution across sandboxes, execution, retry handling,
5//! and result collection.
6//!
7//! # Architecture
8//!
9//! ```text
10//!   Framework                 Scheduler                Provider
11//!       │                         │                        │
12//!       │ discover()              │                        │
13//!       ▼                         │                        │
14//!  Vec<TestRecord>                │                        │
15//!       │                         │                        │
16//!       │ expand to TestInstances │                        │
17//!       ▼                         │                        │
18//!  Vec<TestInstance> ────────────►│ schedule_random()      │
19//!                                 ▼                        │
20//!                        Vec<Vec<TestInstance>> (batches)  │
21//!                                 │                        │
22//!                                 │    create_sandbox() ──►│
23//!                                 │                        ▼
24//!                                 │                     Sandbox
25//!                                 │                        │
26//!                                 └────────┬───────────────┘
27//!                                          ▼
28//!                                     TestRunner
29//!                                          │
30//!   Framework ◄─── produce_command() ──────┤
31//!       │                                  │
32//!       │                        Sandbox.exec(cmd)
33//!       │                                  │
34//!       │ parse_results() ◄─────── ExecResult
35//!       ▼
36//!  Vec<TestResult> ──► TestRecord.record_result()
37//! ```
38//!
39//! # Execution Flow
40//!
41//! 1. **Discovery**: Find tests using the configured framework
42//! 2. **Expansion**: Create parallel retry instances for each test
43//! 3. **Scheduling**: Distribute test instances into batches across sandboxes
44//! 4. **Execution**: Run test batches in parallel sandboxes
45//! 5. **Aggregation**: Combine results (any pass = pass, detect flaky tests)
46//! 6. **Reporting**: Print summary and generate JUnit XML
47//!
48//! # Key Components
49//!
50//! - [`Orchestrator`]: Main entry point coordinating the test run
51//! - [`Scheduler`]: Distributes tests across available sandboxes
52//! - [`TestRunner`]: Executes tests in a single sandbox
53//! - [`RunResult`]: Aggregated results of the entire test run
54//!
55//! # Example
56//!
57//! ```no_run
58//! use offload::orchestrator::{Orchestrator, SandboxPool};
59//! use offload::config::{load_config, SandboxConfig};
60//! use offload::provider::local::LocalProvider;
61//! use offload::framework::{TestFramework, pytest::PytestFramework};
62//!
63//! #[tokio::main]
64//! async fn main() -> anyhow::Result<()> {
65//!     let config = load_config(std::path::Path::new("offload.toml"))?;
66//!
67//!     let provider = LocalProvider::new(Default::default());
68//!     let framework = PytestFramework::new(Default::default());
69//!
70//!     // Discover tests using the framework
71//!     let tests = framework.discover(&[]).await?;
72//!
73//!     // Pre-populate sandbox pool
74//!     let sandbox_config = SandboxConfig {
75//!         id: "sandbox".to_string(),
76//!         working_dir: None,
77//!         env: vec![],
78//!         copy_dirs: vec![],
79//!     };
80//!     let mut sandbox_pool = SandboxPool::new();
81//!     sandbox_pool.populate(config.offload.max_parallel, &provider, &sandbox_config).await?;
82//!
83//!     // Run tests using the orchestrator
84//!     let orchestrator = Orchestrator::new(config, framework, false);
85//!     let result = orchestrator.run_with_tests(&tests, sandbox_pool).await?;
86//!
87//!     if result.success() {
88//!         println!("All tests passed!");
89//!     } else {
90//!         println!("{} tests failed", result.failed);
91//!     }
92//!
93//!     std::process::exit(result.exit_code());
94//! }
95//! ```
96
97pub mod pool;
98pub mod runner;
99pub mod scheduler;
100
101use std::sync::Arc;
102use std::sync::atomic::{AtomicBool, Ordering};
103use std::time::Duration;
104
105use tokio::sync::Mutex;
106use tokio_util::sync::CancellationToken;
107use tracing::{debug, error, info, warn};
108
109use crate::config::Config;
110use crate::framework::{TestFramework, TestInstance, TestRecord, TestResult};
111use crate::provider::{OutputLine, Sandbox};
112use crate::report::{MasterJunitReport, load_test_durations, print_summary};
113
114pub use pool::SandboxPool;
115pub use runner::{OutputCallback, TestRunner};
116pub use scheduler::Scheduler;
117
118/// Aggregated results of an entire test run.
119///
120/// Contains summary statistics and individual test results. This is the
121/// return value of [`Orchestrator::run`] and is passed to reporters
122/// for final output.
123///
124/// # Exit Codes
125///
126/// The [`exit_code`](Self::exit_code) method returns conventional exit codes:
127///
128/// | Code | Meaning |
129/// |------|---------|
130/// | 0 | All tests passed |
131/// | 1 | Some tests failed or weren't run |
132/// | 2 | All tests passed but some were flaky |
133#[derive(Debug, Clone)]
134pub struct RunResult {
135    /// Total number of tests discovered.
136    pub total_tests: usize,
137
138    /// Number of tests that passed.
139    pub passed: usize,
140
141    /// Number of tests that failed (assertions or errors).
142    pub failed: usize,
143
144    /// Number of tests that were skipped.
145    pub skipped: usize,
146
147    /// Number of tests that were flaky (passed on retry).
148    ///
149    /// A flaky test is one that failed initially but passed after retrying.
150    pub flaky: usize,
151
152    /// Number of tests that couldn't be run.
153    ///
154    /// Typically due to sandbox creation failures or infrastructure issues.
155    pub not_run: usize,
156
157    /// Wall-clock duration of the entire test run.
158    pub duration: Duration,
159
160    /// Individual test results for all executed tests.
161    pub results: Vec<TestResult>,
162}
163
164impl RunResult {
165    /// Returns `true` if the test run was successful.
166    ///
167    /// A run is successful if no tests failed and all scheduled tests
168    /// were executed. Flaky tests are considered successful (they passed
169    /// on retry).
170    ///
171    /// # Example
172    ///
173    /// ```
174    /// use offload::orchestrator::RunResult;
175    /// use std::time::Duration;
176    ///
177    /// let result = RunResult {
178    ///     total_tests: 100,
179    ///     passed: 95,
180    ///     failed: 0,
181    ///     skipped: 5,
182    ///     flaky: 2,
183    ///     not_run: 0,
184    ///     duration: Duration::from_secs(60),
185    ///     results: vec![],
186    /// };
187    ///
188    /// assert!(result.success());
189    /// ```
190    pub fn success(&self) -> bool {
191        self.failed == 0 && self.not_run == 0
192    }
193
194    /// Returns an appropriate process exit code for this result.
195    pub fn exit_code(&self) -> i32 {
196        if self.failed > 0 || self.not_run > 0 {
197            1
198        } else if self.flaky > 0 {
199            2 // 2 is the convention that offload has decided to store for flakiness
200        } else {
201            0
202        }
203    }
204}
205
206/// The main orchestrator that coordinates test execution.
207///
208/// The orchestrator is the top-level component that ties together:
209/// - A pre-populated [`SandboxPool`] of execution environments
210/// - A [`TestFramework`] for running tests
211///
212/// It manages the full lifecycle of a test run: scheduling,
213/// parallel execution, retries, and result aggregation.
214///
215/// # Type Parameters
216///
217/// - `S`: The sandbox type (implements [`Sandbox`](crate::provider::Sandbox))
218/// - `D`: The test framework type
219///
220/// # Example
221///
222/// ```no_run
223/// use offload::orchestrator::{Orchestrator, SandboxPool};
224/// use offload::config::{load_config, SandboxConfig};
225/// use offload::provider::local::LocalProvider;
226/// use offload::framework::{TestFramework, pytest::PytestFramework};
227///
228/// #[tokio::main]
229/// async fn main() -> anyhow::Result<()> {
230///     let config = load_config(std::path::Path::new("offload.toml"))?;
231///
232///     // Set up components
233///     let provider = LocalProvider::new(Default::default());
234///     let framework = PytestFramework::new(Default::default());
235///
236///     // Discover tests using the framework
237///     let tests = framework.discover(&[]).await?;
238///
239///     // Pre-populate sandbox pool
240///     let sandbox_config = SandboxConfig {
241///         id: "sandbox".to_string(),
242///         working_dir: None,
243///         env: vec![],
244///         copy_dirs: vec![],
245///     };
246///     let mut sandbox_pool = SandboxPool::new();
247///     sandbox_pool.populate(config.offload.max_parallel, &provider, &sandbox_config).await?;
248///
249///     // Create orchestrator and run tests
250///     let orchestrator = Orchestrator::new(config, framework, false);
251///     let result = orchestrator.run_with_tests(&tests, sandbox_pool).await?;
252///
253///     std::process::exit(result.exit_code());
254/// }
255/// ```
256pub struct Orchestrator<S, D> {
257    config: Config,
258    framework: D,
259    verbose: bool,
260    _sandbox: std::marker::PhantomData<S>,
261}
262
263impl<S, D> Orchestrator<S, D>
264where
265    S: Sandbox,
266    D: TestFramework,
267{
268    /// Creates a new orchestrator with the given components.
269    ///
270    /// # Arguments
271    ///
272    /// * `config` - Configuration loaded from TOML
273    /// * `framework` - Test framework for running tests
274    /// * `verbose` - Whether to show verbose output (streaming test output)
275    pub fn new(config: Config, framework: D, verbose: bool) -> Self {
276        Self {
277            config,
278            framework,
279            verbose,
280            _sandbox: std::marker::PhantomData,
281        }
282    }
283
284    /// Runs the given tests and returns the aggregated results.
285    ///
286    /// Takes already-discovered tests as input, allowing callers to
287    /// inspect or filter tests before execution. Results are recorded
288    /// into each `TestRecord` via interior mutability.
289    ///
290    /// # Arguments
291    ///
292    /// * `tests` - The tests to run (typically from [`discover`](Self::discover))
293    /// * `sandbox_pool` - Pool of sandboxes to use
294    ///
295    /// # Returns
296    ///
297    /// [`RunResult`] containing summary statistics and individual results.
298    ///
299    /// # Errors
300    ///
301    /// Returns an error if critical infrastructure errors occur.
302    pub async fn run_with_tests(
303        &self,
304        tests: &[TestRecord],
305        mut sandbox_pool: SandboxPool<S>,
306    ) -> anyhow::Result<RunResult> {
307        let start = std::time::Instant::now();
308
309        // Load test durations from previous junit.xml for LPT scheduling
310        let junit_path = self
311            .config
312            .report
313            .output_dir
314            .join(&self.config.report.junit_file);
315        let durations = load_test_durations(&junit_path, self.config.framework.test_id_format());
316
317        // Ensure output directory exists (don't clear - junit.xml will be overwritten when ready)
318        let output_dir = &self.config.report.output_dir;
319        std::fs::create_dir_all(output_dir).ok();
320
321        // Clear parts directory from previous run
322        let parts_dir = output_dir.join("junit-parts");
323        if parts_dir.exists() {
324            if let Err(e) = std::fs::remove_dir_all(&parts_dir) {
325                warn!("Failed to clear parts directory: {}", e);
326            } else {
327                debug!("Cleared parts directory: {:?}", parts_dir);
328            }
329        }
330        std::fs::create_dir_all(&parts_dir).ok();
331
332        if tests.is_empty() {
333            warn!("No tests to run");
334            return Ok(RunResult {
335                total_tests: 0,
336                passed: 0,
337                failed: 0,
338                skipped: 0,
339                flaky: 0,
340                not_run: 0,
341                duration: start.elapsed(),
342                results: Vec::new(),
343            });
344        }
345
346        // Set up progress bar
347        let total_instances: usize = tests
348            .iter()
349            .filter(|t| !t.skipped)
350            .map(|t| t.retry_count + 1)
351            .sum();
352        let progress = indicatif::ProgressBar::new(total_instances as u64);
353        if let Ok(style) = indicatif::ProgressStyle::default_bar().template(
354            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
355        ) {
356            progress.set_style(style.progress_chars("#>-"));
357        }
358        progress.enable_steady_tick(std::time::Duration::from_millis(100));
359
360        // Filter out skipped tests and create Test handles
361        // For tests with retry_count > 0, create multiple instances to run in parallel
362        let tests_to_run: Vec<TestInstance<'_>> = tests
363            .iter()
364            .filter(|t| !t.skipped)
365            .flat_map(|t| {
366                let count = t.retry_count + 1; // 1 original + retry_count retries
367                (0..count).map(move |_| t.test())
368            })
369            .collect();
370
371        let skipped_count = tests.len() - tests.iter().filter(|t| !t.skipped).count();
372
373        // Schedule tests using LPT (Longest Processing Time First) if we have durations,
374        // otherwise fall back to round-robin with a warning.
375        let scheduler = Scheduler::new(self.config.offload.max_parallel);
376        let batches = if durations.is_empty() {
377            warn!(
378                "No historical test durations found at {}. Falling back to round-robin scheduling. \
379                 Run tests once to generate junit.xml for optimized LPT scheduling.",
380                junit_path.display()
381            );
382            scheduler.schedule(&tests_to_run)
383        } else {
384            debug!(
385                "Using LPT scheduling with {} historical durations from {}",
386                durations.len(),
387                junit_path.display()
388            );
389            // Default duration for unknown tests: 1 second (conservative estimate)
390            scheduler.schedule_lpt(&tests_to_run, &durations, std::time::Duration::from_secs(1))
391        };
392
393        // Take sandboxes from pool - must match batch count
394        let sandboxes = sandbox_pool.take_all();
395        assert_eq!(
396            sandboxes.len(),
397            batches.len(),
398            "sandbox count ({}) must match batch count ({})",
399            sandboxes.len(),
400            batches.len()
401        );
402
403        // Log batch distribution
404        info!(
405            "[ORCHESTRATOR] Scheduled {} tests into {} batches with {} sandboxes",
406            tests_to_run.len(),
407            batches.len(),
408            sandboxes.len()
409        );
410        for (i, batch) in batches.iter().enumerate() {
411            info!("[ORCHESTRATOR] Batch {}: {} tests", i, batch.len());
412        }
413        let total_in_batches: usize = batches.iter().map(|b| b.len()).sum();
414        info!(
415            "[ORCHESTRATOR] Total tests across all batches: {} (should equal {})",
416            total_in_batches,
417            tests_to_run.len()
418        );
419
420        // Shared JUnit report for accumulating results and early stopping
421        let total_tests_to_run = tests.iter().filter(|t| !t.skipped).count();
422        let junit_report = Arc::new(std::sync::Mutex::new(MasterJunitReport::new(
423            total_tests_to_run,
424        )));
425        let all_passed = Arc::new(AtomicBool::new(false));
426        let cancellation_token = CancellationToken::new();
427
428        // Collect sandboxes back after use for termination
429        let sandboxes_for_cleanup = Arc::new(Mutex::new(Vec::new()));
430
431        // Run tests in parallel
432        // Execute batches concurrently using scoped spawns (no 'static required)
433        tokio_scoped::scope(|scope| {
434            for (batch_idx, (sandbox, batch)) in sandboxes.into_iter().zip(batches).enumerate() {
435                let framework = &self.framework;
436                let config = &self.config;
437                let progress = &progress;
438                let verbose = self.verbose;
439                let junit_report = Arc::clone(&junit_report);
440                let all_passed = Arc::clone(&all_passed);
441                let cancellation_token = cancellation_token.clone();
442                let sandboxes_for_cleanup = Arc::clone(&sandboxes_for_cleanup);
443
444                scope.spawn(async move {
445                    // Early exit if all tests have already passed
446                    if all_passed.load(Ordering::SeqCst) {
447                        let test_ids: Vec<_> = batch.iter().map(|t| t.id()).collect();
448                        info!(
449                            "EARLY STOP: Skipping batch {} ({} tests) - all tests already passed",
450                            batch_idx,
451                            batch.len()
452                        );
453                        debug!("Skipped tests: {:?}", test_ids);
454                        sandboxes_for_cleanup.lock().await.push(sandbox);
455                        return;
456                    }
457
458                    let parts_dir = config.report.output_dir.join("junit-parts");
459                    let mut runner = TestRunner::new(
460                        sandbox,
461                        framework,
462                        Duration::from_secs(config.offload.test_timeout_secs),
463                    )
464                    .with_cancellation_token(cancellation_token.clone())
465                    .with_junit_report(Arc::clone(&junit_report))
466                    .with_parts_dir(parts_dir);
467
468                    // Enable output callback only in verbose mode
469                    if config.offload.stream_output && verbose {
470                        let callback: OutputCallback = Arc::new(|test_id, line| match line {
471                            OutputLine::Stdout(s) => println!("[{}] {}", test_id, s),
472                            OutputLine::Stderr(s) => eprintln!("[{}] {}", test_id, s),
473                            OutputLine::ExitCode(_) => {}
474                        });
475                        runner = runner.with_output_callback(callback);
476                    }
477
478                    // Log test starts in verbose mode
479                    if verbose {
480                        for test in &batch {
481                            println!("Running: {}", test.id());
482                        }
483                    }
484
485                    // Run all tests in batch with a single command
486                    match runner.run_tests(&batch).await {
487                        Ok(true) => {
488                            // Check shared report for early stopping
489                            if let Ok(report) = junit_report.lock()
490                                && report.all_passed()
491                                && !all_passed.load(Ordering::SeqCst)
492                            {
493                                info!(
494                                    "EARLY STOP TRIGGERED: All {} tests have passed after batch {} completed. Cancelling remaining batches.",
495                                    total_tests_to_run,
496                                    batch_idx
497                                );
498                                all_passed.store(true, Ordering::SeqCst);
499                                cancellation_token.cancel();
500                            }
501                        }
502                        Ok(false) => {
503                            // Batch was cancelled - no results to record
504                            debug!("Batch {} was cancelled", batch_idx);
505                        }
506                        Err(e) => {
507                            error!("Batch execution error: {}", e);
508                        }
509                    }
510
511                    // Update progress for completed batch
512                    progress.inc(batch.len() as u64);
513
514                    // Collect sandbox for cleanup
515                    let sandbox = runner.into_sandbox();
516                    sandboxes_for_cleanup.lock().await.push(sandbox);
517                });
518            }
519        });
520
521        // Aggregate results from TestRecords (handles parallel retries automatically)
522        // Get results from the shared JUnit report
523        info!("[ORCHESTRATOR] All batches completed, aggregating results...");
524        let (passed, failed, flaky_count, total_in_report) = if let Ok(report) = junit_report.lock()
525        {
526            let summary = report.summary();
527            let total = report.total_count();
528            info!(
529                "[ORCHESTRATOR] Master report: {} total unique tests, {} passed, {} failed, {} flaky",
530                total, summary.0, summary.1, summary.2
531            );
532            (summary.0, summary.1, summary.2, total)
533        } else {
534            (0, 0, 0, 0)
535        };
536
537        // Check for missing tests
538        let expected_unique_tests = tests.iter().filter(|t| !t.skipped).count();
539        if total_in_report < expected_unique_tests {
540            error!(
541                "[ORCHESTRATOR MISMATCH] Expected {} unique tests but only {} in report! {} TESTS MISSING!",
542                expected_unique_tests,
543                total_in_report,
544                expected_unique_tests - total_in_report
545            );
546        } else {
547            info!(
548                "[ORCHESTRATOR] All {} expected tests accounted for in report",
549                expected_unique_tests
550            );
551        }
552
553        // Write the JUnit report to file
554        if self.config.report.junit {
555            let output_path = self
556                .config
557                .report
558                .output_dir
559                .join(&self.config.report.junit_file);
560            if let Ok(report) = junit_report.lock()
561                && let Err(e) = report.write_to_file(&output_path)
562            {
563                warn!("Failed to write JUnit XML: {}", e);
564            }
565        }
566
567        // Use JUnit report as source of truth for all counts
568        let total_discovered = tests.iter().filter(|t| !t.skipped).count();
569        let total_in_junit = if let Ok(report) = junit_report.lock() {
570            report.total_count()
571        } else {
572            0
573        };
574        let not_run = total_discovered.saturating_sub(total_in_junit);
575
576        // Use the JUnit total as the authoritative count (passed + failed + flaky = total)
577        // This ensures passed can never exceed total
578        let run_result = RunResult {
579            total_tests: total_in_junit,
580            passed: passed + flaky_count, // Flaky tests count as passed
581            failed,
582            skipped: skipped_count,
583            flaky: flaky_count,
584            not_run,
585            duration: start.elapsed(),
586            results: Vec::new(), // Results are in JUnit XML now
587        };
588
589        progress.finish_and_clear();
590        print_summary(&run_result);
591
592        // Terminate all sandboxes in parallel (after printing results)
593        let sandboxes: Vec<_> = sandboxes_for_cleanup.lock().await.drain(..).collect();
594        let terminate_futures = sandboxes.into_iter().map(|sandbox| async move {
595            if let Err(e) = sandbox.terminate().await {
596                warn!("Failed to terminate sandbox {}: {}", sandbox.id(), e);
597            }
598        });
599        futures::future::join_all(terminate_futures).await;
600
601        Ok(run_result)
602    }
603}