use crate::{ProfileEvent, TorshResult};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use torsh_core::TorshError;
#[derive(Debug, Clone, Default)]
pub struct CudaSynchronizationStats {
pub device_sync_count: u64,
pub device_sync_time_us: u64,
pub stream_sync_count: u64,
pub stream_sync_time_us: u64,
pub event_sync_count: u64,
pub event_sync_time_us: u64,
pub total_sync_time_us: u64,
}
pub struct CudaProfiler {
events: Arc<Mutex<Vec<ProfileEvent>>>,
start_time: Instant,
enabled: bool,
device_id: i32,
sync_tracking_enabled: bool,
sync_stats: CudaSynchronizationStats,
}
impl CudaProfiler {
pub fn new(device_id: i32) -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
start_time: Instant::now(),
enabled: false,
device_id,
sync_tracking_enabled: false,
sync_stats: CudaSynchronizationStats::default(),
}
}
pub fn enable(&mut self) -> TorshResult<()> {
self.enabled = true;
self.start_time = Instant::now();
if let Ok(mut events) = self.events.lock() {
events.clear();
}
Ok(())
}
pub fn disable(&mut self) -> TorshResult<()> {
self.enabled = false;
Ok(())
}
pub fn record_kernel_launch(
&self,
kernel_name: &str,
grid_size: (u32, u32, u32),
block_size: (u32, u32, u32),
shared_memory: usize,
duration: Duration,
) -> TorshResult<()> {
if !self.enabled {
return Ok(());
}
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
let start_us = self.start_time.elapsed().as_micros() as u64;
let duration_us = duration.as_micros() as u64;
let event_name = format!(
"{} [grid:({},{},{}), block:({},{},{}), shared:{}B]",
kernel_name,
grid_size.0,
grid_size.1,
grid_size.2,
block_size.0,
block_size.1,
block_size.2,
shared_memory
);
events.push(ProfileEvent {
name: event_name,
category: "cuda_kernel".to_string(),
start_us,
duration_us,
thread_id: self.device_id as usize,
operation_count: None,
flops: None,
bytes_transferred: Some(shared_memory as u64),
stack_trace: None,
});
Ok(())
}
pub fn record_memory_copy(
&self,
direction: &str,
size: usize,
duration: Duration,
) -> TorshResult<()> {
if !self.enabled {
return Ok(());
}
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
let start_us = self.start_time.elapsed().as_micros() as u64;
let duration_us = duration.as_micros() as u64;
let event_name = format!("cudaMemcpy{direction} [{size}B]");
events.push(ProfileEvent {
name: event_name,
category: "cuda_memcpy".to_string(),
start_us,
duration_us,
thread_id: self.device_id as usize,
operation_count: None,
flops: None,
bytes_transferred: Some(size as u64),
stack_trace: None,
});
Ok(())
}
pub fn get_events(&self) -> TorshResult<Vec<ProfileEvent>> {
let events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
Ok(events.clone())
}
pub fn clear(&self) -> TorshResult<()> {
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
events.clear();
Ok(())
}
pub fn set_sync_tracking_enabled(&mut self, enabled: bool) {
self.sync_tracking_enabled = enabled;
if enabled {
self.sync_stats = CudaSynchronizationStats::default();
}
}
pub fn is_sync_tracking_enabled(&self) -> bool {
self.sync_tracking_enabled
}
pub fn get_sync_stats(&self) -> &CudaSynchronizationStats {
&self.sync_stats
}
pub fn reset_sync_stats(&mut self) {
self.sync_stats = CudaSynchronizationStats::default();
}
pub fn record_device_sync(&mut self, duration: Duration) -> TorshResult<()> {
if !self.enabled {
return Ok(());
}
let duration_us = duration.as_micros() as u64;
if self.sync_tracking_enabled {
self.sync_stats.device_sync_count += 1;
self.sync_stats.device_sync_time_us += duration_us;
self.sync_stats.total_sync_time_us += duration_us;
}
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
let start_us = self.start_time.elapsed().as_micros() as u64;
events.push(ProfileEvent {
name: "cudaDeviceSynchronize".to_string(),
category: "cuda_sync".to_string(),
start_us,
duration_us,
thread_id: self.device_id as usize,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
});
Ok(())
}
pub fn record_stream_sync(&mut self, stream_id: u32, duration: Duration) -> TorshResult<()> {
if !self.enabled {
return Ok(());
}
let duration_us = duration.as_micros() as u64;
if self.sync_tracking_enabled {
self.sync_stats.stream_sync_count += 1;
self.sync_stats.stream_sync_time_us += duration_us;
self.sync_stats.total_sync_time_us += duration_us;
}
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
let start_us = self.start_time.elapsed().as_micros() as u64;
events.push(ProfileEvent {
name: format!("cudaStreamSynchronize [stream:{stream_id}]"),
category: "cuda_sync".to_string(),
start_us,
duration_us,
thread_id: self.device_id as usize,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
});
Ok(())
}
pub fn record_event_sync(&mut self, event_id: u32, duration: Duration) -> TorshResult<()> {
if !self.enabled {
return Ok(());
}
let duration_us = duration.as_micros() as u64;
if self.sync_tracking_enabled {
self.sync_stats.event_sync_count += 1;
self.sync_stats.event_sync_time_us += duration_us;
self.sync_stats.total_sync_time_us += duration_us;
}
let mut events = self.events.lock().map_err(|_| {
TorshError::InvalidArgument("Failed to acquire lock on events".to_string())
})?;
let start_us = self.start_time.elapsed().as_micros() as u64;
events.push(ProfileEvent {
name: format!("cudaEventSynchronize [event:{event_id}]"),
category: "cuda_sync".to_string(),
start_us,
duration_us,
thread_id: self.device_id as usize,
operation_count: None,
flops: None,
bytes_transferred: None,
stack_trace: None,
});
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct CudaMemoryStats {
pub allocated: usize,
pub reserved: usize,
pub active: usize,
pub inactive: usize,
pub total: usize,
}
pub fn get_cuda_memory_stats(_device_id: i32) -> TorshResult<CudaMemoryStats> {
Ok(CudaMemoryStats {
allocated: 1024 * 1024 * 512, reserved: 1024 * 1024 * 1024, active: 1024 * 1024 * 400, inactive: 1024 * 1024 * 112, total: 1024 * 1024 * 1024 * 8, })
}
#[derive(Debug, Clone)]
pub struct CudaDeviceProperties {
pub name: String,
pub compute_capability: (i32, i32),
pub multiprocessor_count: i32,
pub clock_rate: i32, pub memory_clock_rate: i32, pub memory_bus_width: i32, pub total_memory: usize,
}
pub fn get_cuda_device_properties(device_id: i32) -> TorshResult<CudaDeviceProperties> {
Ok(CudaDeviceProperties {
name: format!("CUDA Device {device_id}"),
compute_capability: (8, 6),
multiprocessor_count: 68,
clock_rate: 1710000, memory_clock_rate: 9501000, memory_bus_width: 320,
total_memory: 10 * 1024 * 1024 * 1024, })
}
pub fn profile_cuda() -> TorshResult<Vec<ProfileEvent>> {
let mut profiler = CudaProfiler::new(0);
profiler.enable()?;
let kernel_start = Instant::now();
std::thread::sleep(Duration::from_micros(100));
let kernel_duration = kernel_start.elapsed();
profiler.record_kernel_launch(
"matmul_kernel",
(256, 256, 1),
(16, 16, 1),
2048,
kernel_duration,
)?;
let memcpy_start = Instant::now();
std::thread::sleep(Duration::from_micros(50));
let memcpy_duration = memcpy_start.elapsed();
profiler.record_memory_copy("HtoD", 1024 * 1024 * 4, memcpy_duration)?;
profiler.get_events()
}
pub struct CudaEvent {
created: bool,
recorded: bool,
}
impl CudaEvent {
pub fn new() -> TorshResult<Self> {
Ok(Self {
created: true,
recorded: false,
})
}
pub fn record(&mut self) -> TorshResult<()> {
if !self.created {
return Err(TorshError::InvalidArgument("Event not created".to_string()));
}
self.recorded = true;
Ok(())
}
pub fn synchronize(&self) -> TorshResult<()> {
if !self.recorded {
return Err(TorshError::InvalidArgument(
"Event not recorded".to_string(),
));
}
Ok(())
}
pub fn elapsed_time(&self, end: &CudaEvent) -> TorshResult<f32> {
if !self.recorded || !end.recorded {
return Err(TorshError::InvalidArgument(
"Both events must be recorded".to_string(),
));
}
Ok(0.1) }
}
impl Drop for CudaEvent {
fn drop(&mut self) {
if self.created {
}
}
}
pub struct NvtxRange {
active: bool,
}
impl NvtxRange {
pub fn new(_name: &str) -> Self {
Self { active: true }
}
}
impl Drop for NvtxRange {
fn drop(&mut self) {
if self.active {
}
}
}
#[macro_export]
macro_rules! cuda_nvtx_range {
($name:expr) => {
let _nvtx_range = $crate::cuda::NvtxRange::new($name);
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_profiler_creation() {
let profiler = CudaProfiler::new(0);
assert!(!profiler.enabled);
assert_eq!(profiler.device_id, 0);
}
#[test]
fn test_cuda_profiler_enable_disable() {
let mut profiler = CudaProfiler::new(0);
profiler.enable().unwrap();
assert!(profiler.enabled);
profiler.disable().unwrap();
assert!(!profiler.enabled);
}
#[test]
fn test_cuda_kernel_recording() {
let mut profiler = CudaProfiler::new(0);
profiler.enable().unwrap();
profiler
.record_kernel_launch(
"test_kernel",
(128, 1, 1),
(256, 1, 1),
1024,
Duration::from_micros(50),
)
.unwrap();
let events = profiler.get_events().unwrap();
assert_eq!(events.len(), 1);
assert!(events[0].name.contains("test_kernel"));
assert_eq!(events[0].category, "cuda_kernel");
}
#[test]
fn test_cuda_memory_copy_recording() {
let mut profiler = CudaProfiler::new(0);
profiler.enable().unwrap();
profiler
.record_memory_copy("HtoD", 1024 * 1024, Duration::from_micros(100))
.unwrap();
let events = profiler.get_events().unwrap();
assert_eq!(events.len(), 1);
assert!(events[0].name.contains("cudaMemcpyHtoD"));
assert_eq!(events[0].category, "cuda_memcpy");
}
#[test]
fn test_cuda_memory_stats() {
let stats = get_cuda_memory_stats(0).unwrap();
assert!(stats.total > 0);
assert!(stats.allocated <= stats.reserved);
assert!(stats.reserved <= stats.total);
}
#[test]
fn test_cuda_device_properties() {
let props = get_cuda_device_properties(0).unwrap();
assert!(!props.name.is_empty());
assert!(props.compute_capability.0 > 0);
assert!(props.multiprocessor_count > 0);
assert!(props.total_memory > 0);
}
#[test]
fn test_cuda_event() {
let mut start = CudaEvent::new().unwrap();
let mut end = CudaEvent::new().unwrap();
start.record().unwrap();
std::thread::sleep(Duration::from_millis(1));
end.record().unwrap();
start.synchronize().unwrap();
end.synchronize().unwrap();
let elapsed = start.elapsed_time(&end).unwrap();
assert!(elapsed >= 0.0);
}
#[test]
fn test_cuda_synchronization_tracking() {
let mut profiler = CudaProfiler::new(0);
profiler.enable().unwrap();
profiler.set_sync_tracking_enabled(true);
assert!(profiler.is_sync_tracking_enabled());
profiler
.record_device_sync(Duration::from_micros(100))
.unwrap();
profiler
.record_stream_sync(1, Duration::from_micros(50))
.unwrap();
profiler
.record_event_sync(42, Duration::from_micros(25))
.unwrap();
let sync_stats = profiler.get_sync_stats();
assert_eq!(sync_stats.device_sync_count, 1);
assert_eq!(sync_stats.stream_sync_count, 1);
assert_eq!(sync_stats.event_sync_count, 1);
assert_eq!(sync_stats.device_sync_time_us, 100);
assert_eq!(sync_stats.stream_sync_time_us, 50);
assert_eq!(sync_stats.event_sync_time_us, 25);
assert_eq!(sync_stats.total_sync_time_us, 175);
let events = profiler.get_events().unwrap();
assert_eq!(events.len(), 3);
assert!(events[0].name.contains("cudaDeviceSynchronize"));
assert_eq!(events[0].category, "cuda_sync");
assert!(events[1].name.contains("cudaStreamSynchronize"));
assert!(events[1].name.contains("stream:1"));
assert!(events[2].name.contains("cudaEventSynchronize"));
assert!(events[2].name.contains("event:42"));
profiler.reset_sync_stats();
let reset_stats = profiler.get_sync_stats();
assert_eq!(reset_stats.device_sync_count, 0);
assert_eq!(reset_stats.total_sync_time_us, 0);
}
#[test]
fn test_cuda_sync_tracking_disabled() {
let mut profiler = CudaProfiler::new(0);
profiler.enable().unwrap();
assert!(!profiler.is_sync_tracking_enabled());
profiler
.record_device_sync(Duration::from_micros(100))
.unwrap();
let sync_stats = profiler.get_sync_stats();
assert_eq!(sync_stats.device_sync_count, 0);
assert_eq!(sync_stats.total_sync_time_us, 0);
let events = profiler.get_events().unwrap();
assert_eq!(events.len(), 1);
assert!(events[0].name.contains("cudaDeviceSynchronize"));
}
}