ktest/test/
runner.rs

1use conquer_once::spin::OnceCell;
2use heapless::{format, String};
3use spin::RwLock;
4use crate::{args, qemu, serial_print, serial_println, test::{self, outcome::TestResult, Ignore, ShouldPanic, TestCase}, MAX_STRING_LENGTH};
5
6/// A static reference to the list of test functions to run. This is unsafe but only set 
7/// once at the start of runner. The static nature of the tests makes it impossible to use 
8/// OnceCell, Mutex, or RwLock here (at least their no_std variants).
9static mut TESTS: &'static [&'static dyn TestCase] = &[];
10
11/// The global test runner instance. This is initialized once at the start of runner.
12pub static TEST_RUNNER: OnceCell<KernelTestRunner> = OnceCell::uninit();
13
14/// Tracker for the curent test index (corresponding to the index in TESTS)
15pub static CURRENT_TEST_INDEX: OnceCell<RwLock<usize>> = OnceCell::new(RwLock::new(0));
16
17/// Tracker for the current module name, to print headers when it changes
18pub static CURRENT_MODULE: OnceCell<RwLock<&'static str>> = OnceCell::new(RwLock::new(""));
19
20/// A test runner that runs the given tests and exits QEMU after completion.
21/// 
22/// Output from this runner is formatted as line-delimited JSON and printed to the debug 
23/// console. This allows for easy parsing of test results by external tools, such as `kboot`.
24pub fn runner(tests: &'static [&'static dyn TestCase]) -> ! {
25    unsafe { TESTS = tests; }
26
27    TEST_RUNNER.get_or_init(|| KernelTestRunner::default());
28    TEST_RUNNER.get().unwrap().run_tests(0)
29}
30
31/// A trait defining the behavior of a test runner.
32pub trait TestRunner {
33    /// Runs once before all tests.
34    fn before_tests(&self);
35    /// Runs all tests starting from the given index.
36    fn run_tests(&self, start_index: usize) -> !;
37    /// Runs once after all tests.
38    fn after_tests(&self) -> !;
39    /// Called at the start of each test, returns the starting cycle number.
40    fn start_test(&self) -> u64;
41    /// Called at the end of each test, with the result and starting cycle number.
42    fn complete_test(&self, result: TestResult, cycle_start: u64);
43    /// Returns the currently running test, if any.
44    fn current_test(&self) -> Option<&'static dyn TestCase>;
45    /// Called when a test panics. This should print the panic information, mark the current
46    /// test as failed, and continue with the next test (if possible).
47    fn handle_panic(&self, info: &core::panic::PanicInfo) -> !;
48}
49
50/// A kernel test runner that runs all tests sequentially and exits QEMU after completion.
51#[derive(Default)]
52pub struct KernelTestRunner;
53
54impl TestRunner for KernelTestRunner {
55    fn before_tests(&self) {
56        let test_group = args::get_test_group().unwrap_or("default");
57        let tests = unsafe { TESTS };
58
59        test::output::write_test_group(test_group, tests.len());
60        test::output::write_test_names(tests);
61    }
62
63    fn run_tests(&self, start_index: usize) -> ! {
64        if start_index == 0 { // dont run before_tests if resuming
65            self.before_tests();
66        }
67        
68        let tests = unsafe { TESTS };
69        for (i, &test) in tests.iter().enumerate().skip(start_index) {
70            let cycle_start = self.start_test();
71
72            match test.ignore() {
73                Ignore::No => {
74                    test.run();
75                    self.complete_test(TestResult::Success, cycle_start);
76                }
77                Ignore::Yes => {
78                    self.complete_test(TestResult::Ignore, cycle_start);
79                }
80            }
81
82            if !increment_test_index(i) {
83                break; // no more tests to run
84            }
85        }
86        self.after_tests()
87    }
88
89    fn after_tests(&self) -> ! {
90        qemu::exit(qemu::ExitCode::Success)
91    }
92
93    fn start_test(&self) -> u64 {
94        // check if the current module has changed; if so, reassign it and print a header
95        let current_test = self.current_test().unwrap();
96        
97        let module_path = current_test.modules().unwrap_or("unknown_module");
98        {
99            let mut current_module = CURRENT_MODULE.get().unwrap().write();
100            if *current_module != module_path {
101                *current_module = module_path;
102
103                let module_test_count = count_by_module(module_path);
104                let test_group = args::get_test_group().unwrap_or("default");
105                serial_println!("\n################################################################");
106                serial_println!("# Running {} {} tests for module: {}", module_test_count, test_group, module_path);
107                serial_println!("----------------------------------------------------------------");
108            }
109        } // scope will release the lock here
110
111        // print the test name with padding for aligned results
112        print_test_name(current_test.name(), 58);
113
114        // return the current cycle (for duration calculation later)
115        read_current_cycle()
116    }
117
118    fn complete_test(&self, result: TestResult, cycle_start: u64) {
119        let cycle_count = if cycle_start != u64::MAX { // u64::MAX = unknown
120            read_current_cycle() - cycle_start
121        } else {
122            0
123        };
124
125        match result {
126            TestResult::Success => {
127                let current_test = self.current_test().unwrap();
128                let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
129                test::output::write_test_success(&test_name, cycle_count);
130                serial_println!("[pass]");
131            }
132            TestResult::Failure => {
133                // panic handler will print [fail] with details (and same for JSON output)
134            }
135            TestResult::Ignore => {
136                let current_test = self.current_test().unwrap();
137                let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
138                test::output::write_test_ignore(&test_name);
139                serial_println!("[ignore]");
140            }
141        }
142    }
143
144    fn current_test(&self) -> Option<&'static dyn TestCase> {
145        let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
146        let tests = unsafe { TESTS };
147        tests.get(current_index).copied()
148    }
149
150    fn handle_panic(&self, info: &core::panic::PanicInfo) -> ! {
151        // finish the test output, replaces [pass] with panic details
152        let location = if let Some(location) = info.location() {
153            format!("{}:{}", location.file(), location.line()).unwrap()
154        } else {
155            String::<MAX_STRING_LENGTH>::try_from("unknown location").unwrap()
156        };
157        let message = info.message().as_str().unwrap_or("no message");
158
159        let current_test = self.current_test().unwrap();
160        let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
161
162        // handle according to whether the test was expected to panic
163        match current_test.should_panic() {
164            ShouldPanic::No => {
165                serial_println!("[fail] @ {}: {}", location, message); // expected that the line already has "test_name... "
166                test::output::write_test_failure(&test_name, location.as_str(), message);
167                self.complete_test(TestResult::Failure, u64::MAX);
168            }
169            ShouldPanic::Yes => {
170                self.complete_test(TestResult::Success, u64::MAX);
171            }
172        }
173
174        // increment the test index to move to the next test (if possible)
175        let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
176        if !increment_test_index(current_index) {
177            qemu::exit(qemu::ExitCode::Success); // no more tests to run
178        }
179
180        // continue with the next test (and all thereafter)
181        self.run_tests(current_index + 1) // continue with next test
182    }
183}
184
185/// Helper function to read the current CPU cycle count using the RDTSC instruction.
186fn read_current_cycle() -> u64 {
187    unsafe { core::arch::x86_64::_rdtsc() }
188}
189
190/// Helper function to assign base + 1 to CURRENT_TEST_INDEX.
191fn increment_test_index(base: usize) -> bool {
192    let mut current_test_index = CURRENT_TEST_INDEX.get().unwrap().write();
193    
194    let tests = unsafe { TESTS };
195    if *current_test_index >= tests.len() {
196        return false; // no more tests to run
197    }
198    
199    *current_test_index = base + 1;
200    true
201}
202
203/// Helper function to count the number of tests in a given module.
204fn count_by_module(module_name: &str) -> usize {
205    let tests = unsafe { TESTS };
206    tests.iter()
207        .filter(|&&test| test.modules().unwrap_or("") == module_name)
208        .count()
209}
210
211/// Helper to write function names with padding for aligned results
212fn print_test_name(name: &str, result_column: usize) {
213    if name.len() >= result_column {
214        serial_print!("{} ", name); // no padding if name is too long
215        return;
216    }
217
218    let padding = result_column - name.len();
219    serial_print!("{}", name);
220    for _ in 0..padding {
221        serial_print!(" ");
222    }
223}