ghostflow_cuda/
device.rs

1//! CUDA device management - Real Implementation
2
3use crate::error::{CudaError, CudaResult};
4use crate::stream::CudaStream;
5use crate::memory::GpuMemoryPool;
6use std::sync::atomic::{AtomicBool, Ordering};
7use parking_lot::Mutex;
8
9#[cfg(feature = "cuda")]
10use crate::ffi;
11
12static INITIALIZED: AtomicBool = AtomicBool::new(false);
13
14/// CUDA device handle with real device properties
15#[derive(Debug)]
16pub struct CudaDevice {
17    /// Device ID
18    pub id: i32,
19    /// Device name
20    pub name: String,
21    /// Compute capability (major, minor)
22    pub compute_capability: (i32, i32),
23    /// Total memory in bytes
24    pub total_memory: usize,
25    /// Free memory in bytes
26    pub free_memory: usize,
27    /// Number of multiprocessors
28    pub multiprocessor_count: i32,
29    /// Max threads per block
30    pub max_threads_per_block: i32,
31    /// Max threads per multiprocessor
32    pub max_threads_per_mp: i32,
33    /// Warp size
34    pub warp_size: i32,
35    /// Shared memory per block
36    pub shared_mem_per_block: usize,
37    /// Default stream
38    pub default_stream: CudaStream,
39    /// Memory pool
40    pub memory_pool: Mutex<GpuMemoryPool>,
41    /// cuBLAS handle
42    #[cfg(feature = "cuda")]
43    cublas_handle: ffi::cublasHandle_t,
44}
45
46impl CudaDevice {
47    /// Initialize CUDA runtime
48    pub fn init() -> CudaResult<()> {
49        if INITIALIZED.swap(true, Ordering::SeqCst) {
50            return Err(CudaError::AlreadyInitialized);
51        }
52        
53        // Set device 0 as default
54        #[cfg(feature = "cuda")]
55        unsafe {
56            let err = ffi::cudaSetDevice(0);
57            if err != 0 {
58                INITIALIZED.store(false, Ordering::SeqCst);
59                return Err(CudaError::DriverError(err));
60            }
61        }
62        
63        Ok(())
64    }
65
66    /// Get number of CUDA devices
67    pub fn count() -> CudaResult<i32> {
68        #[cfg(feature = "cuda")]
69        {
70            let mut count: i32 = 0;
71            unsafe {
72                let err = ffi::cudaGetDeviceCount(&mut count);
73                if err != 0 {
74                    return Err(CudaError::DriverError(err));
75                }
76            }
77            Ok(count)
78        }
79        
80        #[cfg(not(feature = "cuda"))]
81        {
82            Ok(0)
83        }
84    }
85
86    /// Create device handle for given device ID
87    pub fn new(device_id: i32) -> CudaResult<Self> {
88        let count = Self::count()?;
89        if device_id >= count || device_id < 0 {
90            return Err(CudaError::InvalidDevice(device_id));
91        }
92
93        #[cfg(feature = "cuda")]
94        {
95            // Set device
96            unsafe {
97                let err = ffi::cudaSetDevice(device_id);
98                if err != 0 {
99                    return Err(CudaError::DriverError(err));
100                }
101            }
102            
103            // Get device properties
104            let mut props = cudaDeviceProp::default();
105            unsafe {
106                let err = ffi::cudaGetDeviceProperties(&mut props, device_id);
107                if err != 0 {
108                    return Err(CudaError::DriverError(err));
109                }
110            }
111            
112            // Extract name
113            let name = unsafe {
114                CStr::from_ptr(props.name.as_ptr())
115                    .to_string_lossy()
116                    .into_owned()
117            };
118            
119            // Get memory info
120            let mut free_mem: usize = 0;
121            let mut total_mem: usize = 0;
122            unsafe {
123                let err = ffi::cudaMemGetInfo(&mut free_mem, &mut total_mem);
124                if err != 0 {
125                    return Err(CudaError::DriverError(err));
126                }
127            }
128            
129            // Create cuBLAS handle
130            let mut cublas_handle: ffi::cublasHandle_t = std::ptr::null_mut();
131            unsafe {
132                let status = ffi::cublasCreate_v2(&mut cublas_handle);
133                if status != 0 {
134                    return Err(CudaError::CublasError(status));
135                }
136            }
137            
138            Ok(CudaDevice {
139                id: device_id,
140                name,
141                compute_capability: (props.major, props.minor),
142                total_memory: props.totalGlobalMem,
143                free_memory: free_mem,
144                multiprocessor_count: props.multiProcessorCount,
145                max_threads_per_block: props.maxThreadsPerBlock,
146                max_threads_per_mp: props.maxThreadsPerMultiProcessor,
147                warp_size: props.warpSize,
148                shared_mem_per_block: props.sharedMemPerBlock,
149                default_stream: CudaStream::default_stream(),
150                memory_pool: Mutex::new(GpuMemoryPool::new(device_id, free_mem)),
151                cublas_handle,
152            })
153        }
154        
155        #[cfg(not(feature = "cuda"))]
156        {
157            Err(CudaError::DeviceNotFound)
158        }
159    }
160
161    /// Set current device
162    pub fn set_current(&self) -> CudaResult<()> {
163        #[cfg(feature = "cuda")]
164        unsafe {
165            let err = ffi::cudaSetDevice(self.id);
166            if err != 0 {
167                return Err(CudaError::DriverError(err));
168            }
169        }
170        Ok(())
171    }
172
173    /// Get current device ID
174    pub fn current() -> CudaResult<i32> {
175        let device: i32 = -1;
176        
177        #[cfg(feature = "cuda")]
178        unsafe {
179            let err = ffi::cudaGetDevice(&mut device);
180            if err != 0 {
181                return Err(CudaError::DriverError(err));
182            }
183        }
184        
185        Ok(device)
186    }
187
188    /// Synchronize device - wait for all operations to complete
189    pub fn synchronize() -> CudaResult<()> {
190        #[cfg(feature = "cuda")]
191        unsafe {
192            let err = ffi::cudaDeviceSynchronize();
193            if err != 0 {
194                return Err(CudaError::SyncError);
195            }
196        }
197        Ok(())
198    }
199
200    /// Get current free memory
201    pub fn get_free_memory(&self) -> CudaResult<usize> {
202        #[cfg(feature = "cuda")]
203        {
204            let mut free_mem: usize = 0;
205            let mut total_mem: usize = 0;
206            unsafe {
207                let err = ffi::cudaMemGetInfo(&mut free_mem, &mut total_mem);
208                if err != 0 {
209                    return Err(CudaError::DriverError(err));
210                }
211            }
212            Ok(free_mem)
213        }
214        
215        #[cfg(not(feature = "cuda"))]
216        Ok(0)
217    }
218
219    /// Get cuBLAS handle
220    #[cfg(feature = "cuda")]
221    pub fn cublas_handle(&self) -> ffi::cublasHandle_t {
222        self.cublas_handle
223    }
224
225    /// Get device properties as string
226    pub fn properties_string(&self) -> String {
227        format!(
228            "Device {}: {}\n\
229             Compute Capability: {}.{}\n\
230             Total Memory: {:.2} GB\n\
231             Free Memory: {:.2} GB\n\
232             Multiprocessors: {}\n\
233             Max Threads/Block: {}\n\
234             Warp Size: {}\n\
235             Shared Mem/Block: {} KB",
236            self.id,
237            self.name,
238            self.compute_capability.0,
239            self.compute_capability.1,
240            self.total_memory as f64 / (1024.0 * 1024.0 * 1024.0),
241            self.free_memory as f64 / (1024.0 * 1024.0 * 1024.0),
242            self.multiprocessor_count,
243            self.max_threads_per_block,
244            self.warp_size,
245            self.shared_mem_per_block / 1024,
246        )
247    }
248    
249    /// Check if device supports tensor cores
250    pub fn has_tensor_cores(&self) -> bool {
251        // Tensor cores available on Volta (7.0) and later
252        self.compute_capability.0 >= 7
253    }
254    
255    /// Check if device supports FP16
256    pub fn supports_fp16(&self) -> bool {
257        // FP16 available on Pascal (6.0) and later
258        self.compute_capability.0 >= 6
259    }
260    
261    /// Check if device supports BF16
262    pub fn supports_bf16(&self) -> bool {
263        // BF16 available on Ampere (8.0) and later
264        self.compute_capability.0 >= 8
265    }
266    
267    /// Get optimal block size for a kernel
268    pub fn optimal_block_size(&self, shared_mem_per_thread: usize) -> usize {
269        let max_threads = self.max_threads_per_block as usize;
270        let shared_limit = self.shared_mem_per_block / shared_mem_per_thread.max(1);
271        
272        // Round down to multiple of warp size
273        let optimal = max_threads.min(shared_limit);
274        (optimal / self.warp_size as usize) * self.warp_size as usize
275    }
276}
277
278impl Drop for CudaDevice {
279    fn drop(&mut self) {
280        #[cfg(feature = "cuda")]
281        unsafe {
282            if !self.cublas_handle.is_null() {
283                ffi::cublasDestroy_v2(self.cublas_handle);
284            }
285        }
286    }
287}
288
289/// Device guard for automatic device switching
290pub struct DeviceGuard {
291    #[allow(dead_code)]
292    previous_device: i32,
293}
294
295impl DeviceGuard {
296    pub fn new(_device_id: i32) -> CudaResult<Self> {
297        let previous = CudaDevice::current()?;
298        
299        #[cfg(feature = "cuda")]
300        unsafe {
301            let err = ffi::cudaSetDevice(_device_id);
302            if err != 0 {
303                return Err(CudaError::DriverError(err));
304            }
305        }
306        
307        Ok(DeviceGuard { previous_device: previous })
308    }
309}
310
311impl Drop for DeviceGuard {
312    fn drop(&mut self) {
313        #[cfg(feature = "cuda")]
314        unsafe {
315            let _ = ffi::cudaSetDevice(self.previous_device);
316        }
317    }
318}
319
320/// Get all available CUDA devices
321pub fn get_all_devices() -> CudaResult<Vec<CudaDevice>> {
322    let count = CudaDevice::count()?;
323    let mut devices = Vec::with_capacity(count as usize);
324    
325    for i in 0..count {
326        devices.push(CudaDevice::new(i)?);
327    }
328    
329    Ok(devices)
330}
331
332/// Select best device based on compute capability and memory
333pub fn select_best_device() -> CudaResult<CudaDevice> {
334    let devices = get_all_devices()?;
335    
336    if devices.is_empty() {
337        return Err(CudaError::DeviceNotFound);
338    }
339    
340    // Score devices by compute capability and memory
341    let best = devices.into_iter()
342        .max_by_key(|d| {
343            let compute_score = d.compute_capability.0 * 10 + d.compute_capability.1;
344            let memory_score = (d.total_memory / (1024 * 1024 * 1024)) as i32; // GB
345            compute_score * 100 + memory_score
346        })
347        .unwrap();
348    
349    Ok(best)
350}