use crate::{Device, Result, TensorError};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ProfileStats {
pub total_time: Duration,
pub op_count: u64,
pub avg_time: Duration,
pub peak_memory: u64,
pub gpu_utilization: f32,
pub memory_bandwidth_utilization: f32,
}
#[derive(Debug, Clone)]
pub struct OperationProfile {
pub name: String,
pub execution_time: Duration,
pub memory_usage: u64,
pub occupancy: f32,
pub device: Device,
pub timestamp: Instant,
}
pub struct GpuProfiler {
enabled: AtomicBool,
inner: Arc<Mutex<GpuProfilerInner>>,
op_counter: AtomicU64,
}
#[derive(Debug)]
struct GpuProfilerInner {
operations: Vec<OperationProfile>,
memory_usage: HashMap<Device, u64>,
start_time: Option<Instant>,
peak_memory: u64,
session_stats: HashMap<String, ProfileStats>,
}
impl GpuProfiler {
pub fn new() -> Self {
Self {
enabled: AtomicBool::new(false),
inner: Arc::new(Mutex::new(GpuProfilerInner {
operations: Vec::new(),
memory_usage: HashMap::new(),
start_time: None,
peak_memory: 0,
session_stats: HashMap::new(),
})),
op_counter: AtomicU64::new(0),
}
}
pub fn enable(&self) {
self.enabled.store(true, Ordering::Relaxed);
let mut inner = self.inner.lock().expect("lock should not be poisoned");
inner.start_time = Some(Instant::now());
inner.operations.clear();
inner.memory_usage.clear();
inner.peak_memory = 0;
inner.session_stats.clear();
}
pub fn disable(&self) {
self.enabled.store(false, Ordering::Relaxed);
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Relaxed)
}
pub fn record_operation(
&self,
name: &str,
device: Device,
execution_time: Duration,
memory_usage: u64,
) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let mut inner = self.inner.lock().expect("lock should not be poisoned");
let occupancy = self.estimate_occupancy(execution_time, memory_usage);
let profile = OperationProfile {
name: name.to_string(),
execution_time,
memory_usage,
occupancy,
device,
timestamp: Instant::now(),
};
inner.operations.push(profile);
*inner.memory_usage.entry(device).or_insert(0) += memory_usage;
if inner.memory_usage.values().sum::<u64>() > inner.peak_memory {
inner.peak_memory = inner.memory_usage.values().sum::<u64>();
}
self.op_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub fn get_stats(&self) -> Result<ProfileStats> {
let inner = self.inner.lock().expect("lock should not be poisoned");
if inner.operations.is_empty() {
return Ok(ProfileStats {
total_time: Duration::from_nanos(0),
op_count: 0,
avg_time: Duration::from_nanos(0),
peak_memory: 0,
gpu_utilization: 0.0,
memory_bandwidth_utilization: 0.0,
});
}
let total_time: Duration = inner.operations.iter().map(|op| op.execution_time).sum();
let op_count = inner.operations.len() as u64;
let avg_time = total_time / op_count as u32;
let gpu_utilization = self.estimate_gpu_utilization(&inner.operations);
let memory_bandwidth_utilization =
self.estimate_memory_bandwidth_utilization(&inner.operations);
Ok(ProfileStats {
total_time,
op_count,
avg_time,
peak_memory: inner.peak_memory,
gpu_utilization,
memory_bandwidth_utilization,
})
}
pub fn get_operations(&self) -> Vec<OperationProfile> {
let inner = self.inner.lock().expect("lock should not be poisoned");
inner.operations.clone()
}
pub fn generate_report(&self) -> Result<String> {
let stats = self.get_stats()?;
let operations = self.get_operations();
let mut report = String::new();
report.push_str("=== GPU Profiling Report ===\n\n");
report.push_str(&format!("Total Operations: {}\n", stats.op_count));
report.push_str(&format!("Total Execution Time: {:?}\n", stats.total_time));
report.push_str(&format!(
"Average Time per Operation: {:?}\n",
stats.avg_time
));
report.push_str(&format!("Peak Memory Usage: {} bytes\n", stats.peak_memory));
report.push_str(&format!("GPU Utilization: {:.2}%\n", stats.gpu_utilization));
report.push_str(&format!(
"Memory Bandwidth Utilization: {:.2}%\n",
stats.memory_bandwidth_utilization
));
report.push_str("\n=== Operation Breakdown ===\n");
let mut op_groups: HashMap<String, Vec<&OperationProfile>> = HashMap::new();
for op in &operations {
op_groups.entry(op.name.clone()).or_default().push(op);
}
for (op_name, ops) in op_groups {
let total_time: Duration = ops.iter().map(|op| op.execution_time).sum();
let count = ops.len();
let avg_time = total_time / count as u32;
let avg_occupancy: f32 = ops.iter().map(|op| op.occupancy).sum::<f32>() / count as f32;
report.push_str(&format!(
"{}: {} calls, avg {:?}, avg occupancy {:.1}%\n",
op_name, count, avg_time, avg_occupancy
));
}
Ok(report)
}
pub fn clear(&self) {
let mut inner = self.inner.lock().expect("lock should not be poisoned");
inner.operations.clear();
inner.memory_usage.clear();
inner.peak_memory = 0;
inner.session_stats.clear();
self.op_counter.store(0, Ordering::Relaxed);
}
pub fn profile_operation<F, R>(&self, name: &str, device: Device, operation: F) -> Result<R>
where
F: FnOnce() -> Result<R>,
{
if !self.is_enabled() {
return operation();
}
let start_memory = self.get_current_memory_usage(device)?;
let start_time = Instant::now();
let result = operation()?;
let execution_time = start_time.elapsed();
let end_memory = self.get_current_memory_usage(device)?;
let memory_usage = end_memory.saturating_sub(start_memory);
self.record_operation(name, device, execution_time, memory_usage)?;
Ok(result)
}
fn estimate_occupancy(&self, execution_time: Duration, memory_usage: u64) -> f32 {
let base_occupancy = 75.0; let time_factor = (execution_time.as_nanos() as f32 / 1_000_000.0).min(1.0);
let memory_factor = (memory_usage as f32 / (1024.0 * 1024.0)).min(1.0);
(base_occupancy * time_factor * memory_factor).min(100.0)
}
fn estimate_gpu_utilization(&self, operations: &[OperationProfile]) -> f32 {
if operations.is_empty() {
return 0.0;
}
let total_time: Duration = operations.iter().map(|op| op.execution_time).sum();
let avg_occupancy: f32 =
operations.iter().map(|op| op.occupancy).sum::<f32>() / operations.len() as f32;
(total_time.as_secs_f32() * avg_occupancy / 100.0).min(1.0) * 100.0
}
fn estimate_memory_bandwidth_utilization(&self, operations: &[OperationProfile]) -> f32 {
if operations.is_empty() {
return 0.0;
}
let total_memory: u64 = operations.iter().map(|op| op.memory_usage).sum();
let total_time: Duration = operations.iter().map(|op| op.execution_time).sum();
if total_time.as_secs_f32() == 0.0 {
return 0.0;
}
let bandwidth_usage_gb_s =
(total_memory as f32 / (1024.0 * 1024.0 * 1024.0)) / total_time.as_secs_f32();
let peak_bandwidth = 500.0;
(bandwidth_usage_gb_s / peak_bandwidth * 100.0).min(100.0)
}
fn get_current_memory_usage(&self, device: Device) -> Result<u64> {
let monitor = crate::memory::global_monitor();
match device {
#[cfg(feature = "gpu")]
Device::Gpu(_) => {
let inner = self.inner.lock().expect("lock should not be poisoned");
let gpu_memory = inner.memory_usage.get(&device).copied().unwrap_or(0);
let global_memory = monitor.get_current_memory() as u64;
Ok(gpu_memory.max(global_memory / 2)) }
#[cfg(feature = "rocm")]
Device::Rocm(_) => {
let inner = self.inner.lock().expect("lock should not be poisoned");
let rocm_memory = inner.memory_usage.get(&device).copied().unwrap_or(0);
let global_memory = monitor.get_current_memory() as u64;
Ok(rocm_memory.max(global_memory / 2)) }
Device::Cpu => {
Ok(monitor.get_current_memory() as u64)
}
}
}
}
impl Default for GpuProfiler {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_PROFILER: std::sync::OnceLock<GpuProfiler> = std::sync::OnceLock::new();
pub fn global_profiler() -> &'static GpuProfiler {
GLOBAL_PROFILER.get_or_init(GpuProfiler::default)
}
pub fn enable_gpu_profiling() {
global_profiler().enable();
}
pub fn disable_gpu_profiling() {
global_profiler().disable();
}
pub fn get_gpu_profiling_stats() -> Result<ProfileStats> {
global_profiler().get_stats()
}
pub fn generate_gpu_profiling_report() -> Result<String> {
global_profiler().generate_report()
}
#[macro_export]
macro_rules! profile_gpu_op {
($name:expr, $device:expr, $op:expr) => {
$crate::gpu_profiler::global_profiler().profile_operation($name, $device, || $op)
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_profiler_basic_functionality() {
let profiler = GpuProfiler::new();
assert!(!profiler.is_enabled());
profiler.enable();
assert!(profiler.is_enabled());
let device = Device::Gpu(0);
profiler
.record_operation("matmul", device, Duration::from_millis(10), 1024)
.expect("test: operation should succeed");
profiler
.record_operation("conv2d", device, Duration::from_millis(5), 512)
.expect("test: operation should succeed");
let stats = profiler
.get_stats()
.expect("test: get_stats should succeed");
assert_eq!(stats.op_count, 2);
assert!(stats.total_time >= Duration::from_millis(15));
assert_eq!(stats.peak_memory, 1536);
profiler.disable();
assert!(!profiler.is_enabled());
}
#[test]
fn test_profiler_report_generation() {
let profiler = GpuProfiler::new();
profiler.enable();
let device = Device::Gpu(0);
profiler
.record_operation("test_op", device, Duration::from_millis(1), 100)
.expect("test: operation should succeed");
let report = profiler
.generate_report()
.expect("test: generate_report should succeed");
assert!(report.contains("GPU Profiling Report"));
assert!(report.contains("test_op"));
assert!(report.contains("Total Operations: 1"));
}
#[test]
fn test_profile_operation_macro() {
enable_gpu_profiling();
let device = Device::Gpu(0);
let result = profile_gpu_op!("test_macro", device, {
thread::sleep(Duration::from_millis(1));
Ok(42)
});
assert_eq!(result.expect("test: operation should succeed"), 42);
let stats =
get_gpu_profiling_stats().expect("test: get_gpu_profiling_stats should succeed");
assert!(stats.op_count >= 1);
disable_gpu_profiling();
}
#[test]
fn test_profiler_clear() {
let profiler = GpuProfiler::new();
profiler.enable();
let device = Device::Gpu(0);
profiler
.record_operation("test", device, Duration::from_millis(1), 100)
.expect("test: operation should succeed");
let stats_before = profiler
.get_stats()
.expect("test: get_stats should succeed");
assert_eq!(stats_before.op_count, 1);
profiler.clear();
let stats_after = profiler
.get_stats()
.expect("test: get_stats should succeed");
assert_eq!(stats_after.op_count, 0);
}
}