#[cfg(feature = "advanced_math")]
use std::collections::HashMap;
#[cfg(feature = "advanced_math")]
use std::sync::{Arc, Mutex};
#[cfg(feature = "advanced_math")]
use crate::error::Result;
#[cfg(feature = "advanced_math")]
use super::memory::GpuMemoryPool;
#[cfg(feature = "advanced_math")]
pub type CudaDevicePointer = usize;
#[cfg(feature = "advanced_math")]
pub type CudaEvent = usize;
#[cfg(feature = "advanced_math")]
pub struct CudaContext {
device_id: i32,
device_properties: CudaDeviceProperties,
memory_pool: Arc<Mutex<GpuMemoryPool>>,
profiler: Option<CudaProfiler>,
}
#[cfg(feature = "advanced_math")]
#[derive(Debug, Clone)]
pub struct CudaDeviceProperties {
pub name: String,
pub compute_capability: (i32, i32),
pub total_global_memory: usize,
pub max_threads_per_block: i32,
pub max_block_dimensions: [i32; 3],
pub max_grid_dimensions: [i32; 3],
pub warp_size: i32,
pub memory_clock_rate: i32,
pub memory_bus_width: i32,
}
#[cfg(feature = "advanced_math")]
pub struct CudaProfiler {
pub events: Vec<CudaEvent>,
pub timing_data: HashMap<String, Vec<f64>>,
pub memory_usage: Vec<(String, usize)>,
}
#[cfg(feature = "advanced_math")]
impl CudaContext {
pub fn new(device_id: i32) -> Result<Self> {
let device_properties = Self::query_device_properties(device_id)?;
let memory_pool = Arc::new(Mutex::new(GpuMemoryPool::new()));
let profiler = if cfg!(debug_assertions) {
Some(CudaProfiler::new())
} else {
None
};
Ok(Self {
device_id,
device_properties,
memory_pool,
profiler,
})
}
pub fn get_device_count() -> Result<i32> {
#[cfg(feature = "advanced_math")]
{
let device_count = Self::query_cuda_devices()?;
Ok(device_count)
}
#[cfg(not(feature = "advanced_math"))]
{
Ok(0)
}
}
fn query_device_properties(device_id: i32) -> Result<CudaDeviceProperties> {
Ok(CudaDeviceProperties {
name: format!("CUDA Device {}", device_id),
compute_capability: (7, 5), total_global_memory: 8_000_000_000, max_threads_per_block: 1024,
max_block_dimensions: [1024, 1024, 64],
max_grid_dimensions: [2_147_483_647, 65_535, 65_535],
warp_size: 32,
memory_clock_rate: 7000, memory_bus_width: 256,
})
}
fn query_cuda_devices() -> Result<i32> {
Ok(if std::env::var("CUDA_VISIBLE_DEVICES").is_ok() {
1
} else {
0
})
}
pub fn get_device_properties(&self) -> &CudaDeviceProperties {
&self.device_properties
}
pub fn get_memory_info(&self) -> Result<(usize, usize)> {
let pool = self.memory_pool.lock().unwrap_or_else(|e| e.into_inner());
let total = self.device_properties.total_global_memory;
let used = pool.total_allocated;
Ok((total - used, total))
}
pub fn get_memory_pool(&self) -> Arc<Mutex<GpuMemoryPool>> {
Arc::clone(&self.memory_pool)
}
pub fn get_profiler(&mut self) -> Option<&mut CudaProfiler> {
self.profiler.as_mut()
}
}
#[cfg(feature = "advanced_math")]
impl CudaProfiler {
pub fn new() -> Self {
Self {
events: Vec::new(),
timing_data: HashMap::new(),
memory_usage: Vec::new(),
}
}
pub fn record_event(&mut self, name: String, event: CudaEvent) {
self.events.push(event);
}
pub fn record_timing(&mut self, operation: String, time_ms: f64) {
self.timing_data.entry(operation).or_default().push(time_ms);
}
pub fn record_memory_usage(&mut self, operation: String, bytes: usize) {
self.memory_usage.push((operation, bytes));
}
pub fn get_average_timing(&self, operation: &str) -> Option<f64> {
self.timing_data
.get(operation)
.map(|times| times.iter().sum::<f64>() / times.len() as f64)
}
pub fn get_peak_memory_usage(&self) -> usize {
self.memory_usage
.iter()
.map(|(_, bytes)| *bytes)
.max()
.unwrap_or(0)
}
}