Skip to main content

offload/
orchestrator.rs

1//! Test orchestration: discovery, scheduling, parallel execution, and result aggregation.
2pub mod pool;
3pub mod runner;
4pub mod scheduler;
5pub mod spawn;
6
7use std::collections::VecDeque;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::time::Duration;
11
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, error, info, warn};
14
15use crate::config::Config;
16use crate::framework::{TestFramework, TestInstance, TestRecord, TestResult};
17use crate::provider::Sandbox;
18use crate::report::{MasterJunitReport, load_test_durations, print_summary};
19
20pub use pool::SandboxPool;
21pub use runner::{BatchOutcome, OutputCallback, TestRunner};
22pub use scheduler::Scheduler;
23
24/// Maximum expected duration for a single batch of tests.
25///
26/// Batches produced by LPT scheduling will not exceed this duration
27/// (unless a single test already exceeds it). This keeps batches
28/// small enough for fast feedback and efficient retry granularity.
29const MAX_BATCH_DURATION: Duration = Duration::from_secs(10);
30
31/// Aggregated results of an entire test run.
32///
33/// Contains summary statistics and individual test results. This is the
34/// return value of [`Orchestrator::run`] and is passed to reporters
35/// for final output.
36///
37/// # Exit Codes
38///
39/// The [`exit_code`](Self::exit_code) method returns conventional exit codes:
40///
41/// | Code | Meaning |
42/// |------|---------|
43/// | 0 | All tests passed |
44/// | 1 | Some tests failed or weren't run |
45/// | 2 | All tests passed but some were flaky |
46#[derive(Debug, Clone)]
47pub struct RunResult {
48    /// Total number of tests discovered.
49    pub total_tests: usize,
50
51    /// Number of tests that passed.
52    pub passed: usize,
53
54    /// Number of tests that failed (assertions or errors).
55    pub failed: usize,
56
57    /// Number of tests that were skipped.
58    pub skipped: usize,
59
60    /// Number of tests that were flaky (passed on retry).
61    ///
62    /// A flaky test is one that failed initially but passed after retrying.
63    pub flaky: usize,
64
65    /// Number of tests that couldn't be run.
66    ///
67    /// Typically due to sandbox creation failures or infrastructure issues.
68    pub not_run: usize,
69
70    /// Wall-clock duration of the entire test run.
71    pub duration: Duration,
72
73    /// Individual test results for all executed tests.
74    pub results: Vec<TestResult>,
75}
76
77impl RunResult {
78    /// Returns `true` if the test run was successful.
79    ///
80    /// A run is successful if no tests failed and all scheduled tests
81    /// were executed. Flaky tests are considered successful (they passed
82    /// on retry).
83    pub fn success(&self) -> bool {
84        self.failed == 0 && self.not_run == 0
85    }
86
87    /// Returns an appropriate process exit code for this result.
88    pub fn exit_code(&self) -> i32 {
89        if self.failed > 0 || self.not_run > 0 {
90            1
91        } else if self.flaky > 0 {
92            2 // 2 is the convention that offload has decided to store for flakiness
93        } else {
94            0
95        }
96    }
97}
98
99/// The main orchestrator that coordinates test execution.
100///
101/// The orchestrator is the top-level component that ties together:
102/// - A pre-populated [`SandboxPool`] of execution environments
103/// - A [`TestFramework`] for running tests
104///
105/// It manages the full lifecycle of a test run: scheduling,
106/// parallel execution, retries, and result aggregation.
107///
108/// # Type Parameters
109///
110/// - `S`: The sandbox type (implements [`Sandbox`](crate::provider::Sandbox))
111/// - `D`: The test framework type
112///
113pub struct Orchestrator<S, D> {
114    config: Config,
115    framework: D,
116    verbose: bool,
117    tracer: crate::trace::Tracer,
118    _sandbox: std::marker::PhantomData<S>,
119}
120
121impl<S, D> Orchestrator<S, D>
122where
123    S: Sandbox,
124    D: TestFramework,
125{
126    /// Creates a new orchestrator with the given components.
127    ///
128    /// # Arguments
129    ///
130    /// * `config` - Configuration loaded from TOML
131    /// * `framework` - Test framework for running tests
132    /// * `verbose` - Whether to show verbose output (streaming test output)
133    /// * `tracer` - Performance tracer for emitting trace events
134    pub fn new(config: Config, framework: D, verbose: bool, tracer: crate::trace::Tracer) -> Self {
135        Self {
136            config,
137            framework,
138            verbose,
139            tracer,
140            _sandbox: std::marker::PhantomData,
141        }
142    }
143
144    /// Runs the given tests and returns the aggregated results.
145    ///
146    /// Takes already-discovered tests as input, allowing callers to
147    /// inspect or filter tests before execution. Results are recorded
148    /// into each `TestRecord` via interior mutability.
149    ///
150    /// # Arguments
151    ///
152    /// * `tests` - The tests to run (typically from [`discover`](Self::discover))
153    /// * `sandbox_pool` - Pool of sandboxes to use
154    ///
155    /// # Returns
156    ///
157    /// [`RunResult`] containing summary statistics and individual results.
158    ///
159    /// # Errors
160    ///
161    /// Returns an error if critical infrastructure errors occur.
162    pub async fn run_with_tests(
163        &self,
164        tests: &[TestRecord],
165        mut sandbox_pool: SandboxPool<S>,
166    ) -> anyhow::Result<RunResult> {
167        let start = std::time::Instant::now();
168
169        // Load test durations from previous junit.xml for LPT scheduling
170        let _dur_span = self.tracer.span(
171            "duration_loading",
172            "orchestrator",
173            crate::trace::PID_LOCAL,
174            crate::trace::TID_MAIN,
175        );
176        let junit_path = self
177            .config
178            .report
179            .output_dir
180            .join(&self.config.report.junit_file);
181        let durations = load_test_durations(&junit_path, self.config.framework.test_id_format());
182        drop(_dur_span);
183
184        // Ensure output directory exists (don't clear - junit.xml will be overwritten when ready)
185        let output_dir = &self.config.report.output_dir;
186        std::fs::create_dir_all(output_dir).ok();
187
188        // Clear parts directory from previous run
189        let parts_dir = output_dir.join("junit-parts");
190        if parts_dir.exists() {
191            if let Err(e) = std::fs::remove_dir_all(&parts_dir) {
192                warn!("Failed to clear parts directory: {}", e);
193            } else {
194                debug!("Cleared parts directory: {:?}", parts_dir);
195            }
196        }
197        std::fs::create_dir_all(&parts_dir).ok();
198
199        if tests.is_empty() {
200            warn!("No tests to run");
201            return Ok(RunResult {
202                total_tests: 0,
203                passed: 0,
204                failed: 0,
205                skipped: 0,
206                flaky: 0,
207                not_run: 0,
208                duration: start.elapsed(),
209                results: Vec::new(),
210            });
211        }
212
213        // Set up progress bar
214        let total_instances: usize = tests
215            .iter()
216            .filter(|t| !t.skipped)
217            .map(|t| t.retry_count + 1)
218            .sum();
219        let progress = indicatif::ProgressBar::new(total_instances as u64);
220        if let Ok(style) = indicatif::ProgressStyle::default_bar().template(
221            "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
222        ) {
223            progress.set_style(style.progress_chars("#>-"));
224        }
225        progress.enable_steady_tick(std::time::Duration::from_millis(100));
226
227        // Filter out skipped tests and create Test handles
228        // For tests with retry_count > 0, create multiple instances to run in parallel
229        let tests_to_run: Vec<TestInstance<'_>> = tests
230            .iter()
231            .filter(|t| !t.skipped)
232            .flat_map(|t| {
233                let count = t.retry_count + 1; // 1 original + retry_count retries
234                (0..count).map(move |_| t.test())
235            })
236            .collect();
237
238        let skipped_count = tests.len() - tests.iter().filter(|t| !t.skipped).count();
239
240        // Schedule tests using LPT (Longest Processing Time First) if we have durations,
241        // otherwise fall back to round-robin with a warning.
242        let _sched_span = self.tracer.span(
243            "scheduling",
244            "orchestrator",
245            crate::trace::PID_LOCAL,
246            crate::trace::TID_MAIN,
247        );
248        let scheduler = Scheduler::new(self.config.offload.max_parallel);
249        let batches = if durations.is_empty() {
250            warn!(
251                "No historical test durations found at {}. Falling back to round-robin scheduling. \
252                 Run tests once to generate junit.xml for optimized LPT scheduling.",
253                junit_path.display()
254            );
255            scheduler.schedule(&tests_to_run)
256        } else {
257            debug!(
258                "Using LPT scheduling with {} historical durations from {}",
259                durations.len(),
260                junit_path.display()
261            );
262            // Default duration for unknown tests: 1 second (conservative estimate)
263            scheduler.schedule_lpt(
264                &tests_to_run,
265                &durations,
266                Duration::from_secs(1),
267                Some(MAX_BATCH_DURATION),
268            )
269        };
270        drop(_sched_span);
271
272        // Take sandboxes from pool
273        let sandboxes = sandbox_pool.take_all();
274
275        // Log batch distribution
276        info!(
277            "[ORCHESTRATOR] Scheduled {} tests into {} batches with {} sandboxes",
278            tests_to_run.len(),
279            batches.len(),
280            sandboxes.len()
281        );
282        for (i, batch) in batches.iter().enumerate() {
283            info!("[ORCHESTRATOR] Batch {}: {} tests", i, batch.len());
284        }
285        let total_in_batches: usize = batches.iter().map(|b| b.len()).sum();
286        info!(
287            "[ORCHESTRATOR] Total tests across all batches: {} (should equal {})",
288            total_in_batches,
289            tests_to_run.len()
290        );
291
292        // Shared JUnit report for accumulating results and early stopping
293        let total_tests_to_run = tests.iter().filter(|t| !t.skipped).count();
294        let junit_report = Arc::new(std::sync::Mutex::new(MasterJunitReport::new(
295            total_tests_to_run,
296        )));
297        let all_passed = Arc::new(AtomicBool::new(false));
298        let cancellation_token = CancellationToken::new();
299
300        // Collect sandboxes back after use for termination
301        let sandboxes_for_cleanup = Arc::new(std::sync::Mutex::new(Vec::new()));
302
303        // Create/truncate logs directory for per-runner output
304        let logs_dir = output_dir.join("logs");
305        if logs_dir.exists()
306            && let Err(e) = std::fs::remove_dir_all(&logs_dir)
307        {
308            warn!("Failed to clear logs directory: {}", e);
309        }
310        std::fs::create_dir_all(&logs_dir).ok();
311
312        // Queue-based batching: workers pull from a shared queue
313        let queue = Arc::new(std::sync::Mutex::new(VecDeque::from(batches)));
314        let batch_counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
315
316        // Emit per-sandbox metadata events for trace
317        for i in 0..sandboxes.len() {
318            let pid = crate::trace::sandbox_pid(i);
319            self.tracer.metadata_event(
320                "process_name",
321                pid,
322                crate::trace::TID_API,
323                serde_json::json!({"name": format!("Sandbox {}", i)}),
324            );
325            self.tracer.metadata_event(
326                "thread_name",
327                pid,
328                crate::trace::TID_API,
329                serde_json::json!({"name": "API"}),
330            );
331            self.tracer.metadata_event(
332                "thread_name",
333                pid,
334                crate::trace::TID_EXEC,
335                serde_json::json!({"name": "Exec"}),
336            );
337            self.tracer.metadata_event(
338                "thread_name",
339                pid,
340                crate::trace::TID_IO,
341                serde_json::json!({"name": "I/O"}),
342            );
343        }
344
345        // Run tests in parallel using queue-based workers
346        tokio_scoped::scope(|scope| {
347            for (sandbox_index, sandbox) in sandboxes.into_iter().enumerate() {
348                let cfg = spawn::SpawnConfig {
349                    config: &self.config,
350                    framework: &self.framework,
351                    queue: Arc::clone(&queue),
352                    progress: &progress,
353                    total_tests_to_run,
354                    all_passed: Arc::clone(&all_passed),
355                    cancellation_token: cancellation_token.clone(),
356                    sandboxes_for_cleanup: Arc::clone(&sandboxes_for_cleanup),
357                    junit_report: Arc::clone(&junit_report),
358                    logs_dir: logs_dir.clone(),
359                    batch_counter: Arc::clone(&batch_counter),
360                    verbose: self.verbose,
361                    tracer: self.tracer.clone(),
362                    sandbox_index,
363                };
364                scope.spawn(spawn::spawn_task(cfg, sandbox));
365            }
366        });
367
368        // Aggregate results from TestRecords (handles parallel retries automatically)
369        // Get results from the shared JUnit report
370        let _agg_span = self.tracer.span(
371            "result_aggregation",
372            "orchestrator",
373            crate::trace::PID_LOCAL,
374            crate::trace::TID_MAIN,
375        );
376        info!("[ORCHESTRATOR] All batches completed, aggregating results...");
377        let (passed, failed, flaky_count, total_in_report) = if let Ok(report) = junit_report.lock()
378        {
379            let summary = report.summary();
380            let total = report.total_count();
381            info!(
382                "[ORCHESTRATOR] Master report: {} total unique tests, {} passed, {} failed, {} flaky",
383                total, summary.0, summary.1, summary.2
384            );
385            (summary.0, summary.1, summary.2, total)
386        } else {
387            (0, 0, 0, 0)
388        };
389
390        // Check for missing tests
391        let expected_unique_tests = tests.iter().filter(|t| !t.skipped).count();
392        if total_in_report < expected_unique_tests {
393            error!(
394                "[ORCHESTRATOR MISMATCH] Expected {} unique tests but only {} in report! {} TESTS MISSING!",
395                expected_unique_tests,
396                total_in_report,
397                expected_unique_tests - total_in_report
398            );
399        } else {
400            info!(
401                "[ORCHESTRATOR] All {} expected tests accounted for in report",
402                expected_unique_tests
403            );
404        }
405
406        // Write the JUnit report to file
407        if self.config.report.junit {
408            let output_path = self
409                .config
410                .report
411                .output_dir
412                .join(&self.config.report.junit_file);
413            if let Ok(report) = junit_report.lock()
414                && let Err(e) = report.write_to_file(&output_path)
415            {
416                warn!("Failed to write JUnit XML: {}", e);
417            }
418        }
419
420        // Use JUnit report as source of truth for all counts
421        let total_discovered = tests.iter().filter(|t| !t.skipped).count();
422        let total_in_junit = if let Ok(report) = junit_report.lock() {
423            report.total_count()
424        } else {
425            0
426        };
427        let not_run = total_discovered.saturating_sub(total_in_junit);
428
429        // Use the JUnit total as the authoritative count (passed + failed + flaky = total)
430        // This ensures passed can never exceed total
431        let run_result = RunResult {
432            total_tests: total_in_junit,
433            passed: passed + flaky_count, // Flaky tests count as passed
434            failed,
435            skipped: skipped_count,
436            flaky: flaky_count,
437            not_run,
438            duration: start.elapsed(),
439            results: Vec::new(), // Results are in JUnit XML now
440        };
441        drop(_agg_span);
442
443        progress.finish_and_clear();
444        print_summary(&run_result);
445
446        // Terminate all sandboxes in parallel (after printing results)
447        let _cleanup_span = self.tracer.span(
448            "sandbox_cleanup",
449            "orchestrator",
450            crate::trace::PID_LOCAL,
451            crate::trace::TID_MAIN,
452        );
453        let sandboxes: Vec<_> = match sandboxes_for_cleanup.lock() {
454            Ok(mut guard) => guard.drain(..).collect(),
455            Err(e) => {
456                error!("sandbox cleanup mutex poisoned: {}", e);
457                Vec::new()
458            }
459        };
460        let terminate_futures = sandboxes.into_iter().map(|sandbox| async move {
461            if let Err(e) = sandbox.terminate().await {
462                warn!("Failed to terminate sandbox {}: {}", sandbox.id(), e);
463            }
464        });
465        futures::future::join_all(terminate_futures).await;
466        drop(_cleanup_span);
467
468        Ok(run_result)
469    }
470}