cubek_test_utils/test_mode/base.rs
1//! # Test Mode
2//!
3//! Control how tests handle numerical and compilation errors via the environment variable
4//! `CUBE_TEST_MODE`.
5//!
6//! ## Modes
7//! - `Correct` (default): Numerical errors fail the test, compilation errors are ignored.
8//! - `Strict`: Both numerical and compilation errors fail the test.
9//! - `Print { filter, fail_only }`:
10//! - `printall:<filter>` — all tests fail and matching elements are printed.
11//! - `printfail:<filter>` — only numerical errors fail and matching elements are printed.
12//!
13//! ## Filter Expressions
14//! - The filter is **optional**. If omitted, all elements of the tensor are included.
15//! - When specified, it is a comma-separated list of dimensions, supporting:
16//! - `.` to indicate a wildcard (all indices along that dimension),
17//! - `N` for a single index,
18//! - `M-K` for a range of indices.
19//! - Example for a 4D tensor: `.,.,10-20,30` selects all elements where the 3rd dimension
20//! is 10–20 and the 4th dimension is 30, any values for the first two dimensions.
21//! - **Important:** The number of entries in the filter must match the rank of the tensor.
22//!
23//! ## Examples
24//!
25//! ```bash
26//! # Default mode: only numerical errors fail
27//! export CUBE_TEST_MODE=Correct
28//!
29//! # Strict mode: all errors fail
30//! export CUBE_TEST_MODE=Strict
31//!
32//! # Print all elements (no filter specified)
33//! export CUBE_TEST_MODE=PrintAll
34//!
35//! # Print all elements in a subset of dimensions
36//! export CUBE_TEST_MODE=PrintAll:.,10-20
37//!
38//! # Print only failing numerical elements
39//! export CUBE_TEST_MODE=PrintFail:.,10-20
40//! ```
41
42use crate::{
43 TestDecision, TestOutcome, ValidationResult,
44 correctness::{TensorFilter, parse_tensor_filter},
45};
46
47const CUBE_TEST_MODE_ENV: &str = "CUBE_TEST_MODE";
48
49#[derive(Default, Debug, Clone)]
50pub enum TestMode {
51 #[default]
52 /// Numerical errors cause the test to fail.
53 /// Compilation errors are accepted (do not fail the test).
54 Correct,
55
56 /// Both numerical and compilation errors cause the test to fail.
57 Strict,
58
59 /// All tests can be printed according to the given `filter`.
60 /// `fail_only = true`: only tests with numerical errors are marked as failed and printed.
61 /// `fail_only = false`: all tests are marked as failed and printed.
62 Print {
63 filter: TensorFilter,
64 fail_only: bool,
65 },
66
67 /// Fail only if the test successfully runs.
68 /// Compilation failures are ignored.
69 ///
70 /// Helpful to isolate relevant tests
71 FailIfRun,
72}
73
74impl TestMode {
75 pub fn decide(&self, outcome: TestOutcome) -> TestDecision {
76 use TestDecision::*;
77 use TestMode::*;
78 use TestOutcome::*;
79 use ValidationResult::*;
80
81 match self {
82 Correct => match outcome {
83 Validated(result) => match result {
84 Pass => Accept,
85 Fail(reason) => Reject(reason),
86 Skipped(reason) => Reject(reason),
87 },
88 CompileError(_) => Accept,
89 },
90 Strict => match outcome {
91 Validated(result) => match result {
92 Pass => Accept,
93 Fail(reason) => Reject(reason),
94 Skipped(reason) => Reject(reason),
95 },
96 CompileError(reason) => Reject(reason),
97 },
98 Print {
99 filter: _,
100 fail_only,
101 } => match outcome {
102 Validated(result) => match result {
103 Pass => {
104 if *fail_only {
105 Accept
106 } else {
107 Reject("printed".into())
108 }
109 }
110 Fail(reason) => Reject(reason),
111 Skipped(reason) => Reject(reason),
112 },
113
114 CompileError(reason) => {
115 if *fail_only {
116 Accept
117 } else {
118 Reject(reason)
119 }
120 }
121 },
122 FailIfRun => match outcome {
123 Validated(result) => match result {
124 Pass => Reject("Actually passed, but FailIfRun mode activated".to_string()),
125 Fail(_) => Accept,
126 Skipped(_) => Accept,
127 },
128 CompileError(_) => Accept,
129 },
130 }
131 }
132}
133
134pub fn current_test_mode() -> TestMode {
135 let val = match std::env::var(CUBE_TEST_MODE_ENV) {
136 Ok(v) => v.to_lowercase(),
137 Err(_) => return TestMode::Correct,
138 };
139
140 if let Some(print_mode) = val.strip_prefix("printall") {
141 parse_print_mode(print_mode, false)
142 } else if let Some(print_mode) = val.strip_prefix("printfail") {
143 parse_print_mode(print_mode, true)
144 } else if val == "strict" {
145 TestMode::Strict
146 } else if val == "failifrun" {
147 TestMode::FailIfRun
148 } else {
149 TestMode::Correct
150 }
151}
152
153fn parse_print_mode(suffix: &str, fail_only: bool) -> TestMode {
154 let filter = if let Some(rest) = suffix.strip_prefix(':') {
155 match parse_tensor_filter(rest) {
156 Ok(f) => f,
157 Err(e) => {
158 eprintln!("Invalid print filter '{}': {}", rest, e);
159 vec![]
160 }
161 }
162 } else {
163 vec![]
164 };
165
166 TestMode::Print { filter, fail_only }
167}