cubek_test_utils/test_mode/launch.rs
1//! Generic launch-and-capture-outcome plumbing shared by every kernel-test
2//! helper.
3//!
4//! Kernel launches can fail in two windows: synchronously (the launch closure
5//! returns `Err`) or asynchronously when the runtime processes the queued
6//! work. Catching the asynchronous case requires an explicit `flush` both
7//! before and after the launch.
8
9use cubecl::{
10 TestRuntime,
11 prelude::ComputeClient,
12 server::{self, LaunchError, ServerError},
13};
14
15use crate::ExecutionOutcome;
16
17/// Run `launch` against `client`, returning its [`ExecutionOutcome`] after
18/// flushing for any compile/launch errors that surface only asynchronously.
19///
20/// The pre-flush also catches stale errors from a prior launch on the same
21/// client — without it, an earlier failure would be attributed to this one.
22pub fn launch_and_capture_outcome<F>(
23 client: &ComputeClient<TestRuntime>,
24 launch: F,
25) -> ExecutionOutcome
26where
27 F: FnOnce(&ComputeClient<TestRuntime>) -> ExecutionOutcome,
28{
29 let outcome = flush_compile_error(client).unwrap_or_else(|| launch(client));
30 match outcome {
31 ExecutionOutcome::Executed => {
32 flush_compile_error(client).unwrap_or(ExecutionOutcome::Executed)
33 }
34 other => other,
35 }
36}
37
38/// Flush `client` and surface any pending compile/launch failure as a
39/// [`ExecutionOutcome::CompileError`].
40///
41/// Returns `None` when the flush is clean (the kernel ran). Other server
42/// errors are wrapped as `CompileError` so callers see one uniform shape.
43pub fn flush_compile_error(client: &ComputeClient<TestRuntime>) -> Option<ExecutionOutcome> {
44 match client.flush() {
45 Ok(_) => None,
46 Err(ServerError::ServerUnhealthy { errors, .. }) => {
47 for error in errors.iter() {
48 if let server::ServerError::Launch(LaunchError::TooManyResources(_))
49 | server::ServerError::Launch(LaunchError::CompilationError(_)) = error
50 {
51 return Some(ExecutionOutcome::CompileError(format!("{errors:?}")));
52 }
53 }
54 None
55 }
56 Err(err) => Some(ExecutionOutcome::CompileError(format!("{err:?}"))),
57 }
58}