use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GpuMuda {
Transport {
register_spills: u64,
unnecessary_global_loads: u64,
redundant_shared_stores: u64,
},
Waiting {
barrier_stall_cycles: u64,
memory_stall_cycles: u64,
pipeline_bubbles: u64,
warp_scheduler_idle_pct: f64,
},
Overprocessing {
precision_waste_pct: f64,
redundant_instructions: u64,
unnecessary_bounds_checks: u64,
},
Inventory {
unused_shared_memory_bytes: u64,
unused_registers_per_thread: u32,
occupancy_loss_pct: f64,
},
Motion {
divergent_branches: u64,
branch_efficiency_pct: f64,
loop_overhead_cycles: u64,
},
Defects {
nan_count: u64,
inf_count: u64,
precision_loss_bits: f64,
},
Overproduction {
padding_waste_pct: f64,
inactive_thread_pct: f64,
unused_output_elements: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MudaDetection {
pub muda: GpuMuda,
pub impact_pct: f64,
pub description: String,
pub recommendation: String,
}
#[derive(Default)]
pub struct MudaDetector {
pub thresholds: MudaThresholds,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MudaThresholds {
pub max_register_spills: u64,
pub max_scheduler_idle_pct: f64,
pub min_branch_efficiency_pct: f64,
pub max_occupancy_loss_pct: f64,
pub max_padding_waste_pct: f64,
pub max_nan_inf_count: u64,
pub max_precision_waste_pct: f64,
}
impl Default for MudaThresholds {
fn default() -> Self {
Self {
max_register_spills: 0,
max_scheduler_idle_pct: 20.0,
min_branch_efficiency_pct: 90.0,
max_occupancy_loss_pct: 50.0,
max_padding_waste_pct: 10.0,
max_nan_inf_count: 0,
max_precision_waste_pct: 25.0,
}
}
}
impl MudaDetector {
pub fn new() -> Self {
Self {
thresholds: MudaThresholds::default(),
}
}
pub fn with_thresholds(thresholds: MudaThresholds) -> Self {
Self { thresholds }
}
pub fn detect_transport(
&self,
register_spills: u64,
unnecessary_global_loads: u64,
redundant_shared_stores: u64,
) -> Option<MudaDetection> {
if register_spills > self.thresholds.max_register_spills
|| unnecessary_global_loads > 0
|| redundant_shared_stores > 0
{
let total_waste = register_spills + unnecessary_global_loads + redundant_shared_stores;
Some(MudaDetection {
muda: GpuMuda::Transport {
register_spills,
unnecessary_global_loads,
redundant_shared_stores,
},
impact_pct: (total_waste as f64).min(100.0),
description: format!(
"Data movement waste: {register_spills} register spills, \
{unnecessary_global_loads} unnecessary global loads, \
{redundant_shared_stores} redundant shared stores"
),
recommendation: if register_spills > 0 {
"Reduce register pressure: decrease tile size, use shared memory, or reduce live variables".to_string()
} else {
"Review memory access patterns for redundant loads/stores".to_string()
},
})
} else {
None
}
}
pub fn detect_waiting(
&self,
barrier_stall_cycles: u64,
memory_stall_cycles: u64,
pipeline_bubbles: u64,
warp_scheduler_idle_pct: f64,
) -> Option<MudaDetection> {
if warp_scheduler_idle_pct > self.thresholds.max_scheduler_idle_pct
|| barrier_stall_cycles > 0
|| memory_stall_cycles > 0
{
let impact =
warp_scheduler_idle_pct.max(if memory_stall_cycles > 0 { 10.0 } else { 0.0 });
Some(MudaDetection {
muda: GpuMuda::Waiting {
barrier_stall_cycles,
memory_stall_cycles,
pipeline_bubbles,
warp_scheduler_idle_pct,
},
impact_pct: impact,
description: format!(
"Hardware idle: scheduler {warp_scheduler_idle_pct:.1}% idle, \
{memory_stall_cycles} memory stall cycles, \
{barrier_stall_cycles} barrier stall cycles"
),
recommendation: if memory_stall_cycles > barrier_stall_cycles {
"Increase warps per SM for latency hiding, or improve data locality".to_string()
} else {
"Reduce barrier synchronization or overlap compute with data movement"
.to_string()
},
})
} else {
None
}
}
pub fn detect_motion(
&self,
divergent_branches: u64,
branch_efficiency_pct: f64,
loop_overhead_cycles: u64,
) -> Option<MudaDetection> {
if branch_efficiency_pct < self.thresholds.min_branch_efficiency_pct
|| divergent_branches > 0
{
Some(MudaDetection {
muda: GpuMuda::Motion {
divergent_branches,
branch_efficiency_pct,
loop_overhead_cycles,
},
impact_pct: 100.0 - branch_efficiency_pct,
description: format!(
"Control flow waste: {divergent_branches} divergent branches, \
{branch_efficiency_pct:.1}% branch efficiency"
),
recommendation:
"Ensure warp-uniform branching; move data-dependent branches outside warp"
.to_string(),
})
} else {
None
}
}
pub fn detect_inventory(
&self,
unused_shared_memory_bytes: u64,
unused_registers_per_thread: u32,
occupancy_loss_pct: f64,
) -> Option<MudaDetection> {
if occupancy_loss_pct > self.thresholds.max_occupancy_loss_pct
|| unused_shared_memory_bytes > 0
|| unused_registers_per_thread > 0
{
Some(MudaDetection {
muda: GpuMuda::Inventory {
unused_shared_memory_bytes,
unused_registers_per_thread,
occupancy_loss_pct,
},
impact_pct: occupancy_loss_pct,
description: format!(
"Resource waste: {unused_shared_memory_bytes} bytes unused smem, \
{unused_registers_per_thread} unused regs/thread, \
{occupancy_loss_pct:.1}% occupancy loss"
),
recommendation: "Reduce shared memory or register allocation to improve occupancy"
.to_string(),
})
} else {
None
}
}
pub fn detect_defects(
&self,
nan_count: u64,
inf_count: u64,
precision_loss_bits: f64,
) -> Option<MudaDetection> {
if nan_count > self.thresholds.max_nan_inf_count
|| inf_count > self.thresholds.max_nan_inf_count
|| precision_loss_bits > 1.0
{
Some(MudaDetection {
muda: GpuMuda::Defects {
nan_count,
inf_count,
precision_loss_bits,
},
impact_pct: if nan_count > 0 || inf_count > 0 {
100.0
} else {
precision_loss_bits * 10.0
},
description: format!(
"Numerical defects: {nan_count} NaN, {inf_count} Inf, \
{precision_loss_bits:.1} bits precision loss"
),
recommendation: if nan_count > 0 {
"Investigate NaN source: likely division by zero or log(negative)".to_string()
} else {
"Consider using higher precision for accumulation".to_string()
},
})
} else {
None
}
}
pub fn detect_overproduction(
&self,
padding_waste_pct: f64,
inactive_thread_pct: f64,
unused_output_elements: u64,
) -> Option<MudaDetection> {
if padding_waste_pct > self.thresholds.max_padding_waste_pct
|| inactive_thread_pct > self.thresholds.max_padding_waste_pct
{
Some(MudaDetection {
muda: GpuMuda::Overproduction {
padding_waste_pct,
inactive_thread_pct,
unused_output_elements,
},
impact_pct: padding_waste_pct.max(inactive_thread_pct),
description: format!(
"Overproduction: {padding_waste_pct:.1}% padding waste, \
{inactive_thread_pct:.1}% inactive threads"
),
recommendation: "Adjust tile size to match problem dimensions; use predication for partial tiles".to_string(),
})
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_register_spills() {
let detector = MudaDetector::new();
let result = detector.detect_transport(5, 0, 0);
assert!(result.is_some());
let detection = result.unwrap();
assert!(matches!(
detection.muda,
GpuMuda::Transport {
register_spills: 5,
..
}
));
}
#[test]
fn test_no_transport_waste() {
let detector = MudaDetector::new();
let result = detector.detect_transport(0, 0, 0);
assert!(result.is_none());
}
#[test]
fn test_detect_warp_divergence() {
let detector = MudaDetector::new();
let result = detector.detect_motion(10, 75.0, 100);
assert!(result.is_some());
let detection = result.unwrap();
assert!(matches!(
detection.muda,
GpuMuda::Motion {
divergent_branches: 10,
..
}
));
}
#[test]
fn test_detect_nan_defects() {
let detector = MudaDetector::new();
let result = detector.detect_defects(3, 0, 0.0);
assert!(result.is_some());
assert_eq!(result.unwrap().impact_pct, 100.0);
}
#[test]
fn test_no_defects_clean() {
let detector = MudaDetector::new();
let result = detector.detect_defects(0, 0, 0.5);
assert!(result.is_none());
}
#[test]
fn test_detect_overproduction() {
let detector = MudaDetector::new();
let result = detector.detect_overproduction(25.0, 15.0, 1024);
assert!(result.is_some());
assert_eq!(result.unwrap().impact_pct, 25.0);
}
#[test]
fn test_custom_thresholds() {
let thresholds = MudaThresholds {
max_register_spills: 10,
..Default::default()
};
let detector = MudaDetector::with_thresholds(thresholds);
let result = detector.detect_transport(5, 0, 0);
assert!(result.is_none());
let result = detector.detect_transport(11, 0, 0);
assert!(result.is_some());
}
}