1use std::collections::HashMap;
6
7use super::BrickBottleneck;
8
9#[derive(Debug, Default)]
13pub struct PtxRegistry {
14 kernels: HashMap<u64, (String, String, Option<std::path::PathBuf>)>,
16}
17
18impl PtxRegistry {
19 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn register(&mut self, name: &str, ptx: &str, path: Option<&std::path::Path>) {
31 debug_assert!(!name.is_empty(), "CB-BUDGET: kernel name must not be empty");
32 debug_assert!(!ptx.is_empty(), "CB-BUDGET: PTX source must not be empty");
33 let hash = Self::hash_ptx(ptx);
34 self.kernels
35 .insert(hash, (name.to_string(), ptx.to_string(), path.map(|p| p.to_path_buf())));
36 }
37
38 #[inline]
40 pub fn hash_ptx(ptx: &str) -> u64 {
41 let mut hash: u64 = 0xcbf29ce484222325;
43 for byte in ptx.bytes() {
44 hash ^= byte as u64;
45 hash = hash.wrapping_mul(0x100000001b3);
46 }
47 hash
48 }
49
50 pub fn lookup(&self, hash: u64) -> Option<&str> {
52 self.kernels.get(&hash).map(|(_, ptx, _)| ptx.as_str())
53 }
54
55 pub fn lookup_name(&self, hash: u64) -> Option<&str> {
57 self.kernels.get(&hash).map(|(name, _, _)| name.as_str())
58 }
59
60 pub fn lookup_path(&self, hash: u64) -> Option<&std::path::Path> {
62 self.kernels.get(&hash).and_then(|(_, _, path)| path.as_deref())
63 }
64
65 pub fn hashes(&self) -> impl Iterator<Item = u64> + '_ {
67 self.kernels.keys().copied()
68 }
69
70 pub fn len(&self) -> usize {
72 self.kernels.len()
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.kernels.is_empty()
78 }
79}
80
81#[derive(Debug, Clone, Copy, Default)]
83pub struct CategoryStats {
84 pub total_ns: u64,
86 pub total_elements: u64,
88 pub count: u64,
90}
91
92impl CategoryStats {
93 #[inline]
95 pub fn avg_us(&self) -> f64 {
96 if self.count == 0 {
97 0.0
98 } else {
99 self.total_ns as f64 / self.count as f64 / 1000.0
100 }
101 }
102
103 #[inline]
105 pub fn throughput(&self) -> f64 {
106 if self.total_ns == 0 {
107 0.0
108 } else {
109 self.total_elements as f64 / (self.total_ns as f64 / 1_000_000_000.0)
110 }
111 }
112
113 #[inline]
115 pub fn percentage(&self, total: u64) -> f64 {
116 if total == 0 {
117 0.0
118 } else {
119 100.0 * self.total_ns as f64 / total as f64
120 }
121 }
122}
123
124#[derive(Debug, Clone, Default)]
126pub struct BrickStats {
127 pub name: String,
129 pub count: u64,
131 pub total_ns: u64,
133 pub min_ns: u64,
135 pub max_ns: u64,
137 pub total_elements: u64,
139 pub total_bytes: u64,
141 pub total_compressed_bytes: u64,
143 pub bottleneck: BrickBottleneck,
145 pub total_cycles: u64,
147 pub min_cycles: u64,
149 pub max_cycles: u64,
151}
152
153impl BrickStats {
154 pub fn new(name: &str) -> Self {
156 Self {
157 name: name.to_string(),
158 count: 0,
159 total_ns: 0,
160 min_ns: u64::MAX,
161 max_ns: 0,
162 total_elements: 0,
163 total_bytes: 0,
164 total_compressed_bytes: 0,
165 bottleneck: BrickBottleneck::Unknown,
166 total_cycles: 0,
167 min_cycles: u64::MAX,
168 max_cycles: 0,
169 }
170 }
171
172 pub fn add_sample(&mut self, elapsed_ns: u64, elements: u64) {
174 debug_assert!(elements > 0, "CB-BUDGET: elements must be > 0");
175 self.count += 1;
176 self.total_ns += elapsed_ns;
177 self.min_ns = self.min_ns.min(elapsed_ns);
178 self.max_ns = self.max_ns.max(elapsed_ns);
179 self.total_elements += elements;
180 }
181
182 pub fn add_sample_with_cycles(&mut self, elapsed_ns: u64, elements: u64, cycles: u64) {
187 self.add_sample(elapsed_ns, elements);
188 self.total_cycles += cycles;
189 self.min_cycles = self.min_cycles.min(cycles);
190 self.max_cycles = self.max_cycles.max(cycles);
191 }
192
193 #[must_use]
197 pub fn cycles_per_element(&self) -> f64 {
198 if self.total_elements == 0 {
199 0.0
200 } else {
201 self.total_cycles as f64 / self.total_elements as f64
202 }
203 }
204
205 #[must_use]
207 pub fn avg_cycles(&self) -> f64 {
208 if self.count == 0 {
209 0.0
210 } else {
211 self.total_cycles as f64 / self.count as f64
212 }
213 }
214
215 #[must_use]
221 pub fn estimated_ipc(&self) -> f64 {
222 if self.total_cycles == 0 {
223 0.0
224 } else {
225 self.total_elements as f64 / self.total_cycles as f64
227 }
228 }
229
230 #[must_use]
235 pub fn diagnose_from_cycles(&self) -> &'static str {
236 if self.total_cycles == 0 || self.total_ns == 0 {
237 return "insufficient data";
238 }
239
240 let ipc = self.estimated_ipc();
241 let ns_per_cycle = self.total_ns as f64 / self.total_cycles as f64;
242
243 if ipc < 0.5 {
246 "memory-bound (low IPC, likely cache misses)"
247 } else if ipc > 2.0 {
248 "compute-bound (efficient)"
249 } else if ns_per_cycle > 1.0 {
250 "throttled or context-switched"
251 } else {
252 "balanced"
253 }
254 }
255
256 pub fn add_sample_with_bytes(
264 &mut self,
265 elapsed_ns: u64,
266 elements: u64,
267 input_bytes: u64,
268 output_bytes: u64,
269 ) {
270 self.add_sample(elapsed_ns, elements);
271 self.total_bytes += input_bytes;
272 self.total_compressed_bytes += output_bytes;
273 }
274
275 #[must_use]
278 pub fn compression_ratio(&self) -> f64 {
279 if self.total_compressed_bytes == 0 {
280 1.0
281 } else {
282 self.total_bytes as f64 / self.total_compressed_bytes as f64
283 }
284 }
285
286 #[must_use]
289 pub fn throughput_gbps(&self) -> f64 {
290 if self.total_ns == 0 {
291 0.0
292 } else {
293 let bytes_per_ns = self.total_bytes as f64 / self.total_ns as f64;
294 bytes_per_ns * 1e9 / 1e9 }
296 }
297
298 pub fn set_bottleneck(&mut self, bottleneck: BrickBottleneck) {
300 self.bottleneck = bottleneck;
301 }
302
303 #[must_use]
305 pub fn get_bottleneck(&self) -> BrickBottleneck {
306 self.bottleneck
307 }
308
309 #[must_use]
311 pub fn avg_us(&self) -> f64 {
312 if self.count == 0 {
313 0.0
314 } else {
315 self.total_ns as f64 / self.count as f64 / 1000.0
316 }
317 }
318
319 #[must_use]
321 pub fn throughput(&self) -> f64 {
322 if self.total_ns == 0 {
323 0.0
324 } else {
325 self.total_elements as f64 / (self.total_ns as f64 / 1_000_000_000.0)
326 }
327 }
328
329 #[must_use]
331 pub fn tokens_per_sec(&self) -> f64 {
332 self.throughput()
333 }
334
335 #[must_use]
337 pub fn min_us(&self) -> f64 {
338 if self.min_ns == u64::MAX {
339 0.0
340 } else {
341 self.min_ns as f64 / 1000.0
342 }
343 }
344
345 #[must_use]
347 pub fn max_us(&self) -> f64 {
348 self.max_ns as f64 / 1000.0
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
361 fn test_ptx_registry_new_is_empty() {
362 let reg = PtxRegistry::new();
363 assert!(reg.is_empty());
364 assert_eq!(reg.len(), 0);
365 }
366
367 #[test]
368 fn test_ptx_registry_register_and_lookup() {
369 let mut reg = PtxRegistry::new();
370 let ptx = ".version 8.0\n.entry gemm_tiled {}";
371 reg.register("gemm_tiled", ptx, None);
372
373 assert_eq!(reg.len(), 1);
374 assert!(!reg.is_empty());
375
376 let hash = PtxRegistry::hash_ptx(ptx);
377 assert_eq!(reg.lookup(hash), Some(ptx));
378 assert_eq!(reg.lookup_name(hash), Some("gemm_tiled"));
379 assert_eq!(reg.lookup_path(hash), None);
380 }
381
382 #[test]
383 fn test_ptx_registry_register_with_path() {
384 let mut reg = PtxRegistry::new();
385 let ptx = ".version 8.0\n.entry softmax {}";
386 let path = std::path::Path::new("/src/kernels/softmax.ptx");
387 reg.register("softmax", ptx, Some(path));
388
389 let hash = PtxRegistry::hash_ptx(ptx);
390 assert_eq!(reg.lookup_path(hash), Some(path));
391 }
392
393 #[test]
394 fn test_ptx_registry_lookup_missing() {
395 let reg = PtxRegistry::new();
396 assert_eq!(reg.lookup(12345), None);
397 assert_eq!(reg.lookup_name(12345), None);
398 assert_eq!(reg.lookup_path(12345), None);
399 }
400
401 #[test]
402 fn test_ptx_registry_hashes() {
403 let mut reg = PtxRegistry::new();
404 reg.register("k1", "ptx_source_1", None);
405 reg.register("k2", "ptx_source_2", None);
406
407 let hashes: Vec<u64> = reg.hashes().collect();
408 assert_eq!(hashes.len(), 2);
409 }
410
411 #[test]
412 fn test_ptx_registry_hash_deterministic() {
413 let ptx = "some ptx source code";
414 let h1 = PtxRegistry::hash_ptx(ptx);
415 let h2 = PtxRegistry::hash_ptx(ptx);
416 assert_eq!(h1, h2);
417 }
418
419 #[test]
420 fn test_ptx_registry_hash_different_inputs() {
421 let h1 = PtxRegistry::hash_ptx("kernel_a");
422 let h2 = PtxRegistry::hash_ptx("kernel_b");
423 assert_ne!(h1, h2);
424 }
425
426 #[test]
427 fn test_ptx_registry_overwrite_same_hash() {
428 let mut reg = PtxRegistry::new();
429 let ptx = "same_source";
430 reg.register("name1", ptx, None);
431 reg.register("name2", ptx, None);
432 assert_eq!(reg.len(), 1);
434 let hash = PtxRegistry::hash_ptx(ptx);
435 assert_eq!(reg.lookup_name(hash), Some("name2"));
436 }
437
438 #[test]
443 fn test_category_stats_default() {
444 let stats = CategoryStats::default();
445 assert_eq!(stats.total_ns, 0);
446 assert_eq!(stats.total_elements, 0);
447 assert_eq!(stats.count, 0);
448 }
449
450 #[test]
451 fn test_category_stats_avg_us_zero_count() {
452 let stats = CategoryStats::default();
453 assert_eq!(stats.avg_us(), 0.0);
454 }
455
456 #[test]
457 fn test_category_stats_avg_us() {
458 let stats = CategoryStats { total_ns: 10_000, total_elements: 0, count: 2 };
459 assert!((stats.avg_us() - 5.0).abs() < 1e-10);
461 }
462
463 #[test]
464 fn test_category_stats_throughput_zero_ns() {
465 let stats = CategoryStats::default();
466 assert_eq!(stats.throughput(), 0.0);
467 }
468
469 #[test]
470 fn test_category_stats_throughput() {
471 let stats = CategoryStats {
472 total_ns: 1_000_000_000, total_elements: 1_000,
474 count: 1,
475 };
476 assert!((stats.throughput() - 1_000.0).abs() < 1e-5);
478 }
479
480 #[test]
481 fn test_category_stats_percentage_zero_total() {
482 let stats = CategoryStats { total_ns: 500, total_elements: 0, count: 1 };
483 assert_eq!(stats.percentage(0), 0.0);
484 }
485
486 #[test]
487 fn test_category_stats_percentage() {
488 let stats = CategoryStats { total_ns: 250, total_elements: 0, count: 1 };
489 assert!((stats.percentage(1000) - 25.0).abs() < 1e-10);
491 }
492
493 #[test]
494 fn test_category_stats_percentage_full() {
495 let stats = CategoryStats { total_ns: 1000, total_elements: 0, count: 1 };
496 assert!((stats.percentage(1000) - 100.0).abs() < 1e-10);
497 }
498
499 #[test]
504 fn test_brick_stats_new() {
505 let stats = BrickStats::new("test_brick");
506 assert_eq!(stats.name, "test_brick");
507 assert_eq!(stats.count, 0);
508 assert_eq!(stats.total_ns, 0);
509 assert_eq!(stats.min_ns, u64::MAX);
510 assert_eq!(stats.max_ns, 0);
511 assert_eq!(stats.total_elements, 0);
512 assert_eq!(stats.total_bytes, 0);
513 assert_eq!(stats.total_compressed_bytes, 0);
514 assert_eq!(stats.bottleneck, BrickBottleneck::Unknown);
515 assert_eq!(stats.total_cycles, 0);
516 assert_eq!(stats.min_cycles, u64::MAX);
517 assert_eq!(stats.max_cycles, 0);
518 }
519
520 #[test]
521 fn test_brick_stats_add_sample() {
522 let mut stats = BrickStats::new("op");
523 stats.add_sample(1000, 50);
524 assert_eq!(stats.count, 1);
525 assert_eq!(stats.total_ns, 1000);
526 assert_eq!(stats.min_ns, 1000);
527 assert_eq!(stats.max_ns, 1000);
528 assert_eq!(stats.total_elements, 50);
529
530 stats.add_sample(500, 25);
531 assert_eq!(stats.count, 2);
532 assert_eq!(stats.total_ns, 1500);
533 assert_eq!(stats.min_ns, 500);
534 assert_eq!(stats.max_ns, 1000);
535 assert_eq!(stats.total_elements, 75);
536
537 stats.add_sample(2000, 100);
538 assert_eq!(stats.count, 3);
539 assert_eq!(stats.min_ns, 500);
540 assert_eq!(stats.max_ns, 2000);
541 }
542
543 #[test]
544 fn test_brick_stats_add_sample_with_cycles() {
545 let mut stats = BrickStats::new("op");
546 stats.add_sample_with_cycles(1000, 50, 3000);
547 assert_eq!(stats.count, 1);
548 assert_eq!(stats.total_cycles, 3000);
549 assert_eq!(stats.min_cycles, 3000);
550 assert_eq!(stats.max_cycles, 3000);
551
552 stats.add_sample_with_cycles(500, 25, 1500);
553 assert_eq!(stats.total_cycles, 4500);
554 assert_eq!(stats.min_cycles, 1500);
555 assert_eq!(stats.max_cycles, 3000);
556 }
557
558 #[test]
559 fn test_brick_stats_cycles_per_element_zero() {
560 let stats = BrickStats::new("op");
561 assert_eq!(stats.cycles_per_element(), 0.0);
562 }
563
564 #[test]
565 fn test_brick_stats_cycles_per_element() {
566 let mut stats = BrickStats::new("op");
567 stats.add_sample_with_cycles(1000, 100, 500);
568 assert!((stats.cycles_per_element() - 5.0).abs() < 1e-10);
570 }
571
572 #[test]
573 fn test_brick_stats_avg_cycles_zero() {
574 let stats = BrickStats::new("op");
575 assert_eq!(stats.avg_cycles(), 0.0);
576 }
577
578 #[test]
579 fn test_brick_stats_avg_cycles() {
580 let mut stats = BrickStats::new("op");
581 stats.add_sample_with_cycles(1000, 50, 300);
582 stats.add_sample_with_cycles(1000, 50, 500);
583 assert!((stats.avg_cycles() - 400.0).abs() < 1e-10);
585 }
586
587 #[test]
588 fn test_brick_stats_estimated_ipc_zero() {
589 let stats = BrickStats::new("op");
590 assert_eq!(stats.estimated_ipc(), 0.0);
591 }
592
593 #[test]
594 fn test_brick_stats_estimated_ipc() {
595 let mut stats = BrickStats::new("op");
596 stats.add_sample_with_cycles(1000, 200, 100);
597 assert!((stats.estimated_ipc() - 2.0).abs() < 1e-10);
599 }
600
601 #[test]
602 fn test_brick_stats_diagnose_insufficient_data() {
603 let stats = BrickStats::new("op");
604 assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
605 }
606
607 #[test]
608 fn test_brick_stats_diagnose_insufficient_data_zero_cycles() {
609 let mut stats = BrickStats::new("op");
610 stats.add_sample(1000, 50);
611 assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
613 }
614
615 #[test]
616 fn test_brick_stats_diagnose_insufficient_data_zero_ns() {
617 let mut stats = BrickStats::new("op");
618 stats.total_cycles = 100;
619 assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
621 }
622
623 #[test]
624 fn test_brick_stats_diagnose_memory_bound() {
625 let mut stats = BrickStats::new("op");
626 stats.total_elements = 10;
629 stats.total_cycles = 100;
630 stats.total_ns = 50; assert_eq!(stats.diagnose_from_cycles(), "memory-bound (low IPC, likely cache misses)");
632 }
633
634 #[test]
635 fn test_brick_stats_diagnose_compute_bound() {
636 let mut stats = BrickStats::new("op");
637 stats.total_elements = 300;
640 stats.total_cycles = 100;
641 stats.total_ns = 33; assert_eq!(stats.diagnose_from_cycles(), "compute-bound (efficient)");
643 }
644
645 #[test]
646 fn test_brick_stats_diagnose_throttled() {
647 let mut stats = BrickStats::new("op");
648 stats.total_elements = 100;
652 stats.total_cycles = 100;
653 stats.total_ns = 200;
654 assert_eq!(stats.diagnose_from_cycles(), "throttled or context-switched");
655 }
656
657 #[test]
658 fn test_brick_stats_diagnose_balanced() {
659 let mut stats = BrickStats::new("op");
660 stats.total_elements = 100;
664 stats.total_cycles = 100;
665 stats.total_ns = 50;
666 assert_eq!(stats.diagnose_from_cycles(), "balanced");
667 }
668
669 #[test]
670 fn test_brick_stats_add_sample_with_bytes() {
671 let mut stats = BrickStats::new("compress");
672 stats.add_sample_with_bytes(1000, 1, 4096, 1024);
673 assert_eq!(stats.count, 1);
674 assert_eq!(stats.total_bytes, 4096);
675 assert_eq!(stats.total_compressed_bytes, 1024);
676 assert_eq!(stats.total_elements, 1);
677
678 stats.add_sample_with_bytes(2000, 1, 8192, 2048);
679 assert_eq!(stats.total_bytes, 12288);
680 assert_eq!(stats.total_compressed_bytes, 3072);
681 }
682
683 #[test]
684 fn test_brick_stats_compression_ratio_no_data() {
685 let stats = BrickStats::new("op");
686 assert!((stats.compression_ratio() - 1.0).abs() < 1e-10);
687 }
688
689 #[test]
690 fn test_brick_stats_compression_ratio() {
691 let mut stats = BrickStats::new("compress");
692 stats.add_sample_with_bytes(1000, 1, 4096, 1024);
693 assert!((stats.compression_ratio() - 4.0).abs() < 1e-10);
695 }
696
697 #[test]
698 fn test_brick_stats_throughput_gbps_zero_ns() {
699 let stats = BrickStats::new("op");
700 assert_eq!(stats.throughput_gbps(), 0.0);
701 }
702
703 #[test]
704 fn test_brick_stats_throughput_gbps() {
705 let mut stats = BrickStats::new("op");
706 stats.total_bytes = 1_000_000_000; stats.total_ns = 1_000_000_000; assert!((stats.throughput_gbps() - 1.0).abs() < 1e-5);
710 }
711
712 #[test]
713 fn test_brick_stats_set_get_bottleneck() {
714 let mut stats = BrickStats::new("op");
715 assert_eq!(stats.get_bottleneck(), BrickBottleneck::Unknown);
716
717 stats.set_bottleneck(BrickBottleneck::Memory);
718 assert_eq!(stats.get_bottleneck(), BrickBottleneck::Memory);
719
720 stats.set_bottleneck(BrickBottleneck::Compute);
721 assert_eq!(stats.get_bottleneck(), BrickBottleneck::Compute);
722 }
723
724 #[test]
725 fn test_brick_stats_avg_us_zero_count() {
726 let stats = BrickStats::new("op");
727 assert_eq!(stats.avg_us(), 0.0);
728 }
729
730 #[test]
731 fn test_brick_stats_avg_us() {
732 let mut stats = BrickStats::new("op");
733 stats.add_sample(2000, 10);
734 stats.add_sample(4000, 10);
735 assert!((stats.avg_us() - 3.0).abs() < 1e-10);
737 }
738
739 #[test]
740 fn test_brick_stats_throughput_zero_ns() {
741 let stats = BrickStats::new("op");
742 assert_eq!(stats.throughput(), 0.0);
743 }
744
745 #[test]
746 fn test_brick_stats_throughput() {
747 let mut stats = BrickStats::new("op");
748 stats.add_sample(1_000_000_000, 500); assert!((stats.throughput() - 500.0).abs() < 1e-5);
751 }
752
753 #[test]
754 fn test_brick_stats_tokens_per_sec() {
755 let mut stats = BrickStats::new("op");
756 stats.add_sample(1_000_000_000, 42);
757 assert!((stats.tokens_per_sec() - stats.throughput()).abs() < 1e-10);
759 }
760
761 #[test]
762 fn test_brick_stats_min_us_no_samples() {
763 let stats = BrickStats::new("op");
764 assert_eq!(stats.min_us(), 0.0);
766 }
767
768 #[test]
769 fn test_brick_stats_min_us() {
770 let mut stats = BrickStats::new("op");
771 stats.add_sample(5000, 1);
772 stats.add_sample(3000, 1);
773 assert!((stats.min_us() - 3.0).abs() < 1e-10);
775 }
776
777 #[test]
778 fn test_brick_stats_max_us() {
779 let mut stats = BrickStats::new("op");
780 stats.add_sample(5000, 1);
781 stats.add_sample(3000, 1);
782 assert!((stats.max_us() - 5.0).abs() < 1e-10);
784 }
785
786 #[test]
787 fn test_brick_stats_max_us_no_samples() {
788 let stats = BrickStats::new("op");
789 assert_eq!(stats.max_us(), 0.0);
791 }
792
793 #[test]
794 fn test_brick_stats_default() {
795 let stats = BrickStats::default();
796 assert!(stats.name.is_empty());
797 assert_eq!(stats.count, 0);
798 assert_eq!(stats.total_ns, 0);
799 assert_eq!(stats.bottleneck, BrickBottleneck::Unknown);
800 }
801}