use std::collections::HashMap;
use super::BrickBottleneck;
#[derive(Debug, Default)]
pub struct PtxRegistry {
kernels: HashMap<u64, (String, String, Option<std::path::PathBuf>)>,
}
impl PtxRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, name: &str, ptx: &str, path: Option<&std::path::Path>) {
debug_assert!(!name.is_empty(), "CB-BUDGET: kernel name must not be empty");
debug_assert!(!ptx.is_empty(), "CB-BUDGET: PTX source must not be empty");
let hash = Self::hash_ptx(ptx);
self.kernels
.insert(hash, (name.to_string(), ptx.to_string(), path.map(|p| p.to_path_buf())));
}
#[inline]
pub fn hash_ptx(ptx: &str) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for byte in ptx.bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
pub fn lookup(&self, hash: u64) -> Option<&str> {
self.kernels.get(&hash).map(|(_, ptx, _)| ptx.as_str())
}
pub fn lookup_name(&self, hash: u64) -> Option<&str> {
self.kernels.get(&hash).map(|(name, _, _)| name.as_str())
}
pub fn lookup_path(&self, hash: u64) -> Option<&std::path::Path> {
self.kernels.get(&hash).and_then(|(_, _, path)| path.as_deref())
}
pub fn hashes(&self) -> impl Iterator<Item = u64> + '_ {
self.kernels.keys().copied()
}
pub fn len(&self) -> usize {
self.kernels.len()
}
pub fn is_empty(&self) -> bool {
self.kernels.is_empty()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CategoryStats {
pub total_ns: u64,
pub total_elements: u64,
pub count: u64,
}
impl CategoryStats {
#[inline]
pub fn avg_us(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_ns as f64 / self.count as f64 / 1000.0
}
}
#[inline]
pub fn throughput(&self) -> f64 {
if self.total_ns == 0 {
0.0
} else {
self.total_elements as f64 / (self.total_ns as f64 / 1_000_000_000.0)
}
}
#[inline]
pub fn percentage(&self, total: u64) -> f64 {
if total == 0 {
0.0
} else {
100.0 * self.total_ns as f64 / total as f64
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BrickStats {
pub name: String,
pub count: u64,
pub total_ns: u64,
pub min_ns: u64,
pub max_ns: u64,
pub total_elements: u64,
pub total_bytes: u64,
pub total_compressed_bytes: u64,
pub bottleneck: BrickBottleneck,
pub total_cycles: u64,
pub min_cycles: u64,
pub max_cycles: u64,
}
impl BrickStats {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
count: 0,
total_ns: 0,
min_ns: u64::MAX,
max_ns: 0,
total_elements: 0,
total_bytes: 0,
total_compressed_bytes: 0,
bottleneck: BrickBottleneck::Unknown,
total_cycles: 0,
min_cycles: u64::MAX,
max_cycles: 0,
}
}
pub fn add_sample(&mut self, elapsed_ns: u64, elements: u64) {
debug_assert!(elements > 0, "CB-BUDGET: elements must be > 0");
self.count += 1;
self.total_ns += elapsed_ns;
self.min_ns = self.min_ns.min(elapsed_ns);
self.max_ns = self.max_ns.max(elapsed_ns);
self.total_elements += elements;
}
pub fn add_sample_with_cycles(&mut self, elapsed_ns: u64, elements: u64, cycles: u64) {
self.add_sample(elapsed_ns, elements);
self.total_cycles += cycles;
self.min_cycles = self.min_cycles.min(cycles);
self.max_cycles = self.max_cycles.max(cycles);
}
#[must_use]
pub fn cycles_per_element(&self) -> f64 {
if self.total_elements == 0 {
0.0
} else {
self.total_cycles as f64 / self.total_elements as f64
}
}
#[must_use]
pub fn avg_cycles(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_cycles as f64 / self.count as f64
}
}
#[must_use]
pub fn estimated_ipc(&self) -> f64 {
if self.total_cycles == 0 {
0.0
} else {
self.total_elements as f64 / self.total_cycles as f64
}
}
#[must_use]
pub fn diagnose_from_cycles(&self) -> &'static str {
if self.total_cycles == 0 || self.total_ns == 0 {
return "insufficient data";
}
let ipc = self.estimated_ipc();
let ns_per_cycle = self.total_ns as f64 / self.total_cycles as f64;
if ipc < 0.5 {
"memory-bound (low IPC, likely cache misses)"
} else if ipc > 2.0 {
"compute-bound (efficient)"
} else if ns_per_cycle > 1.0 {
"throttled or context-switched"
} else {
"balanced"
}
}
pub fn add_sample_with_bytes(
&mut self,
elapsed_ns: u64,
elements: u64,
input_bytes: u64,
output_bytes: u64,
) {
self.add_sample(elapsed_ns, elements);
self.total_bytes += input_bytes;
self.total_compressed_bytes += output_bytes;
}
#[must_use]
pub fn compression_ratio(&self) -> f64 {
if self.total_compressed_bytes == 0 {
1.0
} else {
self.total_bytes as f64 / self.total_compressed_bytes as f64
}
}
#[must_use]
pub fn throughput_gbps(&self) -> f64 {
if self.total_ns == 0 {
0.0
} else {
let bytes_per_ns = self.total_bytes as f64 / self.total_ns as f64;
bytes_per_ns * 1e9 / 1e9 }
}
pub fn set_bottleneck(&mut self, bottleneck: BrickBottleneck) {
self.bottleneck = bottleneck;
}
#[must_use]
pub fn get_bottleneck(&self) -> BrickBottleneck {
self.bottleneck
}
#[must_use]
pub fn avg_us(&self) -> f64 {
if self.count == 0 {
0.0
} else {
self.total_ns as f64 / self.count as f64 / 1000.0
}
}
#[must_use]
pub fn throughput(&self) -> f64 {
if self.total_ns == 0 {
0.0
} else {
self.total_elements as f64 / (self.total_ns as f64 / 1_000_000_000.0)
}
}
#[must_use]
pub fn tokens_per_sec(&self) -> f64 {
self.throughput()
}
#[must_use]
pub fn min_us(&self) -> f64 {
if self.min_ns == u64::MAX {
0.0
} else {
self.min_ns as f64 / 1000.0
}
}
#[must_use]
pub fn max_us(&self) -> f64 {
self.max_ns as f64 / 1000.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ptx_registry_new_is_empty() {
let reg = PtxRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
}
#[test]
fn test_ptx_registry_register_and_lookup() {
let mut reg = PtxRegistry::new();
let ptx = ".version 8.0\n.entry gemm_tiled {}";
reg.register("gemm_tiled", ptx, None);
assert_eq!(reg.len(), 1);
assert!(!reg.is_empty());
let hash = PtxRegistry::hash_ptx(ptx);
assert_eq!(reg.lookup(hash), Some(ptx));
assert_eq!(reg.lookup_name(hash), Some("gemm_tiled"));
assert_eq!(reg.lookup_path(hash), None);
}
#[test]
fn test_ptx_registry_register_with_path() {
let mut reg = PtxRegistry::new();
let ptx = ".version 8.0\n.entry softmax {}";
let path = std::path::Path::new("/src/kernels/softmax.ptx");
reg.register("softmax", ptx, Some(path));
let hash = PtxRegistry::hash_ptx(ptx);
assert_eq!(reg.lookup_path(hash), Some(path));
}
#[test]
fn test_ptx_registry_lookup_missing() {
let reg = PtxRegistry::new();
assert_eq!(reg.lookup(12345), None);
assert_eq!(reg.lookup_name(12345), None);
assert_eq!(reg.lookup_path(12345), None);
}
#[test]
fn test_ptx_registry_hashes() {
let mut reg = PtxRegistry::new();
reg.register("k1", "ptx_source_1", None);
reg.register("k2", "ptx_source_2", None);
let hashes: Vec<u64> = reg.hashes().collect();
assert_eq!(hashes.len(), 2);
}
#[test]
fn test_ptx_registry_hash_deterministic() {
let ptx = "some ptx source code";
let h1 = PtxRegistry::hash_ptx(ptx);
let h2 = PtxRegistry::hash_ptx(ptx);
assert_eq!(h1, h2);
}
#[test]
fn test_ptx_registry_hash_different_inputs() {
let h1 = PtxRegistry::hash_ptx("kernel_a");
let h2 = PtxRegistry::hash_ptx("kernel_b");
assert_ne!(h1, h2);
}
#[test]
fn test_ptx_registry_overwrite_same_hash() {
let mut reg = PtxRegistry::new();
let ptx = "same_source";
reg.register("name1", ptx, None);
reg.register("name2", ptx, None);
assert_eq!(reg.len(), 1);
let hash = PtxRegistry::hash_ptx(ptx);
assert_eq!(reg.lookup_name(hash), Some("name2"));
}
#[test]
fn test_category_stats_default() {
let stats = CategoryStats::default();
assert_eq!(stats.total_ns, 0);
assert_eq!(stats.total_elements, 0);
assert_eq!(stats.count, 0);
}
#[test]
fn test_category_stats_avg_us_zero_count() {
let stats = CategoryStats::default();
assert_eq!(stats.avg_us(), 0.0);
}
#[test]
fn test_category_stats_avg_us() {
let stats = CategoryStats { total_ns: 10_000, total_elements: 0, count: 2 };
assert!((stats.avg_us() - 5.0).abs() < 1e-10);
}
#[test]
fn test_category_stats_throughput_zero_ns() {
let stats = CategoryStats::default();
assert_eq!(stats.throughput(), 0.0);
}
#[test]
fn test_category_stats_throughput() {
let stats = CategoryStats {
total_ns: 1_000_000_000, total_elements: 1_000,
count: 1,
};
assert!((stats.throughput() - 1_000.0).abs() < 1e-5);
}
#[test]
fn test_category_stats_percentage_zero_total() {
let stats = CategoryStats { total_ns: 500, total_elements: 0, count: 1 };
assert_eq!(stats.percentage(0), 0.0);
}
#[test]
fn test_category_stats_percentage() {
let stats = CategoryStats { total_ns: 250, total_elements: 0, count: 1 };
assert!((stats.percentage(1000) - 25.0).abs() < 1e-10);
}
#[test]
fn test_category_stats_percentage_full() {
let stats = CategoryStats { total_ns: 1000, total_elements: 0, count: 1 };
assert!((stats.percentage(1000) - 100.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_new() {
let stats = BrickStats::new("test_brick");
assert_eq!(stats.name, "test_brick");
assert_eq!(stats.count, 0);
assert_eq!(stats.total_ns, 0);
assert_eq!(stats.min_ns, u64::MAX);
assert_eq!(stats.max_ns, 0);
assert_eq!(stats.total_elements, 0);
assert_eq!(stats.total_bytes, 0);
assert_eq!(stats.total_compressed_bytes, 0);
assert_eq!(stats.bottleneck, BrickBottleneck::Unknown);
assert_eq!(stats.total_cycles, 0);
assert_eq!(stats.min_cycles, u64::MAX);
assert_eq!(stats.max_cycles, 0);
}
#[test]
fn test_brick_stats_add_sample() {
let mut stats = BrickStats::new("op");
stats.add_sample(1000, 50);
assert_eq!(stats.count, 1);
assert_eq!(stats.total_ns, 1000);
assert_eq!(stats.min_ns, 1000);
assert_eq!(stats.max_ns, 1000);
assert_eq!(stats.total_elements, 50);
stats.add_sample(500, 25);
assert_eq!(stats.count, 2);
assert_eq!(stats.total_ns, 1500);
assert_eq!(stats.min_ns, 500);
assert_eq!(stats.max_ns, 1000);
assert_eq!(stats.total_elements, 75);
stats.add_sample(2000, 100);
assert_eq!(stats.count, 3);
assert_eq!(stats.min_ns, 500);
assert_eq!(stats.max_ns, 2000);
}
#[test]
fn test_brick_stats_add_sample_with_cycles() {
let mut stats = BrickStats::new("op");
stats.add_sample_with_cycles(1000, 50, 3000);
assert_eq!(stats.count, 1);
assert_eq!(stats.total_cycles, 3000);
assert_eq!(stats.min_cycles, 3000);
assert_eq!(stats.max_cycles, 3000);
stats.add_sample_with_cycles(500, 25, 1500);
assert_eq!(stats.total_cycles, 4500);
assert_eq!(stats.min_cycles, 1500);
assert_eq!(stats.max_cycles, 3000);
}
#[test]
fn test_brick_stats_cycles_per_element_zero() {
let stats = BrickStats::new("op");
assert_eq!(stats.cycles_per_element(), 0.0);
}
#[test]
fn test_brick_stats_cycles_per_element() {
let mut stats = BrickStats::new("op");
stats.add_sample_with_cycles(1000, 100, 500);
assert!((stats.cycles_per_element() - 5.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_avg_cycles_zero() {
let stats = BrickStats::new("op");
assert_eq!(stats.avg_cycles(), 0.0);
}
#[test]
fn test_brick_stats_avg_cycles() {
let mut stats = BrickStats::new("op");
stats.add_sample_with_cycles(1000, 50, 300);
stats.add_sample_with_cycles(1000, 50, 500);
assert!((stats.avg_cycles() - 400.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_estimated_ipc_zero() {
let stats = BrickStats::new("op");
assert_eq!(stats.estimated_ipc(), 0.0);
}
#[test]
fn test_brick_stats_estimated_ipc() {
let mut stats = BrickStats::new("op");
stats.add_sample_with_cycles(1000, 200, 100);
assert!((stats.estimated_ipc() - 2.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_diagnose_insufficient_data() {
let stats = BrickStats::new("op");
assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
}
#[test]
fn test_brick_stats_diagnose_insufficient_data_zero_cycles() {
let mut stats = BrickStats::new("op");
stats.add_sample(1000, 50);
assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
}
#[test]
fn test_brick_stats_diagnose_insufficient_data_zero_ns() {
let mut stats = BrickStats::new("op");
stats.total_cycles = 100;
assert_eq!(stats.diagnose_from_cycles(), "insufficient data");
}
#[test]
fn test_brick_stats_diagnose_memory_bound() {
let mut stats = BrickStats::new("op");
stats.total_elements = 10;
stats.total_cycles = 100;
stats.total_ns = 50; assert_eq!(stats.diagnose_from_cycles(), "memory-bound (low IPC, likely cache misses)");
}
#[test]
fn test_brick_stats_diagnose_compute_bound() {
let mut stats = BrickStats::new("op");
stats.total_elements = 300;
stats.total_cycles = 100;
stats.total_ns = 33; assert_eq!(stats.diagnose_from_cycles(), "compute-bound (efficient)");
}
#[test]
fn test_brick_stats_diagnose_throttled() {
let mut stats = BrickStats::new("op");
stats.total_elements = 100;
stats.total_cycles = 100;
stats.total_ns = 200;
assert_eq!(stats.diagnose_from_cycles(), "throttled or context-switched");
}
#[test]
fn test_brick_stats_diagnose_balanced() {
let mut stats = BrickStats::new("op");
stats.total_elements = 100;
stats.total_cycles = 100;
stats.total_ns = 50;
assert_eq!(stats.diagnose_from_cycles(), "balanced");
}
#[test]
fn test_brick_stats_add_sample_with_bytes() {
let mut stats = BrickStats::new("compress");
stats.add_sample_with_bytes(1000, 1, 4096, 1024);
assert_eq!(stats.count, 1);
assert_eq!(stats.total_bytes, 4096);
assert_eq!(stats.total_compressed_bytes, 1024);
assert_eq!(stats.total_elements, 1);
stats.add_sample_with_bytes(2000, 1, 8192, 2048);
assert_eq!(stats.total_bytes, 12288);
assert_eq!(stats.total_compressed_bytes, 3072);
}
#[test]
fn test_brick_stats_compression_ratio_no_data() {
let stats = BrickStats::new("op");
assert!((stats.compression_ratio() - 1.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_compression_ratio() {
let mut stats = BrickStats::new("compress");
stats.add_sample_with_bytes(1000, 1, 4096, 1024);
assert!((stats.compression_ratio() - 4.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_throughput_gbps_zero_ns() {
let stats = BrickStats::new("op");
assert_eq!(stats.throughput_gbps(), 0.0);
}
#[test]
fn test_brick_stats_throughput_gbps() {
let mut stats = BrickStats::new("op");
stats.total_bytes = 1_000_000_000; stats.total_ns = 1_000_000_000; assert!((stats.throughput_gbps() - 1.0).abs() < 1e-5);
}
#[test]
fn test_brick_stats_set_get_bottleneck() {
let mut stats = BrickStats::new("op");
assert_eq!(stats.get_bottleneck(), BrickBottleneck::Unknown);
stats.set_bottleneck(BrickBottleneck::Memory);
assert_eq!(stats.get_bottleneck(), BrickBottleneck::Memory);
stats.set_bottleneck(BrickBottleneck::Compute);
assert_eq!(stats.get_bottleneck(), BrickBottleneck::Compute);
}
#[test]
fn test_brick_stats_avg_us_zero_count() {
let stats = BrickStats::new("op");
assert_eq!(stats.avg_us(), 0.0);
}
#[test]
fn test_brick_stats_avg_us() {
let mut stats = BrickStats::new("op");
stats.add_sample(2000, 10);
stats.add_sample(4000, 10);
assert!((stats.avg_us() - 3.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_throughput_zero_ns() {
let stats = BrickStats::new("op");
assert_eq!(stats.throughput(), 0.0);
}
#[test]
fn test_brick_stats_throughput() {
let mut stats = BrickStats::new("op");
stats.add_sample(1_000_000_000, 500); assert!((stats.throughput() - 500.0).abs() < 1e-5);
}
#[test]
fn test_brick_stats_tokens_per_sec() {
let mut stats = BrickStats::new("op");
stats.add_sample(1_000_000_000, 42);
assert!((stats.tokens_per_sec() - stats.throughput()).abs() < 1e-10);
}
#[test]
fn test_brick_stats_min_us_no_samples() {
let stats = BrickStats::new("op");
assert_eq!(stats.min_us(), 0.0);
}
#[test]
fn test_brick_stats_min_us() {
let mut stats = BrickStats::new("op");
stats.add_sample(5000, 1);
stats.add_sample(3000, 1);
assert!((stats.min_us() - 3.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_max_us() {
let mut stats = BrickStats::new("op");
stats.add_sample(5000, 1);
stats.add_sample(3000, 1);
assert!((stats.max_us() - 5.0).abs() < 1e-10);
}
#[test]
fn test_brick_stats_max_us_no_samples() {
let stats = BrickStats::new("op");
assert_eq!(stats.max_us(), 0.0);
}
#[test]
fn test_brick_stats_default() {
let stats = BrickStats::default();
assert!(stats.name.is_empty());
assert_eq!(stats.count, 0);
assert_eq!(stats.total_ns, 0);
assert_eq!(stats.bottleneck, BrickBottleneck::Unknown);
}
}