cubek_test_utils/test_mode/
base.rs

1//! # Test Mode
2//!
3//! Control how tests handle numerical and compilation errors via the environment variable
4//! `CUBE_TEST_MODE`.
5//!
6//! ## Modes
7//! - `Correct` (default): Numerical errors fail the test, compilation errors are ignored.
8//! - `Strict`: Both numerical and compilation errors fail the test.
9//! - `Print { filter, fail_only }`:
10//!     - `printall:<filter>` — all tests fail and matching elements are printed.
11//!     - `printfail:<filter>` — only numerical errors fail and matching elements are printed.
12//!
13//! ## Filter Expressions
14//! - The filter is **optional**. If omitted, all elements of the tensor are included.
15//! - When specified, it is a comma-separated list of dimensions, supporting:
16//!     - `.` to indicate a wildcard (all indices along that dimension),  
17//!     - `N` for a single index,  
18//!     - `M-K` for a range of indices.
19//! - Example for a 4D tensor: `.,.,10-20,30` selects all elements where the 3rd dimension
20//!   is 10–20 and the 4th dimension is 30, any values for the first two dimensions.
21//! - **Important:** The number of entries in the filter must match the rank of the tensor.
22//!
23//! ## Examples
24//!
25//! ```bash
26//! # Default mode: only numerical errors fail
27//! export CUBE_TEST_MODE=Correct
28//!
29//! # Strict mode: all errors fail
30//! export CUBE_TEST_MODE=Strict
31//!
32//! # Print all elements (no filter specified)
33//! export CUBE_TEST_MODE=PrintAll
34//!
35//! # Print all elements in a subset of dimensions
36//! export CUBE_TEST_MODE=PrintAll:.,10-20
37//!
38//! # Print only failing numerical elements
39//! export CUBE_TEST_MODE=PrintFail:.,10-20
40//! ```
41
42use crate::{
43    TestDecision, TestOutcome, ValidationResult,
44    correctness::{TensorFilter, parse_tensor_filter},
45};
46
47const CUBE_TEST_MODE_ENV: &str = "CUBE_TEST_MODE";
48
49#[derive(Default, Debug, Clone)]
50pub enum TestMode {
51    #[default]
52    /// Numerical errors cause the test to fail.
53    /// Compilation errors are accepted (do not fail the test).
54    Correct,
55
56    /// Both numerical and compilation errors cause the test to fail.
57    Strict,
58
59    /// All tests can be printed according to the given `filter`.
60    /// `fail_only = true`: only tests with numerical errors are marked as failed and printed.
61    /// `fail_only = false`: all tests are marked as failed and printed.
62    Print {
63        filter: TensorFilter,
64        fail_only: bool,
65    },
66
67    /// Fail only if the test successfully runs.
68    /// Compilation failures are ignored.
69    ///
70    /// Helpful to isolate relevant tests
71    FailIfRun,
72}
73
74impl TestMode {
75    pub fn decide(&self, outcome: TestOutcome) -> TestDecision {
76        use TestDecision::*;
77        use TestMode::*;
78        use TestOutcome::*;
79        use ValidationResult::*;
80
81        match self {
82            Correct => match outcome {
83                Validated(result) => match result {
84                    Pass => Accept,
85                    Fail(reason) => Reject(reason),
86                    Skipped(reason) => Reject(reason),
87                },
88                CompileError(_) => Accept,
89            },
90            Strict => match outcome {
91                Validated(result) => match result {
92                    Pass => Accept,
93                    Fail(reason) => Reject(reason),
94                    Skipped(reason) => Reject(reason),
95                },
96                CompileError(reason) => Reject(reason),
97            },
98            Print {
99                filter: _,
100                fail_only,
101            } => match outcome {
102                Validated(result) => match result {
103                    Pass => {
104                        if *fail_only {
105                            Accept
106                        } else {
107                            Reject("printed".into())
108                        }
109                    }
110                    Fail(reason) => Reject(reason),
111                    Skipped(reason) => Reject(reason),
112                },
113
114                CompileError(reason) => {
115                    if *fail_only {
116                        Accept
117                    } else {
118                        Reject(reason)
119                    }
120                }
121            },
122            FailIfRun => match outcome {
123                Validated(result) => match result {
124                    Pass => Reject("Actually passed, but FailIfRun mode activated".to_string()),
125                    Fail(_) => Accept,
126                    Skipped(_) => Accept,
127                },
128                CompileError(_) => Accept,
129            },
130        }
131    }
132}
133
134pub fn current_test_mode() -> TestMode {
135    let val = match std::env::var(CUBE_TEST_MODE_ENV) {
136        Ok(v) => v.to_lowercase(),
137        Err(_) => return TestMode::Correct,
138    };
139
140    if let Some(print_mode) = val.strip_prefix("printall") {
141        parse_print_mode(print_mode, false)
142    } else if let Some(print_mode) = val.strip_prefix("printfail") {
143        parse_print_mode(print_mode, true)
144    } else if val == "strict" {
145        TestMode::Strict
146    } else if val == "failifrun" {
147        TestMode::FailIfRun
148    } else {
149        TestMode::Correct
150    }
151}
152
153fn parse_print_mode(suffix: &str, fail_only: bool) -> TestMode {
154    let filter = if let Some(rest) = suffix.strip_prefix(':') {
155        match parse_tensor_filter(rest) {
156            Ok(f) => f,
157            Err(e) => {
158                eprintln!("Invalid print filter '{}': {}", rest, e);
159                vec![]
160            }
161        }
162    } else {
163        vec![]
164    };
165
166    TestMode::Print { filter, fail_only }
167}