use conquer_once::spin::OnceCell;
use heapless::{format, String};
use spin::RwLock;
use crate::{MAX_STRING_LENGTH, args, qemu, serial_print, serial_println, test::{self, Ignore, ShouldPanic, TestCase, outcome::TestResult}};
static mut TESTS: &'static [&'static dyn TestCase] = &[];
pub static TEST_RUNNER: OnceCell<KernelTestRunner> = OnceCell::uninit();
pub static CURRENT_TEST_INDEX: OnceCell<RwLock<usize>> = OnceCell::new(RwLock::new(0));
pub static CURRENT_MODULE: OnceCell<RwLock<&'static str>> = OnceCell::new(RwLock::new(""));
pub fn runner(tests: &'static [&'static dyn TestCase]) -> ! {
unsafe { TESTS = tests; }
TEST_RUNNER.get_or_init(|| KernelTestRunner::default());
TEST_RUNNER.get().unwrap().run_tests(0)
}
pub trait TestRunner {
fn before_tests(&self);
fn run_tests(&self, start_index: usize) -> !;
fn after_tests(&self) -> !;
fn start_test(&self) -> u64;
fn complete_test(&self, result: TestResult, cycle_start: u64);
fn current_test(&self) -> Option<&'static dyn TestCase>;
fn handle_panic(&self, info: &core::panic::PanicInfo) -> !;
}
#[derive(Default)]
pub struct KernelTestRunner;
impl TestRunner for KernelTestRunner {
fn before_tests(&self) {
let test_group = args::get_test_group().unwrap_or("default");
let tests = unsafe { TESTS };
test::output::write_test_group(test_group, tests.len());
}
fn run_tests(&self, start_index: usize) -> ! {
if start_index == 0 { self.before_tests();
}
let tests = unsafe { TESTS };
for (i, &test) in tests.iter().enumerate().skip(start_index) {
let cycle_start = self.start_test();
match test.ignore() {
Ignore::No => {
test.run();
self.complete_test(TestResult::Success, cycle_start);
}
Ignore::Yes => {
self.complete_test(TestResult::Ignore, cycle_start);
}
}
if !increment_test_index(i) {
break; }
}
self.after_tests()
}
fn after_tests(&self) -> ! {
qemu::exit(qemu::ExitCode::Success)
}
fn start_test(&self) -> u64 {
let current_test = self.current_test().unwrap();
let module_path = current_test.modules().unwrap_or("unknown_module");
{
let mut current_module = CURRENT_MODULE.get().unwrap().write();
if *current_module != module_path {
*current_module = module_path;
let module_test_count = count_by_module(module_path);
let test_group = args::get_test_group().unwrap_or("default");
serial_println!("\n################################################################");
serial_println!("# Running {} {} tests for module: {}", module_test_count, test_group, module_path);
serial_println!("----------------------------------------------------------------");
}
}
print_test_name(current_test.name(), 58);
read_current_cycle()
}
fn complete_test(&self, result: TestResult, cycle_start: u64) {
let cycle_count = if cycle_start != u64::MAX { read_current_cycle() - cycle_start
} else {
0
};
match result {
TestResult::Success => {
let current_test = self.current_test().unwrap();
let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
test::output::write_test_success(&test_name, cycle_count);
serial_println!("[pass]");
}
TestResult::Failure => {
}
TestResult::Ignore => {
let current_test = self.current_test().unwrap();
let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
test::output::write_test_ignore(&test_name);
serial_println!("[ignore]");
}
}
}
fn current_test(&self) -> Option<&'static dyn TestCase> {
let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
let tests = unsafe { TESTS };
tests.get(current_index).copied()
}
fn handle_panic(&self, info: &core::panic::PanicInfo) -> ! {
let location = if let Some(location) = info.location() {
format!("{}:{}", location.file(), location.line()).unwrap()
} else {
String::<MAX_STRING_LENGTH>::try_from("unknown location").unwrap()
};
let message = info.message().as_str().unwrap_or("no message");
let current_test = self.current_test().unwrap();
let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
match current_test.should_panic() {
ShouldPanic::No => {
serial_println!("[fail] @ {}: {}", location, message); test::output::write_test_failure(&test_name, location.as_str(), message);
self.complete_test(TestResult::Failure, u64::MAX);
}
ShouldPanic::Yes => {
self.complete_test(TestResult::Success, u64::MAX);
}
}
let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
if !increment_test_index(current_index) {
qemu::exit(qemu::ExitCode::Success); }
self.run_tests(current_index + 1) }
}
fn read_current_cycle() -> u64 {
unsafe { core::arch::x86_64::_rdtsc() }
}
fn increment_test_index(base: usize) -> bool {
let mut current_test_index = CURRENT_TEST_INDEX.get().unwrap().write();
let tests = unsafe { TESTS };
if *current_test_index >= tests.len() {
return false; }
*current_test_index = base + 1;
true
}
fn count_by_module(module_name: &str) -> usize {
let tests = unsafe { TESTS };
tests.iter()
.filter(|&&test| test.modules().unwrap_or("") == module_name)
.count()
}
fn print_test_name(name: &str, result_column: usize) {
if name.len() >= result_column {
serial_print!("{} ", name); return;
}
let padding = result_column - name.len();
serial_print!("{}", name);
for _ in 0..padding {
serial_print!(" ");
}
}