1use conquer_once::spin::OnceCell;
2use heapless::{format, String};
3use spin::RwLock;
4use crate::{args, qemu, serial_print, serial_println, test::{self, outcome::TestResult, Ignore, ShouldPanic, TestCase}, MAX_STRING_LENGTH};
5
6static mut TESTS: &'static [&'static dyn TestCase] = &[];
10
11pub static TEST_RUNNER: OnceCell<KernelTestRunner> = OnceCell::uninit();
13
14pub static CURRENT_TEST_INDEX: OnceCell<RwLock<usize>> = OnceCell::new(RwLock::new(0));
16
17pub static CURRENT_MODULE: OnceCell<RwLock<&'static str>> = OnceCell::new(RwLock::new(""));
19
20pub fn runner(tests: &'static [&'static dyn TestCase]) -> ! {
25 unsafe { TESTS = tests; }
26
27 TEST_RUNNER.get_or_init(|| KernelTestRunner::default());
28 TEST_RUNNER.get().unwrap().run_tests(0)
29}
30
31pub trait TestRunner {
33 fn before_tests(&self);
35 fn run_tests(&self, start_index: usize) -> !;
37 fn after_tests(&self) -> !;
39 fn start_test(&self) -> u64;
41 fn complete_test(&self, result: TestResult, cycle_start: u64);
43 fn current_test(&self) -> Option<&'static dyn TestCase>;
45 fn handle_panic(&self, info: &core::panic::PanicInfo) -> !;
48}
49
50#[derive(Default)]
52pub struct KernelTestRunner;
53
54impl TestRunner for KernelTestRunner {
55 fn before_tests(&self) {
56 let test_group = args::get_test_group().unwrap_or("default");
57 let tests = unsafe { TESTS };
58
59 test::output::write_test_group(test_group, tests.len());
60 test::output::write_test_names(tests);
61 }
62
63 fn run_tests(&self, start_index: usize) -> ! {
64 if start_index == 0 { self.before_tests();
66 }
67
68 let tests = unsafe { TESTS };
69 for (i, &test) in tests.iter().enumerate().skip(start_index) {
70 let cycle_start = self.start_test();
71
72 match test.ignore() {
73 Ignore::No => {
74 test.run();
75 self.complete_test(TestResult::Success, cycle_start);
76 }
77 Ignore::Yes => {
78 self.complete_test(TestResult::Ignore, cycle_start);
79 }
80 }
81
82 if !increment_test_index(i) {
83 break; }
85 }
86 self.after_tests()
87 }
88
89 fn after_tests(&self) -> ! {
90 qemu::exit(qemu::ExitCode::Success)
91 }
92
93 fn start_test(&self) -> u64 {
94 let current_test = self.current_test().unwrap();
96
97 let module_path = current_test.modules().unwrap_or("unknown_module");
98 {
99 let mut current_module = CURRENT_MODULE.get().unwrap().write();
100 if *current_module != module_path {
101 *current_module = module_path;
102
103 let module_test_count = count_by_module(module_path);
104 let test_group = args::get_test_group().unwrap_or("default");
105 serial_println!("\n################################################################");
106 serial_println!("# Running {} {} tests for module: {}", module_test_count, test_group, module_path);
107 serial_println!("----------------------------------------------------------------");
108 }
109 } print_test_name(current_test.name(), 58);
113
114 read_current_cycle()
116 }
117
118 fn complete_test(&self, result: TestResult, cycle_start: u64) {
119 let cycle_count = if cycle_start != u64::MAX { read_current_cycle() - cycle_start
121 } else {
122 0
123 };
124
125 match result {
126 TestResult::Success => {
127 let current_test = self.current_test().unwrap();
128 let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
129 test::output::write_test_success(&test_name, cycle_count);
130 serial_println!("[pass]");
131 }
132 TestResult::Failure => {
133 }
135 TestResult::Ignore => {
136 let current_test = self.current_test().unwrap();
137 let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
138 test::output::write_test_ignore(&test_name);
139 serial_println!("[ignore]");
140 }
141 }
142 }
143
144 fn current_test(&self) -> Option<&'static dyn TestCase> {
145 let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
146 let tests = unsafe { TESTS };
147 tests.get(current_index).copied()
148 }
149
150 fn handle_panic(&self, info: &core::panic::PanicInfo) -> ! {
151 let location = if let Some(location) = info.location() {
153 format!("{}:{}", location.file(), location.line()).unwrap()
154 } else {
155 String::<MAX_STRING_LENGTH>::try_from("unknown location").unwrap()
156 };
157 let message = info.message().as_str().unwrap_or("no message");
158
159 let current_test = self.current_test().unwrap();
160 let test_name: String<MAX_STRING_LENGTH> = format!("{}::{}", current_test.modules().unwrap(), current_test.name()).unwrap();
161
162 match current_test.should_panic() {
164 ShouldPanic::No => {
165 serial_println!("[fail] @ {}: {}", location, message); test::output::write_test_failure(&test_name, location.as_str(), message);
167 self.complete_test(TestResult::Failure, u64::MAX);
168 }
169 ShouldPanic::Yes => {
170 self.complete_test(TestResult::Success, u64::MAX);
171 }
172 }
173
174 let current_index = *CURRENT_TEST_INDEX.get().unwrap().read();
176 if !increment_test_index(current_index) {
177 qemu::exit(qemu::ExitCode::Success); }
179
180 self.run_tests(current_index + 1) }
183}
184
185fn read_current_cycle() -> u64 {
187 unsafe { core::arch::x86_64::_rdtsc() }
188}
189
190fn increment_test_index(base: usize) -> bool {
192 let mut current_test_index = CURRENT_TEST_INDEX.get().unwrap().write();
193
194 let tests = unsafe { TESTS };
195 if *current_test_index >= tests.len() {
196 return false; }
198
199 *current_test_index = base + 1;
200 true
201}
202
203fn count_by_module(module_name: &str) -> usize {
205 let tests = unsafe { TESTS };
206 tests.iter()
207 .filter(|&&test| test.modules().unwrap_or("") == module_name)
208 .count()
209}
210
211fn print_test_name(name: &str, result_column: usize) {
213 if name.len() >= result_column {
214 serial_print!("{} ", name); return;
216 }
217
218 let padding = result_column - name.len();
219 serial_print!("{}", name);
220 for _ in 0..padding {
221 serial_print!(" ");
222 }
223}