use crate::{
TestDecision, TestOutcome, ValidationResult,
correctness::{TensorFilter, parse_tensor_filter},
};
const CUBE_TEST_MODE_ENV: &str = "CUBE_TEST_MODE";
#[derive(Default, Debug, Clone)]
pub enum TestMode {
#[default]
Correct,
Strict,
Print {
filter: TensorFilter,
fail_only: bool,
},
FailIfRun,
}
impl TestMode {
pub fn decide(&self, outcome: TestOutcome) -> TestDecision {
use TestDecision::*;
use TestMode::*;
use TestOutcome::*;
use ValidationResult::*;
match self {
Correct => match outcome {
Validated(result) => match result {
Pass => Accept,
Fail(reason) => Reject(reason),
Error(reason) => Reject(reason),
Skipped(_) => Accept,
},
CompileError(_) => Accept,
},
Strict => match outcome {
Validated(result) => match result {
Pass => Accept,
Fail(reason) => Reject(reason),
Error(reason) => Reject(reason),
Skipped(_) => Accept,
},
CompileError(reason) => Reject(reason),
},
Print {
filter: _,
fail_only,
} => match outcome {
Validated(result) => match result {
Pass => {
if *fail_only {
Accept
} else {
Reject("printed".into())
}
}
Fail(reason) => Reject(reason),
Error(reason) => Reject(reason),
Skipped(content) => Reject(content),
},
CompileError(reason) => {
if *fail_only {
Accept
} else {
Reject(reason)
}
}
},
FailIfRun => match outcome {
Validated(result) => match result {
Pass => Reject("Actually passed, but FailIfRun mode activated".to_string()),
Fail(_) => Accept,
Error(_) => Accept,
Skipped(_) => Accept,
},
CompileError(_) => Accept,
},
}
}
}
pub fn current_test_mode() -> TestMode {
let val = match std::env::var(CUBE_TEST_MODE_ENV) {
Ok(v) => v.to_lowercase(),
Err(_) => return TestMode::Correct,
};
if let Some(print_mode) = val.strip_prefix("printall") {
parse_print_mode(print_mode, false)
} else if let Some(print_mode) = val.strip_prefix("printfail") {
parse_print_mode(print_mode, true)
} else if val == "strict" {
TestMode::Strict
} else if val == "failifrun" {
TestMode::FailIfRun
} else {
TestMode::Correct
}
}
fn parse_print_mode(suffix: &str, fail_only: bool) -> TestMode {
let filter = if let Some(rest) = suffix.strip_prefix(':') {
match parse_tensor_filter(rest) {
Ok(f) => f,
Err(e) => {
eprintln!("Invalid print filter '{}': {}", rest, e);
vec![]
}
}
} else {
vec![]
};
TestMode::Print { filter, fail_only }
}