Skip to main content

cubek_test_utils/test_mode/
launch.rs

1//! Generic launch-and-capture-outcome plumbing shared by every kernel-test
2//! helper.
3//!
4//! Kernel launches can fail in two windows: synchronously (the launch closure
5//! returns `Err`) or asynchronously when the runtime processes the queued
6//! work. Catching the asynchronous case requires an explicit `flush` both
7//! before and after the launch.
8
9use cubecl::{
10    TestRuntime,
11    prelude::ComputeClient,
12    server::{self, LaunchError, ServerError},
13};
14
15use crate::ExecutionOutcome;
16
17/// Run `launch` against `client`, returning its [`ExecutionOutcome`] after
18/// flushing for any compile/launch errors that surface only asynchronously.
19///
20/// The pre-flush also catches stale errors from a prior launch on the same
21/// client — without it, an earlier failure would be attributed to this one.
22pub fn launch_and_capture_outcome<F>(
23    client: &ComputeClient<TestRuntime>,
24    launch: F,
25) -> ExecutionOutcome
26where
27    F: FnOnce(&ComputeClient<TestRuntime>) -> ExecutionOutcome,
28{
29    let outcome = flush_compile_error(client).unwrap_or_else(|| launch(client));
30    match outcome {
31        ExecutionOutcome::Executed => {
32            flush_compile_error(client).unwrap_or(ExecutionOutcome::Executed)
33        }
34        other => other,
35    }
36}
37
38/// Flush `client` and surface any pending compile/launch failure as a
39/// [`ExecutionOutcome::CompileError`].
40///
41/// Returns `None` when the flush is clean (the kernel ran). Other server
42/// errors are wrapped as `CompileError` so callers see one uniform shape.
43pub fn flush_compile_error(client: &ComputeClient<TestRuntime>) -> Option<ExecutionOutcome> {
44    match client.flush() {
45        Ok(_) => None,
46        Err(ServerError::ServerUnhealthy { errors, .. }) => {
47            for error in errors.iter() {
48                if let server::ServerError::Launch(LaunchError::TooManyResources(_))
49                | server::ServerError::Launch(LaunchError::CompilationError(_)) = error
50                {
51                    return Some(ExecutionOutcome::CompileError(format!("{errors:?}")));
52                }
53            }
54            None
55        }
56        Err(err) => Some(ExecutionOutcome::CompileError(format!("{err:?}"))),
57    }
58}