1#![allow(clippy::unwrap_used)]
7
8use super::ptx_analysis::{PtxBugClass, PtxValidationResult};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone)]
13pub struct KernelPixelConfig {
14 pub test_degenerate_dims: bool,
16 pub test_boundaries: bool,
18 pub strict_ptx: bool,
20 pub timeout: Duration,
22}
23
24impl Default for KernelPixelConfig {
25 fn default() -> Self {
26 Self {
27 test_degenerate_dims: true,
28 test_boundaries: true,
29 strict_ptx: true,
30 timeout: Duration::from_secs(5),
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct GpuPixelResult {
38 pub name: String,
40 pub passed: bool,
42 pub error: Option<String>,
44 pub duration: Duration,
46 pub bug_class: Option<PtxBugClass>,
48}
49
50impl GpuPixelResult {
51 #[must_use]
53 pub fn pass(name: &str, duration: Duration) -> Self {
54 Self {
55 name: name.to_string(),
56 passed: true,
57 error: None,
58 duration,
59 bug_class: None,
60 }
61 }
62
63 #[must_use]
65 pub fn fail(name: &str, error: &str, duration: Duration) -> Self {
66 Self {
67 name: name.to_string(),
68 passed: false,
69 error: Some(error.to_string()),
70 duration,
71 bug_class: None,
72 }
73 }
74
75 #[must_use]
77 pub fn fail_with_bug(name: &str, error: &str, bug: PtxBugClass, duration: Duration) -> Self {
78 Self {
79 name: name.to_string(),
80 passed: false,
81 error: Some(error.to_string()),
82 duration,
83 bug_class: Some(bug),
84 }
85 }
86
87 #[must_use]
89 pub fn from_ptx_validation(result: &PtxValidationResult) -> Self {
90 let start = Instant::now();
91 if result.is_valid() {
92 Self::pass("ptx_validation", start.elapsed())
93 } else {
94 let first_bug = result.bugs.first();
95 let error = first_bug
96 .map(|b| format!("{}: {}", b.class, b.message))
97 .unwrap_or_else(|| "Unknown PTX error".to_string());
98 let bug_class = first_bug.map(|b| b.class.clone());
99 Self {
100 name: "ptx_validation".to_string(),
101 passed: false,
102 error: Some(error),
103 duration: start.elapsed(),
104 bug_class,
105 }
106 }
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct GpuPixelTest {
113 pub name: String,
115 pub description: String,
117 pub catches: PtxBugClass,
119}
120
121impl GpuPixelTest {
122 #[must_use]
124 pub fn new(name: &str, description: &str, catches: PtxBugClass) -> Self {
125 Self {
126 name: name.to_string(),
127 description: description.to_string(),
128 catches,
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct GpuPixelTestSuite {
136 pub kernel_name: String,
138 pub results: Vec<GpuPixelResult>,
140 pub duration: Duration,
142}
143
144impl GpuPixelTestSuite {
145 #[must_use]
147 pub fn new(kernel_name: &str) -> Self {
148 Self {
149 kernel_name: kernel_name.to_string(),
150 results: Vec::new(),
151 duration: Duration::ZERO,
152 }
153 }
154
155 pub fn add_result(&mut self, result: GpuPixelResult) {
157 self.duration += result.duration;
158 self.results.push(result);
159 }
160
161 #[must_use]
163 pub fn all_passed(&self) -> bool {
164 self.results.iter().all(|r| r.passed)
165 }
166
167 #[must_use]
169 pub fn passed_count(&self) -> usize {
170 self.results.iter().filter(|r| r.passed).count()
171 }
172
173 #[must_use]
175 pub fn failed_count(&self) -> usize {
176 self.results.iter().filter(|r| !r.passed).count()
177 }
178
179 #[must_use]
181 pub fn failures(&self) -> Vec<&GpuPixelResult> {
182 self.results.iter().filter(|r| !r.passed).collect()
183 }
184
185 pub fn run_kernel_pixels(&mut self, ptx: &str, config: &KernelPixelConfig) {
187 let start = Instant::now();
188
189 self.add_result(self.pixel_shared_mem_addressing(ptx));
191
192 self.add_result(self.pixel_kernel_entry_exists(ptx));
194
195 self.add_result(self.pixel_loop_structure(ptx, config.strict_ptx));
197
198 if ptx.contains(".shared") {
200 self.add_result(self.pixel_barrier_sync(ptx));
201 }
202
203 self.duration = start.elapsed();
204 }
205
206 fn pixel_shared_mem_addressing(&self, ptx: &str) -> GpuPixelResult {
208 let start = Instant::now();
209 let regex = regex::Regex::new(r"[sl]t\.shared\.[^\[]+\[%rd\d+\]").unwrap();
210
211 if regex.is_match(ptx) {
212 GpuPixelResult::fail_with_bug(
213 "shared_mem_u32_addressing",
214 "Shared memory uses 64-bit addressing (should be 32-bit)",
215 PtxBugClass::SharedMemU64Addressing,
216 start.elapsed(),
217 )
218 } else {
219 GpuPixelResult::pass("shared_mem_u32_addressing", start.elapsed())
220 }
221 }
222
223 fn pixel_kernel_entry_exists(&self, ptx: &str) -> GpuPixelResult {
225 let start = Instant::now();
226 let regex = regex::Regex::new(r"\.visible\s+\.entry\s+\w+").unwrap();
227
228 if regex.is_match(ptx) {
229 GpuPixelResult::pass("kernel_entry_exists", start.elapsed())
230 } else {
231 GpuPixelResult::fail_with_bug(
232 "kernel_entry_exists",
233 "No kernel entry point found",
234 PtxBugClass::MissingEntryPoint,
235 start.elapsed(),
236 )
237 }
238 }
239
240 fn pixel_loop_structure(&self, ptx: &str, strict: bool) -> GpuPixelResult {
242 let start = Instant::now();
243
244 if !strict {
245 return GpuPixelResult::pass("loop_structure", start.elapsed());
246 }
247
248 let branch_regex = regex::Regex::new(r"^\s+bra\s+(\w*_end\w*);").unwrap();
250 for line in ptx.lines() {
251 if branch_regex.is_match(line) && !line.trim().starts_with('@') {
252 return GpuPixelResult::fail_with_bug(
253 "loop_structure",
254 "Unconditional branch to loop end (should branch to start)",
255 PtxBugClass::LoopBranchToEnd,
256 start.elapsed(),
257 );
258 }
259 }
260
261 GpuPixelResult::pass("loop_structure", start.elapsed())
262 }
263
264 fn pixel_barrier_sync(&self, ptx: &str) -> GpuPixelResult {
266 let start = Instant::now();
267
268 if ptx.contains("bar.sync") {
269 GpuPixelResult::pass("barrier_sync", start.elapsed())
270 } else {
271 GpuPixelResult::fail_with_bug(
272 "barrier_sync",
273 "Shared memory used but no bar.sync found",
274 PtxBugClass::MissingBarrierSync,
275 start.elapsed(),
276 )
277 }
278 }
279
280 #[must_use]
282 pub fn summary(&self) -> String {
283 let status = if self.all_passed() { "PASS" } else { "FAIL" };
284 format!(
285 "[{}] {} - {}/{} passed ({:?})",
286 status,
287 self.kernel_name,
288 self.passed_count(),
289 self.results.len(),
290 self.duration
291 )
292 }
293}
294
295pub fn standard_pixel_tests() -> Vec<GpuPixelTest> {
297 vec![
298 GpuPixelTest::new(
299 "shared_mem_u32_addressing",
300 "Verify shared memory uses 32-bit addressing",
301 PtxBugClass::SharedMemU64Addressing,
302 ),
303 GpuPixelTest::new(
304 "loop_branch_to_start",
305 "Verify loop branches go to start label, not end",
306 PtxBugClass::LoopBranchToEnd,
307 ),
308 GpuPixelTest::new(
309 "barrier_sync_present",
310 "Verify barrier sync exists when using shared memory",
311 PtxBugClass::MissingBarrierSync,
312 ),
313 GpuPixelTest::new(
314 "kernel_entry_exists",
315 "Verify kernel has entry point",
316 PtxBugClass::MissingEntryPoint,
317 ),
318 ]
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_suite_all_passed() {
327 let mut suite = GpuPixelTestSuite::new("test_kernel");
328 suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
329 suite.add_result(GpuPixelResult::pass("test2", Duration::from_millis(2)));
330 assert!(suite.all_passed());
331 assert_eq!(suite.passed_count(), 2);
332 assert_eq!(suite.failed_count(), 0);
333 }
334
335 #[test]
336 fn test_suite_has_failure() {
337 let mut suite = GpuPixelTestSuite::new("test_kernel");
338 suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
339 suite.add_result(GpuPixelResult::fail(
340 "test2",
341 "error",
342 Duration::from_millis(2),
343 ));
344 assert!(!suite.all_passed());
345 assert_eq!(suite.passed_count(), 1);
346 assert_eq!(suite.failed_count(), 1);
347 }
348
349 #[test]
350 fn test_pixel_shared_mem_u64_fails() {
351 let ptx = "st.shared.f32 [%rd5], %f0;";
352 let suite = GpuPixelTestSuite::new("test");
353 let result = suite.pixel_shared_mem_addressing(ptx);
354 assert!(!result.passed);
355 assert_eq!(result.bug_class, Some(PtxBugClass::SharedMemU64Addressing));
356 }
357
358 #[test]
359 fn test_pixel_shared_mem_u32_passes() {
360 let ptx = "st.shared.f32 [%r5], %f0;";
361 let suite = GpuPixelTestSuite::new("test");
362 let result = suite.pixel_shared_mem_addressing(ptx);
363 assert!(result.passed);
364 }
365
366 #[test]
367 fn test_standard_pixel_tests() {
368 let tests = standard_pixel_tests();
369 assert!(!tests.is_empty());
370 assert!(tests.iter().any(|t| t.name == "shared_mem_u32_addressing"));
371 }
372
373 #[test]
374 fn test_summary_format() {
375 let mut suite = GpuPixelTestSuite::new("gemm_tiled");
376 suite.add_result(GpuPixelResult::pass("test1", Duration::from_millis(1)));
377 let summary = suite.summary();
378 assert!(summary.contains("PASS"));
379 assert!(summary.contains("gemm_tiled"));
380 }
381}